In [1]:
import os
import torch
import tqdm
import numpy as np
import pandas as pd
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from einops import rearrange, repeat

  warn(


In [2]:
# Hyperparameters
learning_rate = 0.00025
num_epochs = 1
positive_weigh_factor = 1.0
INITIAL_SEED = 113
test_batch_size = 64
num_classes = 11255 # max 11255

When reading data, different **fusion** and **normalization** will be performed for different types of data.

In [3]:
class CustomDataset(Dataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-test-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-test-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-test-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-test-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-test-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-test-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_NIR/pa_test_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        return sample, survey_id

In [4]:
class TestDataset(CustomDataset):
    def __init__(self, metadata, subset, transform=None):
        super().__init__(metadata, subset, transform)

    def __getitem__(self, idx):
        sample, survey_id = super().__getitem__(idx)
        return sample, survey_id

In [5]:
class TrainDataset(CustomDataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-train-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-train-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-train-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-train-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-train-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-train-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)
        self.metadata['speciesId'] = self.metadata['speciesId'].astype(int)
        self.label_dict = self.metadata.groupby('surveyId')['speciesId'].apply(list).to_dict()

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_NIR/pa_train_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]
        species_ids = self.label_dict.get(survey_id, [])  # Get list of species IDs for the survey ID
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            label_id = species_id
            label[label_id] = 1  # Set the corresponding class index to 1 for each species

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        return sample, survey_id, label

In [6]:
# Train Dataset and DataLoader
train_batch_size = 64
train_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

# Load Training metadata
train_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_train.csv"
train_metadata = pd.read_csv(train_metadata_path)
train_dataset = TrainDataset(train_metadata, subset="train", transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=1)

In [7]:
# Test Dataset and DataLoader
test_batch_size = 1
test_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

# Load Testing metadata
test_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_test.csv"
test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(test_metadata, subset="test", transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=1)

This is an MLP used to extract features from generally independent information.

In [8]:
class Embedding(nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, out_dim*5)
        self.fc2 = nn.Linear(out_dim*5, out_dim)
        self.norm = nn.LayerNorm(out_dim*5)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = self.norm(x)
        x = self.fc2(x)
        return x

The following is the part of ViT, which can also be considered as the Encoder part of **Transformer**.

In [9]:
class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        self.head_dim = head_dim
        self.heads = heads
        self.inner_dim = self.heads*self.head_dim
        self.scale = self.head_dim**-0.5
        self.to_qkv = nn.Linear(dim, self.inner_dim*3)
        self.to_output = nn.Linear(self.inner_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2)
        att_score = Q@K_T*self.scale
        att = self.softmax(att_score)
        out = att@V   # (B,H,L,dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)')
        output = self.to_output(out)
        return output

In [10]:
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
class Transformer_block(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim):
        super().__init__()
        self.MHA = Multihead_self_attention(heads=heads, head_dim=head_dim, dim=dim)
        self.FeedForward = FeedForward(dim=dim, mlp_dim=mlp_dim)

    def forward(self, x):
        x = self.MHA(x)+x
        x = self.FeedForward(x)+x
        return x

In [12]:
class ViT(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, num_class):
        super().__init__()
        self.transformer = Transformer_block(dim=dim, heads=heads, head_dim=head_dim, mlp_dim=mlp_dim)

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_class)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.transformer(x)
        CLS_token = x[:, 0, :]
        out = self.MLP_head(CLS_token)
        return out

The following is **CNN**, used to extract feature information from images.

In [13]:
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()

        self.resnet18 = models.resnet18(weights=None)
        self.resnet18.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet18.maxpool = nn.Identity()
        self.ln = nn.LayerNorm(1000)
        self.fc1 = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet18(x)
        x = self.ln(x)
        x = self.fc1(x)
        return x

This is the **final multi-modal model**. Unlike ViT, its input is features extracted from each dimension. 

For features with time information, I added additional position information. 

At the same time, I also added a feature that is a fusion of the first four features.

