# This is an INFERENCE using the model combine the method of MLP, CNN and Transformer

Thanks PICEKL for his baseline, which help me a lot.

The main idea of this method is to obtain multiple features from multiple dimensions. We simply divide it into four types: **image information** with spatial characteristics, **meta information** that is independent of each other, **climate information** and **satellite information** with time characteristics.

In [49]:
import os
import torch
from tqdm import tqdm
import numpy as np
import pandas as pd
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from einops import repeat
from pathlib import Path
import time
from models2 import ViT, MLP, ResNet18

In [50]:
new_dir = Path("output") / time.strftime('%Y-%m-%d_%H%M', time.localtime())
new_dir.mkdir(parents=True, exist_ok=True)

In [51]:
num_classes = 11255
num_epochs = 10
seed = 113

When reading data, different **fusion** and **normalization** will be performed for different types of data.

In [52]:
torch.manual_seed(seed)

<torch._C.Generator at 0x7afe6442bd10>

In [53]:
class TrainDataset(Dataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        labels = self.metadata[['surveyId' ,'speciesId']].astype(int).copy()
        self.label_dict = labels.groupby('surveyId')['speciesId'].apply(list).to_dict()

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-train-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-train-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-train-landsat_time_series/GLC24-PA-train-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-train-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-train-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-train-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-train-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_RGB/pa_train_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Train_SatellitePatches_NIR/pa_train_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        species_ids = self.label_dict[survey_id]  # Get list of species IDs for the survey ID
        label = torch.zeros(num_classes)  # Initialize label tensor
        for species_id in species_ids:
            label[species_id] = 1  # Set the corresponding class index to 1 for each species ID
        count = len(species_ids)
        return sample, survey_id, label, count#, species_ids

In [54]:
class TestDataset(Dataset):
    def __init__(self, metadata, subset, transform=None):
        self.subset = subset
        self.transform = transform
        self.metadata = metadata

        self.metadata = self.metadata.drop_duplicates(subset="surveyId").reset_index(drop=True)
        self.metadata.fillna(0,inplace=True)
        self.metadata.replace({float('-inf'): 0}, inplace=True)
        self.metadata_data = self.Norm(self.metadata.iloc[:,:5])

        self.merge_key = 'surveyId'
        self.climate_average = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Average 1981-2010/GLC24-PA-test-bioclimatic.csv")
        self.climate_monthly = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Climate/Monthly/GLC24-PA-test-bioclimatic_monthly.csv")
        self.climate = pd.merge(self.climate_average, self.climate_monthly, on=self.merge_key)
        self.climate.fillna(self.climate.mean(),inplace=True)
        self.climate_data = self.Norm_all(self.climate)

        self.landsat_b = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-blue.csv")
        self.landsat_b.fillna(self.landsat_b.mean(),inplace=True)
        self.landsat_g = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-green.csv")
        self.landsat_g.fillna(self.landsat_g.mean(),inplace=True)
        self.landsat_r = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-red.csv")
        self.landsat_r.fillna(self.landsat_r.mean(),inplace=True)
        self.landsat_n = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-nir.csv")
        self.landsat_n.fillna(self.landsat_n.mean(),inplace=True)
        self.landsat_s1 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir1.csv")
        self.landsat_s1.fillna(self.landsat_s1.mean(),inplace=True)
        self.landsat_s2 = pd.read_csv("data/geolifeclef-2024/PA-test-landsat_time_series/GLC24-PA-test-landsat_time_series-swir2.csv")
        self.landsat_s2.fillna(self.landsat_s2.mean(),inplace=True)
        self.landsat_data = torch.cat([self.Norm_all(self.landsat_b),self.Norm_all(self.landsat_g),self.Norm_all(self.landsat_r),self.Norm_all(self.landsat_n),self.Norm_all(self.landsat_s1),self.Norm_all(self.landsat_s2)],axis=1)

        self.elevation = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Elevation/GLC24-PA-test-elevation.csv")
        self.elevation[self.elevation<0]=0
        self.elevation.fillna(self.elevation.mean(),inplace=True)
        self.elevation_data = self.Norm(self.elevation)

        self.human_footprint = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/Human Footprint/GLC24-PA-test-human_footprint.csv")
        self.human_footprint[self.human_footprint<0]=0
        self.human_footprint.fillna(self.human_footprint.mean(),inplace=True)
        self.human_footprint_data = self.Norm(self.human_footprint)

        self.landcover = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/LandCover/GLC24-PA-test-landcover.csv")
        self.landcover[self.landcover<0]=0
        self.landcover.fillna(self.landcover.mean(),inplace=True)
        self.landcover_data = self.Norm(self.landcover)

        self.soilgrids = pd.read_csv("data/geolifeclef-2024/EnvironmentalRasters/EnvironmentalRasters/SoilGrids/GLC24-PA-test-soilgrids.csv")
        self.soilgrids[self.soilgrids<0]=0
        self.soilgrids.fillna(self.soilgrids.mean(),inplace=True)
        self.soilgrids_data = self.Norm(self.soilgrids)

        self.metadata_data = torch.cat((self.metadata_data, self.elevation_data, self.human_footprint_data, self.landcover_data, self.soilgrids_data), dim=1)

    def Norm(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean(dim=0))/output.std(dim=0)

    def Norm_all(self,df):
        output=torch.from_numpy(df.iloc[:,1:].values).float()
        return (output-output.mean())/output.std()

    def patch_rgb_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_RGB/pa_test_patches_rgb"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def patch_nir_path(self,survey_id):
        path = "data/geolifeclef-2024/PA_Test_SatellitePatches_NIR/pa_test_patches_nir"
        for d in (str(survey_id)[-2:], str(survey_id)[-4:-2]):
            path = os.path.join(path, d)
        path = os.path.join(path, f"{survey_id}.jpeg")
        return path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):

        survey_id = self.metadata.surveyId[idx]

        image_path = self.patch_rgb_path(survey_id)
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        image = image.unsqueeze(0)
        image_nir_path = self.patch_nir_path(survey_id)
        nir_image = Image.open(image_nir_path).convert("L")
        nir_image = self.transform(nir_image)
        nir_image = nir_image.unsqueeze(0)
        image_data = torch.cat([image,nir_image],dim=1)
        image_data = torch.squeeze(image_data)
        sample=[self.metadata_data[idx,:],image_data,self.landsat_data[idx,:],self.climate_data[idx,:]]
        return sample, survey_id

