In [None]:
!pip -q install wandb
!wandb login

# Data import

In [None]:
import os
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ConvNextImageProcessor, ConvNextForImageClassification
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import copy
from copy import deepcopy
try:
  import wandb
except ImportError:
  pass

In [None]:
!pip install kaggle
!mkdir /.kaggle
!mv kaggle.json /.kaggle
!mv /.kaggle /root/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle competitions download -c dlmi-lymphocytosis-classification

In [None]:
!unzip /content/dlmi-lymphocytosis-classification.zip

In [None]:
!wget -P /content/ -O test_df_features.csv https://www.dropbox.com/scl/fi/xq4n2smw8l72ia02brxhh/test_df_features.csv?rlkey=zxe6i34f6cu0oq0vxy77ixlfl&dl=0
!wget -P /content/ -O train_df_features.csv https://www.dropbox.com/scl/fi/vh72tl6skb8ncx3c22qx5/train_df_features.csv?rlkey=lc5ijt3y01x2g76qzb94z0odp&dl=0

# Creation of the dataset

In [None]:
class PatientDataset(Dataset):
    """
    For each patient get a dict with:
    - 'images': list of different images,
    - 'age': age
    - 'lymph_count': lymph_count
    - 'gender': gender - can be useful
    - 'label': label - useful during training, useless during test
    - 'id': id - useless during training, usefull during test for submission
    """
    def __init__(self, csv_file, root_dir, feat_file, transform=None, ids = None):
        self.data_frame = pd.read_csv(csv_file)
        self.data_frame['GENDER'] = self.data_frame['GENDER'].replace(['f'],['F']) # Some samples were labeled as 'f' instead of 'F'
        self.df_feats = pd.read_csv(feat_file)
        self.data_frame = pd.merge(self.data_frame, self.df_feats, on='ID')
        self.root_dir = root_dir
        self.transform = transform
        self.ids = ids

    def __len__(self):
        if self.ids is not None:
          return len(self.ids)
        else:
          return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if self.ids is not None:
          idx = self.ids[idx]

        img_dir = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
        images = [Image.open(os.path.join(img_dir, img_name)) for img_name in os.listdir(img_dir)]

        id = self.data_frame.iloc[idx, 0]
        label = self.data_frame.iloc[idx, 1]
        gender = 1 if self.data_frame.iloc[idx, 2] == 'M' else 0
        age = self.calculate_age(self.data_frame.iloc[idx, 3])
        lymph_count = self.data_frame.iloc[idx, 4]
        features = eval(self.data_frame.iloc[idx, 5])
        features = torch.tensor(features, dtype=torch.float32)

        sample = {'images': images, 'age': age, 'lymph_count': lymph_count, 'gender': gender, 'features': features, 'label': label, 'id': id}

        if self.transform:
            sample = self.transform(sample)

        return sample

    def calculate_age(self, dob):
        year = int(dob[-4:])
        return 2024-year

class TestTransform(object):
  """ Useful class to turn our image into tensor """
  def __call__(self, sample):
    images, age, lymph_count, gender, features, label, id = sample['images'], sample['age'], sample['lymph_count'], sample['gender'], sample['features'], sample['label'], sample['id']

    transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                  ])

    images = [transform(img) for img in images]

    return {'images': images, 'age': age, 'lymph_count': lymph_count, 'gender': gender, 'features': features, 'label': label, 'id': id}

class TrainTransform(object):
  def __call__(self, sample):
    images, age, lymph_count, gender, features, label, id = sample['images'], sample['age'], sample['lymph_count'], sample['gender'], sample['features'], sample['label'], sample['id']

    transform = transforms.Compose([
                                    transforms.RandomHorizontalFlip(),
                                    transforms.RandomVerticalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                                  ])

    transform = transforms.Compose([
                                    transforms.ToTensor()
                                  ])

    images = [transform(img) for img in images]

    return {'images': images, 'age': age, 'lymph_count': lymph_count, 'gender': gender, 'features': features, 'label': label, 'id': id}

In [None]:
testset = PatientDataset(csv_file='/content/testset/testset_data.csv',
                         root_dir='/content/testset/',
                         transform=transforms.Compose([
                             TestTransform(),
                         ]),
                         feat_file='/content/test_df_features.csv')

testloader = DataLoader(testset)

In [None]:
def k_fold_unbalanced(dataframe, nb_folds=5, seed=None):
  """
  Return k fold trainset, testset unbalanced: with the same distribution
  """
  if seed is not None:
    np.random.seed(seed)

  n_0 = dataframe.groupby('LABEL').nunique()['ID'][0]
  p_0 = n_0/len(dataframe)
  n_1 = dataframe.groupby('LABEL').nunique()['ID'][1]
  p_1 = n_1/len(dataframe)
  train_datasets_list = list()
  val_datasets_list = list()
  n_list = list()
  # Pick ids depending on sizes and proportions
  shuffled_0 = dataframe.index[dataframe['LABEL']==0].tolist()
  np.random.shuffle(shuffled_0)
  shuffled_1 = dataframe.index[dataframe['LABEL']==1].tolist()
  np.random.shuffle(shuffled_1)
  fold_size_0 = n_0 // nb_folds
  nb_remained_0 = n_0 % nb_folds
  start_ind_0 = 0
  fold_size_1 = n_1 // nb_folds
  nb_remained_1 = n_1 % nb_folds
  start_ind_1 = 0

  for i in range(nb_folds):


    if nb_remained_0 > 0:
      i_fold_size_0 = fold_size_0 + 1
      nb_remained_0 -= 1
    else:
      i_fold_size_0 = fold_size_0

    if nb_remained_1 > 0:
      i_fold_size_1 = fold_size_1 + 1
      nb_remained_1 -= 1
    else:
      i_fold_size_1 = fold_size_1

    end_ind_0 = start_ind_0 + i_fold_size_0
    end_ind_1 = start_ind_1 + i_fold_size_1
    train_0 = shuffled_0[start_ind_0: end_ind_0]
    train_1 = shuffled_1[start_ind_1: end_ind_1]

    train_ids = train_0 + train_1
    np.random.shuffle(train_ids)

    train_dataframe=dataframe.loc[train_ids]

    n_0_train = train_dataframe.groupby('LABEL').nunique()['ID'][0]
    n_1_train = train_dataframe.groupby('LABEL').nunique()['ID'][1]

    train_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                                   root_dir='trainset/',
                                   transform=transforms.Compose([
                                       TrainTransform(),
                                   ]),
                                   feat_file='train_df_features.csv',
                                   ids=train_ids)

    val_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                                 root_dir='trainset/',
                                 transform=transforms.Compose([
                                     TestTransform(),
                                 ]),
                                 feat_file='train_df_features.csv',
                                 ids=train_ids)

    train_datasets_list.append(train_dataset)
    val_datasets_list.append(val_dataset)
    n_list.append((n_0_train, n_1_train))

  return train_datasets_list, val_datasets_list, n_list

