# This is an INFERENCE using the model combine the method of MLP, CNN and Transformer

Thanks PICEKL for his baseline, which help me a lot.

The main idea of this method is to obtain multiple features from multiple dimensions. We simply divide it into four types: **image information** with spatial characteristics, **meta information** that is independent of each other, **climate information** and **satellite information** with time characteristics.

In [1]:
import os
import torch
from tqdm import tqdm
import numpy as np
import pandas as pd
import torchvision.models as models
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from pathlib import Path
import time

In [2]:
new_dir = Path("output") / time.strftime('%Y-%m-%d_%H%M', time.localtime())
new_dir.mkdir(parents=True, exist_ok=True)

In [3]:
num_classes = 11255
num_epochs = 100
seed = 113

When reading data, different **fusion** and **normalization** will be performed for different types of data.

In [4]:
torch.manual_seed(seed)

<torch._C.Generator at 0x750d98237cf0>

In [5]:
class TrainDataset(Dataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        labels = self.metadata[['surveyId' ,'speciesId']].astype(int).copy()
        self.label_dict = labels.groupby('surveyId')['speciesId'].apply(list).to_dict()

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-train-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-train-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-train-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-train-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-train-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-train-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_NIR/pa_train_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        species_ids = self.label_dict[survey_id]  # Get list of species IDs for the survey ID
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            label[species_id] = 1  # Set the corresponding class index to 1 for each species ID
        count = len(species_ids)
        return sample, survey_id, label, count#, species_ids

In [6]:
class TestDataset(Dataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-test-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-test-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-test-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-test-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-test-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-test-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_NIR/pa_test_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        return sample, survey_id

In [7]:
# Dataset and DataLoader
train_batch_size = 64
test_batch_size = 1

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

# Load training data
train_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_train.csv"
train_metadata = pd.read_csv(train_metadata_path)
train_dataset = TrainDataset(train_metadata, subset="train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=1)

# Load testing data
test_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_test.csv"
test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(test_metadata, subset="test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=1)

This is an MLP used to extract features from generally independent information.

In [8]:
class Embedding(nn.Module):
    def __init__(self, dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, out_dim*5)
        self.fc2 = nn.Linear(out_dim*5, out_dim)
        self.norm = nn.LayerNorm(out_dim*5)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = self.norm(x)
        x = self.fc2(x)
        return x

The following is the part of ViT, which can also be considered as the Encoder part of **Transformer**.

In [9]:
class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        self.head_dim = head_dim
        self.heads = heads
        self.inner_dim = self.heads*self.head_dim
        self.scale = self.head_dim**-0.5
        self.to_qkv = nn.Linear(dim, self.inner_dim*3)
        self.to_output = nn.Linear(self.inner_dim, dim)
        self.norm = nn.LayerNorm(dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2)
        att_score = Q@K_T*self.scale
        att = self.softmax(att_score)
        out = att@V   # (B,H,L,dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)')
        output = self.to_output(out)
        return output

In [10]:
class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

In [11]:
class Transformer_block(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim):
        super().__init__()
        self.MHA = Multihead_self_attention(heads=heads, head_dim=head_dim, dim=dim)
        self.FeedForward = FeedForward(dim=dim, mlp_dim=mlp_dim)

    def forward(self, x):
        x = self.MHA(x)+x
        x = self.FeedForward(x)+x
        return x

In [12]:
class ViT(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, num_class):
        super().__init__()
        self.transformer = Transformer_block(dim=dim, heads=heads, head_dim=head_dim, mlp_dim=mlp_dim)

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_class)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.transformer(x)
        CLS_token = x[:, 0, :]
        out = self.MLP_head(CLS_token)
        return out

The following is **CNN**, used to extract feature information from images.

In [13]:
class ResNet18(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18, self).__init__()

        self.resnet18 = models.resnet18(weights=None)
        self.resnet18.conv1 = nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet18.maxpool = nn.Identity()
        self.ln = nn.LayerNorm(1000)
        self.fc1 = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet18(x)
        x = self.ln(x)
        x = self.fc1(x)
        return x

This is the **final multi-modal model**. Unlike ViT, its input is features extracted from each dimension. 

For features with time information, I added additional position information. 

At the same time, I also added a feature that is a fusion of the first four features.

In [14]:
class MultiModal(nn.Module):
    def __init__(self, num_classes):
        super(MultiModal, self).__init__()
        self.cls = nn.Parameter(torch.randn(1, 1, 200))
        self.meta = Embedding(31,200)
        self.resnet18 = ResNet18(200)
        self.landsat = Embedding(504,200)
        self.position_landsat = nn.Parameter(torch.randn(1, 504))
        self.climate = Embedding(931,200)
        self.position_climate = nn.Parameter(torch.randn(1, 931))
        self.emb = Embedding(800,200)
        self.position_combine = nn.Parameter(torch.randn(1, 800))
        self.vit = ViT(200, 2, 200, 400, num_classes)
        self.position = nn.Parameter(torch.randn(1, 6, 200))

    def forward(self, x):
        batch = x[0].size(0)
        CLS = repeat(self.cls, '1 1 d -> b 1 d', b=batch).to(device)
        META = self.meta(x[0])
        IMG = self.resnet18(x[1])
        LANDSAT = self.landsat(x[2]+self.position_landsat)
        CLIMATE = self.climate(x[3]+self.position_climate)
        combine = torch.cat((META, IMG, LANDSAT, CLIMATE), dim=1)
        COMBINE = self.emb(combine+self.position_combine)
        token = torch.concat((CLS, META.unsqueeze(1)), dim=1)
        token = torch.concat((token, IMG.unsqueeze(1)), dim=1)
        token = torch.concat((token, LANDSAT.unsqueeze(1)), dim=1)
        token = torch.concat((token, CLIMATE.unsqueeze(1)), dim=1)
        token = torch.concat((token, COMBINE.unsqueeze(1)), dim=1)
        out = self.vit(token+self.position)
        return out

In [15]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")

model = MultiModal(num_classes).to(device)
# model.load_state_dict(torch.load("models/vit/Model.pth", map_location=device))

DEVICE = CUDA


## Train

In [16]:
from torch.optim.lr_scheduler import CosineAnnealingLR
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025)
# logger.info("Optimizer: AdamW")
scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True)
# logger.info("Scheduler: CosineAnnealingLR")



In [17]:
print(f"Training for {num_epochs} epochs started.")

for epoch in tqdm(range(num_epochs)):
    model.train()

    for batch_idx, (sample, survey_id, labels, count) in enumerate(train_loader):
        samples = [tensor.to(device) for tensor in sample]
        survey_id = survey_id.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # with torch.set_grad_enabled(True):
        outputs = model(samples)

        pos_weight = labels*1.0  # All positive weights are equal to 10
        criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        if batch_idx % (len(train_loader)//10) == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    if epoch % 10 == 0:
        torch.save(model.state_dict(), new_dir / f"multimodal-epoch-{epoch}.pth")

    scheduler.step()
torch.save(model.state_dict(), new_dir / "multimodal-final.pth")

Training for 100 epochs started.


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/100, Batch 0/1391, Loss: 0.7349907159805298
Epoch 1/100, Batch 139/1391, Loss: 0.029572946950793266
Epoch 1/100, Batch 278/1391, Loss: 0.014072494581341743
Epoch 1/100, Batch 417/1391, Loss: 0.010163888335227966
Epoch 1/100, Batch 556/1391, Loss: 0.00869766902178526
Epoch 1/100, Batch 695/1391, Loss: 0.007862644270062447
Epoch 1/100, Batch 834/1391, Loss: 0.008331937715411186
Epoch 1/100, Batch 973/1391, Loss: 0.007483727764338255
Epoch 1/100, Batch 1112/1391, Loss: 0.006611417513340712
Epoch 1/100, Batch 1251/1391, Loss: 0.006771528162062168


  1%|          | 1/100 [04:20<7:09:55, 260.56s/it]

Epoch 1/100, Batch 1390/1391, Loss: 0.005203625187277794
Epoch 2/100, Batch 0/1391, Loss: 0.005556628108024597
Epoch 2/100, Batch 139/1391, Loss: 0.005286202300339937
Epoch 2/100, Batch 278/1391, Loss: 0.0059546516276896
Epoch 2/100, Batch 417/1391, Loss: 0.00588897755369544
Epoch 2/100, Batch 556/1391, Loss: 0.005632961168885231
Epoch 2/100, Batch 695/1391, Loss: 0.005461188033223152
Epoch 2/100, Batch 834/1391, Loss: 0.0060033234767615795
Epoch 2/100, Batch 973/1391, Loss: 0.005516052711755037
Epoch 2/100, Batch 1112/1391, Loss: 0.0049687596037983894
Epoch 2/100, Batch 1251/1391, Loss: 0.005389644298702478


  2%|▏         | 2/100 [08:42<7:07:17, 261.60s/it]

Epoch 2/100, Batch 1390/1391, Loss: 0.004094450268894434
Epoch 3/100, Batch 0/1391, Loss: 0.0044764368794858456
Epoch 3/100, Batch 139/1391, Loss: 0.004302266053855419
Epoch 3/100, Batch 278/1391, Loss: 0.004993549082428217
Epoch 3/100, Batch 417/1391, Loss: 0.0051199281588196754
Epoch 3/100, Batch 556/1391, Loss: 0.004981357604265213
Epoch 3/100, Batch 695/1391, Loss: 0.0048621646128594875
Epoch 3/100, Batch 834/1391, Loss: 0.005269576329737902
Epoch 3/100, Batch 973/1391, Loss: 0.004949171096086502
Epoch 3/100, Batch 1112/1391, Loss: 0.004477598238736391
Epoch 3/100, Batch 1251/1391, Loss: 0.004935792647302151


  3%|▎         | 3/100 [13:05<7:03:31, 261.97s/it]

Epoch 3/100, Batch 1390/1391, Loss: 0.0037795782554894686
Epoch 4/100, Batch 0/1391, Loss: 0.004124871920794249
Epoch 4/100, Batch 139/1391, Loss: 0.003960103262215853
Epoch 4/100, Batch 278/1391, Loss: 0.004616585560142994
Epoch 4/100, Batch 417/1391, Loss: 0.004743389785289764
Epoch 4/100, Batch 556/1391, Loss: 0.004715072922408581
Epoch 4/100, Batch 695/1391, Loss: 0.004523777402937412
Epoch 4/100, Batch 834/1391, Loss: 0.004976877011358738
Epoch 4/100, Batch 973/1391, Loss: 0.00461329473182559
Epoch 4/100, Batch 1112/1391, Loss: 0.004278188105672598
Epoch 4/100, Batch 1251/1391, Loss: 0.004747818224132061


  4%|▍         | 4/100 [17:28<6:59:49, 262.39s/it]

Epoch 4/100, Batch 1390/1391, Loss: 0.0036199241876602173
Epoch 5/100, Batch 0/1391, Loss: 0.003939446993172169
Epoch 5/100, Batch 139/1391, Loss: 0.0037916458677500486
Epoch 5/100, Batch 278/1391, Loss: 0.004410835914313793
Epoch 5/100, Batch 417/1391, Loss: 0.004545021802186966
Epoch 5/100, Batch 556/1391, Loss: 0.0045617795549333096
Epoch 5/100, Batch 695/1391, Loss: 0.004355668090283871
Epoch 5/100, Batch 834/1391, Loss: 0.004785933066159487
Epoch 5/100, Batch 973/1391, Loss: 0.004430574364960194
Epoch 5/100, Batch 1112/1391, Loss: 0.004096958786249161
Epoch 5/100, Batch 1251/1391, Loss: 0.0046019237488508224


  5%|▌         | 5/100 [21:50<6:55:33, 262.46s/it]

Epoch 5/100, Batch 1390/1391, Loss: 0.0035346036311239004
Epoch 6/100, Batch 0/1391, Loss: 0.0037472706753760576
Epoch 6/100, Batch 139/1391, Loss: 0.0036977191921323538
Epoch 6/100, Batch 278/1391, Loss: 0.0042938655242323875
Epoch 6/100, Batch 417/1391, Loss: 0.004393143579363823
Epoch 6/100, Batch 556/1391, Loss: 0.004423927515745163
Epoch 6/100, Batch 695/1391, Loss: 0.0042361607775092125
Epoch 6/100, Batch 834/1391, Loss: 0.004649145528674126
Epoch 6/100, Batch 973/1391, Loss: 0.004380247090011835
Epoch 6/100, Batch 1112/1391, Loss: 0.004043597728013992
Epoch 6/100, Batch 1251/1391, Loss: 0.004457805771380663


  6%|▌         | 6/100 [26:13<6:51:15, 262.50s/it]

Epoch 6/100, Batch 1390/1391, Loss: 0.0034150693099945784
Epoch 7/100, Batch 0/1391, Loss: 0.003647738602012396
Epoch 7/100, Batch 139/1391, Loss: 0.0036405082792043686
Epoch 7/100, Batch 278/1391, Loss: 0.0042239646427333355
Epoch 7/100, Batch 417/1391, Loss: 0.0042792572639882565
Epoch 7/100, Batch 556/1391, Loss: 0.004339543636888266
Epoch 7/100, Batch 695/1391, Loss: 0.004147310741245747
Epoch 7/100, Batch 834/1391, Loss: 0.004526773001998663
Epoch 7/100, Batch 973/1391, Loss: 0.004238068591803312
Epoch 7/100, Batch 1112/1391, Loss: 0.003913720138370991
Epoch 7/100, Batch 1251/1391, Loss: 0.004436977207660675


  7%|▋         | 7/100 [30:36<6:46:54, 262.52s/it]

Epoch 7/100, Batch 1390/1391, Loss: 0.0033450121991336346
Epoch 8/100, Batch 0/1391, Loss: 0.0035453469026833773
Epoch 8/100, Batch 139/1391, Loss: 0.003624115139245987
Epoch 8/100, Batch 278/1391, Loss: 0.004147933330386877
Epoch 8/100, Batch 417/1391, Loss: 0.004192279651761055
Epoch 8/100, Batch 556/1391, Loss: 0.0042151776142418385
Epoch 8/100, Batch 695/1391, Loss: 0.004130037501454353
Epoch 8/100, Batch 834/1391, Loss: 0.004439074080437422
Epoch 8/100, Batch 973/1391, Loss: 0.00412951223552227
Epoch 8/100, Batch 1112/1391, Loss: 0.003863994497805834
Epoch 8/100, Batch 1251/1391, Loss: 0.0043493956327438354


  8%|▊         | 8/100 [34:58<6:42:33, 262.54s/it]

Epoch 8/100, Batch 1390/1391, Loss: 0.003260658122599125
Epoch 9/100, Batch 0/1391, Loss: 0.003475713077932596
Epoch 9/100, Batch 139/1391, Loss: 0.0035386730451136827
Epoch 9/100, Batch 278/1391, Loss: 0.004021951928734779
Epoch 9/100, Batch 417/1391, Loss: 0.004149140324443579
Epoch 9/100, Batch 556/1391, Loss: 0.004132898524403572
Epoch 9/100, Batch 695/1391, Loss: 0.004038461484014988
Epoch 9/100, Batch 834/1391, Loss: 0.004364408552646637
Epoch 9/100, Batch 973/1391, Loss: 0.0040654209442436695
Epoch 9/100, Batch 1112/1391, Loss: 0.0038547220174223185
Epoch 9/100, Batch 1251/1391, Loss: 0.004263853654265404


  9%|▉         | 9/100 [39:21<6:38:14, 262.57s/it]

Epoch 9/100, Batch 1390/1391, Loss: 0.0031729491893202066
Epoch 10/100, Batch 0/1391, Loss: 0.0034426318015903234
Epoch 10/100, Batch 139/1391, Loss: 0.003508793655782938
Epoch 10/100, Batch 278/1391, Loss: 0.003971154801547527
Epoch 10/100, Batch 417/1391, Loss: 0.004087402950972319
Epoch 10/100, Batch 556/1391, Loss: 0.004039234481751919
Epoch 10/100, Batch 695/1391, Loss: 0.003984333481639624
Epoch 10/100, Batch 834/1391, Loss: 0.004338400438427925
Epoch 10/100, Batch 973/1391, Loss: 0.004050201270729303
Epoch 10/100, Batch 1112/1391, Loss: 0.003766178386285901
Epoch 10/100, Batch 1251/1391, Loss: 0.004228893201798201


 10%|█         | 10/100 [43:43<6:33:52, 262.59s/it]

Epoch 10/100, Batch 1390/1391, Loss: 0.0030899227131158113
Epoch 11/100, Batch 0/1391, Loss: 0.003362937830388546
Epoch 11/100, Batch 139/1391, Loss: 0.0034604379907250404
Epoch 11/100, Batch 278/1391, Loss: 0.0038739352021366358
Epoch 11/100, Batch 417/1391, Loss: 0.004036239348351955
Epoch 11/100, Batch 556/1391, Loss: 0.00397328520193696
Epoch 11/100, Batch 695/1391, Loss: 0.003963307477533817
Epoch 11/100, Batch 834/1391, Loss: 0.004276209510862827
Epoch 11/100, Batch 973/1391, Loss: 0.003950925078243017
Epoch 11/100, Batch 1112/1391, Loss: 0.003739011473953724
Epoch 11/100, Batch 1251/1391, Loss: 0.004222269169986248


 11%|█         | 11/100 [48:06<6:29:31, 262.60s/it]

Epoch 11/100, Batch 1390/1391, Loss: 0.003049023449420929
Epoch 12/100, Batch 0/1391, Loss: 0.003368351375684142
Epoch 12/100, Batch 139/1391, Loss: 0.0034302978310734034
Epoch 12/100, Batch 278/1391, Loss: 0.003785313107073307
Epoch 12/100, Batch 417/1391, Loss: 0.003938687965273857
Epoch 12/100, Batch 556/1391, Loss: 0.003916731104254723
Epoch 12/100, Batch 695/1391, Loss: 0.00389113905839622
Epoch 12/100, Batch 834/1391, Loss: 0.004186142235994339
Epoch 12/100, Batch 973/1391, Loss: 0.003908249083906412
Epoch 12/100, Batch 1112/1391, Loss: 0.0036769784055650234
Epoch 12/100, Batch 1251/1391, Loss: 0.004078512545675039


 12%|█▏        | 12/100 [52:28<6:25:04, 262.55s/it]

Epoch 12/100, Batch 1390/1391, Loss: 0.002997560193762183
Epoch 13/100, Batch 0/1391, Loss: 0.0033331194426864386
Epoch 13/100, Batch 139/1391, Loss: 0.003402861300855875
Epoch 13/100, Batch 278/1391, Loss: 0.0037263906560838223
Epoch 13/100, Batch 417/1391, Loss: 0.003896048991009593
Epoch 13/100, Batch 556/1391, Loss: 0.003865809878334403
Epoch 13/100, Batch 695/1391, Loss: 0.003775780089199543
Epoch 13/100, Batch 834/1391, Loss: 0.004128935746848583
Epoch 13/100, Batch 973/1391, Loss: 0.0038872857112437487
Epoch 13/100, Batch 1112/1391, Loss: 0.003653877414762974
Epoch 13/100, Batch 1251/1391, Loss: 0.004028670955449343


 13%|█▎        | 13/100 [56:51<6:20:41, 262.54s/it]

Epoch 13/100, Batch 1390/1391, Loss: 0.0029009757563471794
Epoch 14/100, Batch 0/1391, Loss: 0.003278088290244341
Epoch 14/100, Batch 139/1391, Loss: 0.003376986365765333
Epoch 14/100, Batch 278/1391, Loss: 0.0036892902571707964
Epoch 14/100, Batch 417/1391, Loss: 0.003848089836537838
Epoch 14/100, Batch 556/1391, Loss: 0.003819767851382494
Epoch 14/100, Batch 695/1391, Loss: 0.0037171272560954094
Epoch 14/100, Batch 834/1391, Loss: 0.004103981424123049
Epoch 14/100, Batch 973/1391, Loss: 0.003864023834466934
Epoch 14/100, Batch 1112/1391, Loss: 0.003626142628490925
Epoch 14/100, Batch 1251/1391, Loss: 0.003987183328717947


 14%|█▍        | 14/100 [1:01:13<6:16:16, 262.52s/it]

Epoch 14/100, Batch 1390/1391, Loss: 0.0028689694590866566
Epoch 15/100, Batch 0/1391, Loss: 0.003259641584008932
Epoch 15/100, Batch 139/1391, Loss: 0.0033505037426948547
Epoch 15/100, Batch 278/1391, Loss: 0.0036287393886595964
Epoch 15/100, Batch 417/1391, Loss: 0.0037722750566899776
Epoch 15/100, Batch 556/1391, Loss: 0.003783648367971182
Epoch 15/100, Batch 695/1391, Loss: 0.003671746701002121
Epoch 15/100, Batch 834/1391, Loss: 0.004053430166095495
Epoch 15/100, Batch 973/1391, Loss: 0.00383047410286963
Epoch 15/100, Batch 1112/1391, Loss: 0.0035825553350150585
Epoch 15/100, Batch 1251/1391, Loss: 0.003939858637750149


 15%|█▌        | 15/100 [1:05:36<6:11:55, 262.54s/it]

Epoch 15/100, Batch 1390/1391, Loss: 0.00284550990909338
Epoch 16/100, Batch 0/1391, Loss: 0.0032043757382780313
Epoch 16/100, Batch 139/1391, Loss: 0.003324377816170454
Epoch 16/100, Batch 278/1391, Loss: 0.0035977463703602552
Epoch 16/100, Batch 417/1391, Loss: 0.0037078089080750942
Epoch 16/100, Batch 556/1391, Loss: 0.0037196732591837645
Epoch 16/100, Batch 695/1391, Loss: 0.0036208799574524164
Epoch 16/100, Batch 834/1391, Loss: 0.004010818433016539
Epoch 16/100, Batch 973/1391, Loss: 0.0038063968531787395
Epoch 16/100, Batch 1112/1391, Loss: 0.003527275286614895
Epoch 16/100, Batch 1251/1391, Loss: 0.003922350239008665


 16%|█▌        | 16/100 [1:09:59<6:07:36, 262.57s/it]

Epoch 16/100, Batch 1390/1391, Loss: 0.002798818051815033
Epoch 17/100, Batch 0/1391, Loss: 0.0031689044553786516
Epoch 17/100, Batch 139/1391, Loss: 0.0032997471280395985
Epoch 17/100, Batch 278/1391, Loss: 0.003566885832697153
Epoch 17/100, Batch 417/1391, Loss: 0.003676369786262512
Epoch 17/100, Batch 556/1391, Loss: 0.003653602907434106
Epoch 17/100, Batch 695/1391, Loss: 0.0035631456412374973
Epoch 17/100, Batch 834/1391, Loss: 0.003972898703068495
Epoch 17/100, Batch 973/1391, Loss: 0.0037647930439561605
Epoch 17/100, Batch 1112/1391, Loss: 0.0034953244030475616
Epoch 17/100, Batch 1251/1391, Loss: 0.003876172238960862


 17%|█▋        | 17/100 [1:14:21<6:03:15, 262.59s/it]

Epoch 17/100, Batch 1390/1391, Loss: 0.002753953682258725
Epoch 18/100, Batch 0/1391, Loss: 0.00313375610858202
Epoch 18/100, Batch 139/1391, Loss: 0.003273540874943137
Epoch 18/100, Batch 278/1391, Loss: 0.00353226438164711
Epoch 18/100, Batch 417/1391, Loss: 0.003660206450149417
Epoch 18/100, Batch 556/1391, Loss: 0.003602575743570924
Epoch 18/100, Batch 695/1391, Loss: 0.0035355279687792063
Epoch 18/100, Batch 834/1391, Loss: 0.00394072663038969
Epoch 18/100, Batch 973/1391, Loss: 0.0037234097253531218
Epoch 18/100, Batch 1112/1391, Loss: 0.0034835324622690678
Epoch 18/100, Batch 1251/1391, Loss: 0.0038301178719848394


 18%|█▊        | 18/100 [1:18:44<5:58:53, 262.61s/it]

Epoch 18/100, Batch 1390/1391, Loss: 0.002724003279581666
Epoch 19/100, Batch 0/1391, Loss: 0.0030956612899899483
Epoch 19/100, Batch 139/1391, Loss: 0.0032647359184920788
Epoch 19/100, Batch 278/1391, Loss: 0.0035039838403463364
Epoch 19/100, Batch 417/1391, Loss: 0.0036346372216939926
Epoch 19/100, Batch 556/1391, Loss: 0.0035827034153044224
Epoch 19/100, Batch 695/1391, Loss: 0.0035240361467003822


## Produce Submission

In [None]:
with torch.no_grad():
    surveys = []
    top_indices = []
    for data, surveyID in tqdm(test_loader, total=len(test_loader)):

        data = [tensor.to(device) for tensor in data]

        outputs = model(data)
        predictions = torch.sigmoid(outputs).cpu().numpy()
        predictions = np.squeeze(predictions)
        # print(predictions)
        prediction = np.argwhere(predictions>=0.25).flatten()
        top_indices.append(prediction)
        surveys.extend(surveyID.cpu().numpy())

100%|██████████| 4716/4716 [00:11<00:00, 419.52it/s]


In [None]:
data_concatenated = [' '.join(map(str, row)) for row in top_indices]

pd.DataFrame(
    {'surveyId': surveys,
     'predictions': data_concatenated,
    }).to_csv(new_dir / "submission.csv", index = False)

In [None]:
total = 0
for row in top_indices:
    total += row.shape[0]

print(f"Total number of predictions: {total} with average of {total/len(top_indices)} predictions per survey ID.")

Total number of predictions: 79292 with average of 16.81340118744699 predictions per survey ID.
