In [1]:
!pip install -q medmnist

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for fire (setup.py) ... [?25l[?25hdone


In [2]:
import medmnist
from medmnist import INFO, Evaluator
import numpy as np
import torch
from torch.utils.data import Dataset, Subset
from pathlib import Path
import pandas as pd
import random



def load_medmnist_from_npz(data_flag):
    
    data_path = Path('/kaggle/input/tensor-reloaded-multi-task-med-mnist/data') / f'{data_flag}.npz'
    data = np.load(data_path)

    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])

    class NPZDataset(Dataset):
        def __init__(self, images, labels=None):
            self.images = images
            self.labels = labels

        def __len__(self):
            return len(self.images)

        def __getitem__(self, idx):
            image = self.images[idx]
            if self.labels is not None:
                return torch.tensor(image), torch.tensor(self.labels[idx])
            return torch.tensor(image)

    
        
    train_dataset = NPZDataset(data['train_images'], data.get('train_labels'))
    val_dataset = NPZDataset(data['val_images'], data.get('val_labels'))
    test_dataset = NPZDataset(data['test_images'])

    return train_dataset, val_dataset, test_dataset, info


DATASETS = [
    'pathmnist',
    'dermamnist',
    'octmnist',
    'pneumoniamnist',
    'retinamnist',
    'breastmnist',
    'bloodmnist',
    'tissuemnist',
    'organamnist',
    'organcmnist',
    'organsmnist'
]

def load_all_datasets():
    datasets = {}
    for data_flag in DATASETS:
        print(f"Loading {data_flag}...")
        train, val, test, info = load_medmnist_from_npz(data_flag)
        datasets[data_flag] = {
            'train': train,
            'val': val,
            'test': test,
            'info': info
        }
    return datasets

def print_dataset_info(datasets):
    
    total_test = 0
    for data_flag, data in datasets.items():
        info = data['info']
        print(f"\n{data_flag}:")
        print(f"Task: {info['task']}")
        print(f"Classes: {len(info['label'])}")
        print(f"Train size: {len(data['train'])} {''}")
        print(f"Val size: {len(data['val'])}")
        print(f"Test size: {len(data['test'])}")
        total_test += len(data['test'])
    print(f"\nTotal test samples: {total_test}")
    

def calculate_class_weights(datasets):
    weights = {}

    for data_flag in DATASETS:
        
        labels = datasets[data_flag]['train'].labels
        if isinstance(labels, torch.Tensor):
            labels = labels.numpy()

        
        num_classes = len(datasets[data_flag]['info']['label'])

        
        class_counts = np.bincount(labels.flatten(), minlength=num_classes)

        
        total = class_counts.sum()
        raw_weights = total / (class_counts + 1e-6)

        
        normalized_weights = raw_weights / raw_weights.mean()

        
        weights[data_flag] = torch.FloatTensor(normalized_weights)

    return weights


datasets = load_all_datasets()
print_dataset_info(datasets)
class_weights = calculate_class_weights(datasets)

Loading pathmnist...
Loading dermamnist...
Loading octmnist...
Loading pneumoniamnist...
Loading retinamnist...
Loading breastmnist...
Loading bloodmnist...
Loading tissuemnist...
Loading organamnist...
Loading organcmnist...
Loading organsmnist...

pathmnist:
Task: multi-class
Classes: 9
Train size: 89996 
Val size: 10004
Test size: 7180

dermamnist:
Task: multi-class
Classes: 7
Train size: 7007 
Val size: 1003
Test size: 2005

octmnist:
Task: multi-class
Classes: 4
Train size: 97477 
Val size: 10832
Test size: 1000

pneumoniamnist:
Task: binary-class
Classes: 2
Train size: 4708 
Val size: 524
Test size: 624

retinamnist:
Task: ordinal-regression
Classes: 5
Train size: 1080 
Val size: 120
Test size: 400

breastmnist:
Task: binary-class
Classes: 2
Train size: 546 
Val size: 78
Test size: 156

bloodmnist:
Task: multi-class
Classes: 8
Train size: 11959 
Val size: 1712
Test size: 3421

tissuemnist:
Task: multi-class
Classes: 8
Train size: 165466 
Val size: 23640
Test size: 47280

organamn

