# Kaggle Challenge DLMI

This notebook is our model for the data challenge of out of distribution classification of histopathology patches.

In [None]:
import h5py
import torch
import random
import numpy as np
import pandas as pd
import torchmetrics
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torch.nn as nn
from peft import LoraConfig, get_peft_model
import torchvision.models as models
import os

In [None]:
# Data files: training, validation and test datasest
TRAIN_IMAGES_PATH = 'train.h5'
VAL_IMAGES_PATH = 'val.h5'
TEST_IMAGES_PATH = 'test.h5'

In [None]:
# We worked mostly on cpu (local machine) but much faster on gpu (under Kaggle gpu credits)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Working on {device}.')

Working on cpu.


To ensure full **reproducibility** of the experiments, we set the random seed across all libraries involved in randomness: `numpy`, `random`, `torch` and the Python environment. We also explicitly configure PyTorch for deterministic behavior and use a seeded generator for `DataLoader` operations.

In [None]:
SEED = 0 # seed for reproducibility

# set the seed for all libraries
torch.random.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# additional seeds for PyTorch and CUDA since the final training was done on Kaggle GPU (multi GPU compatible)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True # enforce deterministic behavior in PyTorch
torch.backends.cudnn.benchmark = False # disable benchmark optimizations that are not deterministic

# set Python hash seed 
os.environ['PYTHONHASHSEED'] = str(SEED)

# create a torch.Generator for deterministic DataLoader shuffling
g = torch.Generator()
g.manual_seed(SEED)

<torch._C.Generator at 0x78c47bc5b070>

## 1. Introduction to the data
As explained in the `getting_started` baseline notebook, the dataset consists of patches of whole slide images which should be classified into either containing tumor or not. The training images come from 3 different centers (i.e. hospitals), while the validation set comes from another center and the test set from yet another center. The visual aspect of the patches are quite different due to the slightly different staining procedures, conditions, and equipment from each hospital. 

The data is stored in `.h5` files, which can be seen as a folder hierarchy, which are can be seen as the following.
```
├── idx           # index of the image
│   └── img       # image in a tensor format
│   └── label     # binary label of the image
│   └── metadata  # some metadata on the images
```

The following is a visualization of how different the images look from the different centers.

In [None]:
# initialize training and validation dictionaries to hold one sample per (center, label) combination
train_images = {0: {0: None, 1: None},
                3: {0: None, 1: None},
                4: {0: None, 1: None}}
val_images = {1: {0: None, 1: None}}

In [None]:
# Same code as in the getting_started notebook to load the data
# we loop over training and validation paths to extract the samples
for img_data, data_path in zip([train_images, val_images], [TRAIN_IMAGES_PATH, VAL_IMAGES_PATH]):
    with h5py.File(data_path, 'r') as hdf:
        for img_idx in list(hdf.keys()):
            label = int(np.array(hdf.get(img_idx).get('label'))) # read label 0 or 1
            center = int(np.array(hdf.get(img_idx).get('metadata'))[0]) # read center

            # if this (center, label) slot is not filled, store the image
            if img_data[center][label] is None:
                img_data[center][label] = np.array(hdf.get(img_idx).get('img'))

            # break early if all required combinations are filled
            if all(all(value is not None for value in inner_dict.values()) for inner_dict in img_data.values()):
                break

# merge training and validation dictionaries
all_data = {**train_images, **val_images}

In [None]:
# visualizes one image per (center, label) combination
fig, axs = plt.subplots(2, 4, figsize=(20, 10))
center_ids = {center: idx for idx, center in enumerate(all_data.keys())}
for center in all_data:
    for label in all_data[center]:
        axs[label, center_ids[center]].imshow(np.moveaxis(all_data[center][label], 0, -1).astype(np.float32))
        axs[label, center_ids[center]].axis('off')
        if label == 0:
            axs[label, center_ids[center]].set_title(f'Center {center}')
plt.show()

## 2. Baseline DINOv2 model and LoRA fine tuning

In [14]:
BATCH_SIZE = 16

This cell defines two custom PyTorch datasets:
We use the two custom datasets definition function given in the `getting_started` notebook to use DINOv2.

