## Simple baseline with Sentinel Image Patches — Swin-v2-t + Binary Cross Entropy [0.23555]

The occurrence of different types of organisms, whether plants or animals, is generally associated with the characteristics of the environment or ecosystem in which they live. This relationship between the presence of species and their habitat is often interdependent and can be affected by various factors, such as climate, which is another modality we provide.

To demonstrate the performance while using the image data, i.e., Sentinel Image Patches, we provide a straightforward baseline that is based on a slighly modified Swin-v2-t and Binary Cross Entropy. As described above, the satellite patches provide an image-like modality that captures habitats and other aspects of the locality.

Considering the significant extent for enhancing performance of this baseline, we encourage you to experiment with various techniques, architectures, losses, etc.


In [1]:
import os
import torch
import tqdm
import numpy as np
import pandas as pd
import albumentations as A
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import precision_recall_fscore_support
os.chdir("../../")

ModuleNotFoundError: No module named 'albumentations'

## Data description

The Sentinel Image data was acquired through the Sentinel2 satellite program and pre-processed by [Ecodatacube](https://stac.ecodatacube.eu/) to produce raster files scaled to the entire European continent and projected into a unique CRS. We filtered the data in order to pick patches from each spectral band corresponding to a location ((lon, lat) GPS coordinates) and a date matching that of our occurrences', and split them into JPEG files (RGB in 3-channels .jpeg files and NIR in single-channel .jpeg files) with a 128x128 resolution. The images were converted from sentinel uint15 to uint8 by clipping data pixel values over 10000 and applying a gamma correction of 2.5.

The data can simply be loaded using the following method:

```python
def construct_patch_path(output_path, survey_id):
    """Construct the patch file path based on survey_id as './CD/AB/XXXXABCD.jpeg'"""
    path = output_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.jpeg")

    return path
```

**References:**
- *Traceability (lineage): The dataset was produced entirely by mosaicking and seasonally aggregating imagery from the Sentinel-2 Level-2A product (https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/product-types/level-2a)*
- *Ecodatacube.eu: Analysis-ready open environmental data cube for Europe (https://doi.org/10.21203/rs.3.rs-2277090/v3)*

## Prepare custom dataset loader

We have to slightly update the Dataset to provide the relevant data in the appropriate format.

In [None]:
transform_albumentations = A.Compose([
    A.RandomBrightnessContrast(p=0.2),
    A.ColorJitter(p=0.2),
    A.OpticalDistortion(p=0.2)
])

def construct_patch_path(data_path, survey_id):
    """Construct the patch file path based on plot_id as './CD/AB/XXXXABCD.jpeg'"""
    path = data_path
    for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
        path = os.path.join(path, d)

    path = os.path.join(path, f"{survey_id}.jpeg")

    return path

class TrainDataset(Dataset):
    def __init__(self, data_dir, metadata, transform=None):
        self.transform = transform
        self.data_dir = data_dir
        self.metadata = metadata
        self.metadata = self.metadata.dropna(subset="speciesId").reset_index(drop=True)
        self.metadata['speciesId'] = self.metadata['speciesId'].astype(int)
        self.label_dict = self.metadata.groupby('surveyId')['speciesId'].apply(list).to_dict()
        
        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        species_ids = self.label_dict.get(survey_id, [])  # Get list of species IDs for the survey ID
        num_classes = 11255
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            #label_id = self.species_mapping[species_id]  # Get consecutive integer label
            label_id = species_id
            label[label_id] = 1  # Set the corresponding class index to 1 for each species
        
        rgb_sample = np.array(Image.open(construct_patch_path(self.data_dir, survey_id)))
        nir_sample = np.array(Image.open(construct_patch_path(self.data_dir.replace("rgb", "nir").replace("RGB", "NIR"), survey_id)))
        
        
        rgb_sample = transform_albumentations(image=rgb_sample)
        nir_sample = transform_albumentations(image=nir_sample)
        
        sample = np.concatenate((rgb_sample["image"], nir_sample["image"][...,None]), axis=2)    
        sample = self.transform(sample)

        return sample, label, survey_id
    
class TestDataset(TrainDataset):
    def __init__(self, data_dir, metadata, transform=None):
        self.transform = transform
        self.data_dir = data_dir
        self.metadata = metadata
        
    def __getitem__(self, idx):
        
        survey_id = self.metadata.surveyId[idx]
        
        rgb_sample = np.array(Image.open(construct_patch_path(self.data_dir, survey_id)))
        nir_sample = np.array(Image.open(construct_patch_path(self.data_dir.replace("rgb", "nir").replace("RGB", "NIR"), survey_id)))
        
        rgb_sample = transform_albumentations(image=rgb_sample)
        nir_sample = transform_albumentations(image=nir_sample)
        
        sample = np.concatenate((rgb_sample["image"], nir_sample["image"][...,None]), axis=2)    
        sample = self.transform(sample)

        return sample, survey_id

### Load metadata and prepare data loaders

In [None]:
# Dataset and DataLoader
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor()
])