In [55]:
# Dataset and DataLoader
train_batch_size = 64
test_batch_size = 1

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

# Load training data
train_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_train.csv"
train_metadata = pd.read_csv(train_metadata_path)
train_dataset = TrainDataset(train_metadata, subset="train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=False, num_workers=1)

# Load testing data
test_metadata_path = "data/geolifeclef-2024/GLC24_PA_metadata_test.csv"
test_metadata = pd.read_csv(test_metadata_path)
test_dataset = TestDataset(test_metadata, subset="test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=1)

This is the **final multi-modal model**. Unlike ViT, its input is features extracted from each dimension. 

For features with time information, I added additional position information. 

At the same time, I also added a feature that is a fusion of the first four features.

In [56]:
class MultiModal(nn.Module):
    def __init__(self, num_classes):
        super(MultiModal, self).__init__()
        self.cls = nn.Parameter(torch.randn(1, 1, 200))
        self.meta = MLP(31,200)
        self.resnet18 = ResNet18(200)
        self.landsat = MLP(504,200)
        self.position_landsat = nn.Parameter(torch.randn(1, 504))
        self.climate = MLP(931,200)
        self.position_climate = nn.Parameter(torch.randn(1, 931))
        self.emb = MLP(800,200)
        self.position_combine = nn.Parameter(torch.randn(1, 800))
        self.vit = ViT(200, 2, 200, 400, num_classes)
        self.position = nn.Parameter(torch.randn(1, 6, 200))

    def forward(self, x):
        batch = x[0].size(0)
        CLS = repeat(self.cls, '1 1 d -> b 1 d', b=batch).to(device)
        META = self.meta(x[0])
        IMG = self.resnet18(x[1])
        LANDSAT = self.landsat(x[2]+self.position_landsat)
        CLIMATE = self.climate(x[3]+self.position_climate)
        combine = torch.cat((META, IMG, LANDSAT, CLIMATE), dim=1)
        COMBINE = self.emb(combine+self.position_combine)
        token = torch.concat((CLS, META.unsqueeze(1)), dim=1)
        token = torch.concat((token, IMG.unsqueeze(1)), dim=1)
        token = torch.concat((token, LANDSAT.unsqueeze(1)), dim=1)
        token = torch.concat((token, CLIMATE.unsqueeze(1)), dim=1)
        token = torch.concat((token, COMBINE.unsqueeze(1)), dim=1)
        out = self.vit(token+self.position)
        return out

In [57]:
# Check if cuda is available
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("DEVICE = CUDA")

model = MultiModal(num_classes).to(device)
# model.load_state_dict(torch.load("models/vit/Model.pth", map_location=device))

DEVICE = CUDA


## Train

In [58]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00025)
# logger.info("Optimizer: AdamW")
scheduler = CosineAnnealingLR(optimizer, T_max=25, verbose=True)
# logger.info("Scheduler: CosineAnnealingLR")



In [59]:
print(f"Training for {num_epochs} epochs started.")

for epoch in tqdm(range(num_epochs)):
    model.train()

    for batch_idx, (sample, survey_id, labels, count) in enumerate(train_loader):
        samples = [tensor.to(device) for tensor in sample]
        survey_id = survey_id.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # with torch.set_grad_enabled(True):
        outputs = model(samples)

        # pos_weight = labels*1.0  # All positive weights are equal to 10
        # criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        criterion = torch.nn.MultiLabelSoftMarginLoss()

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        if batch_idx % (len(train_loader)//10) == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}")

    if epoch % 10 == 0:
        torch.save(model.state_dict(), new_dir / f"multimodal-epoch-{epoch}.pth")

    scheduler.step()
torch.save(model.state_dict(), new_dir / "multimodal-final.pth")

Training for 10 epochs started.


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10, Batch 0/1391, Loss: 0.7349907755851746
Epoch 1/10, Batch 139/1391, Loss: 0.02957274578511715
Epoch 1/10, Batch 278/1391, Loss: 0.014072997495532036
Epoch 1/10, Batch 417/1391, Loss: 0.010164381004869938
Epoch 1/10, Batch 556/1391, Loss: 0.008698096498847008
Epoch 1/10, Batch 695/1391, Loss: 0.00786302238702774
Epoch 1/10, Batch 834/1391, Loss: 0.008332272991538048
Epoch 1/10, Batch 973/1391, Loss: 0.007484139874577522
Epoch 1/10, Batch 1112/1391, Loss: 0.006611760705709457
Epoch 1/10, Batch 1251/1391, Loss: 0.0067728441208601


 10%|█         | 1/10 [04:22<39:18, 262.10s/it]

Epoch 1/10, Batch 1390/1391, Loss: 0.0052049607038497925
Epoch 2/10, Batch 0/1391, Loss: 0.005556697025895119
Epoch 2/10, Batch 139/1391, Loss: 0.005286019295454025
Epoch 2/10, Batch 278/1391, Loss: 0.005960514303296804
Epoch 2/10, Batch 417/1391, Loss: 0.005892543122172356
Epoch 2/10, Batch 556/1391, Loss: 0.005626040045171976
Epoch 2/10, Batch 695/1391, Loss: 0.005467662587761879
Epoch 2/10, Batch 834/1391, Loss: 0.006014578510075808
Epoch 2/10, Batch 973/1391, Loss: 0.0055175600573420525
Epoch 2/10, Batch 1112/1391, Loss: 0.004954695701599121
Epoch 2/10, Batch 1251/1391, Loss: 0.005378980189561844


 20%|██        | 2/10 [08:43<34:55, 261.91s/it]

Epoch 2/10, Batch 1390/1391, Loss: 0.004102400969713926
Epoch 3/10, Batch 0/1391, Loss: 0.0044676270335912704
Epoch 3/10, Batch 139/1391, Loss: 0.00430769007652998
Epoch 3/10, Batch 278/1391, Loss: 0.004963258281350136
Epoch 3/10, Batch 417/1391, Loss: 0.005102289840579033
Epoch 3/10, Batch 556/1391, Loss: 0.004979284945875406
Epoch 3/10, Batch 695/1391, Loss: 0.004836549982428551
Epoch 3/10, Batch 834/1391, Loss: 0.0052825817838311195
Epoch 3/10, Batch 973/1391, Loss: 0.004920740611851215
Epoch 3/10, Batch 1112/1391, Loss: 0.004498898051679134
Epoch 3/10, Batch 1251/1391, Loss: 0.00496812304481864


 30%|███       | 3/10 [13:05<30:33, 261.93s/it]

Epoch 3/10, Batch 1390/1391, Loss: 0.0037807265762239695
Epoch 4/10, Batch 0/1391, Loss: 0.00413262564688921
Epoch 4/10, Batch 139/1391, Loss: 0.00396841811016202
Epoch 4/10, Batch 278/1391, Loss: 0.004581856541335583
Epoch 4/10, Batch 417/1391, Loss: 0.0047548385336995125
Epoch 4/10, Batch 556/1391, Loss: 0.004730638116598129
Epoch 4/10, Batch 695/1391, Loss: 0.004518009722232819
Epoch 4/10, Batch 834/1391, Loss: 0.004996577277779579
Epoch 4/10, Batch 973/1391, Loss: 0.004587549716234207
Epoch 4/10, Batch 1112/1391, Loss: 0.004273323342204094
Epoch 4/10, Batch 1251/1391, Loss: 0.0047425776720047


 40%|████      | 4/10 [17:28<26:13, 262.18s/it]

Epoch 4/10, Batch 1390/1391, Loss: 0.0035886599216610193
Epoch 5/10, Batch 0/1391, Loss: 0.0039310818538069725
Epoch 5/10, Batch 139/1391, Loss: 0.003820506390184164
Epoch 5/10, Batch 278/1391, Loss: 0.004376373253762722
Epoch 5/10, Batch 417/1391, Loss: 0.0045401244424283504
Epoch 5/10, Batch 556/1391, Loss: 0.004584381356835365
Epoch 5/10, Batch 695/1391, Loss: 0.004343959502875805
Epoch 5/10, Batch 834/1391, Loss: 0.004824021831154823
Epoch 5/10, Batch 973/1391, Loss: 0.004419183358550072
Epoch 5/10, Batch 1112/1391, Loss: 0.004129873588681221
Epoch 5/10, Batch 1251/1391, Loss: 0.004615657962858677


 50%|█████     | 5/10 [21:50<21:51, 262.27s/it]

Epoch 5/10, Batch 1390/1391, Loss: 0.003461048472672701
Epoch 6/10, Batch 0/1391, Loss: 0.0037960680201649666
Epoch 6/10, Batch 139/1391, Loss: 0.0037056072615087032
Epoch 6/10, Batch 278/1391, Loss: 0.004274362698197365
Epoch 6/10, Batch 417/1391, Loss: 0.004382345825433731
Epoch 6/10, Batch 556/1391, Loss: 0.004442816600203514
Epoch 6/10, Batch 695/1391, Loss: 0.004224259406328201
Epoch 6/10, Batch 834/1391, Loss: 0.004653670359402895
Epoch 6/10, Batch 973/1391, Loss: 0.00428025983273983
Epoch 6/10, Batch 1112/1391, Loss: 0.004079033620655537
Epoch 6/10, Batch 1251/1391, Loss: 0.004521721974015236


 60%|██████    | 6/10 [26:12<17:28, 262.21s/it]

Epoch 6/10, Batch 1390/1391, Loss: 0.0034094664733856916
Epoch 7/10, Batch 0/1391, Loss: 0.00365727161988616
Epoch 7/10, Batch 139/1391, Loss: 0.0036135478876531124
Epoch 7/10, Batch 278/1391, Loss: 0.004168370272964239
Epoch 7/10, Batch 417/1391, Loss: 0.004273276310414076
Epoch 7/10, Batch 556/1391, Loss: 0.004323720466345549
Epoch 7/10, Batch 695/1391, Loss: 0.004133121110498905
Epoch 7/10, Batch 834/1391, Loss: 0.004563512280583382
Epoch 7/10, Batch 973/1391, Loss: 0.004167885985225439
Epoch 7/10, Batch 1112/1391, Loss: 0.00400905217975378
Epoch 7/10, Batch 1251/1391, Loss: 0.004478261806070805


 70%|███████   | 7/10 [30:34<13:06, 262.16s/it]

Epoch 7/10, Batch 1390/1391, Loss: 0.0033434201031923294
Epoch 8/10, Batch 0/1391, Loss: 0.003550377208739519
Epoch 8/10, Batch 139/1391, Loss: 0.0035694686230272055
Epoch 8/10, Batch 278/1391, Loss: 0.004242718685418367
Epoch 8/10, Batch 417/1391, Loss: 0.004150160122662783
Epoch 8/10, Batch 556/1391, Loss: 0.004205682780593634
Epoch 8/10, Batch 695/1391, Loss: 0.004084913060069084
Epoch 8/10, Batch 834/1391, Loss: 0.004431955050677061
Epoch 8/10, Batch 973/1391, Loss: 0.004082387313246727
Epoch 8/10, Batch 1112/1391, Loss: 0.003938659094274044
Epoch 8/10, Batch 1251/1391, Loss: 0.0043869661167263985


 80%|████████  | 8/10 [34:57<08:44, 262.13s/it]

Epoch 8/10, Batch 1390/1391, Loss: 0.00326737598516047
Epoch 9/10, Batch 0/1391, Loss: 0.003517943900078535
Epoch 9/10, Batch 139/1391, Loss: 0.003511570394039154
Epoch 9/10, Batch 278/1391, Loss: 0.003980784676969051
Epoch 9/10, Batch 417/1391, Loss: 0.004085118882358074
Epoch 9/10, Batch 556/1391, Loss: 0.004177015274763107
Epoch 9/10, Batch 695/1391, Loss: 0.004043424502015114
Epoch 9/10, Batch 834/1391, Loss: 0.004429874941706657
Epoch 9/10, Batch 973/1391, Loss: 0.003990948665887117
Epoch 9/10, Batch 1112/1391, Loss: 0.00386800873093307
Epoch 9/10, Batch 1251/1391, Loss: 0.0043358877301216125


 90%|█████████ | 9/10 [39:19<04:22, 262.14s/it]

Epoch 9/10, Batch 1390/1391, Loss: 0.0031601854134351015
Epoch 10/10, Batch 0/1391, Loss: 0.003447657683864236
Epoch 10/10, Batch 139/1391, Loss: 0.0034442697651684284
Epoch 10/10, Batch 278/1391, Loss: 0.003933025524020195
Epoch 10/10, Batch 417/1391, Loss: 0.004025811795145273
Epoch 10/10, Batch 556/1391, Loss: 0.0040794555097818375
Epoch 10/10, Batch 695/1391, Loss: 0.003927966579794884
Epoch 10/10, Batch 834/1391, Loss: 0.0043701184913516045
Epoch 10/10, Batch 973/1391, Loss: 0.003941917326301336
Epoch 10/10, Batch 1112/1391, Loss: 0.003791925497353077
Epoch 10/10, Batch 1251/1391, Loss: 0.004210730083286762


100%|██████████| 10/10 [43:41<00:00, 262.13s/it]

Epoch 10/10, Batch 1390/1391, Loss: 0.0031075396109372377





## Produce Submission

In [63]:
with torch.no_grad():
    surveys = []
    top_indices = []
    for data, surveyID in tqdm(test_loader, total=len(test_loader)):

        data = [tensor.to(device) for tensor in data]

        outputs = model(data)
        predictions = torch.sigmoid(outputs).cpu().numpy()
        predictions = np.squeeze(predictions)
        # print(predictions)
        prediction = np.argwhere(predictions>=0.25).flatten()
        top_indices.append(prediction)
        surveys.extend(surveyID.cpu().numpy())

100%|██████████| 4716/4716 [00:11<00:00, 424.03it/s]


In [64]:
data_concatenated = [' '.join(map(str, row)) for row in top_indices]

pd.DataFrame(
    {'surveyId': surveys,
     'predictions': data_concatenated,
    }).to_csv(new_dir / "submission.csv", index = False)

In [65]:
total = 0
for row in top_indices:
    total += row.shape[0]

print(f"Total number of predictions: {total} with average of {total/len(top_indices)} predictions per survey ID.")

Total number of predictions: 77425 with average of 16.41751484308736 predictions per survey ID.