- `BaselineDataset` loads raw image data from HDF5 files for training or for inference
- `PrecomputedDataset` loads pre-extracted features (e.g. from DINOv2) and their corresponding labels for training a small classifier (e.g. linear probing).

These dataset classes will be used in feature extraction and classifier training.

In [None]:
# dataset class for loading raw images and labels from HDF5 files
class BaselineDataset(Dataset):
    def __init__(self, dataset_path, preprocessing, mode):
        super(BaselineDataset, self).__init__()
        self.dataset_path = dataset_path # path to HDF5 file
        self.preprocessing = preprocessing # transformations to apply to the image (resizing in our case)
        self.mode = mode # 'train' or other ('test' for inferance for instance)

        # load image ID from the HDF5 file
        with h5py.File(self.dataset_path, 'r') as hdf:        
            self.image_ids = list(hdf.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        with h5py.File(self.dataset_path, 'r') as hdf:
            img = torch.tensor(hdf.get(img_id).get('img')).float() # load the image
            label = np.array(hdf.get(img_id).get('label')) if self.mode == 'train' else None # label the image
        return self.preprocessing(img).float(), label

In [None]:
# dataset class for loading precomputed features and corresponding labels
class PrecomputedDataset(Dataset):
    def __init__(self, features, labels):
        super(PrecomputedDataset, self).__init__()
        self.features = features # precomputed image features
        self.labels = labels.unsqueeze(-1) # add dimension for binary classification

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx].float()

Then we load the pretrained feature extractor from DINOv2. DINOv2 is a self supervised learning model for visual representation learning. It uses vision transformers (ViTs) to learn representative features without requiring a large amount of labeled data. Bellow we lead the small model DINOv2 ViT-S/14 since it will be easier to fine tune. Also DINOv2 is known to be transferable and performant for downstream tasks such as image classification, which is what we aim to do [[1]](https://arxiv.org/abs/2304.07193).

In [11]:
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
feature_extractor.eval()

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


DinoVisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
    (norm): Identity()
  )
  (blocks): ModuleList(
    (0-11): 12 x NestedTensorBlock(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): MemEffAttention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
      (ls2): LayerScale()
      (drop_path2): Identity()
    )
  )
  (n

Bellow we are fine tuning DINOv2 using **Low Rank Adpatation** (**LoRA**) and attach a small classifier head for binary classification.

LoRA is a parameter-efficient fine-tuning method that inserts low rank adapters into specific layers of a large model. Instead of updating all weights of the DINOv2 backbone, we only train a small number of parameters, which reduces memory and computational costs. It is indeed relevant in our case where we worked mostly on CPU and with limited access to GPU.

We apply LoRA to the query/key/value (QKV) attention blocks of all 12 transformer layers in the ViT-S. A simple two-layer feedforward classifier (with 64 hidden units and a ReLU activation) is stacked on top of the frozen DINOv2 features for binary classification.

We use the AdamW optimizer and the binary cross-entropy loss (BCE) since our task is binary classification (cancer detection or not).

In [None]:
# Apply LoRA to specific attention layers in the DINOv2 transformer
lora_config = LoraConfig(
    r=4, # rank of the low rank decomposition. Chosen among [2, 4, 8, 16] to give the best test accuracy
    lora_alpha=16, # scaling factor of the adapter weights. Chosen among [8, 16, 32]
    lora_dropout=0.05,
    bias="none",
    target_modules=[f"blocks.{i}.attn.qkv" for i in range(12)] # apply LoRA to all attention QKV layers
)

# merge the DINOv2 model with LoRA so that only the LoRA parameters will be trained
feature_extractor = get_peft_model(feature_extractor, lora_config)
feature_extractor.print_trainable_parameters() # print the number of trainable parameters

# set the model to training mode (only LoRA params will be trained)
feature_extractor.train()

# binary classifier on top of the DINOv2 features (CLS token of dim 384)
classifier = nn.Sequential(
    nn.Linear(384, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid() # convert logits to probabilities (for BCE loss)
).to(device)

# AdamW optimizer
optimizer = torch.optim.AdamW(
    list(feature_extractor.parameters()) + list(classifier.parameters()),
    lr=1e-4
)

# Binary Cross Entropy loss
criterion = nn.BCELoss()

trainable params: 73,728 || all params: 22,130,304 || trainable%: 0.3332


Then, we load the training and validation datasets from HDF5 files using the `BaselineDataset` class. Images are resized to 98×98, which matches the resolution used by DINOv2. Then, we prepare a `DataLoader` for each set.

Due to the high computational cost of fine tuning large models (even with LoRA), we perform LoRA training on a $10$% random subset of the training data. This allows us to speed up training (and actually it gave us a slightly better test accuracy than when training with the full train dataset).

In [None]:
# resizing transformation for both training and validation sets (98x98 for DINOv2)
preprocessing = transforms.Resize((98, 98))

# load the full training and validation datasets from HDF5 files
train_dataset = BaselineDataset(TRAIN_IMAGES_PATH, preprocessing, 'train')
val_dataset = BaselineDataset(VAL_IMAGES_PATH, preprocessing, 'train')

# we create DataLoaders for the full training and validation sets
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, generator=g, drop_last=True)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, generator=g)