In [3]:
import timm
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torchvision.transforms as transforms

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        self.act = nn.GELU()
    
    def forward(self, x):
        return self.act(x + self.block(x))

class MedMNISTMultiTaskModel(nn.Module):
    def __init__(self, backbone_name='convnext_tiny', pretrained=True):
        super().__init__()

        self.task_outputs = {
            'pathmnist': 9,
            'dermamnist': 7,
            'octmnist': 4,
            'pneumoniamnist': 2,
            'retinamnist': 5,
            'breastmnist': 2,
            'bloodmnist': 8,
            'tissuemnist': 8,
            'organamnist': 11,
            'organcmnist': 11,
            'organsmnist': 11
        }

        
        self.backbone = timm.create_model(
            backbone_name,
            pretrained=pretrained,
            num_classes=0,
            drop_path_rate=0.1  
        )

        
        self.backbone.stem[0] = nn.Conv2d(
            3, 96, kernel_size=3, stride=1, padding=1
        )
        
        
        feat_dim = self.backbone.num_features  

       
        self.heads = nn.ModuleDict()
        for task, num_classes in self.task_outputs.items():
            self.heads[task] = nn.Sequential(
                
                nn.LayerNorm(feat_dim),
                
               
                nn.Sequential(
                    nn.Linear(feat_dim, feat_dim),
                    nn.GELU(),
                    nn.Dropout(0.2),
                    
                    nn.Sequential(
                        nn.Linear(feat_dim, feat_dim // 4),
                        nn.GELU(),
                        nn.Linear(feat_dim // 4, feat_dim),
                        nn.Sigmoid()
                    )
                ),
                
                
                nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, feat_dim * 4),
                    nn.GELU(),
                    nn.Dropout(0.2),
                    nn.Linear(feat_dim * 4, feat_dim),
                ),
                
                
                nn.Sequential(
                    nn.LayerNorm(feat_dim),
                    nn.Linear(feat_dim, num_classes)
                )
            )

    def forward(self, x, task_ids=None):
        
        features = self.backbone(x)

        if task_ids is not None:
            
            task_groups = {}
            for i, task_id in enumerate(task_ids):
                task_name = DATASETS[task_id]
                if task_name not in task_groups:
                    task_groups[task_name] = {'indices': [], 'features': []}
                task_groups[task_name]['indices'].append(i)
                task_groups[task_name]['features'].append(features[i:i+1])

           
            outputs = torch.zeros(len(task_ids), max(self.task_outputs.values())).to(features.device)
            for task_name, group in task_groups.items():
                task_features = torch.cat(group['features'], dim=0)
                task_outputs = self.heads[task_name](task_features)
                for idx, output in zip(group['indices'], task_outputs):
                    outputs[idx, :self.task_outputs[task_name]] = output

            return outputs
        else:
            return {task: head(features) for task, head in self.heads.items()}


def train_step(model, batch, optimizer, criterion):
    images, labels, task_ids = batch

    
    outputs = model(images, task_ids)

    
    loss = criterion(outputs, labels)

    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()


def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    task_metrics = {task: [] for task in DATASETS}

    with torch.no_grad():
        for images, labels, task_ids in val_loader:
            outputs = model(images, task_ids)

           
            for task_id, label, output in zip(task_ids, labels, outputs):
                task_name = DATASETS[task_id]
                pred = output.argmax(dim=1)
                task_metrics[task_name].append(
                    (pred == label).float().mean().item()
                )

    
    metrics = {
        task: np.mean(scores)
        for task, scores in task_metrics.items()
    }

    return metrics


class MedMNISTMultiDataset(Dataset):
    def __init__(self, datasets, split='train', transform=None):
        self.datasets = datasets
        self.split = split
        self.transform = transform
        
       
        self.dataset_indices = []
        for dataset_idx, (name, dataset_dict) in enumerate(datasets.items()):
            dataset = dataset_dict[split]
            n_samples = len(dataset)
            self.dataset_indices.extend([(dataset_idx, i) for i in range(n_samples)])
    
    def __len__(self):
        return len(self.dataset_indices)
    
    def __getitem__(self, idx):
        dataset_idx, sample_idx = self.dataset_indices[idx]
        dataset_name = DATASETS[dataset_idx]
        dataset = self.datasets[dataset_name][self.split]
        
       
        data = dataset[sample_idx]
        if isinstance(data, tuple):
            image, label = data
        else:
            image = data
            label = torch.tensor(-1)  
        
        
        if isinstance(image, np.ndarray):
            image = torch.from_numpy(image)
        image = image.float()
        
       
        if image.ndim == 2:
            image = image.unsqueeze(0)
        elif image.ndim == 3 and image.shape[-1] in [1, 3]:
            image = image.permute(2, 0, 1)
        
        
        if image.max() > 1.0:
            image = image / 255.0
        
        
        if image.size(0) == 1:
            image = image.repeat(3, 1, 1)
        
        
        if self.transform:
            image = self.transform(image)
        
        return image, label, torch.tensor(dataset_idx, dtype=torch.long)