# Load Training metadata
train_data_path = "Dataset/geolifeclef-2024/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb/"
train_metadata_path = "Dataset/geolifeclef-2024/GLC24_PA_metadata_train.csv"
train_metadata = pd.read_csv(train_metadata_path)
train_dataset = TrainDataset(train_data_path, train_metadata, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Load Test metadata
test_data_path = "Dataset/geolifeclef-2024/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb/"
test_metadata_path = "Dataset/geolifeclef-2024/GLC24_PA_metadata_test.csv"
test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(test_data_path, test_metadata, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

## Modify pretrained Swin-v2-t model

To fully use all the R,G,B and NIR channels, we have to modify the input layer of the Swin-v2-t.
That is all :)

In [None]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")

# Hyperparameters
learning_rate = 0.0002
num_epochs = 10
positive_weigh_factor = 1.0
num_classes = 11255 # Number of all unique classes within the PO and PA data.

In [None]:
model = models.swin_v2_s(weights="IMAGENET1K_V1")
model.features[0][0] = nn.Conv2d(4, 96, kernel_size=(4, 4), stride=(4, 4))
model.head = nn.Linear(in_features=768, out_features=num_classes, bias=True)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True)

In [None]:
def set_seed(seed):
    # Set seed for Python's built-in random number generator
    torch.manual_seed(seed)
    # Set seed for numpy
    np.random.seed(seed)
    # Set seed for CUDA if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        # Set cuDNN's random number generator seed for deterministic behavior
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(69)

## Training Loop

Nothing special, just a standard Pytorch training loop.

In [None]:
print(f"Training for {num_epochs} epochs started.")

for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, targets, _) in enumerate(train_loader):

        data = data.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(data)

        pos_weight = targets*positive_weigh_factor  # All positive weights are equal to 10
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        if batch_idx % 278 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    scheduler.step()
    print("Scheduler:",scheduler.state_dict())

# Save the trained model
model.eval()
torch.save(model.state_dict(), "models/resnet18-with-landsat-cubes.pth")

## Test Loop

Again, nothing special, just a standard inference.

In [None]:
with torch.no_grad():
    all_predictions = []
    surveys = []
    top_k_indices = None
    for data, surveyID in tqdm.tqdm(test_loader, total=len(test_loader)):

        data = data.to(device)
        
        outputs = model(data)
        predictions = torch.sigmoid(outputs).cpu().numpy()

        # Sellect top-25 values as predictions
        top_25 = np.argsort(-predictions, axis=1)[:, :25] 
        if top_k_indices is None:
            top_k_indices = top_25
        else:
            top_k_indices = np.concatenate((top_k_indices, top_25), axis=0)

        surveys.extend(surveyID.cpu().numpy())

## Save prediction file!

In [None]:
data_concatenated = [' '.join(map(str, row)) for row in top_k_indices]

pd.DataFrame(
    {'surveyId': surveys,
     'predictions': data_concatenated,
    }).to_csv("research/Baseline_experiments/outputs/baseline-with-sentinel-images/output.csv", index = False)