# then we use 10% of train set for LoRA fine-tuning
subset_size = int(0.1 * len(train_dataset))
subset_indices = torch.randperm(len(train_dataset), generator=g)[:subset_size]
lora_subset = Subset(train_dataset, subset_indices)
lora_loader = DataLoader(lora_subset, shuffle=True, batch_size=BATCH_SIZE, generator=g, drop_last=True)


Here is the training loop for fine tuning DINOv2-LoRA model. We save training and validation loss and accuracy at each epoch, and save the model that decrease the validation loss the most.

We perform "only" $5$ epochs first because training is quite long to perform (especially the validation phase where more than $2000$ images need to be tested), but also because the training loss reduces drastically after 1 epochs only and does not descrease much then, while the validation loss and accuracy don't improve much with time.

In [None]:
# initialize best validation loss for early model saving
best_val_loss = float("inf")
MODEL_SAVE_PATH = "dinov2_lora_trained.pth"


for epoch in range(5):

    # training loop
    classifier.train()
    feature_extractor.train()
    train_losses, train_accuracies = [], []
    for images, labels in tqdm(lora_loader, desc=f"Epoch {epoch+1} - LoRA Training"):
        images = images.to(device)
        labels = labels.to(device).float()

        # forward pass
        features = feature_extractor(images)
        if features.dim() == 3:
            features = features[:, 0, :] # extract CLS token

        preds = classifier(features).squeeze(-1)
        loss = criterion(preds, labels)

        # backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # store batch loss and accuracy
        train_losses.append(loss.item())
        acc = ((preds > 0.5).float() == labels).float().mean().item()
        train_accuracies.append(acc)

    # average training metrics
    avg_train_loss = np.mean(train_losses)
    avg_train_acc = np.mean(train_accuracies)

    # validation loop
    classifier.eval()
    feature_extractor.eval()
    val_losses, val_accuracies = [], []
    with torch.no_grad():
        for images, labels in tqdm(val_loader):
            images = images.to(device)
            labels = labels.to(device).float()

            # forward pass 
            features = feature_extractor(images)
            if features.dim() == 3:
                features = features[:, 0, :]

            preds = classifier(features).squeeze(-1)
            loss = criterion(preds, labels)

            val_losses.append(loss.item())
            acc = ((preds > 0.5).float() == labels).float().mean().item()
            val_accuracies.append(acc)

    # average validation metrics
    avg_val_loss = np.mean(val_losses)
    avg_val_acc = np.mean(val_accuracies)

    # print the training and validation loss and accuracy for this epoch
    print(f"[LoRA Training] Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {avg_val_acc:.4f}")

    # save model if it improves on validation loss
    if avg_val_loss < best_val_loss:
        print(f'New best loss {best_val_loss:.4f} -> {avg_val_loss:.4f}')
        best_val_loss = avg_val_loss
        torch.save({
            "feature_extractor": feature_extractor.state_dict(),
            "classifier": classifier.state_dict()
        }, MODEL_SAVE_PATH)


After fine tuning the DINOv2 backbone using LoRA and training a lightweight classifier on top, we now reload the best saved checkpoint to reuse the trained models for evaluation or inference. This ensures reproducibility and avoids retraining the entire model.