In [None]:
dataframe = pd.read_csv('/content/trainset/trainset_true.csv')
dataframe['GENDER'] = dataframe['GENDER'].replace(['f'],['F'])
nb_folds = 2

train_datasets_list, val_datasets_list, n_list = k_fold_unbalanced(dataframe, nb_folds)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def split_train_val_balanced(dataframe, p_val=0.2):
  """
  Split a dataset in a trainset and a valset which is balanced
  """
  n_0 = dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1 = dataframe.groupby('LABEL').nunique()['ID'][1]
  nb_val = int(p_val * (n_0 + n_1))
  # Pick ids depending on sizes and proportions
  shuffled_0 = dataframe.index[dataframe['LABEL']==0].tolist()
  np.random.shuffle(shuffled_0)
  shuffled_1 = dataframe.index[dataframe['LABEL']==1].tolist()
  np.random.shuffle(shuffled_1)


  val_0 = shuffled_0[: nb_val // 2]
  val_1 = shuffled_1[: nb_val // 2]
  val_ids = val_0 + val_1
  np.random.shuffle(val_ids)

  val_dataframe = dataframe.loc[val_ids]
  n_0_val = val_dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1_val = val_dataframe.groupby('LABEL').nunique()['ID'][1]

  train_0 = shuffled_0[nb_val // 2:]
  train_1 = shuffled_1[nb_val // 2:]
  train_ids = train_0 + train_1
  np.random.shuffle(train_ids)

  train_dataframe = dataframe.loc[train_ids]
  n_0_train = train_dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1_train = train_dataframe.groupby('LABEL').nunique()['ID'][1]

  train_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                                root_dir='trainset/',
                                transform=transforms.Compose([
                                    TrainTransform(),
                                ]),
                                feat_file='train_df_features.csv',
                                ids=train_ids)

  val_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                                root_dir='trainset/',
                                transform=transforms.Compose([
                                    TestTransform(),
                                ]),
                                feat_file='train_df_features.csv',
                                ids=val_ids)

  return train_dataset, val_dataset, n_0_train, n_1_train, n_0_val, n_1_val

In [None]:
def split_train_val(dataframe, p_val=0.5, seed=None):
  """
  Classic split train val set
  """
  if seed is not None:
    np.random.seed(seed)
  n_0 = dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1 = dataframe.groupby('LABEL').nunique()['ID'][1]
  nb_val = int(p_val * (n_0 + n_1))
  # Pick ids depending on sizes and proportions
  shuffled = dataframe.index.tolist()
  np.random.shuffle(shuffled)
  val_ids = shuffled[: nb_val]
  np.random.shuffle(val_ids)
  val_dataframe = dataframe.loc[val_ids]
  n_0_val = val_dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1_val = val_dataframe.groupby('LABEL').nunique()['ID'][1]

  train_ids = shuffled[nb_val:]
  np.random.shuffle(train_ids)
  train_dataframe = dataframe.loc[train_ids]
  n_0_train = train_dataframe.groupby('LABEL').nunique()['ID'][0]
  n_1_train = train_dataframe.groupby('LABEL').nunique()['ID'][1]

  train_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                                 root_dir='trainset/',
                                 transform=transforms.Compose([
                                     TrainTransform(),
                                 ]),
                                 feat_file='train_df_features.csv',
                                 ids=train_ids)

  val_dataset = PatientDataset(csv_file='trainset/trainset_true.csv',
                               root_dir='trainset/',
                               transform=transforms.Compose([
                                   TestTransform(),
                               ]),
                               feat_file='train_df_features.csv',
                               ids=val_ids)

  return train_dataset, val_dataset, n_0_train, n_1_train, n_0_val, n_1_val

# Models

In [None]:
from torchvision.models import resnet18, ResNet18_Weights

class ResNetOneTaskModel(nn.Module):
    def __init__(self, ):
      super(ResNetOneTaskModel, self).__init__()
      # Load the feature extractor: ResNet
      self.feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT)
      self.feature_extractor = torch.nn.Sequential(*list(self.feature_extractor.children())[:-1])
      self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      self.classifier_image = nn.Sequential(nn.Linear(512, 1),
                                            nn.Sigmoid())

      self.classifier_feats = nn.Sequential(nn.Linear(8+2, 64),
                                            nn.ReLU(),
                                            nn.Linear(64, 1),
                                            nn.Sigmoid())

      self.image_to_scalar = nn.Linear(512, 1)

      self.feat_to_scalar = nn.Linear(8, 1)

      self.gate = nn.Sequential(nn.Linear(2+1+1, 1),
                                nn.Sigmoid())

      self.encoder_feats = nn.Sequential(nn.Linear(51, 32),
                                         nn.ReLU(),
                                         nn.Linear(32, 16),
                                         nn.ReLU(),
                                         nn.Linear(16, 8))

    def forward(self, images, age, lymph_count, gender, features, return_everything=False):
      images_encoded = torch.stack([self.feature_extractor(image.to(self.device)).squeeze((2, 3)) for image in images], dim=0) # chaque image encodée est de shape (batch_size, 512)
      images_agg = torch.mean(images_encoded, dim=0)
      y_images = self.classifier_image(images_agg)
      feats_encoded = torch.stack([self.encoder_feats(features[:, i]) for i in range(features.shape[1])], dim=0)
      feats_agg = torch.mean(feats_encoded, dim=0)
      metadata = torch.cat([age, lymph_count], dim=0).unsqueeze(dim=0).float()
      all_feat = torch.cat([metadata, feats_agg], dim=1)
      y_feat = self.classifier_feats(all_feat)
      image_scalar = self.image_to_scalar(images_agg)
      feat_scalar = self.feat_to_scalar(feats_agg)
      pi_images = self.gate(torch.cat([metadata, image_scalar, feat_scalar], dim=1))
      pi_feat = 1 - pi_images
      y_hat = pi_images * y_images + pi_feat * y_feat

      if return_everything:
        return y_hat, y_images, y_feat, pi_feat, pi_images

      return y_hat

