In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from pathlib import Path
import random
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
!pip install -q timm medmnist

import timm
import medmnist
from medmnist import INFO, Evaluator  # INFO provides metadata for each dataset

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [3]:
# Global list of tasks – ensure the order matches your training.
DATASETS = [
    'pathmnist',
    'dermamnist',
    'octmnist',
    'pneumoniamnist',
    'retinamnist',
    'breastmnist',
    'bloodmnist',
    'tissuemnist',
    'organamnist',
    'organcmnist',
    'organsmnist'
]

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        self.act = nn.GELU()
    
    def forward(self, x):
        return self.act(x + self.block(x))

class MedMNISTMultiTaskModel(nn.Module):
    def __init__(self, backbone_name='convnext_tiny', pretrained=True):
        super().__init__()

        self.task_outputs = {
            'pathmnist': 9,
            'dermamnist': 7,
            'octmnist': 4,
            'pneumoniamnist': 2,
            'retinamnist': 5,
            'breastmnist': 2,
            'bloodmnist': 8,
            'tissuemnist': 8,
            'organamnist': 11,
            'organcmnist': 11,
            'organsmnist': 11
        }

        # Create ConvNeXt backbone using timm
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=pretrained,
            num_classes=0,
            drop_path_rate=0.1
        )

        # Modify stem for 28x28 images (3 channels)
        self.backbone.stem[0] = nn.Conv2d(
            3, 96, kernel_size=3, stride=1, padding=1
        )
        
        feat_dim = self.backbone.num_features  # e.g., 768 for convnext_tiny

        # Create task-specific heads
        self.heads = nn.ModuleDict()
        for task, num_classes in self.task_outputs.items():
            self.heads[task] = nn.Sequential(
                nn.LayerNorm(feat_dim),
                nn.Linear(feat_dim, feat_dim),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Sequential(
                    nn.Linear(feat_dim, feat_dim // 4),
                    nn.GELU(),
                    nn.Linear(feat_dim // 4, feat_dim),
                    nn.Sigmoid()
                ),
                nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, feat_dim * 4),
                    nn.GELU(),
                    nn.Dropout(0.2),
                    nn.Linear(feat_dim * 4, feat_dim),
                ),
                nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, num_classes)
                )
            )

    def forward(self, x, task_ids=None):
        features = self.backbone(x)  # shape: (B, feat_dim)
        if task_ids is not None:
            outputs = torch.zeros(x.size(0), max(self.task_outputs.values())).to(x.device)
            for i, tid in enumerate(task_ids):
                task_name = DATASETS[tid.item()]
                num_cls = self.task_outputs[task_name]
                out = self.heads[task_name](features[i:i+1])
                outputs[i, :num_cls] = out
            return outputs
        else:
            return {task: head(features) for task, head in self.heads.items()}

# (Optional: print a summary)
# model = MedMNISTMultiTaskModel()
# print(model)