After fine tuning the DINOv2+LoRA model, we load the best saved checkpoint (since the last feature_extractor obtained is not necessarily the one we are going to use for inference). We define the DINOv2+LoRA model as before but with loaded `feature_extractor` and `classifier`.

In [None]:
# moad saved model checkpoint 
checkpoint = torch.load("dinov2_lora_trained.pth", map_location="cpu")

# reinitialize the DINOv2 feature extractor
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
# redefine the classifier architecture that was trained on top of DINOv2 features
classifier = nn.Sequential(
    nn.Linear(384, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid()
).to(device)

# apply LoRA to the DINOv2 model (same configuration as during training)
lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=[f"blocks.{i}.attn.qkv" for i in range(12)],
)
feature_extractor = get_peft_model(feature_extractor, lora_config)

# load the fine tuned weights into the feature extractor and classifier
feature_extractor.load_state_dict(checkpoint["feature_extractor"])
classifier.load_state_dict(checkpoint["classifier"])

## 3. Train a linear probing on top of the DINOv2+LoRA model

To speed up the training of the final classifiers, we precompute the features extracted by the DINOv2+LoRA model. The following function `precompute_features_lora` extracts features for all images in a `DataLoader`.

In [None]:
def precompute_features_lora(dataloader, model, device):
    xs, ys = [], [] # will store extracted features xs and labels ys
    model.eval() # set the model to evaluation mode

    with torch.no_grad():
        for x, y in tqdm(dataloader):
            feats = model(x.to(device)) # the model extracts features from the image

            # if the output has a CLS token (since we work with a ViT), then keep only the CLS embedding
            if feats.dim() == 3:
                feats = feats[:, 0, :]

            xs.append(feats.cpu()) # move features to cpu
            ys.append(torch.tensor(y)) # store the label

    # concatenate all batches into a single tensor of features and labels
    return torch.cat(xs), torch.cat(ys)

To efficiently train the linear probing classifier on top of the frozen DINOv2 + LoRA model, we first precompute the features from the training dataset.

In [None]:
# load the training dataset with class BaselineDataset (preprocessing = resizing)
raw_train_dataset = BaselineDataset(TRAIN_IMAGES_PATH, preprocessing, 'train')
# create a DataLoader for the training dataset
raw_train_dataloader = DataLoader(raw_train_dataset, shuffle=True, batch_size=BATCH_SIZE, generator=g, drop_last=True)
# extract features using the pretrained DINOv2+LoRA model
train_features, train_labels = precompute_features_lora(raw_train_dataloader, feature_extractor, device)
torch.save((train_features, train_labels), 'train_features_dinov2_lora.pth')

We do the same with the validation dataset.

In [None]:
# load the validation dataset with class BaselineDataset (preprocessing = resizing)
raw_val_dataset = BaselineDataset(VAL_IMAGES_PATH, preprocessing, 'train')
# create a DataLoader for the validation dataset
raw_val_dataloader = DataLoader(raw_val_dataset, batch_size=BATCH_SIZE, shuffle=False, generator=g, drop_last=True)
# extract features using the pretrained DINOv2+LoRA model
val_features, val_labels = precompute_features_lora(raw_val_dataloader, feature_extractor, device)
torch.save((val_features, val_labels), 'val_features_dinov2_lora.pth')

After precomputation of the DINOv2+LoRA features, we can load them to train the linear probing classifier. We use the `PrecomputedDataset` class to create the training and validation datasets (with precomputed features) and create as before a training and validation dataloader.

In [None]:
# load precomputed feature tensors and labels from disk for both train and validation sets
train_features, train_labels = torch.load('train_features_dinov2_lora.pth')
val_features, val_labels = torch.load('val_features_dinov2_lora.pth')

# wrap the tensors into a custom Dataset class for PyTorch compatibility
train_dataset = PrecomputedDataset(train_features, train_labels)
val_dataset = PrecomputedDataset(val_features, val_labels)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, generator=g, drop_last=True)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE, generator=g, drop_last=True)

We now define a simple classifier head `linear_probing`, on top of the precomputed DINOv2+LoRA features. The classifier consists of two fully connected layers with a ReLU activation (for non linearity), followed by a sigmoid for binary classification.