In [None]:
from torchvision.models import resnet18, ResNet18_Weights

class ResNetMultiTaskModel(nn.Module):
    def __init__(self, ):
      super(ResNetMultiTaskModel, self).__init__()

      self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

      # Load the feature extractor: ResNet
      self.feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT)
      self.feature_extractor = torch.nn.Sequential(*list(self.feature_extractor.children())[:-1])

      self.classifier_image = nn.Sequential(nn.Linear(512, 128),
                                            nn.ReLU(),
                                            nn.Linear(128, 1),
                                            nn.Sigmoid())

      self.classifier_data = nn.Sequential(nn.Linear(8+2, 64),
                                           nn.ReLU(),
                                           nn.Linear(64, 1),
                                           nn.Sigmoid())

      self.regressor_image = nn.Linear(512, 1)

      self.image_to_scalar = nn.Sequential(nn.Linear(512, 128),
                                           nn.ReLU(),
                                           nn.Linear(128, 1))

      self.data_to_scalar = nn.Linear(8, 1)

      self.gate = nn.Sequential(nn.Linear(2, 1),
                                nn.Sigmoid())

      self.encoder_data = nn.Sequential(nn.Linear(51, 32),
                                        nn.ReLU(),
                                        nn.Linear(32, 16),
                                        nn.ReLU(),
                                        nn.Linear(16, 8))

    def forward(self, images, age, lymph_count, gender, features, return_everything=False):
      # About the images
      images_encoded = torch.stack([self.feature_extractor(image.to(self.device)).squeeze((2, 3)) for image in images], dim=0)
      images_agg = torch.mean(images_encoded, dim=0)
      y_images = self.classifier_image(images_agg)
      reg_img = self.regressor_image(images_agg)
      image_scalar = self.image_to_scalar(images_agg)

      # About the scalar data
      data_encoded = torch.stack([self.encoder_data(features[:, i]) for i in range(features.shape[1])], dim=0)
      data_agg = torch.mean(data_encoded, dim=0)
      metadata = torch.cat([age, lymph_count], dim=0).unsqueeze(dim=0).float()
      all_feat = torch.cat([metadata, data_agg], dim=1)
      y_data = self.classifier_data(all_feat)
      data_scalar = self.data_to_scalar(data_agg)

      # Gate
      pi_images = self.gate(torch.cat([data_scalar, image_scalar], dim=1))
      pi_data = 1 - pi_images
      y_hat = pi_images * y_images + pi_data * y_data

      if return_everything:
        return y_hat, y_images, y_data, pi_data, pi_images

      return y_hat, reg_img

In [None]:
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights

class ConvNextMultiTaskModel(nn.Module):
    def __init__(self, ):
      super(ConvNextMultiTaskModel, self).__init__()
      # Load the feature extractor: ResNet
      self.feature_extractor = convnext_tiny(weights = ConvNeXt_Tiny_Weights.DEFAULT)
      self.feature_extractor = torch.nn.Sequential(*list(self.feature_extractor.children())[:-1])
      # Freeze feature extractor
      for param in self.feature_extractor.parameters():
            param.requires_grad = False
      self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      self.feature_common_embedding = nn.Linear(768,128)
      self.classifier_image = nn.Sequential(nn.Linear(128, 1),
                                            nn.Sigmoid())

      self.regressor_image = nn.Sequential(nn.Linear(128, 1))

      self.classifier_feats = nn.Sequential(nn.Linear(8+2, 64),
                                            nn.ReLU(),
                                            nn.Linear(64, 1),
                                            nn.Sigmoid())

      self.image_to_scalar = nn.Linear(128, 1)

      self.feat_to_scalar = nn.Linear(8, 1)

      self.gate = nn.Sequential(nn.Linear(2+1+1, 1),
                                nn.Sigmoid())

      self.encoder_feats = nn.Sequential(nn.Linear(51, 32),
                                         nn.ReLU(),
                                         nn.Linear(32, 16),
                                         nn.ReLU(),
                                         nn.Linear(16, 8))

    def forward(self, images, age, lymph_count, gender, features):
      images_encoded = torch.stack([self.feature_common_embedding((self.feature_extractor(image.to(self.device))).squeeze((2, 3))) for image in images], dim=0) # chaque image encodée est de shape (batch_size, 512)
      images_agg = torch.mean(images_encoded, dim=0)
      y_images = self.classifier_image(images_agg)
      reg_img = self.regressor_image(images_agg)
      feats_encoded = torch.stack([self.encoder_feats(features[:, i]) for i in range(features.shape[1])], dim=0)
      feats_agg = torch.mean(feats_encoded, dim=0)
      metadata = torch.cat([age, lymph_count], dim=0).unsqueeze(dim=0).float()
      all_feat = torch.cat([metadata, feats_agg], dim=1)
      y_feat = self.classifier_feats(all_feat)
      image_scalar = self.image_to_scalar(images_agg)
      feat_scalar = self.feat_to_scalar(feats_agg)
      pi_images = self.gate(torch.cat([metadata, image_scalar, feat_scalar], dim=1))
      pi_feat = 1 - pi_images
      y_hat = pi_images * y_images + pi_feat * y_feat

      return y_hat, reg_img