In [4]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import f1_score
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import GradScaler, autocast

import torch
import torch.nn.functional as F
from collections import defaultdict

class Trainer:
    def __init__(
        self,
        model,
        train_dataset,
        val_dataset,
        batch_size=32,
        num_epochs=10,
        lr=1e-4,
        weight_decay=0.01,
        device='cuda',
        wandb_logging=False
    ):
        self.model = model.to(device)
        self.device = device
        self.num_epochs = num_epochs
        self.wandb_logging = wandb_logging
        self.scaler = GradScaler()

        
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        
        self.class_weights = calculate_class_weights(datasets)

        
        self.optimizer = AdamW(
            model.parameters(),
            lr=lr,
            weight_decay=weight_decay
        )

        self.scheduler = OneCycleLR(
            self.optimizer,
            max_lr=lr,
            epochs=num_epochs,
            steps_per_epoch=len(self.train_loader)
        )
        self.criterion = torch.nn.CrossEntropyLoss()


    def train_epoch(self):
        self.model.train()
        total_loss = 0
        task_predictions = {task: {'preds': [], 'targets': []} for task in DATASETS}

        pbar = tqdm(self.train_loader, desc='Training')
        for batch in pbar:
            images, labels, task_ids = [x.to(self.device) for x in batch]

            with autocast(enabled=True):
                outputs = self.model(images, task_ids)
                batch_losses = []

                for i, (output, label, task_id) in enumerate(zip(outputs, labels, task_ids)):
                    task_name = DATASETS[task_id]
                    num_classes = self.model.task_outputs[task_name]
                    task_output = output[:num_classes].unsqueeze(0)
                    task_label = label.view(-1)

                    
                    class_weight = self.class_weights[task_name].to(self.device)

                    
                    loss_fn = torch.nn.CrossEntropyLoss(weight=class_weight)
                    batch_losses.append(loss_fn(task_output, task_label))

                loss = torch.stack(batch_losses).mean()

            
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.scheduler.step()

           
            total_loss += loss.item()
            pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})

        
        task_f1_scores = {}
        for task in DATASETS:
            preds = task_predictions[task]['preds']
            targets = task_predictions[task]['targets']
            if len(preds) > 0:  
                task_f1_scores[task] = f1_score(
                    targets,
                    preds,
                    average='macro'
                )
            else:
                task_f1_scores[task] = 0.0

        return total_loss / len(self.train_loader), task_f1_scores


    def forward(self, x, task_id=None):
        
        features = self.backbone(x)

        if task_id is not None:
            
            outputs = []
            for i, tid in enumerate(task_id):
                task_name = DATASETS[tid]
                outputs.append(self.heads[task_name](features[i:i+1]).squeeze(0))
            return outputs
        else:
            
            return {task: head(features) for task, head in self.heads.items()}


    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0
        task_predictions = {task: {'preds': [], 'targets': []} for task in DATASETS}

        for batch in tqdm(self.val_loader, desc='Validating'):
            images, labels, task_ids = [x.to(self.device) for x in batch]

           
            labels = labels.view(-1).long()

           
            outputs = self.model(images, task_ids)

            
            batch_losses = []
            for task_name in set(DATASETS[tid.item()] for tid in task_ids):
                
                task_mask = torch.tensor([DATASETS[tid.item()] == task_name for tid in task_ids], device=self.device)
                if not task_mask.any():
                    continue

               
                task_outputs = outputs[task_mask]
                task_labels = labels[task_mask]

                
                n_classes = self.model.task_outputs[task_name]
                task_loss = self.criterion(task_outputs[:, :n_classes], task_labels)
                batch_losses.append(task_loss)

            
            loss = torch.stack(batch_losses).mean()
            total_loss += loss.item()

            
            for i, (task_id, label) in enumerate(zip(task_ids, labels)):
                task_name = DATASETS[task_id.item()]
                n_classes = self.model.task_outputs[task_name]
                pred = outputs[i, :n_classes].argmax(dim=0).cpu()
                task_predictions[task_name]['preds'].append(pred.item())
                task_predictions[task_name]['targets'].append(label.cpu().item())

       
        task_f1_scores = {}
        for task in DATASETS:
            preds = task_predictions[task]['preds']
            targets = task_predictions[task]['targets']
            if len(preds) > 0:  # Skip empty tasks
                task_f1_scores[task] = f1_score(
                    targets,
                    preds,
                    average='macro'
                )

        return total_loss / len(self.val_loader), task_f1_scores

    def train(self):
        best_f1 = 0
        for epoch in range(self.num_epochs):
            print(f"\nEpoch {epoch+1}/{self.num_epochs}")

            
            train_loss, train_f1_scores = self.train_epoch()
            
            train_f1_values = list(train_f1_scores.values())
            train_f1_mean = len(train_f1_values) / sum(1/f1 if f1 > 0 else 1e+6 for f1 in train_f1_values)

            
            val_loss, val_f1_scores = self.validate()
           
            val_f1_values = list(val_f1_scores.values())
            val_f1_mean = len(val_f1_values) / sum(1/f1 if f1 > 0 else 0 for f1 in val_f1_values)

            
            print(f"Train Loss: {train_loss:.4f}, Train F1 (Harmonic): {train_f1_mean:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Val F1 (Harmonic): {val_f1_mean:.4f}")

            
            if val_f1_mean > best_f1:
                best_f1 = val_f1_mean
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_f1': best_f1,
                }, 'best_model.pth')