In [None]:
# linear probing classifier (also referred to as linear probing)
# Input: 384-dimensional features from DINOv2
# Hidden dimension: 128 (was tuned among [64, 128, 256])
# Output: single neuron with sigmoid for binary classification
linear_probing = torch.nn.Sequential(
    torch.nn.Linear(384, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 1),
    torch.nn.Sigmoid()
).to(device)

We kept almost the same training loop for training the classification head as in the `getting_started` notebook, using Adam optimizer and BCE loss (known to be performant for binary classification especially when the classes are balanced). We apply early stopping based on the validtion loss as the model that achieves the lowest validation loss is saved. In our fine tuning experiments, we obersved that in most of the cases, the best model was achieved during the first epochs (never we have trained until $100$ epochs), indicating that the model overfit the training set if we don't do early stopping.

Note: we did not experiment fine tuning the learning rate neither the weight decay, which could also lead to improvement of the final prediction if carefully chosen.

In [None]:
# Training configuration
OPTIMIZER = 'Adam'
OPTIMIZER_PARAMS = {'lr': 0.001, 'weight_decay': 0.005}
LOSS = 'BCELoss'
METRIC = 'Accuracy'
NUM_EPOCHS = 100
PATIENCE = 10

In [None]:
from torchmetrics.classification import BinaryAccuracy

# define optimizer and loss
optimizer = getattr(torch.optim, OPTIMIZER)(linear_probing.parameters(), **OPTIMIZER_PARAMS)
criterion = getattr(torch.nn, LOSS)()

# the used metric is the accuracy
train_metric_fn = BinaryAccuracy().to(device)
val_metric_fn = BinaryAccuracy().to(device)

# early stopping
min_loss, best_epoch = float('inf'), 0

# set classifier to training mode
linear_probing.train()
for param in linear_probing.parameters():
    param.requires_grad = True

# training loop
for epoch in range(NUM_EPOCHS):
    linear_probing.train()
    train_metrics, train_losses = [], []

    for train_x, train_y in tqdm(train_dataloader, leave=False):
        train_y = train_y.squeeze(-1)
        train_x, train_y = train_x.to(device), train_y.to(device)

        optimizer.zero_grad()
        train_pred = linear_probing(train_x).squeeze(-1)
        loss = criterion(train_pred, train_y)
        loss.backward()
        optimizer.step()

        train_losses.extend([loss.item()] * len(train_y))
        train_metric = train_metric_fn(train_pred, train_y.int())
        train_metrics.extend([train_metric.item()] * len(train_y))

    print(f'Epoch train [{epoch+1}/{NUM_EPOCHS}] | Loss {np.mean(train_losses):.4f} | Metric {np.mean(train_metrics):.4f}')

    # validation
    linear_probing.eval()
    val_metrics, val_losses = [], []

    for val_x, val_y in tqdm(val_dataloader, leave=False):
        val_x, val_y = val_x.to(device), val_y.to(device)
        val_y = val_y.squeeze(-1)

        with torch.no_grad():
            val_pred = linear_probing(val_x).squeeze(-1)
        loss = criterion(val_pred, val_y)
        val_losses.extend([loss.item()] * len(val_y))
        val_metric = val_metric_fn(val_pred, val_y.int())
        val_metrics.extend([val_metric.item()] * len(val_y))

    print(f'Epoch valid [{epoch+1}/{NUM_EPOCHS}] | Loss {np.mean(val_losses):.4f} | Metric {np.mean(val_metrics):.4f}')

    # Save best model
    if np.mean(val_losses) < min_loss:
        mean_val_loss = np.mean(val_losses)
        print(f'New best loss {min_loss:.4f} -> {mean_val_loss:.4f}')
        min_loss = mean_val_loss
        best_epoch = epoch
        torch.save(linear_probing.state_dict(), 'best_model_lora.pth')

    if epoch - best_epoch >= PATIENCE:
        print("Early stopping triggered.")
        break

  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [1/30] | Loss 0.0745 | Metric 0.9742


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [1/30] | Loss 0.1996 | Metric 0.9237
New best loss inf -> 0.1996


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [2/30] | Loss 0.0708 | Metric 0.9757


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [2/30] | Loss 0.2121 | Metric 0.9163


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [3/30] | Loss 0.0695 | Metric 0.9763


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [3/30] | Loss 0.2057 | Metric 0.9190


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [4/30] | Loss 0.0692 | Metric 0.9766


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [4/30] | Loss 0.2165 | Metric 0.9178


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [5/30] | Loss 0.0687 | Metric 0.9765


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [5/30] | Loss 0.1874 | Metric 0.9286
New best loss 0.1996 -> 0.1874


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [6/30] | Loss 0.0690 | Metric 0.9762


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [6/30] | Loss 0.2026 | Metric 0.9184


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [7/30] | Loss 0.0687 | Metric 0.9765


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [7/30] | Loss 0.2197 | Metric 0.9238


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [8/30] | Loss 0.0687 | Metric 0.9765


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [8/30] | Loss 0.1950 | Metric 0.9251


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [9/30] | Loss 0.0689 | Metric 0.9764


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [9/30] | Loss 0.1876 | Metric 0.9263


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [10/30] | Loss 0.0687 | Metric 0.9763


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [10/30] | Loss 0.2047 | Metric 0.9206


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [11/30] | Loss 0.0690 | Metric 0.9758


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [11/30] | Loss 0.1948 | Metric 0.9257


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [12/30] | Loss 0.0686 | Metric 0.9762


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [12/30] | Loss 0.1965 | Metric 0.9258


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [13/30] | Loss 0.0689 | Metric 0.9762


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [13/30] | Loss 0.2165 | Metric 0.9153


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [14/30] | Loss 0.0690 | Metric 0.9762


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [14/30] | Loss 0.2569 | Metric 0.9070


  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch train [15/30] | Loss 0.0689 | Metric 0.9761


  0%|          | 0/2181 [00:00<?, ?it/s]