In [None]:
!pip install timm
import timm

class LoRA(nn.Module):
    def __init__(self, linear_layer, in_dim, rank=32, alpha=16):
        super(LoRA, self).__init__()
        ##### START CODE #####
        self.linear_layer = linear_layer
        std = 1 / torch.sqrt(torch.tensor(rank).float())
        self.adapter_Q_downsample = nn.Parameter((torch.randn((in_dim,rank))*std).to(device))
        self.adapter_Q_upsample = nn.Parameter(torch.zeros((rank,in_dim)).to(device))
        self.adapter_V_downsample = nn.Parameter((torch.randn((in_dim,rank))*std).to(device))
        self.adapter_V_upsample = nn.Parameter(torch.zeros((rank,in_dim)).to(device))
        self.adapter_alpha = alpha
        ##### END CODE #####

    def forward(self, x):
        ##### START CODE #####
        x_q = self.adapter_alpha * (x @ self.adapter_Q_downsample @ self.adapter_Q_upsample)
        x_v = self.adapter_alpha * (x @ self.adapter_V_downsample @ self.adapter_V_upsample)
        x_lora = torch.cat([x_q, torch.zeros_like(x_v), x_v], dim=-1)
        x = self.linear_layer(x) + x_lora
        return x
        ##### END CODE #####

def add_lora(model):
    ##### START CODE #####
    for block in model.blocks:
        block.attn.qkv = LoRA(block.attn.qkv, block.attn.qkv.in_features).to(device)

def freeze_model_lora(model):
    ##### START CODE #####
    for name, param in model.named_parameters():
        if not('adapter' in name):
            param.requires_grad = False

    ##### END CODE #####


class VitMultiTaskLoraModel(nn.Module):
    def __init__(self, ):
      super(VitMultiTaskLoraModel, self).__init__()
      # Load the feature extractor: ConvNext
      self.feature_extractor = timm.create_model('vit_base_patch16_224', pretrained=True).to(device)
      # add LoRa to feature extractor
      add_lora(self.feature_extractor)
      freeze_model_lora(self.feature_extractor)
      self.feature_extractor.head = nn.Identity()
      self.classifier_image = nn.Sequential(nn.Linear(768, 1),
                                            nn.Sigmoid())

      self.regressor_image = nn.Sequential(nn.Linear(768, 1))

      self.classifier_feats = nn.Sequential(nn.Linear(8+2, 64),
                                            nn.ReLU(),
                                            nn.Linear(64, 1),
                                            nn.Sigmoid())

      self.image_to_scalar = nn.Linear(768, 1)

      self.feat_to_scalar = nn.Linear(8, 1)

      self.gate = nn.Sequential(nn.Linear(2+1+1, 1),
                                nn.Sigmoid())

      self.encoder_feats = nn.Sequential(nn.Linear(51, 32),
                                         nn.ReLU(),
                                         nn.Linear(32, 16),
                                         nn.ReLU(),
                                         nn.Linear(16, 8))

    def forward(self, images, age, lymph_count, gender, features):
      images_encoded = torch.stack([(self.feature_extractor(image.to(device))) for image in images], dim=0) # chaque image encodée est de shape (batch_size, 512)
      images_agg = torch.mean(images_encoded, dim=0)
      y_images = self.classifier_image(images_agg)
      reg_img = self.regressor_image(images_agg)
      feats_encoded = torch.stack([self.encoder_feats(features[:, i]) for i in range(features.shape[1])], dim=0)
      feats_agg = torch.mean(feats_encoded, dim=0)
      metadata = torch.cat([age, lymph_count], dim=0).unsqueeze(dim=0).float()
      all_feat = torch.cat([metadata, feats_agg], dim=1)
      y_feat = self.classifier_feats(all_feat)
      image_scalar = self.image_to_scalar(images_agg)
      feat_scalar = self.feat_to_scalar(feats_agg)
      pi_images = self.gate(torch.cat([metadata, image_scalar, feat_scalar], dim=1))
      pi_feat = 1 - pi_images
      y_hat = pi_images * y_images + pi_feat * y_feat

      return y_hat, reg_img

In [None]:
def weighted_binary_cross_entropy(output, target, weights=None):

  output = torch.clamp(output, 1e-7, 1-1e-7)

  if weights is not None:
    loss = weights[1] * (target * torch.log(output)) + weights[0] * ((1 - target) * torch.log(1 - output))

  else:
    loss = target * torch.log(output) + (1 - target) * torch.log(1 - output)

  return torch.neg(torch.mean(loss))

In [None]:
def nmse_loss(output, target):
    numerator = torch.sum((target - output) ** 2)
    denominator = torch.sum(target ** 2)
    return numerator / denominator