train_transforms = transforms.Compose([
    
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(size=28, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomHorizontalFlip(),

    
    transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.1),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),

    
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3)
    ], p=0.1),

    
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

val_transforms = transforms.Compose([
    
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


train_dataset = MedMNISTMultiDataset(
    datasets,
    split='train',
    transform=train_transforms
)

val_dataset = MedMNISTMultiDataset(
    datasets,
    split='val',
    transform=val_transforms
)

test_dataset = MedMNISTMultiDataset(
    datasets,
    split='test',
    transform=val_transforms
)


model = MedMNISTMultiTaskModel(backbone_name='convnext_tiny.in12k_ft_in1k', pretrained=True)
trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=512,
    num_epochs=33,
    lr=1e-4,
    device='cuda',
    weight_decay=0.05
)


trainer.train()

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

  self.scaler = GradScaler()



Epoch 1/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:40<00:00,  1.23it/s, loss=1.24]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 1.2389, Train F1 (Harmonic): 0.0000
Val Loss: 0.8936, Val F1 (Harmonic): 0.4548

Epoch 2/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:41<00:00,  1.22it/s, loss=0.912]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.82it/s]


Train Loss: 0.9124, Train F1 (Harmonic): 0.0000
Val Loss: 0.7479, Val F1 (Harmonic): 0.5250

Epoch 3/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:41<00:00,  1.22it/s, loss=0.76]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.82it/s]


Train Loss: 0.7605, Train F1 (Harmonic): 0.0000
Val Loss: 0.6256, Val F1 (Harmonic): 0.6082

Epoch 4/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:41<00:00,  1.22it/s, loss=0.664]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.80it/s]


Train Loss: 0.6643, Train F1 (Harmonic): 0.0000
Val Loss: 0.5682, Val F1 (Harmonic): 0.5948

Epoch 5/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:42<00:00,  1.22it/s, loss=0.606]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.6055, Train F1 (Harmonic): 0.0000
Val Loss: 0.5740, Val F1 (Harmonic): 0.5504

Epoch 6/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:41<00:00,  1.22it/s, loss=0.564]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.5636, Train F1 (Harmonic): 0.0000
Val Loss: 0.5227, Val F1 (Harmonic): 0.6484

Epoch 7/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:39<00:00,  1.23it/s, loss=0.531]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.5305, Train F1 (Harmonic): 0.0000
Val Loss: 0.5010, Val F1 (Harmonic): 0.6293