Epoch valid [15/30] | Loss 0.1991 | Metric 0.9236
Early stopping triggered.


In [None]:
linear_probing.load_state_dict(torch.load('best_model_lora.pth', weights_only=True))
linear_probing.eval()
linear_probing.to(device)

Sequential(
  (0): Linear(in_features=384, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=1, bias=True)
  (3): Sigmoid()
)

## 4. Train a classification head on top of the pretrained ResNet34

We also worked with several other models that are composed of CNNs, such as well known ResNets model. Especially, `resnet34` provided us the best test accuracy when ensembled with the DINOv2+LoRA model. Bellow is the code for the training of a classification head on top the the pretrained model.

First, we prepare the data loaders as for the DINOv2 model, but with a different `resnet_preprocessing`, resizing images to 224x224 since it is the expected dimension by ResNet. As before, we train with only $10$% of the training dataset, but also used $30$% of the validation dataset to accelerate the validation phase during training.

In [None]:
# new preprocessing : resizing to 224x224
resnet_preprocessing = transforms.Resize((224, 224))

# load training and validation datasets
raw_train_dataset_resnet = BaselineDataset(TRAIN_IMAGES_PATH, resnet_preprocessing, 'train')
raw_val_dataset_resnet = BaselineDataset(VAL_IMAGES_PATH, resnet_preprocessing, 'train')

# use 10% of the training dataset
train_total_len = len(raw_train_dataset_resnet)
train_subset_len = int(0.1 * train_total_len)
_ , train_subset_resnet = random_split(raw_train_dataset_resnet, [train_total_len - train_subset_len, train_subset_len], generator=g)

# use 30% of the validation dataset
val_total_len = len(raw_val_dataset_resnet)
val_subset_len = int(0.3 * val_total_len)
val_subset_resnet, _ = random_split(raw_val_dataset_resnet, [val_subset_len, val_total_len - val_subset_len], generator=g)

# training and validation dataloaders
train_dataloader_resnet = DataLoader(train_subset_resnet, shuffle=True, batch_size=BATCH_SIZE, generator=g, drop_last=True)
val_dataloader_resnet = DataLoader(val_subset_resnet, shuffle=False, batch_size=BATCH_SIZE, generator=g, drop_last=True)

Bellow we load the pretrained resnet34 and replace the last fully connected layer for a 2 layer MLP as a classification head. 

We also use Adam and BCE loss for the training of the classification head.

In [None]:
resnet_model = models.resnet34(pretrained=True)
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid()
)
resnet_model = resnet_model.to(device)