# Training

## Cross validation

In [None]:
dataframe = pd.read_csv('/content/trainset/trainset_true.csv')
dataframe['GENDER'] = dataframe['GENDER'].replace(['f'],['F'])
nb_folds = 5

train_datasets_list, val_datasets_list, n_list = k_fold_unbalanced(dataframe, nb_folds)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def cv_multitask(train_datasets_list, val_datasets_list, n_list, name_model, device, batch_size=1, lr=1e-3, max_epochs=100, patience_early_stopping=10, patience_lr=5, lr_decay=0.1, score=0):

  nb_folds = len(train_datasets_list)

  for i in range(nb_folds):
    print('----------------------------------------------------------')
    print(f'Fold number {i+1}')
    print('----------------------------------------------------------')
    if name_model == "ConvNextMultiTaskModel":
      model = ConvNextMultiTaskModel()
    elif name_model == "ResNetMultiTaskModel":
      model = ResNetMultiTaskModel()
    elif name_model == "VitMultiTaskLoraModel":
      model = VitMultiTaskLoraModel()
    model.to(device)
    valset = val_datasets_list[i]
    valloader = DataLoader(valset, shuffle=True)
    n_0_val, n_1_val = n_list[i]
    trainset = torch.utils.data.ConcatDataset([train_datasets_list[j] for j in range(nb_folds) if j !=i])
    trainloader = DataLoader(trainset, shuffle=True)
    n_0_train, n_1_train = np.array([n_list[j] for j in range(nb_folds) if j!=i]).sum(axis=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    n_train = n_0_train + n_1_train
    weights_train = torch.Tensor([n_train/n_0_train, n_train/n_1_train])
    n_val = n_0_val + n_1_val
    weights_val = torch.Tensor([n_val/n_0_val, n_val/n_1_val])

    best_val_loss = 1e8
    n_wt_improvement_lr = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score depuis le dernier changement de lr
    n_wt_improvement_stop = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score
    best_dict = deepcopy(model.state_dict())
    best_b_acc = 0

    for epoch in range(max_epochs):
      print(f'Epoch: {epoch+1}')

      model.train()
      train_loss = 0
      train_class_loss = 0
      train_reg_loss = 0
      train_acc = 0
      train_b_acc = 0
      loss = 0
      for i, data in enumerate(tqdm(trainloader, desc="Train")):
        images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
        y_hat, reg_hat = model(images, age, lymph_count, gender, features)
        y_hat, reg_hat = y_hat.to(device), reg_hat.to(device)
        y_true = data['label'].float().unsqueeze(0).to(device)
        loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_train)
        loss_reg = nmse_loss(reg_hat, lymph_count)
        loss += loss_classification + loss_reg
        train_loss += loss_classification.item() + loss_reg.item()
        train_class_loss += loss_classification.item()
        train_reg_loss += loss_reg.item()
        if (i+1) % batch_size ==0:
          loss /= batch_size
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          loss = 0
        elif (i+1) == len(trainloader):
          loss /= (len(trainloader) % batch_size)
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          loss = 0
        pred = torch.round(y_hat)
        train_b_acc += 0.5*(1/n_0_train)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_train)*(pred.int() == y_true.int()).sum().item()
        train_acc += (pred.int() == y_true.int()).sum().item()

      train_acc /= len(trainloader)
      train_loss /= len(trainloader)
      train_class_loss /= len(trainloader)
      train_reg_loss /= len(trainloader)
      train_b_acc = np.round(train_b_acc, 3)
      train_acc = np.round(train_acc, 3)
      train_loss = np.round(train_loss, 3)
      train_class_loss = np.round(train_class_loss, 3)
      train_reg_loss = np.round(train_reg_loss, 3)
      print(f'Train balanced accuracy: {train_b_acc}, Train accuracy: {train_acc}, Train loss: {train_loss}, Train classification loss: {train_class_loss}, Train regression loss: {train_reg_loss}')

      model.eval()
      with torch.no_grad():
        val_loss =  0
        val_acc = 0
        val_b_acc = 0
        val_reg_loss = 0
        val_class_loss = 0
        for data in tqdm(valloader, desc="Validation"):
          images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
          y_hat, reg_hat = model(images, age, lymph_count, gender, features)
          y_hat = y_hat.to(device)
          reg_hat = reg_hat.to(device)
          y_true = data['label'].float().to(device)
          loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_val)
          loss_reg = nmse_loss(reg_hat, lymph_count)
          loss = loss_classification + loss_reg
          val_loss += loss.item()
          val_reg_loss += loss_reg.item()
          val_class_loss += loss_classification.item()
          pred = torch.round(y_hat)
          val_b_acc += 0.5*(1/n_0_val)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_val)*(pred.int() == y_true.int()).sum().item()
          val_acc += (pred.int() == y_true.int()).sum().item()

        val_acc /= len(valloader)
        val_loss /= len(valloader)
        val_reg_loss /= len(valloader)
        val_class_loss /= len(valloader)
        val_b_acc = np.round(val_b_acc, 3)
        val_acc = np.round(val_acc, 3)
        val_loss = np.round(val_loss, 3)
        val_class_loss = np.round(val_class_loss, 3)
        val_reg_loss = np.round(val_reg_loss, 3)
        print(f'Val balanced accuracy: {val_b_acc}, Val accuracy: {val_acc}, Val loss: {val_loss}, Val classification loss: {val_class_loss}, Val regression loss: {val_reg_loss}')

      if val_loss < best_val_loss:
        n_wt_improvement_lr = 0
        n_wt_improvement_stop = 0
        best_val_loss = val_loss
        best_b_acc = val_b_acc
        best_dict = deepcopy(model.state_dict())

      else:
        n_wt_improvement_lr += 1
        n_wt_improvement_stop += 1
        if n_wt_improvement_stop == patience_early_stopping:
          break
        if n_wt_improvement_lr == patience_lr:
          n_wt_improvement_lr = 0
          for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
          print(f"Change of learning rate. Now lr = {optimizer.param_groups[0]['lr']}")

      print('-------------------------------------------------------------')

    score += best_b_acc

  print(f'The mean score obtained with cross-validation is {score/nb_folds}')
  return score/nb_folds