In [4]:
def load_best_model(checkpoint_path, model):
    """
    Load the best model checkpoint into the provided model.
    Use strict=False to allow for minor differences.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    missing, unexpected = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    print("Missing keys:", missing)
    print("Unexpected keys:", unexpected)
    best_f1 = checkpoint.get('best_f1', 0.0)
    epoch = checkpoint.get('epoch', -1)
    print(f"Loaded checkpoint from epoch {epoch} with F1: {best_f1:.4f}")
    return model, best_f1

In [5]:
class MedMNISTTestDataset(Dataset):
    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path)
        self.images = data['test_images']  # shape: (N, 28, 28)
        if self.images.ndim == 3:
            self.images = np.expand_dims(self.images, axis=-1)
        if self.images.shape[-1] == 1:
            self.images = np.tile(self.images, (1, 1, 1, 3))
        self.images = self.images.astype(np.float32) / 255.0
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        image = torch.tensor(image).permute(2, 0, 1)  # Convert HWC to CHW
        return image

In [6]:
test_dataloaders = {}
for task in DATASETS:
    npz_path = Path("/kaggle/input/tensor-reloaded-multi-task-med-mnist/data") / f"{task}.npz"
    dataset = MedMNISTTestDataset(npz_path)
    loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
    test_dataloaders[task] = loader
    print(f"{task}: {len(dataset)} test images, {len(loader)} batches")

pathmnist: 7180 test images, 29 batches
dermamnist: 2005 test images, 8 batches
octmnist: 1000 test images, 4 batches
pneumoniamnist: 624 test images, 3 batches
retinamnist: 400 test images, 2 batches
breastmnist: 156 test images, 1 batches
bloodmnist: 3421 test images, 14 batches
tissuemnist: 47280 test images, 185 batches
organamnist: 17778 test images, 70 batches
organcmnist: 8268 test images, 33 batches
organsmnist: 8829 test images, 35 batches


In [7]:
checkpoint_path = "/kaggle/input/improving-accuracy-with-multihead-backbone/best_model.pth" 

model = MedMNISTMultiTaskModel(backbone_name='convnext_tiny', pretrained=True)
model.to(device)
model, best_f1 = load_best_model(checkpoint_path, model)
model.eval()

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

  checkpoint = torch.load(checkpoint_path, map_location=device)


Missing keys: ['heads.pathmnist.1.weight', 'heads.pathmnist.1.bias', 'heads.pathmnist.4.0.weight', 'heads.pathmnist.4.0.bias', 'heads.pathmnist.4.2.weight', 'heads.pathmnist.4.2.bias', 'heads.pathmnist.5.0.weight', 'heads.pathmnist.5.0.bias', 'heads.pathmnist.5.1.weight', 'heads.pathmnist.5.1.bias', 'heads.pathmnist.5.4.weight', 'heads.pathmnist.5.4.bias', 'heads.pathmnist.6.0.weight', 'heads.pathmnist.6.0.bias', 'heads.pathmnist.6.1.weight', 'heads.pathmnist.6.1.bias', 'heads.dermamnist.1.weight', 'heads.dermamnist.1.bias', 'heads.dermamnist.4.0.weight', 'heads.dermamnist.4.0.bias', 'heads.dermamnist.4.2.weight', 'heads.dermamnist.4.2.bias', 'heads.dermamnist.5.0.weight', 'heads.dermamnist.5.0.bias', 'heads.dermamnist.5.1.weight', 'heads.dermamnist.5.1.bias', 'heads.dermamnist.5.4.weight', 'heads.dermamnist.5.4.bias', 'heads.dermamnist.6.0.weight', 'heads.dermamnist.6.0.bias', 'heads.dermamnist.6.1.weight', 'heads.dermamnist.6.1.bias', 'heads.octmnist.1.weight', 'heads.octmnist.1.bias

MedMNISTMultiTaskModel(
  (backbone): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (norm): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=96, out_features=384, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=384, out_features=96, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock

In [8]:
# # Example: unfreeze the last block of the backbone and set differential learning rates
# for name, param in model.backbone.named_parameters():
#     if "block4" in name:  # Adjust this to the appropriate layer name in your backbone
#         param.requires_grad = True

# backbone_params = [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad]
# head_params = [p for n, p in model.named_parameters() if "backbone" not in n]

# optimizer = torch.optim.AdamW([
#     {"params": backbone_params, "lr": 1e-5},
#     {"params": head_params, "lr": 1e-4}
# ], weight_decay=0.05)

In [9]:
submission_rows = []
global_id = 0

with torch.no_grad(), torch.amp.autocast('cuda'):
    for task in DATASETS:
        print(f"Processing task: {task}")
        loader = test_dataloaders[task]
        image_idx = 0
        task_idx = DATASETS.index(task)
        for images in tqdm(loader, desc=f"Task: {task}"):
            images = images.to(device, non_blocking=True)
            batch_size_current = images.size(0)
            task_ids = torch.full((batch_size_current,), task_idx, dtype=torch.long, device=device)
            outputs = model(images, task_ids=task_ids)
            num_cls = model.task_outputs[task]
            preds = outputs[:, :num_cls].argmax(dim=1).cpu().numpy()
            for pred in preds:
                submission_rows.append([global_id, image_idx, task, int(pred)])
                global_id += 1
                image_idx += 1

submission_df = pd.DataFrame(submission_rows, columns=["id", "id_image_in_task", "task_name", "label"])
print("Total submission rows:", len(submission_df))
submission_df.to_csv("submission.csv", index=False)
print("Submission file saved as submission.csv")

Processing task: pathmnist


Task: pathmnist: 100%|██████████| 29/29 [00:10<00:00,  2.66it/s]


Processing task: dermamnist


Task: dermamnist: 100%|██████████| 8/8 [00:02<00:00,  2.72it/s]


Processing task: octmnist


Task: octmnist: 100%|██████████| 4/4 [00:01<00:00,  2.61it/s]


Processing task: pneumoniamnist


Task: pneumoniamnist: 100%|██████████| 3/3 [00:01<00:00,  2.95it/s]


Processing task: retinamnist


Task: retinamnist: 100%|██████████| 2/2 [00:00<00:00,  2.84it/s]


Processing task: breastmnist


Task: breastmnist: 100%|██████████| 1/1 [00:00<00:00,  2.69it/s]


Processing task: bloodmnist


Task: bloodmnist: 100%|██████████| 14/14 [00:04<00:00,  2.89it/s]


Processing task: tissuemnist


Task: tissuemnist: 100%|██████████| 185/185 [01:04<00:00,  2.86it/s]


Processing task: organamnist


Task: organamnist: 100%|██████████| 70/70 [00:24<00:00,  2.82it/s]


Processing task: organcmnist


Task: organcmnist: 100%|██████████| 33/33 [00:11<00:00,  2.88it/s]


Processing task: organsmnist


Task: organsmnist: 100%|██████████| 35/35 [00:12<00:00,  2.85it/s]


Total submission rows: 96941
Submission file saved as submission.csv