Epoch 8/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:41<00:00,  1.22it/s, loss=0.506]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.5064, Train F1 (Harmonic): 0.0000
Val Loss: 0.4829, Val F1 (Harmonic): 0.6362

Epoch 9/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:40<00:00,  1.23it/s, loss=0.484]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.4842, Train F1 (Harmonic): 0.0000
Val Loss: 0.4634, Val F1 (Harmonic): 0.6186

Epoch 10/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:42<00:00,  1.22it/s, loss=0.465]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.4647, Train F1 (Harmonic): 0.0000
Val Loss: 0.4562, Val F1 (Harmonic): 0.6287

Epoch 11/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:42<00:00,  1.22it/s, loss=0.442]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.4425, Train F1 (Harmonic): 0.0000
Val Loss: 0.4639, Val F1 (Harmonic): 0.6844

Epoch 12/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:43<00:00,  1.22it/s, loss=0.424]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.4243, Train F1 (Harmonic): 0.0000
Val Loss: 0.4372, Val F1 (Harmonic): 0.5391

Epoch 13/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:43<00:00,  1.22it/s, loss=0.407]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.4075, Train F1 (Harmonic): 0.0000
Val Loss: 0.4305, Val F1 (Harmonic): 0.7307

Epoch 14/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.388]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.3883, Train F1 (Harmonic): 0.0000
Val Loss: 0.4346, Val F1 (Harmonic): 0.7079

Epoch 15/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.372]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.3717, Train F1 (Harmonic): 0.0000
Val Loss: 0.4178, Val F1 (Harmonic): 0.7390

Epoch 16/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.352]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.3523, Train F1 (Harmonic): 0.0000
Val Loss: 0.4243, Val F1 (Harmonic): 0.7371

Epoch 17/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.331]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.3315, Train F1 (Harmonic): 0.0000
Val Loss: 0.4224, Val F1 (Harmonic): 0.7202

Epoch 18/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.31]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.3099, Train F1 (Harmonic): 0.0000
Val Loss: 0.4508, Val F1 (Harmonic): 0.7273

Epoch 19/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.288]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.2881, Train F1 (Harmonic): 0.0000
Val Loss: 0.4384, Val F1 (Harmonic): 0.7586

Epoch 20/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.262]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.81it/s]


Train Loss: 0.2621, Train F1 (Harmonic): 0.0000
Val Loss: 0.4519, Val F1 (Harmonic): 0.7497

Epoch 21/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:43<00:00,  1.22it/s, loss=0.24]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.80it/s]


Train Loss: 0.2402, Train F1 (Harmonic): 0.0000
Val Loss: 0.4581, Val F1 (Harmonic): 0.7385

Epoch 22/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.215]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.2154, Train F1 (Harmonic): 0.0000
Val Loss: 0.4746, Val F1 (Harmonic): 0.7482

Epoch 23/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.194]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1935, Train F1 (Harmonic): 0.0000
Val Loss: 0.5274, Val F1 (Harmonic): 0.7506

Epoch 24/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.172]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1717, Train F1 (Harmonic): 0.0000
Val Loss: 0.5379, Val F1 (Harmonic): 0.7561

Epoch 25/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:46<00:00,  1.22it/s, loss=0.153]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1531, Train F1 (Harmonic): 0.0000
Val Loss: 0.5678, Val F1 (Harmonic): 0.7530

Epoch 26/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.136]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1362, Train F1 (Harmonic): 0.0000
Val Loss: 0.5793, Val F1 (Harmonic): 0.7727

Epoch 27/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:47<00:00,  1.21it/s, loss=0.122]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1222, Train F1 (Harmonic): 0.0000
Val Loss: 0.5981, Val F1 (Harmonic): 0.7692

Epoch 28/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:44<00:00,  1.22it/s, loss=0.111]
Validating: 100%|██████████| 116/116 [01:03<00:00,  1.82it/s]


Train Loss: 0.1107, Train F1 (Harmonic): 0.0000
Val Loss: 0.6175, Val F1 (Harmonic): 0.7698

Epoch 29/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:46<00:00,  1.22it/s, loss=0.102]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.1015, Train F1 (Harmonic): 0.0000
Val Loss: 0.6288, Val F1 (Harmonic): 0.7672