In [None]:
cv_score = cv_multitask(train_datasets_list, val_datasets_list, n_list, "ResNetMultiTaskModel", device)

In [None]:
def cv_onetask(train_datasets_list, val_datasets_list, n_list, name_model, device, batch_size=1, lr=1e-3, max_epochs=100, patience_early_stopping=10, patience_lr=5, lr_decay=0.1, score=0):

  nb_folds = len(train_datasets_list)

  for i in range(nb_folds):
    print('----------------------------------------------------------')
    print(f'Fold number {i+1}')
    print('----------------------------------------------------------')
    if name_model == "ResNetOneTaskModel":
      model = ResNetOneTaskModel()
    model.to(device)
    valset = val_datasets_list[i]
    valloader = DataLoader(valset, shuffle=True)
    n_0_val, n_1_val = n_list[i]
    trainset = torch.utils.data.ConcatDataset([train_datasets_list[j] for j in range(nb_folds) if j !=i])
    trainloader = DataLoader(trainset, shuffle=True)
    n_0_train, n_1_train = np.array([n_list[j] for j in range(nb_folds) if j!=i]).sum(axis=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    n_train = n_0_train + n_1_train
    weights_train = torch.Tensor([n_train/n_0_train, n_train/n_1_train])
    n_val = n_0_val + n_1_val
    weights_val = torch.Tensor([n_val/n_0_val, n_val/n_1_val])

    best_val_loss = 1e8
    n_wt_improvement_lr = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score depuis le dernier changement de lr
    n_wt_improvement_stop = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score
    best_dict = deepcopy(model.state_dict())
    best_b_acc = 0

    for epoch in range(max_epochs):
      print(f'Epoch: {epoch+1}')

      model.train()
      train_loss =  0
      train_acc = 0
      train_b_acc = 0
      loss = 0
      for i, data in enumerate(tqdm(trainloader, desc="Train")):
        images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
        y_hat = model(images, age, lymph_count, gender, features).to(device)
        y_true = data['label'].float().unsqueeze(0).to(device)
        loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_train)
        loss += loss_classification
        train_loss += loss_classification.item()
        if (i+1) % batch_size ==0:
          loss /= batch_size
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          loss = 0
        elif (i+1) == len(trainloader):
          loss /= (len(trainloader) % batch_size)
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()
          loss = 0
        pred = torch.round(y_hat)
        train_b_acc += 0.5*(1/n_0_train)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_train)*(pred.int() == y_true.int()).sum().item()
        train_acc += (pred.int() == y_true.int()).sum().item()

      train_acc /= len(trainloader)
      train_loss /= len(trainloader)
      train_b_acc = np.round(train_b_acc, 3)
      train_acc = np.round(train_acc, 3)
      train_loss = np.round(train_loss, 3)
      print(f'Train balanced accuracy: {train_b_acc}, Train accuracy: {train_acc}, Train loss: {train_loss}')

      model.eval()
      with torch.no_grad():
        val_loss =  0
        val_acc = 0
        val_b_acc = 0
        for data in tqdm(valloader, desc="Validation"):
          images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
          y_hat = model(images, age, lymph_count, gender, features)
          y_hat = y_hat.to(device)
          y_true = data['label'].float().to(device)
          loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_val)
          loss = loss_classification
          val_loss += loss.item()
          pred = torch.round(y_hat)
          val_b_acc += 0.5*(1/n_0_val)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_val)*(pred.int() == y_true.int()).sum().item()
          val_acc += (pred.int() == y_true.int()).sum().item()

        val_acc /= len(valloader)
        val_loss /= len(valloader)
        val_b_acc = np.round(val_b_acc, 3)
        val_acc = np.round(val_acc, 3)
        val_loss = np.round(val_loss, 3)
        print(f'Val balanced accuracy: {val_b_acc}, Val accuracy: {val_acc}, Val loss: {val_loss}')

      if val_loss < best_val_loss:
        n_wt_improvement_lr = 0
        n_wt_improvement_stop = 0
        best_val_loss = val_loss
        best_b_acc = val_b_acc
        best_dict = deepcopy(model.state_dict())

      else:
        n_wt_improvement_lr += 1
        n_wt_improvement_stop += 1
        if n_wt_improvement_stop == patience_early_stopping:
          break
        if n_wt_improvement_lr == patience_lr:
          n_wt_improvement_lr = 0
          for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
          print(f"Change of learning rate. Now lr = {optimizer.param_groups[0]['lr']}")

      print('-------------------------------------------------------------')

    score += best_b_acc

  print(f'The mean score obtained with cross-validation is {score/nb_folds}')
  return score/nb_folds

In [None]:
cv_score = cv_onetask(train_datasets_list, val_datasets_list, n_list, "ResNetOneTaskModel", device)

## Train classicaly

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

i = 1

trainset = torch.utils.data.ConcatDataset([train_datasets_list[j] for j in range(nb_folds) if j !=i])
trainloader = DataLoader(trainset, shuffle=True)

n_0_val, n_1_val = n_list[i]
n_val = n_0_val + n_1_val
weights_val = torch.Tensor([n_val/n_0_val, n_val/n_1_val])

