# Weighted Ensemble for MedMNIST Multi‑Task Challenge
This notebook implements a weighted ensemble of two pre‑trained MedMNIST multi‑task models. Instead of a simple average of their softmax outputs, we compute model weights (based on their validation scores) and use these weights to average the probabilities. The final predictions are used to generate a submission CSV file. This approach may boost performance by giving more influence to the better‐performing model.

In [1]:
!pip install -q timm medmnist
!pip install torchvision

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [2]:
import medmnist
import numpy as np
import pandas as pd
import random
import timm
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from medmnist import INFO
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using : ", device)

Using :  cuda


In [3]:
# Global list of tasks – must match your training order!
DATASETS = [
    'pathmnist',
    'dermamnist',
    'octmnist',
    'pneumoniamnist',
    'retinamnist',
    'breastmnist',
    'bloodmnist',
    'tissuemnist',
    'organamnist',
    'organcmnist',
    'organsmnist'
]

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        self.act = nn.GELU()
    
    def forward(self, x):
        return self.act(x + self.block(x))

class MedMNISTMultiTaskModel(nn.Module):
    def __init__(self, backbone_name='convnext_tiny', pretrained=True, head_type='bottleneck', dropout_rate=0.2, stochastic_depth_rate=0.1):
        super().__init__()

        # Set the number of classes per task using MedMNIST INFO
        self.task_outputs = {task: len(INFO[task]['label']) for task in DATASETS}

        # Create backbone using timm (ConvNeXt in this example)
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=pretrained,
            num_classes=0,
            drop_path_rate=stochastic_depth_rate
        )
        # Adapt the stem for 28x28 images (3 channels)
        self.backbone.stem[0] = nn.Conv2d(3, 96, kernel_size=3, stride=1, padding=1)
        feat_dim = self.backbone.num_features

        # Create task-specific heads
        self.heads = nn.ModuleDict()
        for task, num_classes in self.task_outputs.items():
            if head_type == 'bottleneck':
                head = nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, feat_dim // 4),
                    nn.GELU(),
                    nn.Dropout(dropout_rate),
                    nn.Linear(feat_dim // 4, num_classes)
                )
            else:
                head = nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, num_classes)
                )
            self.heads[task] = head

    def forward(self, x, task_ids=None):
        features = self.backbone(x)  # shape: (B, feat_dim)
        if task_ids is not None:
            outputs = torch.zeros(x.size(0), max(self.task_outputs.values())).to(x.device)
            for i, tid in enumerate(task_ids):
                task_name = DATASETS[tid.item()]
                num_cls = self.task_outputs[task_name]
                out = self.heads[task_name](features[i:i+1])
                outputs[i, :num_cls] = out.squeeze(0)
            return outputs
        else:
            return {task: head(features) for task, head in self.heads.items()}