resnet_optimizer = torch.optim.AdamW(resnet_model.parameters(), lr=1e-4)
resnet_criterion = nn.BCELoss()

In [None]:
NUM_EPOCHS_RESNET = 10
min_loss_resnet, best_epoch_resnet = float('inf'), 0

# training loop for resnet34
for epoch in range(NUM_EPOCHS_RESNET):
    resnet_model.train()
    train_losses_resnet, train_metrics_resnet = [], []

    for images, labels in tqdm(train_dataloader_resnet, leave=False, desc=f'ResNet Epoch {epoch+1}'):
        images, labels = images.to(device), labels.to(device)
        resnet_optimizer.zero_grad()

        # forward pass
        preds = resnet_model(images).squeeze(-1)
        loss = resnet_criterion(preds, labels.float())

        # backpropagation
        loss.backward()
        resnet_optimizer.step()

        # training loss and accuracy
        train_losses_resnet.append(loss.item())
        preds_binary = (preds > 0.5).float()
        acc = (preds_binary == labels.float()).float().mean().item()
        train_metrics_resnet.append(acc)

    print(f'ResNet Train Epoch {epoch+1}: Loss {np.mean(train_losses_resnet):.4f}, Acc {np.mean(train_metrics_resnet):.4f}')

    # validation
    resnet_model.eval()
    val_losses_resnet, val_metrics_resnet = [], []

    with torch.no_grad():
        for images, labels in tqdm(val_dataloader_resnet, leave=False, desc=f'ResNet Val Epoch {epoch+1}'):
            images, labels = images.to(device), labels.to(device)

            # forward pass
            preds = resnet_model(images).squeeze(-1)
            loss = resnet_criterion(preds, labels.float())

            # validation loss and accuracy
            val_losses_resnet.append(loss.item())
            preds_binary = (preds > 0.5).float()
            acc = (preds_binary == labels.float()).float().mean().item()
            val_metrics_resnet.append(acc)

    print(f'ResNet Val Epoch {epoch+1}: Loss {np.mean(val_losses_resnet):.4f}, Acc {np.mean(val_metrics_resnet):.4f}')

    # save the best model that minimize the validation loss
    if np.mean(val_losses_resnet) < min_loss_resnet:
        print(f'New best ResNet loss: {min_loss_resnet:.4f} -> {np.mean(val_losses_resnet):.4f}')
        min_loss_resnet = np.mean(val_losses_resnet)
        best_epoch_resnet = epoch
        torch.save(resnet_model.state_dict(), 'best_resnet34.pth')


ResNet Epoch 1:   0%|          | 0/625 [00:00<?, ?it/s]

  img = torch.tensor(hdf.get(img_id).get('img')).float()


ResNet Train Epoch 1: Loss 0.1490, Acc 0.9481


ResNet Val Epoch 1:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 1: Loss 0.2502, Acc 0.8984
New best ResNet loss: inf -> 0.2502


ResNet Epoch 2:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 2: Loss 0.0820, Acc 0.9726


ResNet Val Epoch 2:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 2: Loss 0.2452, Acc 0.9100
New best ResNet loss: 0.2502 -> 0.2452


ResNet Epoch 3:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 3: Loss 0.0572, Acc 0.9811


ResNet Val Epoch 3:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 3: Loss 0.2457, Acc 0.9156


ResNet Epoch 4:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 4: Loss 0.0367, Acc 0.9867


ResNet Val Epoch 4:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 4: Loss 0.3059, Acc 0.9106


ResNet Epoch 5:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 5: Loss 0.0297, Acc 0.9897


ResNet Val Epoch 5:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 5: Loss 0.3384, Acc 0.8874


ResNet Epoch 6:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 6: Loss 0.0211, Acc 0.9925


ResNet Val Epoch 6:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 6: Loss 0.2848, Acc 0.9248


ResNet Epoch 7:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 7: Loss 0.0254, Acc 0.9912


ResNet Val Epoch 7:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 7: Loss 0.3408, Acc 0.9046


ResNet Epoch 8:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 8: Loss 0.0199, Acc 0.9939


ResNet Val Epoch 8:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 8: Loss 0.2173, Acc 0.9370
New best ResNet loss: 0.2452 -> 0.2173


ResNet Epoch 9:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 9: Loss 0.0187, Acc 0.9949