Epoch 30/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:46<00:00,  1.22it/s, loss=0.096]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.0960, Train F1 (Harmonic): 0.0000
Val Loss: 0.6476, Val F1 (Harmonic): 0.7565

Epoch 31/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:46<00:00,  1.22it/s, loss=0.0905]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]


Train Loss: 0.0905, Train F1 (Harmonic): 0.0000
Val Loss: 0.6479, Val F1 (Harmonic): 0.7594

Epoch 32/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:49<00:00,  1.21it/s, loss=0.0871]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.80it/s]


Train Loss: 0.0871, Train F1 (Harmonic): 0.0000
Val Loss: 0.6565, Val F1 (Harmonic): 0.7547

Epoch 33/33


  with autocast(enabled=True):
Training: 100%|██████████| 859/859 [11:45<00:00,  1.22it/s, loss=0.0867]
Validating: 100%|██████████| 116/116 [01:04<00:00,  1.81it/s]

Train Loss: 0.0867, Train F1 (Harmonic): 0.0000
Val Loss: 0.6550, Val F1 (Harmonic): 0.7536





In [5]:
from PIL import Image

def load_best_model(checkpoint_path, model):
    """
    Load the best model from a checkpoint file.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file.
        model (nn.Module): Model to load the weights into.
        
    Returns:
        model (nn.Module): Model with loaded weights.
        best_f1 (float): Best F1 score achieved by this model.
    """
    try:
        
        checkpoint = torch.load(checkpoint_path)
        
        
        model.load_state_dict(checkpoint['model_state_dict'])
        best_f1 = checkpoint['best_f1']
        
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with F1: {best_f1:.4f}")
        
        return model, best_f1
        
    except FileNotFoundError:
        print(f"No checkpoint found at {checkpoint_path}")
        return model, 0.0
    except Exception as e:
        print(f"Error loading checkpoint: {str(e)}")
        return model, 0.0


In [6]:
def create_submission(model, test_dataset, device='cuda', batch_size=512):
    """Create submission file for the MedMNIST competition."""
    model.eval()
    all_predictions = []
    
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
        persistent_workers=True
    )
    
    
    task_counters = {task: 0 for task in DATASETS}
    global_id = 0
    
    with torch.no_grad(), torch.cuda.amp.autocast(): 
        for batch in tqdm(test_loader, desc='Generating predictions'):
            images, _, task_ids = batch
            images = images.to(device, non_blocking=True)
            task_ids = task_ids.cpu().numpy()
            
           
            unique_tasks = np.unique(task_ids)
            
            
            for task_idx in unique_tasks:
                task_name = DATASETS[task_idx]
                mask = task_ids == task_idx
                
                if mask.any():
                    
                    task_images = images[mask]
                    
                    
                    features = model.backbone(task_images)
                    outputs = model.heads[task_name](features)
                    preds = outputs.argmax(dim=1).cpu().numpy()
                    
                    
                    n_preds = len(preds)
                    task_start_idx = task_counters[task_name]
                    
                    
                    batch_predictions = [{
                        'id': global_id + i,
                        'label': int(pred),
                        'task_name': task_name,
                        'id_image_in_task': task_start_idx + i
                    } for i, pred in enumerate(preds)]
                    
                    all_predictions.extend(batch_predictions)
                    
                   
                    task_counters[task_name] += n_preds
                    global_id += n_preds
    
    
    df = pd.DataFrame(all_predictions)
    df = df[['id', 'label', 'task_name', 'id_image_in_task']]
    df.to_csv('submission.csv', index=False)
    
    print(f"\nSubmission saved with {len(df)} total predictions")
    return df


submission_df = create_submission(
    model=model,
    test_dataset=test_dataset,
    device='cuda',
    batch_size=256
)
submission_df.head()

  with torch.no_grad(), torch.cuda.amp.autocast():
Generating predictions: 100%|██████████| 379/379 [00:38<00:00,  9.95it/s]



Submission saved with 96941 total predictions


Unnamed: 0,id,label,task_name,id_image_in_task
0,0,8,pathmnist,0
1,1,4,pathmnist,1
2,2,4,pathmnist,2
3,3,3,pathmnist,3
4,4,4,pathmnist,4