In [4]:
def load_checkpoint(checkpoint_path, model):
    """
    Load checkpoint into the provided model (using strict=False).
    Returns the model and its validation metric.
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    missing, unexpected = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', -1)}; missing keys: {missing}, unexpected keys: {unexpected}")
    metric = checkpoint.get('best_f1', checkpoint.get('best_metric', 0.0))
    return model, metric

# Define paths for the two checkpoints (adjust these paths to your input folders)
checkpoint_path1 = "/kaggle/input/improving-accuracy-with-multihead-backbone/best_model.pth"  
checkpoint_path2 = "/kaggle/input/fork-of-convnext-tiny-notebook8a8c996b9/best_model.pth"  

In [5]:
class MedMNISTTestDataset(Dataset):
    def __init__(self, npz_path, transform=None):
        data = np.load(npz_path)
        self.images = data['test_images']  # shape: (N, 28, 28)
        if self.images.ndim == 3:
            self.images = np.expand_dims(self.images, axis=-1)
        if self.images.shape[-1] == 1:
            self.images = np.tile(self.images, (1, 1, 1, 3))
        self.images = self.images.astype(np.float32) / 255.0
        self.transform = transform

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        # Convert HWC to CHW
        image = torch.tensor(image).permute(2, 0, 1)
        return image
        
# Create test DataLoaders for each task
test_dataloaders = {}
base_path = Path("/kaggle/input/tensor-reloaded-multi-task-med-mnist/data")
for task in DATASETS:
    npz_path = base_path / f"{task}.npz"
    dataset = MedMNISTTestDataset(npz_path)
    loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
    test_dataloaders[task] = loader
    print(f"{task}: {len(dataset)} test images, {len(loader)} batches")

pathmnist: 7180 test images, 29 batches
dermamnist: 2005 test images, 8 batches
octmnist: 1000 test images, 4 batches
pneumoniamnist: 624 test images, 3 batches
retinamnist: 400 test images, 2 batches
breastmnist: 156 test images, 1 batches
bloodmnist: 3421 test images, 14 batches
tissuemnist: 47280 test images, 185 batches
organamnist: 17778 test images, 70 batches
organcmnist: 8268 test images, 33 batches
organsmnist: 8829 test images, 35 batches


In [6]:
# Instantiate two models with the same architecture
model1 = MedMNISTMultiTaskModel(backbone_name='convnext_tiny', pretrained=True)
model2 = MedMNISTMultiTaskModel(backbone_name='convnext_tiny', pretrained=True)
model1.to(device)
model2.to(device)

# Load checkpoints for each model
model1, metric1 = load_checkpoint(checkpoint_path1, model1)
model2, metric2 = load_checkpoint(checkpoint_path2, model2)

# Set models to evaluation mode
model1.eval()
model2.eval()

# Compute ensemble weights based on validation metrics (e.g., F1 scores)
total_metric = metric1 + metric2 + 1e-6
w1 = metric1 / total_metric
w2 = metric2 / total_metric
print(f"Model1 weight: {w1:.4f}, Model2 weight: {w2:.4f}")

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

  checkpoint = torch.load(checkpoint_path, map_location=device)


Loaded checkpoint from epoch 30; missing keys: ['heads.pathmnist.1.weight', 'heads.pathmnist.1.bias', 'heads.pathmnist.4.weight', 'heads.pathmnist.4.bias', 'heads.dermamnist.1.weight', 'heads.dermamnist.1.bias', 'heads.dermamnist.4.weight', 'heads.dermamnist.4.bias', 'heads.octmnist.1.weight', 'heads.octmnist.1.bias', 'heads.octmnist.4.weight', 'heads.octmnist.4.bias', 'heads.pneumoniamnist.1.weight', 'heads.pneumoniamnist.1.bias', 'heads.pneumoniamnist.4.weight', 'heads.pneumoniamnist.4.bias', 'heads.retinamnist.1.weight', 'heads.retinamnist.1.bias', 'heads.retinamnist.4.weight', 'heads.retinamnist.4.bias', 'heads.breastmnist.1.weight', 'heads.breastmnist.1.bias', 'heads.breastmnist.4.weight', 'heads.breastmnist.4.bias', 'heads.bloodmnist.1.weight', 'heads.bloodmnist.1.bias', 'heads.bloodmnist.4.weight', 'heads.bloodmnist.4.bias', 'heads.tissuemnist.1.weight', 'heads.tissuemnist.1.bias', 'heads.tissuemnist.4.weight', 'heads.tissuemnist.4.bias', 'heads.organamnist.1.weight', 'heads.org

In [7]:
submission_rows = []
global_id = 0

# Use torch.no_grad() and the new AMP syntax.
with torch.no_grad(), torch.amp.autocast('cuda'):
    for task in DATASETS:
        print(f"Processing task: {task}")
        loader = test_dataloaders[task]
        image_idx = 0
        task_idx = DATASETS.index(task)
        for images in tqdm(loader, desc=f"Task: {task}"):
            images = images.to(device, non_blocking=True)
            batch_size_current = images.size(0)
            # Create a task_ids tensor for the batch
            task_ids = torch.full((batch_size_current,), task_idx, dtype=torch.long, device=device)
            
            # Get outputs (logits) from both models
            outputs1 = model1(images, task_ids=task_ids)
            outputs2 = model2(images, task_ids=task_ids)
            
            num_cls = model1.task_outputs[task]  # same for both models
            # Convert outputs to softmax probabilities over the task-specific number of classes
            probs1 = torch.softmax(outputs1[:, :num_cls], dim=1)
            probs2 = torch.softmax(outputs2[:, :num_cls], dim=1)
            
            # Compute weighted average of probabilities
            avg_probs = w1 * probs1 + w2 * probs2
            
            # Final prediction is the argmax of the averaged probabilities
            preds = avg_probs.argmax(dim=1).cpu().numpy()
            
            # Record predictions
            for pred in preds:
                submission_rows.append([global_id, image_idx, task, int(pred)])
                global_id += 1
                image_idx += 1

submission_df = pd.DataFrame(submission_rows, columns=["id", "id_image_in_task", "task_name", "label"])
print("Total submission rows:", len(submission_df))
submission_df.to_csv("submission.csv", index=False)
print("Submission file saved as submission.csv")

Processing task: pathmnist


Task: pathmnist: 100%|██████████| 29/29 [00:09<00:00,  3.22it/s]


Processing task: dermamnist


Task: dermamnist: 100%|██████████| 8/8 [00:02<00:00,  3.49it/s]


Processing task: octmnist


Task: octmnist: 100%|██████████| 4/4 [00:01<00:00,  3.22it/s]


Processing task: pneumoniamnist


Task: pneumoniamnist: 100%|██████████| 3/3 [00:00<00:00,  3.54it/s]


Processing task: retinamnist


Task: retinamnist: 100%|██████████| 2/2 [00:00<00:00,  3.42it/s]


Processing task: breastmnist


Task: breastmnist: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]


Processing task: bloodmnist


Task: bloodmnist: 100%|██████████| 14/14 [00:03<00:00,  3.60it/s]


Processing task: tissuemnist


Task: tissuemnist: 100%|██████████| 185/185 [00:52<00:00,  3.54it/s]


Processing task: organamnist


Task: organamnist: 100%|██████████| 70/70 [00:20<00:00,  3.36it/s]


Processing task: organcmnist


Task: organcmnist: 100%|██████████| 33/33 [00:09<00:00,  3.37it/s]


Processing task: organsmnist


Task: organsmnist: 100%|██████████| 35/35 [00:10<00:00,  3.36it/s]


Total submission rows: 96941
Submission file saved as submission.csv


# Final Summary
In this notebook, we built a weighted ensemble of two pre‑trained MedMNIST multi‑task models:
- **Model Definition:** We use a custom model (with a ConvNeXt backbone and task‑specific heads) exactly as used during training.
- **Checkpoint Loading:** Two checkpoints are loaded with `strict=False`, and their validation metrics are used to compute ensemble weights.
- **Test Data Preparation:** We load NPZ test files for each task and create DataLoaders.
- **Weighted Inference:** For each batch, we compute softmax probabilities from both models, weight them according to their validation performance, and average to obtain the final predictions.
- **Submission File:** The predictions are saved to a CSV file in the required format.

This weighted ensemble approach gives more influence to the better‑performing model and is expected to yield competitive results on the leaderboard. Adjust the weighting scheme and hyperparameters as needed.