ResNet Val Epoch 9:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 9: Loss 0.7923, Acc 0.8088


ResNet Epoch 10:   0%|          | 0/625 [00:00<?, ?it/s]

ResNet Train Epoch 10: Loss 0.0141, Acc 0.9946


ResNet Val Epoch 10:   0%|          | 0/654 [00:00<?, ?it/s]

ResNet Val Epoch 10: Loss 0.3592, Acc 0.9081


We now load the best ResNet-34 model + trained classification head checkpoint (based on validation loss) and switch it to evaluation mode to perform inference.

In [None]:
resnet_model.load_state_dict(torch.load('best_resnet34.pth'))
resnet_model.eval()

  resnet_model.load_state_dict(torch.load('/kaggle/working/best_resnet34.pth'))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## 5. Final Prediction

We now generate predictions on the test set using three configurations:

- DINOv2+LoRA + linear probing
- ResNet34 + trained classification head
- Ensemble: the average prediction of the two models

For each, we save both binary predictions and raw scores to CSV files, which can be used later for evaluation, ensemble analysis or correlation heatmaps.

In [28]:
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    test_ids = list(hdf.keys())

In [None]:
# initialize dictionaries to store predictions
solutions_dinov2_lora = {'ID': [], 'Pred': []}
solutions_dinov2_lora_raw = {'ID': [], 'Raw': []}
solutions_resnet34 = {'ID': [], 'Pred': []}
solutions_resnet34_raw = {'ID': [], 'Raw': []}
solutions_dinov2_lora_resnet34 = {'ID': [], 'Pred': []}

# inference loop
with h5py.File(TEST_IMAGES_PATH, 'r') as hdf:
    for test_id in tqdm(test_ids):
        # DINOv2+LoRA prediction
        img_dinov2_lora = preprocessing(torch.tensor(np.array(hdf.get(test_id).get('img')))).unsqueeze(0).float()
        pred_dinov2_lora = linear_probing(feature_extractor(img_dinov2_lora.to(device))).detach().cpu()

        solutions_dinov2_lora['ID'].append(int(test_id))
        solutions_dinov2_lora['Pred'].append(int(pred_dinov2_lora.item() > 0.5))
        solutions_dinov2_lora_raw['ID'].append(int(test_id))
        solutions_dinov2_lora_raw['Raw'].append(pred_dinov2_lora.item())
        
        # ResNet34 prediction
        img_resnet = resnet_preprocessing(torch.tensor(np.array(hdf.get(test_id).get('img')))).unsqueeze(0).float()
        pred_resnet = resnet_model(img_resnet.to(device)).detach().cpu()
        
        solutions_resnet34['ID'].append(int(test_id))
        solutions_resnet34['Pred'].append(int(pred_resnet.item() > 0.5))
        solutions_resnet34_raw['ID'].append(int(test_id))
        solutions_resnet34_raw['Raw'].append(pred_resnet.item())

        # Ensemble prediction (average of the two models)
        pred_avg = (pred_dinov2_lora.item() + pred_resnet.item()) / 2.0
        solutions_dinov2_lora_resnet34['ID'].append(int(test_id))
        solutions_dinov2_lora_resnet34['Pred'].append(int(pred_avg > 0.5))

In [31]:
solutions_dinov2_lora = pd.DataFrame(solutions_dinov2_lora).set_index('ID')
solutions_dinov2_lora.to_csv('prediction_dinov2_lora.csv')

solutions_dinov2_lora_raw = pd.DataFrame(solutions_dinov2_lora_raw).set_index('ID')
solutions_dinov2_lora_raw.to_csv('raw_prediction_dinov2_lora.csv')

solutions_resnet34 = pd.DataFrame(solutions_resnet34).set_index('ID')
solutions_resnet34.to_csv('prediction_resnet34.csv')

solutions_resnet34_raw = pd.DataFrame(solutions_resnet34_raw).set_index('ID')
solutions_resnet34_raw.to_csv('raw_prediction_resnet34.csv')

solutions_dinov2_lora_resnet34 = pd.DataFrame(solutions_dinov2_lora_resnet34).set_index('ID')
solutions_dinov2_lora_resnet34.to_csv('prediction_dinov2_lora_resnet34.csv')