valset = val_datasets_list[i]
valloader = DataLoader(valset)

n_0_train, n_1_train = np.array([n_list[j] for j in range(nb_folds) if j!=i]).sum(axis=0)
n_train = n_0_train + n_1_train
weights_train = torch.Tensor([n_train/n_0_train, n_train/n_1_train])

In [None]:
from tqdm import tqdm
from copy import deepcopy
from google.colab import files

def train_onetask(model, trainloader, valloader, device, n_0_train, n_1_train, n_0_val, n_1_val, batch_size=1, lr=1e-3, lr_decay=0.1, max_epochs=25, patience_lr=5, patience_early_stopping=10, wandb_log=False, wandb_name=None, wandb_config=None):

  model.to(device)
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

  best_val_loss = 1e8
  n_wt_improvement_lr = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score depuis le dernier changement de lr
  n_wt_improvement_stop = 0 # Nombre d'epoch consécutives sans dépasser le meilleur score
  best_dict = deepcopy(model.state_dict())

  n_train = n_0_train + n_1_train
  weights_train = torch.Tensor([n_train/n_0_train, n_train/n_1_train])

  n_val = n_0_val + n_1_val
  weights_val = torch.Tensor([n_val/n_0_val, n_val/n_1_val])

  if wandb_log:
    wandb.init(
        project="DLMI-challenge",
        name=wandb_name,
        config=wandb_config,
        entity='fous-du-wan'
    )

  for epoch in range(max_epochs):
    print(f'Epoch: {epoch+1}')

    model.train()
    train_loss =  0
    train_acc = 0
    train_b_acc = 0
    loss = 0
    for i, data in enumerate(tqdm(trainloader, desc="Train")):
      images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
      y_hat = model(images, age, lymph_count, gender, features).to(device)
      y_true = data['label'].float().unsqueeze(0).to(device)
      loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_train)
      loss += loss_classification
      train_loss += loss_classification.item()
      if (i+1) % batch_size ==0:
        loss /= batch_size
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = 0
      elif (i+1) == len(trainloader):
        loss /= (len(trainloader) % batch_size)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = 0
      pred = torch.round(y_hat)
      train_b_acc += 0.5*(1/n_0_train)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_train)*(pred.int() == y_true.int()).sum().item()
      train_acc += (pred.int() == y_true.int()).sum().item()

    train_acc /= len(trainloader)
    train_loss /= len(trainloader)
    train_b_acc = np.round(train_b_acc, 3)
    train_acc = np.round(train_acc, 3)
    train_loss = np.round(train_loss, 3)
    print(f'Train balanced accuracy: {train_b_acc}, Train accuracy: {train_acc}, Train loss: {train_loss}')

    model.eval()
    with torch.no_grad():
      val_loss =  0
      val_acc = 0
      val_b_acc = 0
      for data in tqdm(valloader, desc="Validation"):
        images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
        y_hat = model(images, age, lymph_count, gender, features)
        y_hat = y_hat.to(device)
        y_true = data['label'].float().to(device)
        loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_val)
        loss = loss_classification
        val_loss += loss.item()
        pred = torch.round(y_hat)
        val_b_acc += 0.5*(1/n_0_val)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_val)*(pred.int() == y_true.int()).sum().item()
        val_acc += (pred.int() == y_true.int()).sum().item()

      val_acc /= len(valloader)
      val_loss /= len(valloader)
      val_b_acc = np.round(val_b_acc, 3)
      val_acc = np.round(val_acc, 3)
      val_loss = np.round(val_loss, 3)
      print(f'Val balanced accuracy: {val_b_acc}, Val accuracy: {val_acc}, Val loss: {val_loss}')

    if wandb_log:
      wandb.log({"val_loss": val_loss, "val_balanced_acc": val_b_acc, "train_loss": train_loss, 'train_balanced_acc': train_b_acc})

    if val_loss < best_val_loss:
      n_wt_improvement_lr = 0
      n_wt_improvement_stop = 0
      best_val_loss = val_loss
      best_dict = deepcopy(model.state_dict())
      torch.save(best_dict, 'best_model.pth')

    else:
      n_wt_improvement_lr += 1
      n_wt_improvement_stop += 1
      if n_wt_improvement_stop == patience_early_stopping:
        break
      if n_wt_improvement_lr == patience_lr:
        n_wt_improvement_lr = 0
        for param_group in optimizer.param_groups:
          param_group['lr'] *= lr_decay
        print(f"Change of learning rate. Now lr = {optimizer.param_groups[0]['lr']}")

    print('-------------------------------------------------------------')

  if wandb_log:
    wandb.finish()

  files.download('/content/best_model.pth')

  return best_dict

In [None]:
model = ResNetOneTaskModel()
model.to(device)

best_dict = train_onetask(model, trainloader, valloader, device, n_0_train, n_1_train, n_0_val, n_1_val)