In [14]:
class MultiModal(nn.Module):
    def __init__(self, num_classes):
        super(MultiModal, self).__init__()
        self.cls = nn.Parameter(torch.randn(1, 1, 200))
        self.meta = Embedding(31,200)
        self.resnet18 = ResNet18(200)
        self.landsat = Embedding(504,200)
        self.position_landsat = nn.Parameter(torch.randn(1, 504))
        self.climate = Embedding(931,200)
        self.position_climate = nn.Parameter(torch.randn(1, 931))
        self.emb = Embedding(800,200)
        self.position_combine = nn.Parameter(torch.randn(1, 800))
        self.vit = ViT(200, 2, 200, 400, num_classes)
        self.position = nn.Parameter(torch.randn(1, 6, 200))

    def forward(self, x):
        batch = x[0].size(0)
        CLS = repeat(self.cls, '1 1 d -> b 1 d', b=batch).to(device)
        META = self.meta(x[0])
        IMG = self.resnet18(x[1])
        LANDSAT = self.landsat(x[2]+self.position_landsat)
        CLIMATE = self.climate(x[3]+self.position_climate)
        combine = torch.cat((META, IMG, LANDSAT, CLIMATE), dim=1)
        COMBINE = self.emb(combine+self.position_combine)
        token = torch.concat((CLS, META.unsqueeze(1)), dim=1)
        token = torch.concat((token, IMG.unsqueeze(1)), dim=1)
        token = torch.concat((token, LANDSAT.unsqueeze(1)), dim=1)
        token = torch.concat((token, CLIMATE.unsqueeze(1)), dim=1)
        token = torch.concat((token, COMBINE.unsqueeze(1)), dim=1)
        out = self.vit(token+self.position)
        return out

In [15]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")

model = MultiModal(num_classes).to(device)
# model.load_state_dict(torch.load("models/multimodal-vit/Model.pth", map_location=device))

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True)



In [17]:
print(f"Training for {num_epochs} epochs started.")

for epoch in range(num_epochs):
    model.train()

    for batch_idx, (sample, survey_id, labels) in enumerate(train_loader):
        samples = [tensor.to(device) for tensor in sample]
        survey_id = survey_id.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(samples)

        pos_weight = labels*positive_weigh_factor  # All positive weights are equal to 10
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        if batch_idx % 278 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    scheduler.step()
    print("Scheduler:",scheduler.state_dict())

# Save the trained model
model.eval()
torch.save(model.state_dict(), "multimodal-model.pth")

Training for 10 epochs started.


In [None]:
with torch.no_grad():
    surveys = []
    top_indices = []
    for data, surveyID in tqdm.tqdm(test_loader, total=len(test_loader)):

        data = [tensor.to(device) for tensor in data]

        outputs = model(data)
        predictions = torch.sigmoid(outputs).cpu().numpy()
        predictions = np.squeeze(predictions)
#         prediction = np.argwhere(predictions>0).flatten()
        top_indices.append(predictions)
        surveys.extend(surveyID.cpu().numpy())

100%|██████████| 4716/4716 [01:50<00:00, 42.64it/s]


In [None]:
# with torch.no_grad():
#     surveys = []
#     top_indices = []
#     for data, surveyID in tqdm.tqdm(test_loader, total=len(test_loader)):

#         data = [tensor.to(device) for tensor in data]

#         outputs = model(data)
#         predictions = torch.sigmoid(outputs).cpu().numpy()
#         predictions = np.squeeze(predictions)
#         prediction = np.argwhere(predictions>=0.95).flatten()
#         top_indices.append(prediction)
#         surveys.extend(surveyID.cpu().numpy())

In [None]:
pd.DataFrame(top_indices).add_prefix("speciesId_").to_pickle("vit.pkl")

In [None]:
# data_concatenated = [' '.join(map(str, row)) for row in top_indices]

# pd.DataFrame(
#     {'surveyId': surveys,
#      'predictions': data_concatenated,
#     }).to_csv("submission3.csv", index = False)