In [None]:
def train_multitask(model, trainloader, valloader, device, n_0_train, n_1_train, n_0_val, n_1_val, batch_size=1, max_epochs=25, patience_lr=5, patience_early_stopping=10, lr=1e-3, lr_decay=0.1, wandb_log=False, wandb_name=None, wandb_config=None):

  model.to(device)
  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

  best_val_loss = 1e8
  n_wt_improvement_lr = 0
  n_wt_improvement_stop = 0
  best_dict = deepcopy(model.state_dict())

  n_train = n_0_train + n_1_train
  weights_train = torch.Tensor([n_train/n_0_train, n_train/n_1_train])

  n_val = n_0_val + n_1_val
  weights_val = torch.Tensor([n_val/n_0_val, n_val/n_1_val])

  if wandb_log:
    wandb.init(
        project="DLMI-challenge",
        name=wandb_name,
        config=wandb_config,
        entity='fous-du-wan'
    )

  for epoch in range(max_epochs):
    print(f'Epoch: {epoch+1}')

    model.train()
    train_loss = 0
    train_class_loss = 0
    train_reg_loss = 0
    train_acc = 0
    train_b_acc = 0
    loss = 0
    for i, data in enumerate(tqdm(trainloader, desc="Train")):
      images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
      y_hat, reg_hat = model(images, age, lymph_count, gender, features)
      y_hat, reg_hat = y_hat.to(device), reg_hat.to(device)
      y_true = data['label'].float().unsqueeze(0).to(device)
      loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_train)
      loss_reg = nmse_loss(reg_hat, lymph_count)
      loss += loss_classification + loss_reg
      train_loss += loss_classification.item() + loss_reg.item()
      train_class_loss += loss_classification.item()
      train_reg_loss += loss_reg.item()
      if (i+1) % batch_size ==0:
        loss /= batch_size
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = 0
      elif (i+1) == len(trainloader):
        loss /= (len(trainloader) % batch_size)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss = 0
      pred = torch.round(y_hat)
      train_b_acc += 0.5*(1/n_0_train)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_train)*(pred.int() == y_true.int()).sum().item()
      train_acc += (pred.int() == y_true.int()).sum().item()

    train_acc /= len(trainloader)
    train_loss /= len(trainloader)
    train_class_loss /= len(trainloader)
    train_reg_loss /= len(trainloader)
    train_b_acc = np.round(train_b_acc, 3)
    train_acc = np.round(train_acc, 3)
    train_loss = np.round(train_loss, 3)
    train_class_loss = np.round(train_class_loss, 3)
    train_reg_loss = np.round(train_reg_loss, 3)
    print(f'Train balanced accuracy: {train_b_acc}, Train accuracy: {train_acc}, Train loss: {train_loss}, Train classification loss: {train_class_loss}, Train regression loss: {train_reg_loss}')

    model.eval()
    with torch.no_grad():
      val_loss =  0
      val_acc = 0
      val_b_acc = 0
      val_reg_loss = 0
      val_class_loss = 0
      for data in tqdm(valloader, desc="Validation"):
        images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
        y_hat, reg_hat = model(images, age, lymph_count, gender, features)
        y_hat = y_hat.to(device)
        reg_hat = reg_hat.to(device)
        y_true = data['label'].float().to(device)
        loss_classification = weighted_binary_cross_entropy(y_hat, y_true, weights=weights_val)
        loss_reg = nmse_loss(reg_hat, lymph_count)
        loss = loss_classification + loss_reg
        val_loss += loss.item()
        val_reg_loss += loss_reg.item()
        val_class_loss += loss_classification.item()
        pred = torch.round(y_hat)
        val_b_acc += 0.5*(1/n_0_val)*(pred.int() == y_true.int()).sum().item() if y_true.int() == 0 else 0.5*(1/n_1_val)*(pred.int() == y_true.int()).sum().item()
        val_acc += (pred.int() == y_true.int()).sum().item()

      val_acc /= len(valloader)
      val_loss /= len(valloader)
      val_reg_loss /= len(valloader)
      val_class_loss /= len(valloader)
      val_b_acc = np.round(val_b_acc, 3)
      val_acc = np.round(val_acc, 3)
      val_loss = np.round(val_loss, 3)
      val_class_loss = np.round(val_class_loss, 3)
      val_reg_loss = np.round(val_reg_loss, 3)
      print(f'Val balanced accuracy: {val_b_acc}, Val accuracy: {val_acc}, Val loss: {val_loss}, Val classification loss: {val_class_loss}, Val regression loss: {val_reg_loss}')

    if wandb_log:
      wandb.log({"val_loss": val_loss, "val_balanced_acc": val_b_acc, "val_class_loss": val_class_loss, "val_reg_loss": val_reg_loss, "train_loss": train_loss, 'train_balanced_acc': train_b_acc, "train_class_loss": train_class_loss, "train_reg_loss": train_reg_loss})

    if val_loss < best_val_loss:
      n_wt_improvement_lr = 0
      n_wt_improvement_stop = 0
      best_val_loss = val_loss
      best_dict = deepcopy(model.state_dict())
      torch.save(best_dict, 'best_model.pth')

    else:
      n_wt_improvement_lr += 1
      n_wt_improvement_stop += 1
      if n_wt_improvement_stop == patience_early_stopping:
        break
      if n_wt_improvement_lr == patience_lr:
        n_wt_improvement_lr = 0
        for param_group in optimizer.param_groups:
          param_group['lr'] *= lr_decay
        print(f"Change of learning rate. Now lr = {optimizer.param_groups[0]['lr']}")

    print('-------------------------------------------------------------')

  if wandb_log:
    wandb.finish()

  files.download('/content/best_model.pth')

  return best_dict

In [None]:
model = ConvNextMultiTaskModel()
model.to(device)

best_dict = train_multitask(model, trainloader, valloader, device, n_0_train, n_1_train, n_0_val, n_1_val)

# Submission

In [None]:
model.eval()
test_cases = []
for data in tqdm(testloader):
    images, age, lymph_count, gender, features = data['images'], data['age'].to(device), data['lymph_count'].to(device), data['gender'].to(device), data['features'].to(device)
    y_hat, reg_hat = model(images, age, lymph_count, gender, features)
    y_hat = y_hat.to(device)
    pred = torch.round(y_hat)
    test_cases.append((data['id'][0], pred.int().item()))

import pandas as pd

df_submission = pd.DataFrame(test_cases, columns=['Id', 'Predicted'])

df_submission.to_csv('test_transform.csv', index=False)