**<h1>Enhancing Brain Disease Diagnosis<h1>**

**<h2>Federated Learning for Multi-Center Medical Imaging<h2>**

**<h3>AI for Trustworthy Decision Making<h3>**

- Poață Andrei-Cătălin

- Vulpe Ștefan

- Vișan Ionuț

*github: https://github.com/stefanvulpe-dev/brain-disease-diagnosis*

# **Setup**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

!cp /content/drive/MyDrive/AITDM_Stuff/Preprocessed-Data.zip /content/Preprocessed-Data.zip
!cp /content/drive/MyDrive/AITDM_Stuff/client.zip /content/client.zip
!cp /content/drive/MyDrive/AITDM_Stuff/cleaned_df.pkl /content/cleaned_df.pkl
!cp /content/drive/MyDrive/AITDM_Stuff/labels_list.pkl /content/labels_list.pkl

Mounted at /content/drive


In [None]:
!pip install segmentation_models_pytorch
!pip install torchio

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: segmentation_models_pytorch
Successfully installed segmentation_models_pytorch-0.5.0
Collecting torchio
  Downloading torchio-0.21.1-py3-none-any.whl.metadata (52 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.0/53.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting deprecated>=1.2 (from torchio)
  Downloading deprecated-1.3.1-py2.py3-none-any.whl.metadata (5.9 kB)
Collecting simpleitk!=2.0.*,!=2.1.1.1,>=1.3 (from torchio)
  Downloading simpleitk-2.5.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.4 kB)
Downloading torchio-0.21.1-py3-none-any.whl (194 kB)
[2K   [90m━━━━━━━━━━━━━━━

In [None]:
!unzip -qo /content/client.zip -d /content/client

In [None]:
!unzip -qo /content/Preprocessed-Data.zip -d /content/Preprocessed-Data

In [None]:
!pip install -q "flwr[simulation]"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.4/71.4 MB[0m [31m35.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.2/98.2 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m132.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m323.3/323.3 kB[0m [31m33.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m105.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.7/251.7 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.4/47.4 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m727.1/727.1 kB[0m [31m59.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver d

# **Main Code**

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import nibabel as nib
import numpy as np
import os
import pickle
class GliomaDataset(Dataset):
    def __init__(self, metadata_df_path, labels_path, transform=None):
        with open(metadata_df_path, 'rb') as f:
            self.metadata_df = pickle.load(f)

        self.metadata_df = self.metadata_df[self.metadata_df['Patient_ID'] != 'PatientID_0191']
        with open(labels_path, 'rb') as f:
            self.labels = pickle.load(f)

        self.data_root = "/content/drive/MyDrive/PKG - MU-Glioma-Post/Preprocessed-Data"
        self.transform = transform


        self.useful_columns = [
            'Patient_ID', 'Sex at Birth', 'Race', 'Age at diagnosis',
            'Primary Diagnosis', 'H3-3A mutation', 'PTEN mutation',
            'CDKN2A/B deletion', 'TP53 alteration', 'Other mutations/alterations',
            'Previous Brain Tumor', 'Type of previous brain tumor',
            'Report', 'Age Range', 'Top 5 Regions'
        ]

        self.categorical_cols = [
            'Sex at Birth', 'Race', 'Primary Diagnosis',
            'Previous Brain Tumor', 'Type of previous brain tumor', 'Age Range'
        ]

        self.code_maps = {}

        for col in self.categorical_cols:
            cat = pd.Categorical(self.metadata_df[col])
            self.metadata_df[col + '_code'] = cat.codes
            self.code_maps[col] = dict(enumerate(cat.categories))

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        row = self.metadata_df.iloc[idx]

        patient_id = row['Patient_ID']

        sex = row['Sex at Birth_code']
        race = row['Race_code']
        primary_diagnosis = row['Primary Diagnosis_code']
        previous_brain_tumor = row['Previous Brain Tumor_code']
        type_of_previous_brain_tumor = row['Type of previous brain tumor_code']
        age_range = row['Age Range_code']

        age = row['Age at diagnosis']
        h3_3a_mutation = row['H3-3A mutation']
        pten_mutation = row['PTEN mutation']
        CDKN2A_B_deletion = row['CDKN2A/B deletion']
        TP53_alteration = row['TP53 alteration']
        other_mutations_alterations = row['Other mutations/alterations']
        report = row['Report']
        top_5 = [self.labels.index(elem) for elem in row['Top 5 Regions']]

        target_regions = np.zeros(len(self.labels))
        for i in top_5:
          target_regions[i] = 1

        mri = np.load(self.data_root + f'/{patient_id}/{patient_id}_mri.npy')
        regions = np.load(self.data_root + f'/{patient_id}/{patient_id}_regions.npy')
        tumor = np.load(self.data_root + f'/{patient_id}/{patient_id}_tumor.npy')

        if self.transform:
            mri = self.transform(mri)
            regions = self.transform(regions)
            tumor = self.transform(tumor)

        dict_output = {
            'mri': mri,
            'regions': regions,
            'tumor': tumor,
            'sex': sex,
            'race': race,
            'age': age,
            'primary_diagnosis': primary_diagnosis,
            'h3_3a_mutation': h3_3a_mutation,
            'pten_mutation': pten_mutation,
            'CDKN2A_B_deletion': CDKN2A_B_deletion,
            'TP53_alteration': TP53_alteration,
            'other_mutations_alterations': other_mutations_alterations,
            'previous_brain_tumor': previous_brain_tumor,
            'type_of_previous_brain_tumor': type_of_previous_brain_tumor,
            'report': report,
            'age_range': age_range,
            'target_regions': target_regions,
        }

        return dict_output




def glioma_collate_fn(batch):
    mri_batch = torch.stack([torch.tensor(item['mri']) for item in batch])
    regions_batch = torch.stack([torch.tensor(item['regions']) for item in batch])
    tumor_batch = torch.stack([torch.tensor(item['tumor']) for item in batch])

    sex_batch = torch.tensor([item['sex'] for item in batch], dtype=torch.long)
    race_batch = torch.tensor([item['race'] for item in batch], dtype=torch.long)
    age_batch = torch.tensor([item['age'] for item in batch], dtype=torch.float)
    primary_diagnosis_batch = torch.tensor([item['primary_diagnosis'] for item in batch], dtype=torch.long)
    h3_3a_mutation_batch = torch.tensor([item['h3_3a_mutation'] for item in batch], dtype=torch.float)
    pten_mutation_batch = torch.tensor([item['pten_mutation'] for item in batch], dtype=torch.float)
    CDKN2A_B_deletion_batch = torch.tensor([item['CDKN2A_B_deletion'] for item in batch], dtype=torch.float)
    TP53_alteration_batch = torch.tensor([item['TP53_alteration'] for item in batch], dtype=torch.float)
    previous_brain_tumor_batch = torch.tensor([item['previous_brain_tumor'] for item in batch], dtype=torch.long)
    type_of_previous_brain_tumor_batch = torch.tensor([item['type_of_previous_brain_tumor'] for item in batch], dtype=torch.long)
    age_range_batch = torch.tensor([item['age_range'] for item in batch], dtype=torch.long)
    target_regions_batch = torch.from_numpy(np.array([item['target_regions'] for item in batch])).float()

    other_mutations_batch = [item['other_mutations_alterations'] for item in batch]
    report_batch = [item['report'] for item in batch]

    return {
        'mri': mri_batch,
        'regions': regions_batch,
        'tumor': tumor_batch,
        'sex': sex_batch,
        'race': race_batch,
        'age': age_batch,
        'primary_diagnosis': primary_diagnosis_batch,
        'h3_3a_mutation': h3_3a_mutation_batch,
        'pten_mutation': pten_mutation_batch,
        'CDKN2A_B_deletion': CDKN2A_B_deletion_batch,
        'TP53_alteration': TP53_alteration_batch,
        'other_mutations_alterations': other_mutations_batch,
        'previous_brain_tumor': previous_brain_tumor_batch,
        'type_of_previous_brain_tumor': type_of_previous_brain_tumor_batch,
        'report': report_batch,
        'age_range': age_range_batch,
        'target_regions': target_regions_batch,
    }

KeyboardInterrupt: 

In [None]:
from torch.utils.data import random_split, DataLoader

dataset = GliomaDataset("/content/drive/MyDrive/PKG - MU-Glioma-Post/cleaned_df.pkl", "/content/drive/MyDrive/PKG - MU-Glioma-Post/labels_list.pkl")

test_size = int(0.2 * len(dataset))
train_size = len(dataset) - test_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=glioma_collate_fn)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=glioma_collate_fn)

print(f"Train loader length = {len(train_loader)}")
print(f"Test loader length = {len(test_loader)}")

for batch in train_loader:
    print("Train batch:", batch)
    break

for batch in test_loader:
    print("Test batch:", batch)
    break

In [None]:
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
from tqdm import tqdm

class BioBERTMultiLabelClassifier(nn.Module):
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        x = self.dropout(pooled_output)
        return self.classifier(x)

# Tokenizer (Bio_ClinicalBERT e foarte bun)
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Functie de procesare rapoarte
def tokenize_reports(reports, tokenizer, max_len=512):
    return tokenizer(
        reports,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )


import matplotlib.pyplot as plt
import os

def plot_and_save_history(history, save_dir="/content/drive/MyDrive/PKG - MU-Glioma-Post/Graphs"):
    os.makedirs(save_dir, exist_ok=True)
    epochs = len(history['train_loss'])
    epochs_range = range(1, epochs + 1)

    # Loss
    plt.figure()
    plt.plot(epochs_range, history['train_loss'], label='Train Loss')
    plt.plot(epochs_range, history['test_loss'], label='Test Loss')
    plt.title("Bioclinical BERT Loss Evolution")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "loss_evolution_biobert.png"))
    plt.show()
    plt.close()

    # F1 Score
    plt.figure()
    plt.plot(epochs_range, history['train_f1'], label='Train F1')
    plt.plot(epochs_range, history['test_f1'], label='Test F1')
    plt.title("Bioclinical BERT F1 Evolution")
    plt.xlabel("Epoch")
    plt.ylabel("F1 Score")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "f1_evolution_biobert.png"))
    plt.show()
    plt.close()

    # Accuracy (doar dacă există în history)
    if 'train_acc' in history and 'test_acc' in history:
        plt.figure()
        plt.plot(epochs_range, history['train_acc'], label='Train Accuracy')
        plt.plot(epochs_range, history['test_acc'], label='Test Accuracy')
        plt.title("Bioclinical BERT Accuracy Evolution")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.legend()
        plt.savefig(os.path.join(save_dir, "accuracy_evolution_biobert.png"))
        plt.show()
        plt.close()
    else:
        print("Accuracy metrics not found in history, skipping accuracy plot.")


In [None]:
from sklearn.metrics import f1_score, roc_auc_score, classification_report
import numpy as np
import os
import warnings
from transformers import get_linear_schedule_with_warmup
from sklearn.exceptions import UndefinedMetricWarning


warnings.simplefilter("ignore", category=UndefinedMetricWarning)

def train_eval_biobert(model, train_loader, test_loader, label_names=None, epochs=5, lr=2e-5, save_path="best_model.pt"):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    total_steps = len(train_loader) * epochs
    warmup_steps = int(0.1 * total_steps)  # 10% pași warmup
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

    model.to(device)
    best_f1 = 0
    history = {
        'train_loss': [], 'train_f1': [], 'train_acc': [],
        'test_loss': [], 'test_f1': [], 'test_acc': [],
        'test_f1_per_class': [], 'test_auc_per_class': []
    }

    for epoch in range(epochs):
        model.train()
        train_loss = 0
        y_true_train = []
        y_pred_train = []

        for batch in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
            reports = batch['report']
            labels = batch['target_regions'].to(device)

            encodings = tokenize_reports(reports, tokenizer)
            input_ids = encodings['input_ids'].to(device)
            attention_mask = encodings['attention_mask'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()

            preds = (torch.sigmoid(outputs) > 0.5).int().cpu()
            y_true_train.append(labels.cpu())
            y_pred_train.append(preds)

        y_true_train = torch.cat(y_true_train).numpy()
        y_pred_train = torch.cat(y_pred_train).numpy()

        f1_train = f1_score(y_true_train, y_pred_train, average='macro')
        acc_train = (y_true_train == y_pred_train).mean()  # calcul acuratețe pe toate clasele și exemplele

        history['train_loss'].append(train_loss / len(train_loader))
        history['train_f1'].append(f1_train)
        history['train_acc'].append(acc_train)

        model.eval()
        test_loss = 0
        y_true_test = []
        y_pred_test = []
        y_proba_test = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Test"):
                reports = batch['report']
                labels = batch['target_regions'].to(device)

                encodings = tokenize_reports(reports, tokenizer)
                input_ids = encodings['input_ids'].to(device)
                attention_mask = encodings['attention_mask'].to(device)

                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(outputs, labels)

                test_loss += loss.item()

                probs = torch.sigmoid(outputs).cpu()
                preds = (probs > 0.5).int()

                y_true_test.append(labels.cpu())
                y_pred_test.append(preds)
                y_proba_test.append(probs)

        y_true_test = torch.cat(y_true_test).numpy()
        y_pred_test = torch.cat(y_pred_test).numpy()
        y_proba_test = torch.cat(y_proba_test).numpy()

        f1_test = f1_score(y_true_test, y_pred_test, average='macro')
        acc_test = (y_true_test == y_pred_test).mean()

        f1_per_class = f1_score(y_true_test, y_pred_test, average=None).tolist()

        try:
            auc_per_class = roc_auc_score(y_true_test, y_proba_test, average=None).tolist()
        except ValueError:
            auc_per_class = [float('nan')] * y_true_test.shape[1]

        history['test_loss'].append(test_loss / len(test_loader))
        history['test_f1'].append(f1_test)
        history['test_acc'].append(acc_test)
        history['test_f1_per_class'].append(f1_per_class)
        history['test_auc_per_class'].append(auc_per_class)

        print(f"\nEpoch {epoch+1}:")
        print(f"Train Loss: {history['train_loss'][-1]:.4f}")
        print(f"Test Loss:  {history['test_loss'][-1]:.4f}")
        print(f"Train F1 (macro): {f1_train:.4f}")
        print(f"Train Accuracy:   {acc_train:.4f}")
        print(f"Test F1 (macro):  {f1_test:.4f}")
        print(f"Test Accuracy:    {acc_test:.4f}")
        print("F1 per class:")
        if label_names:
            for i, name in enumerate(label_names):
                print(f"  {name}: {f1_per_class[i]:.4f} | AUC: {auc_per_class[i]:.4f}")
        else:
            for i, f1c in enumerate(f1_per_class):
                print(f"  Class {i}: F1={f1c:.4f}, AUC={auc_per_class[i]:.4f}")

        if f1_test > best_f1:
            best_f1 = f1_test
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved new best model with F1={f1_test:.4f} to {save_path}")

    return history





model = BioBERTMultiLabelClassifier("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(dataset.labels))

label_names = dataset.labels  # de ex: ['frontal', 'temporal', 'parietal', ...]
history = train_eval_biobert(
    model,
    train_loader,
    test_loader,
    label_names=label_names,
    epochs=20,
    save_path="/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/BioClinicalBert/best.pt"
)

plot_and_save_history(history)

In [None]:
import segmentation_models_pytorch as smp
import torch.nn as nn

class UNetTumorSegmentation(nn.Module):
    def __init__(self, encoder_name='resnet34', pretrained=True):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights='imagenet' if pretrained else None,
            in_channels=2,  # mri + regions
            classes=1,
            activation=None  # BCEWithLogitsLoss expects raw logits
        )

    def forward(self, mri, regions):
        x = torch.cat([mri, regions], dim=1)  # concatenate along channel dim
        return self.model(x)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model U-Net cu 2 input channels (MRI + Regions), 1 output channel
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=2,
    classes=1,
).to(device)

# Loss + optimizer + scheduler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=50)

def calc_metrics(y_true, y_pred):
    """ Calculează Dice, IoU, Accuracy pe tensori numpy binarizați """
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    dice = 2 * (y_true * y_pred).sum() / (y_true.sum() + y_pred.sum() + 1e-8)
    iou = jaccard_score(y_true, y_pred, average='binary')
    acc = (y_true == y_pred).mean()
    return dice, iou, acc

def train_one_epoch(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_dice, total_iou, total_acc = 0, 0, 0, 0
    for batch in dataloader:
        optimizer.zero_grad()

        # Input: concatenăm MRI și Regions pe canal
        x1 = batch['mri'].float().unsqueeze(1).to(device)  # [B,1,H,W]
        x2 = batch['regions'].float().unsqueeze(1).to(device)  # [B,1,H,W]
        x = torch.cat([x1, x2], dim=1)  # [B,2,H,W]

        y = batch['tumor'].float().unsqueeze(1).to(device)  # [B,1,H,W]

        preds = model(x)

        loss = criterion(preds, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # preds: logits -> sigmoid -> binarizare prag 0.5
        preds_prob = torch.sigmoid(preds).detach().cpu().numpy() > 0.5
        y_np = y.detach().cpu().numpy() > 0.5

        dice, iou, acc = calc_metrics(y_np, preds_prob)

        total_dice += dice
        total_iou += iou
        total_acc += acc

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

def eval_one_epoch(model, dataloader, criterion):
    model.eval()
    total_loss, total_dice, total_iou, total_acc = 0, 0, 0, 0
    last_batch_imgs = None  # pentru a salva imagini de afișat
    with torch.no_grad():
        for batch in dataloader:
            x1 = batch['mri'].float().unsqueeze(1).to(device)
            x2 = batch['regions'].float().unsqueeze(1).to(device)
            x = torch.cat([x1, x2], dim=1)

            y = batch['tumor'].float().unsqueeze(1).to(device)

            preds = model(x)
            loss = criterion(preds, y)
            total_loss += loss.item()

            preds_prob = torch.sigmoid(preds).cpu().numpy() > 0.5
            y_np = y.cpu().numpy() > 0.5

            dice, iou, acc = calc_metrics(y_np, preds_prob)
            total_dice += dice
            total_iou += iou
            total_acc += acc

            # Salvăm batch-ul curent pentru afișare (ultimul batch)
            last_batch_imgs = {
                'mri': x1.cpu().numpy(),
                'regions': x2.cpu().numpy(),
                'y_true': y.cpu().numpy(),
                'y_pred': preds_prob
            }

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def plot_prediction(imgs):
    # imgs conține: mri, regions, y_true, y_pred; toate în format numpy
    mri = imgs['mri'][0,0]        # [B,1,H,W], selectăm prima imagine și canalul 0
    regions = imgs['regions'][0,0]
    y_true = imgs['y_true'][0,0]
    y_pred = imgs['y_pred'][0,0]

    fig, axs = plt.subplots(1,4, figsize=(20,5))
    axs[0].imshow(mri, cmap='gray')
    axs[0].set_title("MRI")
    axs[1].imshow(regions, cmap='gray')
    axs[1].set_title("Regions")
    axs[2].imshow(y_true, cmap='gray')
    axs[2].set_title("Ground Truth Mask")
    axs[3].imshow(y_pred, cmap='gray')
    axs[3].set_title("Predicted Mask")
    for ax in axs:
        ax.axis('off')
    plt.savefig("/content/drive/MyDrive/PKG - MU-Glioma-Post/Graphs/sample_prediction_UNet.png")
    plt.show()
    plt.close()

def plot_metrics(train_metrics, val_metrics, save_path):
    epochs = range(1, len(train_metrics['loss']) + 1)
    plt.figure(figsize=(16, 12))

    plt.subplot(2,2,1)
    plt.plot(epochs, train_metrics['loss'], label='Train Loss')
    plt.plot(epochs, val_metrics['loss'], label='Val Loss')
    plt.legend()
    plt.title('Loss')

    plt.subplot(2,2,2)
    plt.plot(epochs, train_metrics['dice'], label='Train Dice')
    plt.plot(epochs, val_metrics['dice'], label='Val Dice')
    plt.legend()
    plt.title('Dice Score')

    plt.subplot(2,2,3)
    plt.plot(epochs, train_metrics['iou'], label='Train IoU')
    plt.plot(epochs, val_metrics['iou'], label='Val IoU')
    plt.legend()
    plt.title('IoU')

    plt.subplot(2,2,4)
    plt.plot(epochs, train_metrics['acc'], label='Train Accuracy')
    plt.plot(epochs, val_metrics['acc'], label='Val Accuracy')
    plt.legend()
    plt.title('Accuracy')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

num_epochs = 50
best_val_iou = 0
train_metrics = {'loss': [], 'dice': [], 'iou': [], 'acc': []}
val_metrics = {'loss': [], 'dice': [], 'iou': [], 'acc': []}
last_val_batch_imgs = None  # aici salvăm imaginile pentru ultima evaluare

for epoch in range(1, num_epochs + 1):
    train_loss, train_dice, train_iou, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_dice, val_iou, val_acc, val_imgs = eval_one_epoch(model, test_loader, criterion)
    scheduler.step()

    train_metrics['loss'].append(train_loss)
    train_metrics['dice'].append(train_dice)
    train_metrics['iou'].append(train_iou)
    train_metrics['acc'].append(train_acc)

    val_metrics['loss'].append(val_loss)
    val_metrics['dice'].append(val_dice)
    val_metrics['iou'].append(val_iou)
    val_metrics['acc'].append(val_acc)

    print(f"Epoch {epoch}/{num_epochs}")
    print(f" Train   - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}, IoU: {train_iou:.4f}, Acc: {train_acc:.4f}")
    print(f" Valid   - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}, IoU: {val_iou:.4f}, Acc: {val_acc:.4f}")

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), "/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/UNet/best_unet_model.pth")
        print("  Saved best model!")

    # Salvăm imaginile din ultimul batch de validare la ultima epocă
    if epoch >= num_epochs - 5:
        last_val_batch_imgs = val_imgs
        plot_prediction(last_val_batch_imgs)

# Afișăm predicția vs adevărul pentru ultimul batch
if last_val_batch_imgs is not None:
    plot_prediction(last_val_batch_imgs)

# Salvează graficele metricilor
plot_metrics(train_metrics, val_metrics, "/content/drive/MyDrive/PKG - MU-Glioma-Post/Graphs/training_metrics_UNet.png")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bert_model_name = "emilyalsentzer/Bio_ClinicalBERT"
bert_num_labels = len(dataset.labels)
bert_model = BioBERTMultiLabelClassifier(bert_model_name, bert_num_labels)
bert_state_dict = torch.load("/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/BioClinicalBert/best.pt")
bert_model.load_state_dict(bert_state_dict)
bert_tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

unet_model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights=None,
    in_channels=2,
    classes=1,
)
unet_state_dict = torch.load("/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/UNet/best_unet_model.pth")
unet_model.load_state_dict(unet_state_dict)

train_loader_v2 = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=glioma_collate_fn)

class ClipModel(nn.Module):
    def __init__(self, bert_model, bert_tokenizer, unet_model, embed_dim=512):
        super().__init__()
        self.bert_model = bert_model
        self.unet_model = unet_model
        self.bert_tokenizer = bert_tokenizer
        self.text_projection = nn.Linear(768, embed_dim)
        self.image_projection = nn.Linear(512, embed_dim)

        self.log_sigma_clip = nn.Parameter(torch.tensor(0.0))
        self.log_sigma_seg = nn.Parameter(torch.tensor(0.0))

    def forward(self, text, mri, regions):
        encodings = tokenize_reports(text, self.bert_tokenizer)
        input_ids = encodings['input_ids'].to(device)
        attention_mask = encodings['attention_mask'].to(device)

        # Get Text embeddings
        outputs = self.bert_model.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        text_embeddings = self.text_projection(pooled_output)
        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
        # print(f"Text embeddings size: {text_embeddings.shape}")

        # Get Image embeddings
        x1 = mri.float().unsqueeze(1).to(device)
        x2 = regions.float().unsqueeze(1).to(device)
        x = torch.cat([x1, x2], dim=1)
        image_embeddings = self.unet_model.encoder(x)[-1]
        image_embeddings = image_embeddings.mean(dim=[2, 3])
        image_embeddings = self.image_projection(image_embeddings)
        image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
        # print(f"Image embeddings size: {image_embeddings.shape}")

        segmentation_result = self.unet_model(x)
        # print(f"Segmentation result size: {segmentation_result.shape}")

        return text_embeddings, image_embeddings, segmentation_result

    def clip_contrastive_loss(self, text_embeddings, image_embeddings, temperature=0.07):
      logits = torch.matmul(text_embeddings, image_embeddings.T) * torch.exp(torch.tensor(temperature))

      labels = torch.arange(text_embeddings.shape[0]).to(device)
      loss_i = F.cross_entropy(logits, labels)
      loss_t = F.cross_entropy(logits.T, labels)


      loss = (loss_i + loss_t) / 2
      return loss

    def segmentation_loss(self, segmentation_result, target, criterion):
      target = target.float().unsqueeze(1).to(device)
      loss = criterion(segmentation_result, target)
      return loss

    def combined_loss(self, text_embeddings, image_embeddings, segmentation_result, target, criterion, temperature=0.07):
      clip_loss = self.clip_contrastive_loss(text_embeddings, image_embeddings, temperature)
      seg_loss = self.segmentation_loss(segmentation_result, target, criterion)
      loss = (
            (1.0 / (2.0 * torch.exp(self.log_sigma_clip) ** 2)) * clip_loss +
            (1.0 / (2.0 * torch.exp(self.log_sigma_seg) ** 2)) * seg_loss +
            self.log_sigma_clip + self.log_sigma_seg
        )
      return loss


clipModel = ClipModel(bert_model, bert_tokenizer, unet_model).to(device)



In [None]:
def calc_metrics(y_true, y_pred):
    """ Calculează Dice, IoU, Accuracy pe tensori numpy binarizați """
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    dice = 2 * (y_true * y_pred).sum() / (y_true.sum() + y_pred.sum() + 1e-8)
    iou = jaccard_score(y_true, y_pred, average='binary')
    acc = (y_true == y_pred).mean()
    return dice, iou, acc


def plot_prediction(imgs):
    # imgs conține: mri, regions, y_true, y_pred; toate în format numpy
    mri = imgs['mri'][0,0]        # [B,1,H,W], selectăm prima imagine și canalul 0
    regions = imgs['regions'][0,0]
    y_true = imgs['y_true'][0,0]
    y_pred = imgs['y_pred'][0,0]

    fig, axs = plt.subplots(1,4, figsize=(20,5))
    axs[0].imshow(mri, cmap='gray')
    axs[0].set_title("MRI")
    axs[1].imshow(regions, cmap='gray')
    axs[1].set_title("Regions")
    axs[2].imshow(y_true, cmap='gray')
    axs[2].set_title("Ground Truth Mask")
    axs[3].imshow(y_pred, cmap='gray')
    axs[3].set_title("Predicted Mask")
    for ax in axs:
        ax.axis('off')
    plt.savefig("/content/drive/MyDrive/PKG - MU-Glioma-Post/Graphs/sample_prediction_CLIP.png")
    plt.show()
    plt.close()

def plot_metrics(train_metrics, val_metrics, save_path):
    epochs = range(1, len(train_metrics['loss']) + 1)
    plt.figure(figsize=(16, 12))

    plt.subplot(2,2,1)
    plt.plot(epochs, train_metrics['loss'], label='Train Loss')
    plt.plot(epochs, val_metrics['loss'], label='Val Loss')
    plt.legend()
    plt.title('Loss')

    plt.subplot(2,2,2)
    plt.plot(epochs, train_metrics['dice'], label='Train Dice')
    plt.plot(epochs, val_metrics['dice'], label='Val Dice')
    plt.legend()
    plt.title('Dice Score')

    plt.subplot(2,2,3)
    plt.plot(epochs, train_metrics['iou'], label='Train IoU')
    plt.plot(epochs, val_metrics['iou'], label='Val IoU')
    plt.legend()
    plt.title('IoU')

    plt.subplot(2,2,4)
    plt.plot(epochs, train_metrics['acc'], label='Train Accuracy')
    plt.plot(epochs, val_metrics['acc'], label='Val Accuracy')
    plt.legend()
    plt.title('Accuracy')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

def train_one_epoch_clip(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, total_dice, total_iou, total_acc = 0, 0, 0, 0
    for batch in dataloader:
        optimizer.zero_grad()

        # Input: concatenăm MRI și Regions pe canal
        mri = batch['mri']
        regions = batch['regions']

        tumor = batch['tumor']
        text = batch['report']

        text_embeddings, image_embeddings, segmentation_result = model(text, mri, regions)

        loss = model.combined_loss(text_embeddings, image_embeddings, segmentation_result, tumor, criterion)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # preds: logits -> sigmoid -> binarizare prag 0.5
        preds_prob = torch.sigmoid(segmentation_result).detach().cpu().numpy() > 0.5
        y_np = tumor.float().unsqueeze(1).to(device).detach().cpu().numpy() > 0.5

        dice, iou, acc = calc_metrics(y_np, preds_prob)

        total_dice += dice
        total_iou += iou
        total_acc += acc

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n


def eval_one_epoch_clip(model, dataloader, criterion):
    model.eval()
    total_loss, total_dice, total_iou, total_acc = 0, 0, 0, 0
    last_batch_imgs = None  # pentru a salva imagini de afișat
    with torch.no_grad():
        for batch in dataloader:
            mri = batch['mri']
            regions = batch['regions']

            tumor = batch['tumor']
            text = batch['report']

            text_embeddings, image_embeddings, segmentation_result = model(text, mri, regions)

            loss = model.combined_loss(text_embeddings, image_embeddings, segmentation_result, tumor, criterion)
            total_loss += loss.item()

            preds_prob = torch.sigmoid(segmentation_result).cpu().numpy() > 0.5
            y_np = tumor.float().unsqueeze(1).to(device).detach().cpu().numpy() > 0.5

            dice, iou, acc = calc_metrics(y_np, preds_prob)
            total_dice += dice
            total_iou += iou
            total_acc += acc

            # Salvăm batch-ul curent pentru afișare (ultimul batch)
            last_batch_imgs = {
                'mri': mri.float().unsqueeze(1).to(device).cpu().numpy(),
                'regions': regions.float().unsqueeze(1).to(device).cpu().numpy(),
                'y_true': tumor.float().unsqueeze(1).to(device).cpu().numpy(),
                'y_pred': preds_prob
            }

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

model = clipModel

criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0, total_iters=50)

num_epochs = 50
best_val_iou = 0
train_metrics = {'loss': [], 'dice': [], 'iou': [], 'acc': []}
val_metrics = {'loss': [], 'dice': [], 'iou': [], 'acc': []}
last_val_batch_imgs = None  # aici salvăm imaginile pentru ultima evaluare

for epoch in range(1, num_epochs + 1):
    train_loss, train_dice, train_iou, train_acc = train_one_epoch_clip(model, train_loader, optimizer, criterion)
    val_loss, val_dice, val_iou, val_acc, val_imgs = eval_one_epoch_clip(model, test_loader, criterion)
    scheduler.step()

    train_metrics['loss'].append(train_loss)
    train_metrics['dice'].append(train_dice)
    train_metrics['iou'].append(train_iou)
    train_metrics['acc'].append(train_acc)

    val_metrics['loss'].append(val_loss)
    val_metrics['dice'].append(val_dice)
    val_metrics['iou'].append(val_iou)
    val_metrics['acc'].append(val_acc)

    print(f"Epoch {epoch}/{num_epochs}")
    print(f" Train   - Loss: {train_loss:.4f}, Dice: {train_dice:.4f}, IoU: {train_iou:.4f}, Acc: {train_acc:.4f}")
    print(f" Valid   - Loss: {val_loss:.4f}, Dice: {val_dice:.4f}, IoU: {val_iou:.4f}, Acc: {val_acc:.4f}")

    if val_iou > best_val_iou:
        best_val_iou = val_iou
        torch.save(model.state_dict(), "/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/CLIP_Based/best_clip_model.pth")
        print("  Saved best model!")

    # Salvăm imaginile din ultimul batch de validare la ultima epocă
    if epoch >= num_epochs - 5:
      last_val_batch_imgs = val_imgs
      plot_prediction(last_val_batch_imgs)

# Afișăm predicția vs adevărul pentru ultimul batch
if last_val_batch_imgs is not None:
    plot_prediction(last_val_batch_imgs)

# Salvează graficele metricilor
plot_metrics(train_metrics, val_metrics, "/content/drive/MyDrive/PKG - MU-Glioma-Post/Graphs/training_metrics_CLIP.png")

**<h2>Libraries<h2>**

In [None]:
!pip install segmentation_models_pytorch

# **M1 - Data Preprocessing**

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Load all files
mri = np.load("/content/PatientID_0036_mri.npy")
tumor = np.load("/content/PatientID_0036_tumor.npy")
atlas = np.load("/content/PatientID_0036_regions.npy")

# Print shapes
print("MRI shape:", mri.shape)
print("Tumor shape:", tumor.shape)
print("Atlas shape:", atlas.shape)

# Plot them together
plt.figure(figsize=(15, 5))

# MRI
plt.subplot(1, 3, 1)
plt.imshow(mri, cmap="gray")
plt.title("MRI")
plt.axis("off")

# Tumor mask
plt.subplot(1, 3, 2)
plt.imshow(tumor, cmap="hot")
plt.title("Tumor mask")
plt.axis("off")

# Atlas
plt.subplot(1, 3, 3)
plt.imshow(atlas, cmap="nipy_spectral")
plt.title("Atlas regions")
plt.axis("off")
plt.colorbar()

plt.show()

***MRI (left image)***

- This is the actual brain scan.

- It shows the anatomical structures and tissue appearance.

- The model uses this image as the input.

--------------------------------------------
***Tumor Mask (middle image)***

This is the ground-truth annotation.

- Yellow = whole tumor region

- Red = core or high-confidence tumor area

- Black = background

The model tries to predict this mask.

--------------------------------------------
***Atlas Regions (right image)***

- This is a brain anatomical atlas aligned to the MRI.

- Each color corresponds to a different anatomical region.

- It gives the model context about where in the brain each pixel is located.


In [None]:
!unzip -q /content/Preprocessed-Data.zip -d /content/Preprocessed-Data

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

DATA_ROOT = "/content/Preprocessed-Data"
USE_ATLAS = True

rows = []

# 1. Iterate over all patient folders in Preprocessed-Data
for pid in sorted(os.listdir(DATA_ROOT)):
    patient_dir = os.path.join(DATA_ROOT, pid)
    if not os.path.isdir(patient_dir):
        continue

    tumor_path = os.path.join(patient_dir, f"{pid}_tumor.npy")
    atlas_path = os.path.join(patient_dir, f"{pid}_regions.npy")

    # Skip if tumor mask is missing
    if not os.path.isfile(tumor_path):
        continue

    try:
        # Load tumor mask
        tumor = np.load(tumor_path).astype(np.float32)
        mask = tumor > 0.5

        # Tumor area and presence
        area = float(mask.sum())
        has_tumor = 1 if area >= 1 else 0

        # Dominant atlas region (inside tumor) if atlas exists
        dom_region = -1
        if USE_ATLAS and os.path.isfile(atlas_path) and mask.any():
            atlas = np.load(atlas_path).astype(np.int32)
            vals, counts = np.unique(atlas[mask], return_counts=True)
            if len(vals) > 0:
                dom_region = int(vals[np.argmax(counts)])

        rows.append({
            "pid": pid,
            "has_tumor": has_tumor,
            "area": area,
            "dom": dom_region,
        })

    except Exception as e:
        print(f"[WARN] Failed for {pid}: {e}")

# 2. Build DataFrame with metadata
meta = pd.DataFrame(rows)
print("Metadata head:")
print(meta.head())
print("\nColumns:", meta.columns.tolist())
print("\nNumber of patients:", len(meta))

# 3. Compute size_bin (tumor size bins) like in your partitioning
meta["size_bin"] = 0
mask_has_tumor = meta["has_tumor"] == 1

if mask_has_tumor.any():
    areas = meta.loc[mask_has_tumor, "area"].values
    qs = np.quantile(areas, np.linspace(0, 1, 4))
    qs = np.unique(qs)
    if len(qs) > 2:
        bins = np.digitize(areas, qs[1:-1], right=True)
    else:
        bins = np.zeros_like(areas, dtype=int)
    meta.loc[mask_has_tumor, "size_bin"] = bins.astype(int)

print("\nSize bin value counts:")
print(meta["size_bin"].value_counts().sort_index())

# 4. Plot the distributions used in data partitioning

fig, axs = plt.subplots(2, 2, figsize=(12, 8))

# Tumor size bins
meta["size_bin"].value_counts().sort_index().plot(kind="bar", ax=axs[0, 0])
axs[0, 0].set_title("Tumor size bins")
axs[0, 0].set_xlabel("size_bin")
axs[0, 0].set_ylabel("Number of patients")

# Has tumor (0/1)
meta["has_tumor"].value_counts().sort_index().plot(kind="bar", ax=axs[0, 1])
axs[0, 1].set_title("Has tumor")
axs[0, 1].set_xlabel("has_tumor (0 = no, 1 = yes)")
axs[0, 1].set_ylabel("Number of patients")

# Tumor area histogram
meta["area"].hist(bins=30, ax=axs[1, 0])
axs[1, 0].set_title("Tumor area (voxel count)")
axs[1, 0].set_xlabel("area")
axs[1, 0].set_ylabel("Number of patients")

# Dominant atlas regions (top 10)
meta["dom"].value_counts().head(10).plot(kind="bar", ax=axs[1, 1])
axs[1, 1].set_title("Top 10 dominant atlas regions")
axs[1, 1].set_xlabel("region label")
axs[1, 1].set_ylabel("Number of patients")

plt.tight_layout()
plt.show()

***Tumor size bins***

- These bins represent three tumor-size groups created using quantiles of the tumor voxel count.

- They ensure that small, medium, and large tumors are balanced across the dataset and across client splits.

------------------------------------
***Tumor area (voxel count)***

- This histogram shows the raw distribution of tumor sizes in the dataset.

- It highlights the variability in tumor burden, from small lesions to very large ones.

------------------------------------
***Top dominant atlas regions***

- Each tumor belongs mostly to one anatomical region in the brain atlas.

- This plot shows which regions occur most frequently and helps guide stratification, since some regions are common while others are rare.


In [None]:
import pickle
import pandas as pd
import matplotlib.pyplot as plt

# Load the cleaned dataframe from pickle
with open("cleaned_df.pkl", "rb") as f:
    df = pickle.load(f)

# Print all column names
print("Columns in cleaned_df:")
for c in df.columns:
    print(" -", c)

# Column names used for plotting
col_sex  = "Sex at Birth"
col_race = "Race"
col_age  = "Age at diagnosis"
col_diag = "Primary Diagnosis"

# Create a 2x2 grid of subplots
fig, axs = plt.subplots(2, 2, figsize=(14, 10))

# Pie chart: Sex at Birth
df[col_sex].value_counts(dropna=False).plot(
    kind="pie", autopct='%1.1f%%', ax=axs[0,0], ylabel=""
)
axs[0,0].set_title("Sex at Birth")

# Pie chart: Race
df[col_race].value_counts(dropna=False).plot(
    kind="pie", autopct='%1.1f%%', ax=axs[0,1], ylabel=""
)
axs[0,1].set_title("Race")

# Pie chart: Primary Diagnosis
df[col_diag].value_counts(dropna=False).plot(
    kind="pie", autopct='%1.1f%%', ax=axs[1,0], ylabel=""
)
axs[1,0].set_title("Primary Diagnosis")

# Create age bins and pie chart: Age at Diagnosis
age_bins = pd.cut(df[col_age], bins=[0, 20, 40, 60, 80, 120])
age_bins.value_counts().sort_index().plot(
    kind="pie", autopct='%1.1f%%', ax=axs[1,1], ylabel=""
)
axs[1,1].set_title("Age at Diagnosis (binned)")

plt.tight_layout()
plt.show()

**<h1>Data Partitioning<h1>**

<h2>One client, all data<h2>

In [None]:
%%writefile data_prep_split.py
import os, json, pickle, random
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple
from sklearn.model_selection import StratifiedShuffleSplit

import torch
from torch.utils.data import Dataset, DataLoader

# Paths and global configuration
DATA_ROOT = "Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "client")
os.makedirs(CLIENT_DIR, exist_ok=True)

USE_ATLAS = True
N_CLIENTS = 1
VAL_FRAC_PER_CLIENT = 0.2
BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

def set_seed(seed=42):
    """Set all relevant random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask and optional atlas for each patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]

        # Filter out excluded patient IDs
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Keep only patients for which all required .npy files exist
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")
            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Apply simple min-max normalization to a numpy array."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load one patient sample (MRI, tumor mask, and optional atlas)."""
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load and normalize atlas regions
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """
    Custom collate function to build batched tensors:
    x: [B, C, H, W], y: [B, 1, H, W], pid: list of patient IDs.
    """
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    pids = [it["patient_id"] for it in batch]
    return {"x": x, "y": y.float(), "pid": pids}

def patient_meta(pid: str) -> Tuple[int, float, int]:
    """
    Compute simple metadata for one patient:
    - has_tumor: 0/1 depending on tumor area
    - area: tumor voxel count
    - dom_region: dominant atlas region inside tumor mask
    """
    base = os.path.join(DATA_ROOT, pid)

    tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)
    mask = tumor > 0.5
    area = float(mask.sum())
    has_tumor = 1 if area >= 1 else 0

    dom_region = -1
    reg_path = os.path.join(base, f"{pid}_regions.npy")

    # If atlas is present and tumor exists, find region with most tumor voxels
    if USE_ATLAS and os.path.isfile(reg_path) and mask.any():
        regs = np.load(reg_path).astype(np.int32)
        vals, counts = np.unique(regs[mask], return_counts=True)
        if len(vals) > 0:
            dom_region = int(vals[np.argmax(counts)])

    return has_tumor, area, dom_region

def build_meta_for(dataset: ImageOnlyGliomaDataset) -> pd.DataFrame:
    """
    Build a DataFrame with metadata for each patient:
    - pid
    - has_tumor
    - area
    - dom (dominant region)
    - size_bin (tumor size bin)
    - strat_label (full stratification label)
    """
    rows = []
    for pid in dataset.patient_ids:
        try:
            ht, area, dom = patient_meta(pid)
            rows.append({"pid": pid, "has_tumor": ht, "area": area, "dom": dom})
        except Exception:
            # Skip patients that fail to load or compute
            pass

    meta = pd.DataFrame(rows)

    # Default size bin = 0 (no tumor or very small)
    meta["size_bin"] = 0
    m = meta["has_tumor"] == 1

    # For patients with tumor, compute quantile-based size bins
    if m.any():
        areas = meta.loc[m, "area"].values
        qs = np.quantile(areas, np.linspace(0, 1, 4))
        qs = np.unique(qs)
        if len(qs) > 2:
            bins = np.digitize(areas, qs[1:-1], right=True)
        else:
            bins = np.zeros_like(areas, dtype=int)
        meta.loc[m, "size_bin"] = bins.astype(int)

    # Full stratification label combining several attributes
    meta["strat_label"] = [
        f"{int(ht)}_{int(dom)}_{int(sb)}"
        for ht, dom, sb in zip(meta["has_tumor"], meta["dom"], meta["size_bin"])
    ]
    return meta

def build_label(meta_df: pd.DataFrame, level: str) -> np.ndarray:
    """
    Build a stratification label at different granularity levels:
    - "full": has_tumor + dom + size_bin
    - "ht_dom": has_tumor + dom
    - "ht_size": has_tumor + size_bin
    - "ht": has_tumor only
    """
    if level == "full":
        return np.array([
            f"{int(ht)}_{int(dom)}_{int(sb)}"
            for ht, dom, sb in zip(meta_df["has_tumor"], meta_df["dom"], meta_df["size_bin"])
        ])
    if level == "ht_dom":
        return np.array([
            f"{int(ht)}_{int(dom)}"
            for ht, dom in zip(meta_df["has_tumor"], meta_df["dom"])
        ])
    if level == "ht_size":
        return np.array([
            f"{int(ht)}_{int(sb)}"
            for ht, sb in zip(meta_df["has_tumor"], meta_df["size_bin"])
        ])
    if level == "ht":
        return meta_df["has_tumor"].astype(str).values
    raise ValueError(level)

def pick_strat_labels_for_client(meta_client_df: pd.DataFrame, min_per_class: int = 2):
    """
    Try different stratification levels and pick the most detailed one
    that has at least `min_per_class` samples for each class.
    """
    for level in ["full", "ht_dom", "ht_size", "ht"]:
        y = build_label(meta_client_df, level)
        counts = pd.Series(y).value_counts()
        if (counts >= min_per_class).all():
            print(f"[INFO] Using client train/val stratification level: {level}")
            return y, level
    print("[WARN] Falling back to NON-stratified train/val.")
    return None, "none"

# Load full dataset and build metadata
dataset = ImageOnlyGliomaDataset(METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"])
meta = build_meta_for(dataset)

clients: List[Dict[str, List[str]]] = []
client_metas = []

# In this script we only create a single "client_0" with a stratified train/val split
meta_c = meta.reset_index(drop=True)
Xc = meta_c["pid"].values

y_client, tv_level = pick_strat_labels_for_client(meta_c, min_per_class=2)

if tv_level != "none":
    # Stratified train/validation split
    sss = StratifiedShuffleSplit(n_splits=1, test_size=VAL_FRAC_PER_CLIENT, random_state=SEED)
    (tr_idx, va_idx), = sss.split(Xc, y_client)
    train_pids = Xc[tr_idx].tolist()
    val_pids   = Xc[va_idx].tolist()
else:
    # Fallback: random split without stratification
    rng = np.random.RandomState(SEED)
    perm = rng.permutation(len(Xc))
    split_at = int((1.0 - VAL_FRAC_PER_CLIENT) * len(Xc))
    train_pids = Xc[perm[:split_at]].tolist()
    val_pids   = Xc[perm[split_at:]].tolist()

# Save client_0 train/val patient ID lists
cdir = os.path.join(CLIENT_DIR, "client_0")
os.makedirs(cdir, exist_ok=True)
with open(os.path.join(cdir, "train_pids.json"), "w") as f:
    json.dump(train_pids, f, indent=2)
with open(os.path.join(cdir, "val_pids.json"), "w") as f:
    json.dump(val_pids, f, indent=2)

clients.append({"train": train_pids, "val": val_pids})
client_metas.append(meta_c)

# Save basic manifest of the split configuration
manifest = {
    "seed": SEED,
    "use_atlas": USE_ATLAS,
    "n_clients": 1,
    "val_frac_per_client": VAL_FRAC_PER_CLIENT,
    "batch_size": BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "kfold_level": "single_client"
}
with open(os.path.join(CLIENT_DIR, "manifest.json"), "w") as f:
    json.dump(manifest, f, indent=2)

class SubsetByPIDs(Dataset):
    """Simple subset wrapper that restricts the dataset to a given list of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        # Keep only IDs that exist in the base dataset
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

def make_loader(ds, shuffle):
    """Create a DataLoader with the custom collate function."""
    return DataLoader(
        ds,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=torch.Generator().manual_seed(SEED),
    )

# Build small preview loaders for sanity checks
preview_loaders = []
for cid, cl in enumerate(clients):
    ds_tr = SubsetByPIDs(dataset, cl["train"])
    ds_va = SubsetByPIDs(dataset, cl["val"])
    ld_tr = make_loader(ds_tr, shuffle=True)
    ld_va = make_loader(ds_va, shuffle=False)
    preview_loaders.append((ld_tr, ld_va))

def summarize(meta_df: pd.DataFrame, name: str) -> Dict:
    """
    Create a compact summary of a metadata subset:
    - number of patients
    - tumor presence counts
    - size bin counts
    - top 5 dominant regions
    """
    s = {"name": name, "n": int(len(meta_df))}
    s["has_tumor_counts"] = meta_df["has_tumor"].value_counts().to_dict()
    s["size_bin_counts"] = meta_df["size_bin"].value_counts().to_dict()
    s["dom_region_top5"] = meta_df["dom"].value_counts().head(5).to_dict()
    return s

# Global summary
summary_all = summarize(meta, "ALL")
print("\n=== Global summary ===")
print(summary_all)

# Per-client train/val summaries (here only one client)
per_client_summaries = []
for cid, cl in enumerate(clients):
    meta_c = client_metas[cid]
    tr = meta_c[meta_c["pid"].isin(cl["train"])]
    va = meta_c[meta_c["pid"].isin(cl["val"])]
    s_client = {
        "client": "client",
        "train": summarize(tr, "client_train"),
        "val": summarize(va, "client_val"),
    }
    per_client_summaries.append(s_client)
    print("\n=== Client ===")
    print(s_client)

# Save all summaries to disk
with open(os.path.join(CLIENT_DIR, "summary.json"), "w") as f:
    json.dump({"global": summary_all, "per_client": per_client_summaries}, f, indent=2)

print(f"\nSaved client split and summaries to: {CLIENT_DIR}")

**Data Distribution**

In [None]:
{
  "global": {
    "name": "ALL",
    "n": 202,
    "has_tumor_counts": {
      "1": 202
    },
    "size_bin_counts": {
      "0": 68,
      "1": 67,
      "2": 67
    },
    "dom_region_top5": {
      "61": 85,
      "50": 81,
      "0": 10,
      "10": 4,
      "22": 3
    }
  },
  "per_client": [
    {
      "client": "client",
      "train": {
        "name": "client_train",
        "n": 161,
        "has_tumor_counts": {
          "1": 161
        },
        "size_bin_counts": {
          "2": 54,
          "0": 54,
          "1": 53
        },
        "dom_region_top5": {
          "61": 68,
          "50": 64,
          "0": 9,
          "10": 3,
          "15": 2
        }
      },
      "val": {
        "name": "client_val",
        "n": 41,
        "has_tumor_counts": {
          "1": 41
        },
        "size_bin_counts": {
          "1": 14,
          "0": 14,
          "2": 13
        },
        "dom_region_top5": {
          "61": 17,
          "50": 17,
          "31": 1,
          "53": 1,
          "10": 1
        }
      }
    }
  ]
}

<h2>Same distribution between 3 clients<h2>

In [None]:
%%writefile data_prep_split.py
import os, json, pickle, random
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit

import torch
from torch.utils.data import Dataset, DataLoader

# Global paths and configuration
DATA_ROOT = "Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "client")
os.makedirs(CLIENT_DIR, exist_ok=True)

USE_ATLAS = True
N_CLIENTS = 3
VAL_FRAC_PER_CLIENT = 0.2
BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

# Set all random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# Dataset that loads MRI, tumor mask and optional atlas for each patient
class ImageOnlyGliomaDataset(Dataset):
    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required .npy files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")
            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    # Simple min–max normalization
    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    # Load a single patient sample
    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)
        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions
        return sample

# Custom collate function building batched tensors and patient IDs
def image_only_collate_fn(batch, use_atlas=True):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)
    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()
    pids = [it["patient_id"] for it in batch]
    return {"x": x, "y": y.float(), "pid": pids}

# Compute basic metadata for a patient (tumor presence, area, dominant atlas region)
def patient_meta(pid: str) -> Tuple[int, float, int]:
    base = os.path.join(DATA_ROOT, pid)
    tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)
    mask = tumor > 0.5
    area = float(mask.sum())
    has_tumor = 1 if area >= 1 else 0

    dom_region = -1
    reg_path = os.path.join(base, f"{pid}_regions.npy")
    if USE_ATLAS and os.path.isfile(reg_path) and mask.any():
        regs = np.load(reg_path).astype(np.int32)
        vals, counts = np.unique(regs[mask], return_counts=True)
        if len(vals) > 0:
            dom_region = int(vals[np.argmax(counts)])
    return has_tumor, area, dom_region

# Build a metadata DataFrame for all patients in the dataset
def build_meta_for(dataset: ImageOnlyGliomaDataset) -> pd.DataFrame:
    rows = []
    for pid in dataset.patient_ids:
        try:
            ht, area, dom = patient_meta(pid)
            rows.append({"pid": pid, "has_tumor": ht, "area": area, "dom": dom})
        except Exception:
            pass
    meta = pd.DataFrame(rows)

    # Compute tumor size bins for patients with tumor
    meta["size_bin"] = 0
    m = meta["has_tumor"] == 1
    if m.any():
        areas = meta.loc[m, "area"].values
        qs = np.quantile(areas, np.linspace(0, 1, 4))
        qs = np.unique(qs)
        bins = np.digitize(areas, qs[1:-1], right=True) if len(qs) > 2 else np.zeros_like(areas, dtype=int)
        meta.loc[m, "size_bin"] = bins.astype(int)

    # Full stratification label combining several attributes
    meta["strat_label"] = [
        f"{int(ht)}_{int(dom)}_{int(sb)}"
        for ht, dom, sb in zip(meta["has_tumor"], meta["dom"], meta["size_bin"])
    ]
    return meta

# Build a stratification label with different levels of granularity
def build_label(meta_df: pd.DataFrame, level: str) -> np.ndarray:
    if level == "full":
        return np.array([
            f"{int(ht)}_{int(dom)}_{int(sb)}"
            for ht, dom, sb in zip(meta_df["has_tumor"], meta_df["dom"], meta_df["size_bin"])
        ])
    if level == "ht_dom":
        return np.array([
            f"{int(ht)}_{int(dom)}"
            for ht, dom in zip(meta_df["has_tumor"], meta_df["dom"])
        ])
    if level == "ht_size":
        return np.array([
            f"{int(ht)}_{int(sb)}"
            for ht, sb in zip(meta_df["has_tumor"], meta_df["size_bin"])
        ])
    if level == "ht":
        return meta_df["has_tumor"].astype(str).values
    raise ValueError(level)

# Choose the stratification labels level for K-Fold client split
def pick_strat_labels_for_kfold(meta_df: pd.DataFrame, n_splits: int):
    for level in ["full", "ht_dom", "ht_size", "ht"]:
        y = build_label(meta_df, level)
        counts = pd.Series(y).value_counts()
        if (counts >= n_splits).all():
            print(f"[INFO] Using K-Fold stratification level: {level}")
            return y, level
        else:
            rare = counts[counts < n_splits]
            print(f"[WARN] Level '{level}' has rare classes (<{n_splits}): {rare.to_dict()} -> trying coarser...")
    print("[WARN] Falling back to NON-stratified K-Fold (insufficient counts at all levels).")
    return None, "none"

# Choose the stratification labels level for the train/val split inside each client
def pick_strat_labels_for_client(meta_client_df: pd.DataFrame, min_per_class: int = 2):
    for level in ["full", "ht_dom", "ht_size", "ht"]:
        y = build_label(meta_client_df, level)
        counts = pd.Series(y).value_counts()
        if (counts >= min_per_class).all():
            print(f"[INFO] Using client train/val stratification level: {level}")
            return y, level
        else:
            rare = counts[counts < min_per_class]
            print(f"[WARN] Client level '{level}' rare classes (<{min_per_class}): {rare.to_dict()} -> trying coarser...")
    print("[WARN] Falling back to NON-stratified train/val (insufficient counts at all levels).")
    return None, "none"

# Load dataset and metadata for all patients
dataset = ImageOnlyGliomaDataset(METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"])
meta = build_meta_for(dataset)

# Prepare global patient IDs and K-Fold labels
X_all = meta["pid"].values
y_all_full, kfold_level = pick_strat_labels_for_kfold(meta, n_splits=N_CLIENTS)

clients: List[Dict[str, List[str]]] = []
client_metas = []

# Create client splits and per-client train/val splits
if kfold_level != "none":
    skf = StratifiedKFold(n_splits=N_CLIENTS, shuffle=True, random_state=SEED)
    for split_idx, (_, idx) in enumerate(skf.split(X_all, y_all_full)):
        pids_client = X_all[idx]
        meta_c = meta[meta["pid"].isin(pids_client)].reset_index(drop=True)

        y_client, tv_level = pick_strat_labels_for_client(meta_c, min_per_class=2)
        Xc = meta_c["pid"].values

        if tv_level != "none":
            sss = StratifiedShuffleSplit(n_splits=1, test_size=VAL_FRAC_PER_CLIENT, random_state=SEED)
            (tr_idx, va_idx), = sss.split(Xc, y_client)
            train_pids = Xc[tr_idx].tolist()
            val_pids   = Xc[va_idx].tolist()
        else:
            rng = np.random.RandomState(SEED)
            perm = rng.permutation(len(Xc))
            split_at = int((1.0 - VAL_FRAC_PER_CLIENT) * len(Xc))
            train_pids = Xc[perm[:split_at]].tolist()
            val_pids   = Xc[perm[split_at:]].tolist()

        cdir = os.path.join(CLIENT_DIR, f"client_{split_idx}")
        os.makedirs(cdir, exist_ok=True)
        with open(os.path.join(cdir, "train_pids.json"), "w") as f:
            json.dump(train_pids, f, indent=2)
        with open(os.path.join(cdir, "val_pids.json"), "w") as f:
            json.dump(val_pids, f, indent=2)

        clients.append({"train": train_pids, "val": val_pids})
        client_metas.append(meta_c)
else:
    print("[INFO] Non-stratified 3-way split (deterministic) for clients.")
    rng = np.random.RandomState(SEED)
    perm = rng.permutation(len(X_all))
    sizes = [len(X_all) // N_CLIENTS] * N_CLIENTS
    sizes[-1] += len(X_all) - sum(sizes)
    start = 0
    for split_idx, sz in enumerate(sizes):
        pids_client = X_all[perm[start:start + sz]]
        start += sz

        meta_c = meta[meta["pid"].isin(pids_client)].reset_index(drop=True)
        Xc = meta_c["pid"].values

        perm_c = rng.permutation(len(Xc))
        split_at = int((1.0 - VAL_FRAC_PER_CLIENT) * len(Xc))
        train_pids = Xc[perm_c[:split_at]].tolist()
        val_pids   = Xc[perm_c[split_at:]].tolist()

        cdir = os.path.join(CLIENT_DIR, f"client_{split_idx}")
        os.makedirs(cdir, exist_ok=True)
        with open(os.path.join(cdir, "train_pids.json"), "w") as f:
            json.dump(train_pids, f, indent=2)
        with open(os.path.join(cdir, "val_pids.json"), "w") as f:
            json.dump(val_pids, f, indent=2)

        clients.append({"train": train_pids, "val": val_pids})
        client_metas.append(meta_c)

# Save a small manifest describing the split configuration
manifest = {
    "seed": SEED,
    "use_atlas": USE_ATLAS,
    "n_clients": N_CLIENTS,
    "val_frac_per_client": VAL_FRAC_PER_CLIENT,
    "batch_size": BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "kfold_level": kfold_level
}
with open(os.path.join(CLIENT_DIR, "manifest.json"), "w") as f:
    json.dump(manifest, f, indent=2)

# Dataset wrapper that restricts to a given list of patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Helper to create a DataLoader with the custom collate function
def make_loader(ds, shuffle):
    return DataLoader(
        ds,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=torch.Generator().manual_seed(SEED),
    )

# Optional preview loaders for quick sanity checks
preview_loaders = []
for cid, cl in enumerate(clients):
    ds_tr = SubsetByPIDs(dataset, cl["train"])
    ds_va = SubsetByPIDs(dataset, cl["val"])
    ld_tr = make_loader(ds_tr, shuffle=True)
    ld_va = make_loader(ds_va, shuffle=False)
    preview_loaders.append((ld_tr, ld_va))

# Summarize metadata distribution for a subset
def summarize(meta_df: pd.DataFrame, name: str) -> Dict:
    s = {"name": name, "n": int(len(meta_df))}
    s["has_tumor_counts"] = meta_df["has_tumor"].value_counts().to_dict()
    s["size_bin_counts"] = meta_df["size_bin"].value_counts().to_dict()
    s["dom_region_top5"] = meta_df["dom"].value_counts().head(5).to_dict()
    return s

# Global metadata summary
summary_all = summarize(meta, "ALL")
print("\n=== Global summary ===")
print(summary_all)

# Per-client train/val metadata summaries
per_client_summaries = []
for cid, cl in enumerate(clients):
    meta_c = client_metas[cid]
    tr = meta_c[meta_c["pid"].isin(cl["train"])]
    va = meta_c[meta_c["pid"].isin(cl["val"])]
    s_client = {
        "client": cid,
        "train": summarize(tr, f"client_{cid}_train"),
        "val": summarize(va, f"client_{cid}_val"),
    }
    per_client_summaries.append(s_client)
    print(f"\n=== Client {cid} ===")
    print(s_client)

# Save all summaries to disk
with open(os.path.join(CLIENT_DIR, "summary.json"), "w") as f:
    json.dump({"global": summary_all, "per_client": per_client_summaries}, f, indent=2)

print(f"\nSaved 3 client splits and summaries to: {CLIENT_DIR}")

**Data Distribution**

In [None]:
{
  "global": {
    "name": "ALL",
    "n": 202,
    "has_tumor_counts": {
      "1": 202
    },
    "size_bin_counts": {
      "0": 68,
      "1": 67,
      "2": 67
    },
    "dom_region_top5": {
      "61": 85,
      "50": 81,
      "0": 10,
      "10": 4,
      "22": 3
    }
  },
  "per_client": [
    {
      "client": 0,
      "train": {
        "name": "client_0_train",
        "n": 54,
        "has_tumor_counts": {
          "1": 54
        },
        "size_bin_counts": {
          "1": 18,
          "2": 18,
          "0": 18
        },
        "dom_region_top5": {
          "61": 29,
          "50": 15,
          "0": 4,
          "7": 1,
          "17": 1
        }
      },
      "val": {
        "name": "client_0_val",
        "n": 14,
        "has_tumor_counts": {
          "1": 14
        },
        "size_bin_counts": {
          "0": 5,
          "1": 5,
          "2": 4
        },
        "dom_region_top5": {
          "50": 9,
          "53": 1,
          "10": 1,
          "61": 1,
          "15": 1
        }
      }
    },
    {
      "client": 1,
      "train": {
        "name": "client_1_train",
        "n": 53,
        "has_tumor_counts": {
          "1": 53
        },
        "size_bin_counts": {
          "1": 18,
          "2": 18,
          "0": 17
        },
        "dom_region_top5": {
          "50": 22,
          "61": 19,
          "0": 3,
          "22": 2,
          "10": 1
        }
      },
      "val": {
        "name": "client_1_val",
        "n": 14,
        "has_tumor_counts": {
          "1": 14
        },
        "size_bin_counts": {
          "2": 5,
          "0": 5,
          "1": 4
        },
        "dom_region_top5": {
          "50": 6,
          "61": 6,
          "10": 1,
          "22": 1
        }
      }
    },
    {
      "client": 2,
      "train": {
        "name": "client_2_train",
        "n": 53,
        "has_tumor_counts": {
          "1": 53
        },
        "size_bin_counts": {
          "2": 18,
          "0": 18,
          "1": 17
        },
        "dom_region_top5": {
          "61": 25,
          "50": 22,
          "23": 1,
          "31": 1,
          "62": 1
        }
      },
      "val": {
        "name": "client_2_val",
        "n": 14,
        "has_tumor_counts": {
          "1": 14
        },
        "size_bin_counts": {
          "0": 5,
          "1": 5,
          "2": 4
        },
        "dom_region_top5": {
          "50": 7,
          "61": 5,
          "10": 1,
          "0": 1
        }
      }
    }
  ]
}

<h2>non-IID between 3 clients<h2>

In [None]:
%%writefile data_prep_split.py
import os, json, pickle, random
import numpy as np
import pandas as pd
from collections import defaultdict
from typing import List, Dict, Tuple

import torch
from torch.utils.data import Dataset, DataLoader

# Global configuration and paths
DATA_ROOT = "Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "client")
os.makedirs(CLIENT_DIR, exist_ok=True)

USE_ATLAS = True
N_CLIENTS = 3
VAL_FRAC_PER_CLIENT = 0.2
BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

# Dirichlet alpha controls degree of non-IID distribution
DIRICHLET_ALPHA = 0.3

# Ensure reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# Dataset class that loads MRI, tumor masks, and optional atlas data
class ImageOnlyGliomaDataset(Dataset):
    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]

        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect only patients with required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")
            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    # Min–max normalization
    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    # Load a single patient's MRI + masks
    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

# Collate function used in DataLoaders
def image_only_collate_fn(batch, use_atlas=True):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    pids = [it["patient_id"] for it in batch]
    return {"x": x, "y": y.float(), "pid": pids}

# Compute tumor presence, area and dominant atlas region
def patient_meta(pid: str) -> Tuple[int, float, int]:
    base = os.path.join(DATA_ROOT, pid)
    tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)
    mask = tumor > 0.5

    area = float(mask.sum())
    has_tumor = 1 if area >= 1 else 0

    dom_region = -1
    reg_path = os.path.join(base, f"{pid}_regions.npy")

    if USE_ATLAS and os.path.isfile(reg_path) and mask.any():
        regs = np.load(reg_path).astype(np.int32)
        vals, counts = np.unique(regs[mask], return_counts=True)
        if len(vals) > 0:
            dom_region = int(vals[np.argmax(counts)])

    return has_tumor, area, dom_region

# Build metadata for all patients
def build_meta_for(dataset: ImageOnlyGliomaDataset) -> pd.DataFrame:
    rows = []
    for pid in dataset.patient_ids:
        try:
            ht, area, dom = patient_meta(pid)
            rows.append({"pid": pid, "has_tumor": ht, "area": area, "dom": dom})
        except Exception:
            pass

    meta = pd.DataFrame(rows)

    # Create tumor size bins
    meta["size_bin"] = 0
    m = meta["has_tumor"] == 1
    if m.any():
        areas = meta.loc[m, "area"].values
        qs = np.quantile(areas, np.linspace(0, 1, 4))
        qs = np.unique(qs)
        bins = np.digitize(areas, qs[1:-1], right=True) if len(qs) > 2 else np.zeros_like(areas, int)
        meta.loc[m, "size_bin"] = bins.astype(int)

    return meta

# Load dataset + metadata
dataset = ImageOnlyGliomaDataset(METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"])
meta = build_meta_for(dataset)
rng = np.random.RandomState(SEED)

# Group patients by class (tumor size bin)
class_to_pids = defaultdict(list)
for _, row in meta.iterrows():
    label = row["size_bin"]
    class_to_pids[label].append(row["pid"])

# Allocate patients to clients using Dirichlet sampling
client_pid_sets = [[] for _ in range(N_CLIENTS)]

for label, pids in class_to_pids.items():
    pids = rng.permutation(pids)
    proportions = rng.dirichlet([DIRICHLET_ALPHA] * N_CLIENTS)
    counts = (proportions * len(pids)).astype(int)

    # Adjust counts to cover all samples
    while counts.sum() < len(pids):
        counts[rng.randint(0, N_CLIENTS)] += 1

    start = 0
    for cid, count in enumerate(counts):
        subset = pids[start:start + count]
        client_pid_sets[cid].extend(subset)
        start += count

# Create clients + train/val splits
clients = []
client_metas = []

for cid in range(N_CLIENTS):
    Xc = np.array(client_pid_sets[cid])

    rng_c = np.random.RandomState(SEED + cid)
    perm_c = rng_c.permutation(len(Xc))

    split_at = max(1, int((1.0 - VAL_FRAC_PER_CLIENT) * len(Xc)))
    split_at = min(split_at, len(Xc) - 1)

    train_pids = Xc[perm_c[:split_at]].tolist()
    val_pids   = Xc[perm_c[split_at:]].tolist()

    meta_c = meta[meta["pid"].isin(Xc)].reset_index(drop=True)

    cdir = os.path.join(CLIENT_DIR, f"client_{cid}")
    os.makedirs(cdir, exist_ok=True)

    with open(os.path.join(cdir, "train_pids.json"), "w") as f:
        json.dump(train_pids, f, indent=2)

    with open(os.path.join(cdir, "val_pids.json"), "w") as f:
        json.dump(val_pids, f, indent=2)

    clients.append({"train": train_pids, "val": val_pids})
    client_metas.append(meta_c)

print(f"[INFO] Created {N_CLIENTS} non-IID clients using Dirichlet split (alpha={DIRICHLET_ALPHA}).")

# Save manifest describing split configuration
manifest = {
    "seed": SEED,
    "use_atlas": USE_ATLAS,
    "n_clients": N_CLIENTS,
    "val_frac_per_client": VAL_FRAC_PER_CLIENT,
    "batch_size": BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "split_type": "non-IID_Dirichlet",
    "dirichlet_alpha": DIRICHLET_ALPHA
}

with open(os.path.join(CLIENT_DIR, "manifest.json"), "w") as f:
    json.dump(manifest, f, indent=2)

# Dataset wrapper for selecting specific patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Helper to build DataLoaders
def make_loader(ds, shuffle):
    return DataLoader(
        ds,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=torch.Generator().manual_seed(SEED),
    )

# Build preview loaders for validation
preview_loaders = []
for cid, cl in enumerate(clients):
    ds_tr = SubsetByPIDs(dataset, cl["train"])
    ds_va = SubsetByPIDs(dataset, cl["val"])
    ld_tr = make_loader(ds_tr, shuffle=True)
    ld_va = make_loader(ds_va, shuffle=False)
    preview_loaders.append((ld_tr, ld_va))

# Produce metadata summaries
def summarize(meta_df: pd.DataFrame, name: str) -> Dict:
    s = {"name": name, "n": int(len(meta_df))}
    s["has_tumor_counts"] = meta_df["has_tumor"].value_counts().to_dict()
    s["size_bin_counts"] = meta_df["size_bin"].value_counts().to_dict()
    s["dom_region_top5"] = meta_df["dom"].value_counts().head(5).to_dict()
    return s

summary_all = summarize(meta, "ALL")
print("\n=== Global summary ===")
print(summary_all)

per_client_summaries = []
for cid, cl in enumerate(clients):
    meta_c = client_metas[cid]
    tr = meta_c[meta_c["pid"].isin(cl["train"])]
    va = meta_c[meta_c["pid"].isin(cl["val"])]

    s_client = {
        "client": cid,
        "train": summarize(tr, f"client_{cid}_train"),
        "val": summarize(va, f"client_{cid}_val"),
    }

    per_client_summaries.append(s_client)
    print(f"\n=== Client {cid} ===")
    print(s_client)

# Save summaries
with open(os.path.join(CLIENT_DIR, "summary.json"), "w") as f:
    json.dump({"global": summary_all, "per_client": per_client_summaries}, f, indent=2)

print(f"\nSaved {N_CLIENTS} non-IID client splits and summaries to: {CLIENT_DIR}")

**Data Distribution**

In [None]:
{
  "global": {
    "name": "ALL",
    "n": 202,
    "has_tumor_counts": {
      "1": 202
    },
    "size_bin_counts": {
      "0": 68,
      "1": 67,
      "2": 67
    },
    "dom_region_top5": {
      "61": 85,
      "50": 81,
      "0": 10,
      "10": 4,
      "22": 3
    }
  },
  "per_client": [
    {
      "client": 0,
      "train": {
        "name": "client_0_train",
        "n": 35,
        "has_tumor_counts": {
          "1": 35
        },
        "size_bin_counts": {
          "1": 32,
          "0": 2,
          "2": 1
        },
        "dom_region_top5": {
          "50": 18,
          "61": 13,
          "7": 1,
          "62": 1,
          "10": 1
        }
      },
      "val": {
        "name": "client_0_val",
        "n": 9,
        "has_tumor_counts": {
          "1": 9
        },
        "size_bin_counts": {
          "1": 8,
          "0": 1
        },
        "dom_region_top5": {
          "50": 4,
          "61": 3,
          "10": 1,
          "58": 1
        }
      }
    },
    {
      "client": 1,
      "train": {
        "name": "client_1_train",
        "n": 96,
        "has_tumor_counts": {
          "1": 96
        },
        "size_bin_counts": {
          "2": 49,
          "0": 27,
          "1": 20
        },
        "dom_region_top5": {
          "61": 47,
          "50": 34,
          "0": 5,
          "1": 1,
          "23": 1
        }
      },
      "val": {
        "name": "client_1_val",
        "n": 25,
        "has_tumor_counts": {
          "1": 25
        },
        "size_bin_counts": {
          "2": 17,
          "0": 4,
          "1": 4
        },
        "dom_region_top5": {
          "61": 13,
          "50": 11,
          "0": 1
        }
      }
    },
    {
      "client": 2,
      "train": {
        "name": "client_2_train",
        "n": 29,
        "has_tumor_counts": {
          "1": 29
        },
        "size_bin_counts": {
          "0": 26,
          "1": 3
        },
        "dom_region_top5": {
          "50": 12,
          "61": 8,
          "22": 2,
          "0": 2,
          "31": 1
        }
      },
      "val": {
        "name": "client_2_val",
        "n": 8,
        "has_tumor_counts": {
          "1": 8
        },
        "size_bin_counts": {
          "0": 8
        },
        "dom_region_top5": {
          "50": 2,
          "61": 1,
          "10": 1,
          "45": 1,
          "22": 1
        }
      }
    }
  ]
}

# **M1 - Baseline Model**

<h2>Train a model for each client<h2>

In [None]:
!unzip -q /content/client.zip -d /content/client

In [None]:
!unzip -q /content/Preprocessed-Data.zip -d /content/Preprocessed-Data

In [None]:
import os, json, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from typing import List
from sklearn.metrics import jaccard_score
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

# Paths and I/O config
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "/content/client")

# Experiment config
USE_ATLAS = True
CLIENT_ID = 0
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 50
NUM_WORKERS = 2
SEED = 42

# Model / encoder config
ENCODER_NAME = "resnet34"
ENCODER_WEIGHTS = "imagenet"

# Output directories
OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "UNet_ImageOnly")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

def worker_init_fn(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask, and optional atlas regions per patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        import pickle, os, numpy as np, pandas as pd

        # Load metadata dataframe
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude some patient IDs
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)

        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")

            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Simple min-max normalization to [0, 1]."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load and return one sample (dict) for a patient."""
        import os, numpy as np

        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load regions/atlas
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """Custom collate: stack MRI (+ optional regions) and tumor into tensors."""
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        # Concatenate MRI and regions as channels
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

class SubsetByPIDs(Dataset):
    """Wrap a dataset but keep only a subset of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        # Map patient IDs to indices
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Load client-specific train/val patient IDs
cdir = os.path.join(CLIENT_DIR, f"client_{CLIENT_ID}")
with open(os.path.join(cdir, "train_pids.json"), "r") as f:
    train_pids = json.load(f)
with open(os.path.join(cdir, "val_pids.json"), "r") as f:
    val_pids = json.load(f)

# Build full dataset and then client subsets
full_ds = ImageOnlyGliomaDataset(
    METADATA_DF_PATH,
    DATA_ROOT,
    use_atlas=USE_ATLAS,
    exclude_ids=["PatientID_0191"],
)
train_dataset = SubsetByPIDs(full_ds, train_pids)
val_dataset = SubsetByPIDs(full_ds, val_pids)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)

print(f"Loaded client_{CLIENT_ID}: train patients={len(train_dataset)}, val patients={len(val_dataset)}")

# UNet model (from segmentation_models_pytorch)
in_channels = 2 if USE_ATLAS else 1
model = smp.Unet(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=in_channels,
    classes=1,
).to(device)

# Loss, optimizer, scheduler, and AMP scaler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=EPOCHS
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

def calc_metrics(y_true, y_pred):
    """Compute Dice, IoU, and pixel accuracy."""
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    # IoU with sklearn, fallback if it fails
    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)

def plot_prediction(sample_imgs, save_path):
    """Plot MRI (+ optional regions), GT, and prediction for one batch."""
    if "regions" in sample_imgs:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # MRI
    axs[0].imshow(sample_imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    col = 1

    # Regions (if available)
    if "regions" in sample_imgs:
        axs[col].imshow(sample_imgs["regions"][0, 0], cmap="gray")
        axs[col].set_title("Regions")
        axs[col].axis("off")
        col += 1

    # Ground truth
    axs[col].imshow(sample_imgs["y_true"][0, 0], cmap="gray")
    axs[col].set_title("Ground Truth")
    axs[col].axis("off")
    col += 1

    # Prediction
    axs[col].imshow(sample_imgs["y_pred"][0, 0], cmap="gray")
    axs[col].set_title("Predicted")
    axs[col].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_metrics(history, save_path):
    """Plot training and validation curves for loss and metrics."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(16, 12))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss")

    # Dice
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history["train_dice"], label="Train Dice")
    plt.plot(epochs, history["val_dice"], label="Val Dice")
    plt.legend()
    plt.title("Dice")

    # IoU
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history["train_iou"], label="Train IoU")
    plt.plot(epochs, history["val_iou"], label="Val IoU")
    plt.legend()
    plt.title("IoU")

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def train_one_epoch(model, dataloader, optimizer, criterion, scaler):
    """One training epoch over the dataloader."""
    model.train()
    total_loss = total_dice = total_iou = total_acc = 0.0

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        # Mixed precision forward pass
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(x)
            loss = criterion(preds, y)

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item())

        # Convert predictions to binary and compute metrics
        preds_prob = (torch.sigmoid(preds).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

@torch.no_grad()
def eval_one_epoch(model, dataloader, criterion, keep_last_batch=True):
    """Evaluation loop over the validation dataloader."""
    model.eval()
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        preds = model(x)
        loss = criterion(preds, y)
        total_loss += float(loss.item())

        preds_prob = (torch.sigmoid(preds).cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

        # Optionally keep last batch for visualization
        if keep_last_batch:
            ch0 = x[:, 0:1].cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.cpu().numpy(), "y_pred": preds_prob}
            if x.shape[1] == 2:
                ch1 = x[:, 1:2].cpu().numpy()
                imgs["regions"] = ch1
            last_batch_imgs = imgs

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def infer_and_visualize_best(
    model,
    val_dataset,
    use_atlas: bool,
    out_dir: str,
    client_id: int,
    best_ckpt_path: str,
    k_samples: int = 3,
    threshold: float = 0.5,
):
    """Load best checkpoint, run inference on a few validation samples, and save visualizations."""
    import random

    # Make sampling deterministic
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    os.makedirs(out_dir, exist_ok=True)

    if os.path.isfile(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        model.eval()
        print(f"[Best Model] Loaded: {best_ckpt_path}\n")
    else:
        print(f"[Best Model] Missing checkpoint: {best_ckpt_path}\n")
        return

    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("[Best Model] Empty val dataset.\n")
        return

    # Deterministic random subset of validation indices
    idxs = random.sample(range(len(val_dataset)), k)

    def _predict_one(sample):
        """Run model on one sample and return inputs and binarized prediction."""
        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float().to(device)

        if use_atlas and ("regions" in sample):
            regs = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float().to(device)
            x = torch.cat([mri, regs], dim=1)
        else:
            x = mri

        with torch.no_grad():
            prob = torch.sigmoid(model(x)).cpu().numpy()
            pred_bin = (prob > threshold).astype(np.uint8)

        return x.cpu().numpy(), pred_bin

    saved_paths = []

    # Save individual sample figures
    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")
        x_np, pred_bin = _predict_one(sample)

        imgs = {
            "mri": x_np[:, 0:1],
            "y_true": np.expand_dims(
                np.expand_dims(sample["tumor"], 0), 0
            ).astype(np.float32),
            "y_pred": pred_bin.astype(np.float32),
        }

        if use_atlas and ("regions" in sample):
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(
            out_dir, f"best_val_sample_{i}_client{client_id}_{pid}.png"
        )
        plot_prediction(imgs, out_path)
        saved_paths.append(out_path)

    print("[Best Model] Saved individual figures:")
    for p in saved_paths:
        print(" -", p)
    print()

    # Save grid figure
    cols = 4 if use_atlas else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))

    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]
        x_np, pred_bin = _predict_one(sample)

        mri = x_np[0, 0]
        gt = sample["tumor"]
        col = 0

        axs[row, col].imshow(mri, cmap="gray")
        axs[row, col].set_title("MRI")
        axs[row, col].axis("off")
        col += 1

        if use_atlas and ("regions" in sample):
            regs = x_np[0, 1]
            axs[row, col].imshow(regs, cmap="gray")
            axs[row, col].set_title("Regions")
            axs[row, col].axis("off")
            col += 1

        axs[row, col].imshow(gt, cmap="gray")
        axs[row, col].set_title("Ground Truth")
        axs[row, col].axis("off")
        col += 1

        axs[row, col].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, col].set_title("Predicted (τ=0.5)")
        axs[row, col].axis("off")

    plt.tight_layout()
    grid_path = os.path.join(out_dir, f"best_model_val_grid_client{client_id}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Best Model] Saved grid -> {grid_path}\n")

# History containers for training curves
history = {
    "train_loss": [],
    "train_dice": [],
    "train_iou": [],
    "train_acc": [],
    "val_loss": [],
    "val_dice": [],
    "val_iou": [],
    "val_acc": [],
}

# Tracking best validation metrics
best_val_iou = 0.0
best_val_dice = -float("inf")
best_val_loss = float("inf")
best_path = os.path.join(OUT_MODELS_DIR, f"best_unet_client{CLIENT_ID}.pth")

log_rows = []

# Main training loop
for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    trL, trD, trI, trA = train_one_epoch(
        model, train_loader, optimizer, criterion, scaler
    )
    # Validate
    vaL, vaD, vaI, vaA, _ = eval_one_epoch(
        model, val_loader, criterion, keep_last_batch=True
    )

    # Update scheduler
    scheduler.step()

    # Save metrics for plots
    history["train_loss"].append(trL)
    history["train_dice"].append(trD)
    history["train_iou"].append(trI)
    history["train_acc"].append(trA)

    history["val_loss"].append(vaL)
    history["val_dice"].append(vaD)
    history["val_iou"].append(vaI)
    history["val_acc"].append(vaA)

    print(
        f"[Epoch {epoch:03d}/{EPOCHS}] "
        f"Train — Loss {trL:.4f} | Dice {trD:.4f} | IoU {trI:.4f} | Accuracy {trA:.4f} || "
        f"Val — Loss {vaL:.4f} | Dice {vaD:.4f} | IoU {vaI:.4f} | Accuracy {vaA:.4f}\n"
    )

    # Save checkpoint if Dice improves and loss decreases
    saved_ckpt = False
    if (vaD > best_val_dice) and (vaL < best_val_loss):
        best_val_dice = vaD
        best_val_loss = vaL
        torch.save(model.state_dict(), best_path)
        saved_ckpt = True
        print(
            f"Saved best model (Val Dice↑ {best_val_dice:.4f} & Val Loss↓ {best_val_loss:.4f}) -> {best_path}\n"
        )

    # Log row for CSV
    log_rows.append(
        {
            "epoch": epoch,
            "train_loss": trL,
            "train_dice": trD,
            "train_iou": trI,
            "train_acc": trA,
            "val_loss": vaL,
            "val_dice": vaD,
            "val_iou": vaI,
            "val_acc": vaA,
            "saved_ckpt": saved_ckpt,
            "marker": "X" if saved_ckpt else "",
        }
    )

# Save metrics as CSV
metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"metrics_client{CLIENT_ID}.csv")
pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
print(f"[Log] Wrote per-epoch metrics CSV -> {metrics_csv}\n")

# Plot training curves
plot_metrics(history, os.path.join(OUT_GRAPHS_DIR, f"training_curves_client{CLIENT_ID}.png"))

# Run inference with best model and visualize a few validation samples
infer_and_visualize_best(
    model=model,
    val_dataset=val_dataset,
    use_atlas=USE_ATLAS,
    out_dir=OUT_GRAPHS_DIR,
    client_id=CLIENT_ID,
    best_ckpt_path=best_path,
    k_samples=3,
    threshold=0.5,
)

# **M1 - Federated Learning Setup**

<h2>Federated Experiment<h2>

**with Flower**

In [None]:
%%writefile seg_data.py
import os, pickle, numpy as np, torch
from typing import List
from torch.utils.data import Dataset, DataLoader

# Global paths and configuration
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
USE_ATLAS = True
EXCLUDE_IDS = ["PatientID_0191"]

# Dataset that loads MRI, tumor mask and optional atlas for each patient
class ImageOnlyGliomaDataset(Dataset):
    def __init__(
        self,
        metadata_df_path=METADATA_DF_PATH,
        data_root=DATA_ROOT,
        use_atlas=USE_ATLAS,
        exclude_ids=None,
        transform=None,
    ):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude specific patients
        if exclude_ids is None:
            exclude_ids = EXCLUDE_IDS

        # Keep only non-excluded patient rows
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required .npy files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            base = os.path.join(self.data_root, pid)
            mri_p = os.path.join(base, f"{pid}_mri.npy")
            tumor_p = os.path.join(base, f"{pid}_tumor.npy")
            if self.use_atlas:
                reg_p = os.path.join(base, f"{pid}_regions.npy")
                ok = os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p)
            else:
                ok = os.path.isfile(mri_p) and os.path.isfile(tumor_p)
            if ok:
                self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    # Simple min–max normalization
    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    # Load a single sample: MRI, tumor mask, and optional atlas
    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

# Collate function to build batched tensors and patient ID list
def image_only_collate_fn(batch, use_atlas=USE_ATLAS):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

# Dataset wrapper that restricts to a subset of patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Compute Dice, IoU and accuracy for binary masks
def calc_metrics(y_true, y_pred):
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true & y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)
    union = y_true.sum() + y_pred.sum() - inter + 1e-8
    iou = inter / union
    acc = (y_true == y_pred).mean()

    return float(dice), float(iou), float(acc)

In [None]:
%%writefile fl_client.py
import argparse, os, json, numpy as np, torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import flwr as fl
import copy
import segmentation_models_pytorch as smp
from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    calc_metrics,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
)

# Base directory for client data splits
CLIENT_DIR = "/content/client"

# Data/loading config
BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

# Device and model config
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ENCODER_NAME = "resnet34"
ENCODER_WEIGHTS = "imagenet"

# Local checkpoint directory (per client best models)
CKPT_DIR = os.path.join("AITDM", "checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)


def get_model():
    """Create and return a UNet segmentation model."""
    in_ch = 2 if USE_ATLAS else 1
    model = smp.Unet(
        encoder_name=ENCODER_NAME,
        encoder_weights=ENCODER_WEIGHTS,
        in_channels=in_ch,
        classes=1,
    )
    return model.to(DEVICE)


def get_loaders(cid: int):
    """Build train/val dataloaders for a given client ID."""
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"]
    )

    # Load client-specific patient IDs
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "train_pids.json")) as f:
        tr_p = json.load(f)
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    g = torch.Generator().manual_seed(SEED)

    # Subset datasets
    ds_tr = SubsetByPIDs(full, tr_p)
    ds_va = SubsetByPIDs(full, va_p)

    # Train loader
    ld_tr = DataLoader(
        ds_tr,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=g,
    )

    # Validation loader
    ld_va = DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=g,
    )

    return ld_tr, ld_va, len(ds_tr), len(ds_va)


def get_parameters(model):
    """Convert model parameters to a list of NumPy arrays (for Flower)."""
    return [p.detach().cpu().numpy() for _, p in model.state_dict().items()]


def set_parameters(model, params):
    """Load model parameters from a list of NumPy arrays (from Flower)."""
    sd = model.state_dict()
    for k, v in zip(sd.keys(), params):
        sd[k] = torch.tensor(v)
    model.load_state_dict(sd, strict=True)


# Loss functions used for combined criterion
bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(pred, y):
    """Hybrid loss: BCE + Dice."""
    return 0.5 * bce(pred, y) + 0.5 * dice_loss(pred, y)


def maybe_save_best(cid, val_loss, val_dice, best_epoch, rnd, model):
    """Save best local model (per client) based on validation loss and Dice."""
    best_json = os.path.join(CKPT_DIR, f"client_{cid}_best.json")
    best_pt = os.path.join(CKPT_DIR, f"client_{cid}_best.pt")

    # Default previous best values
    prev = {"val_loss": float("inf"), "val_dice": -1.0}
    if os.path.isfile(best_json):
        try:
            with open(best_json, "r") as f:
                prev = json.load(f)
        except Exception:
            pass

    # Check if current model improved both loss and Dice
    improved = (val_loss < prev.get("val_loss", float("inf"))) and (
        val_dice > prev.get("val_dice", -1.0)
    )

    if improved:
        # Save state dict and metadata
        torch.save(model.state_dict(), best_pt)
        with open(best_json, "w") as f:
            json.dump(
                {
                    "round": int(rnd),
                    "epoch": int(best_epoch),
                    "val_loss": float(val_loss),
                    "val_dice": float(val_dice),
                },
                f,
            )


class SegClient(fl.client.NumPyClient):
    """Flower NumPyClient for federated glioma segmentation."""

    def __init__(self, cid: int):
        self.cid = cid
        self.model = get_model()
        self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(cid)

    def get_parameters(self, config):
        """Return current local model parameters."""
        return get_parameters(self.model)

    def fit(self, parameters, config):
        """Local training for a number of epochs, then return updated parameters and metrics."""
        # Load global parameters
        set_parameters(self.model, parameters)

        # Read fit configuration from server
        epochs = int(config.get("local_epochs", 1))
        lr = float(config.get("lr", 1e-3))
        rnd = int(config.get("round", 0))

        # Optimizer and AMP scaler
        opt = optim.AdamW(self.model.parameters(), lr=lr)
        scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

        # Track best validation performance
        best_state = None
        best_val_loss = float("inf")
        best_val_dice = -1.0
        best_epoch_idx = -1

        # Track best train metrics for that epoch (optional, for debugging)
        best_train_loss = float("inf")
        best_train_dice = -1.0
        best_train_iou = 0.0
        best_train_acc = 0.0

        # Per-epoch logs (will be sent to server)
        epoch_logs = []

        for epoch_idx in range(1, epochs + 1):
            # ----- Training phase -----
            self.model.train()
            tot_tr_loss = tot_tr_d = tot_tr_i = tot_tr_a = 0.0
            nb_tr = 0

            for batch in self.train_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                opt.zero_grad(set_to_none=True)

                # Forward pass with mixed precision
                with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
                    pred = self.model(x)
                    loss = criterion(pred, y)

                # Backward pass and optimizer step
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

                # Compute train metrics on this batch
                with torch.no_grad():
                    y_hat = (torch.sigmoid(pred).detach().cpu().numpy() > 0.5).astype(
                        np.uint8
                    )
                    y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                tot_tr_loss += float(loss.item())
                tot_tr_d += d
                tot_tr_i += i
                tot_tr_a += a
                nb_tr += 1

            nb_tr = max(nb_tr, 1)
            epoch_tr_loss = tot_tr_loss / nb_tr
            epoch_tr_dice = tot_tr_d / nb_tr
            epoch_tr_iou = tot_tr_i / nb_tr
            epoch_tr_acc = tot_tr_a / nb_tr

            # ----- Validation phase -----
            self.model.eval()
            tot_val_loss = tot_val_d = tot_val_i = tot_val_a = 0.0
            nb_val = 0

            with torch.no_grad():
                for batch in self.val_loader:
                    x = batch["x"].to(DEVICE)
                    y = batch["y"].to(DEVICE)
                    pred = self.model(x)

                    v_loss = float(criterion(pred, y).item())
                    y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                    tot_val_loss += v_loss
                    tot_val_d += d
                    tot_val_i += i
                    tot_val_a += a
                    nb_val += 1

            nb_val = max(nb_val, 1)
            epoch_val_loss = tot_val_loss / nb_val
            epoch_val_dice = tot_val_d / nb_val
            epoch_val_iou = tot_val_i / nb_val
            epoch_val_acc = tot_val_a / nb_val

            # Log per-epoch metrics
            epoch_logs.append(
                {
                    "epoch": int(epoch_idx),
                    "train_loss": float(epoch_tr_loss),
                    "train_dice": float(epoch_tr_dice),
                    "train_iou": float(epoch_tr_iou),
                    "train_acc": float(epoch_tr_acc),
                    "val_loss": float(epoch_val_loss),
                    "val_dice": float(epoch_val_dice),
                    "val_iou": float(epoch_val_iou),
                    "val_acc": float(epoch_val_acc),
                }
            )

            # Update best model based on validation metrics
            if (epoch_val_loss < best_val_loss) and (epoch_val_dice > best_val_dice):
                best_val_loss = epoch_val_loss
                best_val_dice = epoch_val_dice
                best_state = copy.deepcopy(self.model.state_dict())
                best_epoch_idx = epoch_idx

                best_train_loss = epoch_tr_loss
                best_train_dice = epoch_tr_dice
                best_train_iou = epoch_tr_iou
                best_train_acc = epoch_tr_acc

        # Load best local model state if available
        if best_state is not None:
            self.model.load_state_dict(best_state)

        # Mark best epoch inside logs
        for ep in epoch_logs:
            ep["best_epoch"] = (ep["epoch"] == best_epoch_idx)

        # Metrics sent back to server
        train_metrics = {
            "cid": int(self.cid),
            "best_epoch": int(best_epoch_idx),
            "best_val_loss": float(best_val_loss),
            "best_val_dice": float(best_val_dice),
            "per_epoch": json.dumps(epoch_logs),
        }

        # Save best local checkpoint
        maybe_save_best(self.cid, best_val_loss, best_val_dice, best_epoch_idx, rnd, self.model)

        # Return updated parameters, number of train examples, and metrics
        return get_parameters(self.model), self.ntr, train_metrics

    def evaluate(self, parameters, config):
        """Evaluate global parameters on local validation set."""
        # Load global parameters
        set_parameters(self.model, parameters)
        self.model.eval()

        tot_loss = tot_d = tot_i = tot_a = 0.0
        nb = 0

        with torch.no_grad():
            for batch in self.val_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                pred = self.model(x)

                loss = float(criterion(pred, y).item())
                y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                d, i, a = calc_metrics(y_np, y_hat)

                tot_loss += loss
                tot_d += d
                tot_i += i
                tot_a += a
                nb += 1

        nb = max(nb, 1)
        metrics = {
            "loss": tot_loss / nb,
            "dice": tot_d / nb,
            "iou": tot_i / nb,
            "acc": tot_a / nb,
            "cid": int(self.cid),
        }

        # Flower expects (loss, num_examples, metrics)
        return metrics["loss"], self.nva, metrics


if __name__ == "__main__":
    # Standalone client entry point (for non-simulation setups)
    parser = argparse.ArgumentParser()
    parser.add_argument("--cid", type=int, required=True)
    parser.add_argument("--server", default="0.0.0.0:8080")
    args = parser.parse_args()

    print(f"[SegClient {args.cid}] device={DEVICE}, cuda={torch.cuda.is_available()}")

    fl.client.start_numpy_client(
        server_address=args.server,
        client=SegClient(args.cid),
    )

In [None]:
%%writefile fl_sim_colab.py
import os
import csv
import json
import torch
import flwr as fl
from flwr.common import FitIns
from fl_client import SegClient

# Base directory for outputs
BASE_DIR = "AITDM"
# Directory where per-client metrics will be saved
METRICS_DIR = os.path.join(BASE_DIR, "metrics")
os.makedirs(METRICS_DIR, exist_ok=True)


def client_fn(cid: str):
    """Create a Flower client for a given client ID."""
    return SegClient(int(cid)).to_client()


def ensure_csv(path: str, header: list[str]):
    """Create a CSV file with header if it does not exist."""
    if not os.path.isfile(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(header)


def append_row(path: str, row: list):
    """Append a single row to a CSV file."""
    with open(path, "a", newline="") as f:
        csv.writer(f).writerow(row)


class PerClientLoggingFedAvg(fl.server.strategy.FedAvg):
    """Custom FedAvg strategy that logs per-client metrics to CSV."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # CSV header for logged metrics
        self.header = [
            "round",
            "epoch",
            "train_loss",
            "train_dice",
            "train_iou",
            "train_acc",
            "val_loss",
            "val_dice",
            "val_iou",
            "val_acc",
            "best_epoch",
        ]

    def configure_fit(self, server_round, parameters, client_manager):
        """Inject current round number into fit configuration for each client."""
        items = super().configure_fit(server_round, parameters, client_manager)
        out = []
        for it in items:
            if isinstance(it, tuple):
                client, fitins = it
            else:
                client, fitins = None, it

            cfg = dict(fitins.config)
            cfg["round"] = server_round
            new_fitins = FitIns(fitins.parameters, cfg)

            out.append((client, new_fitins) if client is not None else new_fitins)
        return out

    def aggregate_fit(self, rnd, results, failures):
        """Aggregate fit results and log per-epoch metrics for each client."""
        agg = super().aggregate_fit(rnd, results, failures)

        for client_proxy, fit_res in results:
            m = fit_res.metrics or {}
            cid = str(m.get("cid", client_proxy.cid))

            client_csv = os.path.join(METRICS_DIR, f"metrics_client_{cid}.csv")
            ensure_csv(client_csv, self.header)

            best_epoch = int(m.get("best_epoch", -1))
            per_epoch_raw = m.get("per_epoch", "[]")

            # Parse per-epoch metrics sent from the client
            try:
                per_epoch = json.loads(per_epoch_raw)
            except Exception:
                per_epoch = []

            # Write one row per local epoch
            for ep in per_epoch:
                epoch = ep.get("epoch", "")
                row = [
                    rnd,
                    epoch,
                    ep.get("train_loss", ""),
                    ep.get("train_dice", ""),
                    ep.get("train_iou", ""),
                    ep.get("train_acc", ""),
                    ep.get("val_loss", ""),
                    ep.get("val_dice", ""),
                    ep.get("val_iou", ""),
                    ep.get("val_acc", ""),
                    "x" if int(epoch) == best_epoch else "",
                ]
                append_row(client_csv, row)

        return agg


# Federated learning strategy configuration
strategy = PerClientLoggingFedAvg(
    fraction_fit=1.0,
    fraction_evaluate=1.0,
    min_fit_clients=3,
    min_evaluate_clients=3,
    min_available_clients=3,
    on_fit_config_fn=lambda rnd: {"local_epochs": 5, "lr": 1e-3},
)

# Resource configuration (GPU if available)
use_gpu = torch.cuda.is_available()
client_resources = {"num_cpus": 1, "num_gpus": 1.0 if use_gpu else 0.0}

# Start Flower simulation with 3 clients and 5 communication rounds
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=3,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
    client_resources=client_resources,
    ray_init_args={"include_dashboard": False},
)

In [None]:
!python fl_sim_colab.py

# **M2 - Clip Model**

**We train UNet for tumor segmentation, train BioClinicalBERT for region prediction in reports, then train a multimodal CLIP-like model that aligns text–image embeddings and, simultaneously, preserves segmentation through a combined loss.**

**<h2>UNet on images<h2>**

In [None]:
import os, json, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from typing import List
from sklearn.metrics import jaccard_score
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

# Paths and I/O config
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "/content/client")

# Experiment config
USE_ATLAS = True
CLIENT_ID = 0
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 50
NUM_WORKERS = 2
SEED = 42

# Model / encoder config
ENCODER_NAME = "resnet34"
ENCODER_WEIGHTS = "imagenet"

# Output directories
OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "UNet_ImageOnly")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

def worker_init_fn(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask, and optional atlas regions per patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        import pickle, os, numpy as np, pandas as pd

        # Load metadata dataframe
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude some patient IDs
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)

        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")

            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Simple min-max normalization to [0, 1]."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load and return one sample (dict) for a patient."""
        import os, numpy as np

        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load regions/atlas
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """Custom collate: stack MRI (+ optional regions) and tumor into tensors."""
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        # Concatenate MRI and regions as channels
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

class SubsetByPIDs(Dataset):
    """Wrap a dataset but keep only a subset of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        # Map patient IDs to indices
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Load client-specific train/val patient IDs
cdir = os.path.join(CLIENT_DIR, f"client_{CLIENT_ID}")
with open(os.path.join(cdir, "train_pids.json"), "r") as f:
    train_pids = json.load(f)
with open(os.path.join(cdir, "val_pids.json"), "r") as f:
    val_pids = json.load(f)

# Build full dataset and then client subsets
full_ds = ImageOnlyGliomaDataset(
    METADATA_DF_PATH,
    DATA_ROOT,
    use_atlas=USE_ATLAS,
    exclude_ids=["PatientID_0191"],
)
train_dataset = SubsetByPIDs(full_ds, train_pids)
val_dataset = SubsetByPIDs(full_ds, val_pids)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)

print(f"Loaded client_{CLIENT_ID}: train patients={len(train_dataset)}, val patients={len(val_dataset)}")

# UNet model (from segmentation_models_pytorch)
in_channels = 2 if USE_ATLAS else 1
model = smp.Unet(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=in_channels,
    classes=1,
).to(device)

# Loss, optimizer, scheduler, and AMP scaler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=EPOCHS
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

def calc_metrics(y_true, y_pred):
    """Compute Dice, IoU, and pixel accuracy."""
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    # IoU with sklearn, fallback if it fails
    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)

def plot_prediction(sample_imgs, save_path):
    """Plot MRI (+ optional regions), GT, and prediction for one batch."""
    if "regions" in sample_imgs:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # MRI
    axs[0].imshow(sample_imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    col = 1

    # Regions (if available)
    if "regions" in sample_imgs:
        axs[col].imshow(sample_imgs["regions"][0, 0], cmap="gray")
        axs[col].set_title("Regions")
        axs[col].axis("off")
        col += 1

    # Ground truth
    axs[col].imshow(sample_imgs["y_true"][0, 0], cmap="gray")
    axs[col].set_title("Ground Truth")
    axs[col].axis("off")
    col += 1

    # Prediction
    axs[col].imshow(sample_imgs["y_pred"][0, 0], cmap="gray")
    axs[col].set_title("Predicted")
    axs[col].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_metrics(history, save_path):
    """Plot training and validation curves for loss and metrics."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(16, 12))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss")

    # Dice
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history["train_dice"], label="Train Dice")
    plt.plot(epochs, history["val_dice"], label="Val Dice")
    plt.legend()
    plt.title("Dice")

    # IoU
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history["train_iou"], label="Train IoU")
    plt.plot(epochs, history["val_iou"], label="Val IoU")
    plt.legend()
    plt.title("IoU")

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def train_one_epoch(model, dataloader, optimizer, criterion, scaler):
    """One training epoch over the dataloader."""
    model.train()
    total_loss = total_dice = total_iou = total_acc = 0.0

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        # Mixed precision forward pass
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(x)
            loss = criterion(preds, y)

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item())

        # Convert predictions to binary and compute metrics
        preds_prob = (torch.sigmoid(preds).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

@torch.no_grad()
def eval_one_epoch(model, dataloader, criterion, keep_last_batch=True):
    """Evaluation loop over the validation dataloader."""
    model.eval()
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        preds = model(x)
        loss = criterion(preds, y)
        total_loss += float(loss.item())

        preds_prob = (torch.sigmoid(preds).cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

        # Optionally keep last batch for visualization
        if keep_last_batch:
            ch0 = x[:, 0:1].cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.cpu().numpy(), "y_pred": preds_prob}
            if x.shape[1] == 2:
                ch1 = x[:, 1:2].cpu().numpy()
                imgs["regions"] = ch1
            last_batch_imgs = imgs

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def infer_and_visualize_best(
    model,
    val_dataset,
    use_atlas: bool,
    out_dir: str,
    client_id: int,
    best_ckpt_path: str,
    k_samples: int = 3,
    threshold: float = 0.5,
):
    """Load best checkpoint, run inference on a few validation samples, and save visualizations."""
    import random

    # Make sampling deterministic
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    os.makedirs(out_dir, exist_ok=True)

    if os.path.isfile(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        model.eval()
        print(f"[Best Model] Loaded: {best_ckpt_path}\n")
    else:
        print(f"[Best Model] Missing checkpoint: {best_ckpt_path}\n")
        return

    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("[Best Model] Empty val dataset.\n")
        return

    # Deterministic random subset of validation indices
    idxs = random.sample(range(len(val_dataset)), k)

    def _predict_one(sample):
        """Run model on one sample and return inputs and binarized prediction."""
        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float().to(device)

        if use_atlas and ("regions" in sample):
            regs = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float().to(device)
            x = torch.cat([mri, regs], dim=1)
        else:
            x = mri

        with torch.no_grad():
            prob = torch.sigmoid(model(x)).cpu().numpy()
            pred_bin = (prob > threshold).astype(np.uint8)

        return x.cpu().numpy(), pred_bin

    saved_paths = []

    # Save individual sample figures
    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")
        x_np, pred_bin = _predict_one(sample)

        imgs = {
            "mri": x_np[:, 0:1],
            "y_true": np.expand_dims(
                np.expand_dims(sample["tumor"], 0), 0
            ).astype(np.float32),
            "y_pred": pred_bin.astype(np.float32),
        }

        if use_atlas and ("regions" in sample):
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(
            out_dir, f"best_val_sample_{i}_client{client_id}_{pid}.png"
        )
        plot_prediction(imgs, out_path)
        saved_paths.append(out_path)

    print("[Best Model] Saved individual figures:")
    for p in saved_paths:
        print(" -", p)
    print()

    # Save grid figure
    cols = 4 if use_atlas else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))

    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]
        x_np, pred_bin = _predict_one(sample)

        mri = x_np[0, 0]
        gt = sample["tumor"]
        col = 0

        axs[row, col].imshow(mri, cmap="gray")
        axs[row, col].set_title("MRI")
        axs[row, col].axis("off")
        col += 1

        if use_atlas and ("regions" in sample):
            regs = x_np[0, 1]
            axs[row, col].imshow(regs, cmap="gray")
            axs[row, col].set_title("Regions")
            axs[row, col].axis("off")
            col += 1

        axs[row, col].imshow(gt, cmap="gray")
        axs[row, col].set_title("Ground Truth")
        axs[row, col].axis("off")
        col += 1

        axs[row, col].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, col].set_title("Predicted (τ=0.5)")
        axs[row, col].axis("off")

    plt.tight_layout()
    grid_path = os.path.join(out_dir, f"best_model_val_grid_client{client_id}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Best Model] Saved grid -> {grid_path}\n")

# History containers for training curves
history = {
    "train_loss": [],
    "train_dice": [],
    "train_iou": [],
    "train_acc": [],
    "val_loss": [],
    "val_dice": [],
    "val_iou": [],
    "val_acc": [],
}

# Tracking best validation metrics
best_val_iou = 0.0
best_val_dice = -float("inf")
best_val_loss = float("inf")
best_path = os.path.join(OUT_MODELS_DIR, f"best_unet_client{CLIENT_ID}.pth")

log_rows = []

# Main training loop
for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    trL, trD, trI, trA = train_one_epoch(
        model, train_loader, optimizer, criterion, scaler
    )
    # Validate
    vaL, vaD, vaI, vaA, _ = eval_one_epoch(
        model, val_loader, criterion, keep_last_batch=True
    )

    # Update scheduler
    scheduler.step()

    # Save metrics for plots
    history["train_loss"].append(trL)
    history["train_dice"].append(trD)
    history["train_iou"].append(trI)
    history["train_acc"].append(trA)

    history["val_loss"].append(vaL)
    history["val_dice"].append(vaD)
    history["val_iou"].append(vaI)
    history["val_acc"].append(vaA)

    print(
        f"[Epoch {epoch:03d}/{EPOCHS}] "
        f"Train — Loss {trL:.4f} | Dice {trD:.4f} | IoU {trI:.4f} | Accuracy {trA:.4f} || "
        f"Val — Loss {vaL:.4f} | Dice {vaD:.4f} | IoU {vaI:.4f} | Accuracy {vaA:.4f}\n"
    )

    # Save checkpoint if Dice improves and loss decreases
    saved_ckpt = False
    if (vaD > best_val_dice) and (vaL < best_val_loss):
        best_val_dice = vaD
        best_val_loss = vaL
        torch.save(model.state_dict(), best_path)
        saved_ckpt = True
        print(
            f"Saved best model (Val Dice↑ {best_val_dice:.4f} & Val Loss↓ {best_val_loss:.4f}) -> {best_path}\n"
        )

    # Log row for CSV
    log_rows.append(
        {
            "epoch": epoch,
            "train_loss": trL,
            "train_dice": trD,
            "train_iou": trI,
            "train_acc": trA,
            "val_loss": vaL,
            "val_dice": vaD,
            "val_iou": vaI,
            "val_acc": vaA,
            "saved_ckpt": saved_ckpt,
            "marker": "X" if saved_ckpt else "",
        }
    )

# Save metrics as CSV
metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"metrics_client{CLIENT_ID}.csv")
pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
print(f"[Log] Wrote per-epoch metrics CSV -> {metrics_csv}\n")

# Plot training curves
plot_metrics(history, os.path.join(OUT_GRAPHS_DIR, f"training_curves_client{CLIENT_ID}.png"))

# Run inference with best model and visualize a few validation samples
infer_and_visualize_best(
    model=model,
    val_dataset=val_dataset,
    use_atlas=USE_ATLAS,
    out_dir=OUT_GRAPHS_DIR,
    client_id=CLIENT_ID,
    best_ckpt_path=best_path,
    k_samples=3,
    threshold=0.5,
)

**<h2>BioClinicalBERT on reports<h2>**

In [None]:
import os
import json
import pickle
import random
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import warnings
from sklearn.metrics import f1_score, roc_auc_score
from sklearn.exceptions import UndefinedMetricWarning
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

warnings.simplefilter("ignore", category=UndefinedMetricWarning)

# ---- Reproducibility ----
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# ---- Paths / Config ----
METADATA_DF_PATH = "/content/cleaned_df.pkl"
LABELS_PATH = "/content/labels_list.pkl"
OUT_BASE = "/content/AITDM"
CLIENT_BASE_DIR = "/content/client"

CLIENT_ID = 0
BATCH_SIZE = 8
LR = 2e-5
EPOCHS = 20
NUM_WORKERS = 2
MAX_LEN = 512
WARMUP_RATIO = 0.1

BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
DROPOUT = 0.3

TOPK_PRED = 5
USE_POS_WEIGHT = True
REMOVE_BACKGROUND_LABEL = True

OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "BioClinicalBert")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)


def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)


def worker_init_fn(worker_id: int) -> None:
    s = SEED + worker_id
    random.seed(s)
    np.random.seed(s)


set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class ReportsGliomaDataset(Dataset):
    """Multi-label dataset: report text -> set of region labels."""

    def __init__(
        self,
        metadata_df_path: str,
        labels_path: str,
        exclude_ids: Optional[List[str]] = None,
        remove_background: bool = False,
    ):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        with open(labels_path, "rb") as f:
            labels = list(pickle.load(f))

        if remove_background:
            labels = [l for l in labels if str(l).strip().lower() != "background"]

        self.labels: List[str] = labels
        self.label_to_idx: Dict[str, int] = {lab: i for i, lab in enumerate(self.labels)}

        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]

        df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        df = df[df["Report"].notna() & df["Top 5 Regions"].notna()].reset_index(drop=True)

        self.df = df
        self.patient_ids = df["Patient_ID"].tolist()

    def __len__(self) -> int:
        return len(self.patient_ids)

    def _make_multilabel_target(self, regions_list) -> np.ndarray:
        y = np.zeros(len(self.labels), dtype=np.float32)
        if not isinstance(regions_list, (list, tuple)):
            return y
        for reg in regions_list:
            idx = self.label_to_idx.get(reg, None)
            if idx is not None:
                y[idx] = 1.0
        return y

    def __getitem__(self, idx: int) -> Dict[str, object]:
        row = self.df.iloc[idx]
        pid = row["Patient_ID"]
        report = str(row["Report"])
        top5 = row["Top 5 Regions"]
        target = self._make_multilabel_target(top5)
        return {"patient_id": pid, "report": report, "target": target}


class SubsetByPIDs(Dataset):
    """Filter a dataset by a provided list of patient IDs."""

    def __init__(self, full_dataset: ReportsGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, i: int) -> Dict[str, object]:
        return self.ds[self.indices[i]]


def reports_collate_fn(batch: List[Dict[str, object]]) -> Dict[str, object]:
    reports = [b["report"] for b in batch]
    targets = torch.from_numpy(np.stack([b["target"] for b in batch])).float()
    pids = [b["patient_id"] for b in batch]
    return {"report": reports, "target": targets, "pid": pids}


class BioBERTMultiLabelClassifier(nn.Module):
    """BioClinicalBERT + linear head for multi-label prediction."""

    def __init__(self, model_name: str, num_labels: int, dropout: float = 0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.pooler_output
        if pooled is None:
            pooled = outputs.last_hidden_state[:, 0]  # CLS
        x = self.dropout(pooled)
        return self.classifier(x)


tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)


def tokenize_reports(reports: List[str], max_len: int = 512) -> Dict[str, torch.Tensor]:
    return tokenizer(
        reports,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    )


def compute_pos_weight(train_dataset: Dataset, num_labels: int) -> torch.Tensor:
    ys = np.stack([train_dataset[i]["target"] for i in range(len(train_dataset))]).astype(np.float32)
    pos = ys.sum(axis=0)
    neg = ys.shape[0] - pos
    pw = (neg + 1e-6) / (pos + 1e-6)
    pw = np.clip(pw, 1.0, 100.0)
    return torch.tensor(pw, dtype=torch.float32)


def preds_topk(probs: np.ndarray, k: int) -> np.ndarray:
    k = max(1, min(k, probs.shape[1]))
    pred = np.zeros_like(probs, dtype=np.int32)
    topk_idx = np.argsort(-probs, axis=1)[:, :k]
    rows = np.arange(probs.shape[0])[:, None]
    pred[rows, topk_idx] = 1
    return pred


def precision_recall_at_k(y_true: np.ndarray, probs: np.ndarray, k: int) -> Tuple[float, float]:
    pred = preds_topk(probs, k)
    tp = (pred * y_true).sum(axis=1)
    prec = (tp / (k + 1e-8)).mean()
    true_pos = y_true.sum(axis=1)
    rec = (tp / (true_pos + 1e-8)).mean()
    return float(prec), float(rec)


def safe_auc_per_class(y_true: np.ndarray, y_score: np.ndarray) -> np.ndarray:
    num_labels = y_true.shape[1]
    out = np.full(num_labels, np.nan, dtype=np.float32)
    for j in range(num_labels):
        col = y_true[:, j]
        if col.min() == col.max():
            continue
        try:
            out[j] = float(roc_auc_score(col, y_score[:, j]))
        except Exception:
            out[j] = np.nan
    return out


def plot_history(history: Dict[str, List[float]], save_dir: str, client_id: int, topk: int) -> None:
    epochs = range(1, len(history["train_loss"]) + 1)

    def _plot(tr_key: str, va_key: str, title: str, fname: str, ylabel: str):
        plt.figure()
        plt.plot(epochs, history[tr_key], label="Train")
        plt.plot(epochs, history[va_key], label="Val")
        plt.title(title)
        plt.xlabel("Epoch")
        plt.ylabel(ylabel)
        plt.legend()
        plt.savefig(os.path.join(save_dir, fname), bbox_inches="tight")
        plt.show()
        plt.close()

    _plot("train_loss", "val_loss", "BioClinicalBERT Loss", f"biobert_loss_client{client_id}.png", "Loss")
    _plot("train_f1_macro_topk", "val_f1_macro_topk", f"F1 Macro (Top-{topk})", f"biobert_f1macro_topk_client{client_id}.png", "F1")
    _plot("train_f1_micro_topk", "val_f1_micro_topk", f"F1 Micro (Top-{topk})", f"biobert_f1micro_topk_client{client_id}.png", "F1")
    _plot("train_exact_match_topk", "val_exact_match_topk", f"Exact Match (Top-{topk})", f"biobert_exactmatch_topk_client{client_id}.png", "Exact Match")


def train_eval_biobert(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    label_names: List[str],
    train_dataset: Dataset,
    epochs: int = 20,
    lr: float = 2e-5,
    warmup_ratio: float = 0.1,
    max_len: int = 512,
    client_id: int = 0,
    topk: int = 5,
    use_pos_weight: bool = True,
) -> Tuple[Dict[str, List[float]], str]:
    num_labels = len(label_names)

    pos_weight = compute_pos_weight(train_dataset, num_labels).to(device) if use_pos_weight else None
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight) if pos_weight is not None else nn.BCEWithLogitsLoss()

    optimizer = optim.AdamW(model.parameters(), lr=lr)

    total_steps = len(train_loader) * epochs
    warmup_steps = int(warmup_ratio * total_steps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps,
    )

    scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))
    model.to(device)

    best_val_f1 = -1.0
    best_path = os.path.join(OUT_MODELS_DIR, f"best_biobert_client{client_id}.pt")

    history: Dict[str, List[float]] = {
        "train_loss": [],
        "val_loss": [],
        "train_f1_macro_topk": [],
        "val_f1_macro_topk": [],
        "train_f1_micro_topk": [],
        "val_f1_micro_topk": [],
        "train_f1_samples_topk": [],
        "val_f1_samples_topk": [],
        "train_exact_match_topk": [],
        "val_exact_match_topk": [],
        "train_p_at_k": [],
        "val_p_at_k": [],
        "train_r_at_k": [],
        "val_r_at_k": [],
        "train_f1_macro_thr": [],
        "val_f1_macro_thr": [],
    }

    log_rows: List[Dict[str, object]] = []

    def _metrics(y_true: np.ndarray, probs: np.ndarray) -> Dict[str, float]:
        pred_topk = preds_topk(probs, topk)
        f1_macro = f1_score(y_true, pred_topk, average="macro", zero_division=0)
        f1_micro = f1_score(y_true, pred_topk, average="micro", zero_division=0)
        f1_samples = f1_score(y_true, pred_topk, average="samples", zero_division=0)
        exact_match = float((y_true == pred_topk).all(axis=1).mean())
        p_at_k, r_at_k = precision_recall_at_k(y_true, probs, topk)
        pred_thr = (probs > 0.5).astype(np.int32)
        f1_macro_thr = f1_score(y_true, pred_thr, average="macro", zero_division=0)
        return {
            "f1_macro_topk": float(f1_macro),
            "f1_micro_topk": float(f1_micro),
            "f1_samples_topk": float(f1_samples),
            "exact_match_topk": float(exact_match),
            "p_at_k": float(p_at_k),
            "r_at_k": float(r_at_k),
            "f1_macro_thr": float(f1_macro_thr),
        }

    for epoch in range(1, epochs + 1):
        # ---- Train ----
        model.train()
        train_loss = 0.0
        y_true_train, y_prob_train = [], []

        for batch in train_loader:
            reports = batch["report"]
            labels = batch["target"].to(device, non_blocking=True)

            enc = tokenize_reports(reports, max_len=max_len)
            input_ids = enc["input_ids"].to(device, non_blocking=True)
            attention_mask = enc["attention_mask"].to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
                logits = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            train_loss += float(loss.item())
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            y_prob_train.append(probs)
            y_true_train.append(labels.detach().cpu().numpy())

        y_true_train = np.concatenate(y_true_train, axis=0).astype(np.int32)
        y_prob_train = np.concatenate(y_prob_train, axis=0)
        trm = _metrics(y_true_train, y_prob_train)

        history["train_loss"].append(train_loss / max(1, len(train_loader)))
        history["train_f1_macro_topk"].append(trm["f1_macro_topk"])
        history["train_f1_micro_topk"].append(trm["f1_micro_topk"])
        history["train_f1_samples_topk"].append(trm["f1_samples_topk"])
        history["train_exact_match_topk"].append(trm["exact_match_topk"])
        history["train_p_at_k"].append(trm["p_at_k"])
        history["train_r_at_k"].append(trm["r_at_k"])
        history["train_f1_macro_thr"].append(trm["f1_macro_thr"])

        # ---- Validation ----
        model.eval()
        val_loss = 0.0
        y_true_val, y_prob_val = [], []

        with torch.no_grad():
            for batch in val_loader:
                reports = batch["report"]
                labels = batch["target"].to(device, non_blocking=True)

                enc = tokenize_reports(reports, max_len=max_len)
                input_ids = enc["input_ids"].to(device, non_blocking=True)
                attention_mask = enc["attention_mask"].to(device, non_blocking=True)

                logits = model(input_ids=input_ids, attention_mask=attention_mask)
                loss = criterion(logits, labels)
                val_loss += float(loss.item())

                probs = torch.sigmoid(logits).detach().cpu().numpy()
                y_prob_val.append(probs)
                y_true_val.append(labels.detach().cpu().numpy())

        y_true_val = np.concatenate(y_true_val, axis=0).astype(np.int32)
        y_prob_val = np.concatenate(y_prob_val, axis=0)
        vam = _metrics(y_true_val, y_prob_val)

        history["val_loss"].append(val_loss / max(1, len(val_loader)))
        history["val_f1_macro_topk"].append(vam["f1_macro_topk"])
        history["val_f1_micro_topk"].append(vam["f1_micro_topk"])
        history["val_f1_samples_topk"].append(vam["f1_samples_topk"])
        history["val_exact_match_topk"].append(vam["exact_match_topk"])
        history["val_p_at_k"].append(vam["p_at_k"])
        history["val_r_at_k"].append(vam["r_at_k"])
        history["val_f1_macro_thr"].append(vam["f1_macro_thr"])

        pred_topk_val = preds_topk(y_prob_val, topk)
        f1_per_class = f1_score(y_true_val, pred_topk_val, average=None, zero_division=0)
        auc_per_class = safe_auc_per_class(y_true_val, y_prob_val)

        print(f"[Epoch {epoch:03d}/{epochs}]")
        print(
            f"Train — Loss {history['train_loss'][-1]:.4f} | "
            f"F1macro@{topk} {trm['f1_macro_topk']:.4f} | F1micro@{topk} {trm['f1_micro_topk']:.4f} | "
            f"P@{topk} {trm['p_at_k']:.4f} | R@{topk} {trm['r_at_k']:.4f} | "
            f"Exact {trm['exact_match_topk']:.4f} | F1macro(thr0.5) {trm['f1_macro_thr']:.4f}"
        )
        print(
            f"Val   — Loss {history['val_loss'][-1]:.4f} | "
            f"F1macro@{topk} {vam['f1_macro_topk']:.4f} | F1micro@{topk} {vam['f1_micro_topk']:.4f} | "
            f"P@{topk} {vam['p_at_k']:.4f} | R@{topk} {vam['r_at_k']:.4f} | "
            f"Exact {vam['exact_match_topk']:.4f} | F1macro(thr0.5) {vam['f1_macro_thr']:.4f}"
        )

        valid_auc_mask = ~np.isnan(auc_per_class)
        valid_idx = np.where(valid_auc_mask)[0]
        print(f"AUC valid for {int(valid_auc_mask.sum())}/{len(auc_per_class)} classes in validation.")
        print("Val per-class F1/AUC (first 5 VALID):")
        for j in valid_idx[:5]:
            print(f"  {label_names[j]}: F1={float(f1_per_class[j]):.4f}, AUC={float(auc_per_class[j]):.4f}")
        if len(valid_idx) == 0:
            print("  (No valid AUC classes in this validation split.)")
        print()

        saved_ckpt = False
        if vam["f1_macro_topk"] > best_val_f1:
            best_val_f1 = float(vam["f1_macro_topk"])
            torch.save(model.state_dict(), best_path)
            saved_ckpt = True
            print(f"[Checkpoint] Saved best BioBERT (Val F1macro@{topk}={best_val_f1:.4f}) -> {best_path}\n")

        log_rows.append(
            {
                "epoch": epoch,
                "train_loss": history["train_loss"][-1],
                "val_loss": history["val_loss"][-1],
                "train_f1_macro_topk": history["train_f1_macro_topk"][-1],
                "val_f1_macro_topk": history["val_f1_macro_topk"][-1],
                "train_f1_micro_topk": history["train_f1_micro_topk"][-1],
                "val_f1_micro_topk": history["val_f1_micro_topk"][-1],
                "train_f1_samples_topk": history["train_f1_samples_topk"][-1],
                "val_f1_samples_topk": history["val_f1_samples_topk"][-1],
                "train_exact_match_topk": history["train_exact_match_topk"][-1],
                "val_exact_match_topk": history["val_exact_match_topk"][-1],
                "train_p_at_k": history["train_p_at_k"][-1],
                "val_p_at_k": history["val_p_at_k"][-1],
                "train_r_at_k": history["train_r_at_k"][-1],
                "val_r_at_k": history["val_r_at_k"][-1],
                "train_f1_macro_thr": history["train_f1_macro_thr"][-1],
                "val_f1_macro_thr": history["val_f1_macro_thr"][-1],
                "saved_ckpt": saved_ckpt,
                "marker": "X" if saved_ckpt else "",
            }
        )

    metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"biobert_metrics_client{client_id}.csv")
    pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
    print(f"[Log] Wrote BioBERT per-epoch metrics CSV -> {metrics_csv}\n")

    plot_history(history, OUT_GRAPHS_DIR, client_id, topk)
    return history, best_path


if __name__ == "__main__":
    # Load client split
    cdir = os.path.join(CLIENT_BASE_DIR, f"client_{CLIENT_ID}")
    with open(os.path.join(cdir, "train_pids.json"), "r") as f:
        train_pids = json.load(f)
    with open(os.path.join(cdir, "val_pids.json"), "r") as f:
        val_pids = json.load(f)

    # Build datasets
    full_ds = ReportsGliomaDataset(
        METADATA_DF_PATH,
        LABELS_PATH,
        exclude_ids=["PatientID_0191"],
        remove_background=REMOVE_BACKGROUND_LABEL,
    )
    label_names = full_ds.labels

    train_dataset = SubsetByPIDs(full_ds, train_pids)
    val_dataset = SubsetByPIDs(full_ds, val_pids)

    # DataLoaders
    g = torch.Generator().manual_seed(SEED)
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=reports_collate_fn,
        worker_init_fn=worker_init_fn,
        generator=g,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=reports_collate_fn,
        worker_init_fn=worker_init_fn,
        generator=g,
    )

    print(
        f"Loaded client_{CLIENT_ID}: "
        f"train patients={len(train_dataset)}, val patients={len(val_dataset)}, num_labels={len(label_names)}"
    )

    # Model + training
    model = BioBERTMultiLabelClassifier(
        BERT_MODEL_NAME,
        num_labels=len(label_names),
        dropout=DROPOUT,
    )

    history, best_model_path = train_eval_biobert(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        label_names=label_names,
        train_dataset=train_dataset,
        epochs=EPOCHS,
        lr=LR,
        warmup_ratio=WARMUP_RATIO,
        max_len=MAX_LEN,
        client_id=CLIENT_ID,
        topk=TOPK_PRED,
        use_pos_weight=USE_POS_WEIGHT,
    )

    print(f"[Done] Best BioBERT model saved at: {best_model_path}")

**<h2>Clip-like model<h2>**

In [None]:
import os
import pickle
import random
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup

import segmentation_models_pytorch as smp
from sklearn.metrics import jaccard_score


# ------------------------
# Global config
# ------------------------
SEED = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ------------------------
# Reproducibility helpers
# ------------------------
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)


def worker_init_fn(worker_id: int) -> None:
    # Make each dataloader worker deterministic
    s = SEED + worker_id
    np.random.seed(s)
    random.seed(s)


set_seed(SEED)


# ------------------------
# Small utilities
# ------------------------
def minmax01(x: np.ndarray) -> np.ndarray:
    # Min-max normalize to [0, 1]
    x = x.astype(np.float32)
    mn = float(np.min(x))
    mx = float(np.max(x))
    if mx <= mn:
        return np.zeros_like(x, dtype=np.float32)
    return (x - mn) / (mx - mn)


def find_first_existing(paths: List[str]) -> str:
    # Pick the first path that exists
    for p in paths:
        if p and os.path.exists(p):
            return p
    raise FileNotFoundError("None of these paths exist:\n" + "\n".join(paths))


def ensure_dir(p: str) -> None:
    os.makedirs(p, exist_ok=True)


# ------------------------
# Dataset: loads MRI + atlas regions + tumor mask + report text
# Also builds a multi-label target from "Top 5 Regions"
# ------------------------
class GliomaDataset(Dataset):
    def __init__(
        self,
        metadata_df_path: str,
        labels_path: str,
        data_root: str,
        exclude_ids: Optional[List[str]] = None,
        remove_background: bool = False,
    ):
        # Load metadata dataframe (contains report + labels per patient)
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]

        # Filter invalid rows and excluded patients
        df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        df = df[df["Report"].notna() & df["Top 5 Regions"].notna()].reset_index(drop=True)

        # Load label list
        with open(labels_path, "rb") as f:
            labels = list(pickle.load(f))

        if remove_background:
            labels = [l for l in labels if str(l).strip().lower() != "background"]

        self.labels = labels
        self.label_to_idx = {lab: i for i, lab in enumerate(self.labels)}

        self.data_root = data_root
        self.df = df

        # Encode some categorical columns
        self.categorical_cols = [
            "Sex at Birth",
            "Race",
            "Primary Diagnosis",
            "Previous Brain Tumor",
            "Type of previous brain tumor",
            "Age Range",
        ]
        self.code_maps: Dict[str, Dict[int, Any]] = {}
        for col in self.categorical_cols:
            cat = pd.Categorical(self.df[col])
            self.df[col + "_code"] = cat.codes.astype(np.int64)
            self.code_maps[col] = dict(enumerate(cat.categories))

        # Keep only patients that have all required .npy files
        self.patient_ids: List[str] = []
        for pid in self.df["Patient_ID"].tolist():
            base = os.path.join(self.data_root, pid)
            mri_p = os.path.join(base, f"{pid}_mri.npy")
            reg_p = os.path.join(base, f"{pid}_regions.npy")
            tumor_p = os.path.join(base, f"{pid}_tumor.npy")
            if os.path.isfile(mri_p) and os.path.isfile(reg_p) and os.path.isfile(tumor_p):
                self.patient_ids.append(pid)

        self.df = self.df[self.df["Patient_ID"].isin(self.patient_ids)].reset_index(drop=True)

    def __len__(self) -> int:
        return len(self.df)

    def _make_target_regions(self, regions_list: Any) -> np.ndarray:
        # Multi-label vector over all region labels
        y = np.zeros(len(self.labels), dtype=np.float32)
        if not isinstance(regions_list, (list, tuple)):
            return y
        for reg in regions_list:
            idx = self.label_to_idx.get(reg, None)
            if idx is not None:
                y[idx] = 1.0
        return y

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]
        pid = row["Patient_ID"]
        base = os.path.join(self.data_root, pid)

        # Load arrays
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize inputs, binarize mask
        mri = minmax01(mri)
        regions = minmax01(regions)
        tumor = (tumor > 0.5).astype(np.float32)

        target_regions = self._make_target_regions(row["Top 5 Regions"])

        return {
            "patient_id": pid,
            "mri": mri,
            "regions": regions,
            "tumor": tumor,
            "report": str(row["Report"]),
            "target_regions": target_regions,
            # Extra metadata
            "sex": int(row["Sex at Birth_code"]),
            "race": int(row["Race_code"]),
            "age": float(row["Age at diagnosis"]),
            "primary_diagnosis": int(row["Primary Diagnosis_code"]),
            "h3_3a_mutation": float(row["H3-3A mutation"]),
            "pten_mutation": float(row["PTEN mutation"]),
            "CDKN2A_B_deletion": float(row["CDKN2A/B deletion"]),
            "TP53_alteration": float(row["TP53 alteration"]),
            "other_mutations_alterations": row["Other mutations/alterations"],
            "previous_brain_tumor": int(row["Previous Brain Tumor_code"]),
            "type_of_previous_brain_tumor": int(row["Type of previous brain tumor_code"]),
            "age_range": int(row["Age Range_code"]),
        }


def glioma_collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    # Build tensors for MRI/regions/mask and keep report text as list
    mri = torch.from_numpy(np.stack([b["mri"] for b in batch])).float().unsqueeze(1)
    regions = torch.from_numpy(np.stack([b["regions"] for b in batch])).float().unsqueeze(1)
    tumor = torch.from_numpy(np.stack([b["tumor"] for b in batch])).float().unsqueeze(1)
    target_regions = torch.from_numpy(np.stack([b["target_regions"] for b in batch])).float()

    return {
        "patient_id": [b["patient_id"] for b in batch],
        "report": [b["report"] for b in batch],
        "mri": mri,
        "regions": regions,
        "tumor": tumor,
        "target_regions": target_regions,
    }


# ------------------------
# Text encoder: BioClinicalBERT + linear head (head not used later)
# ------------------------
class BioBERTMultiLabelClassifier(nn.Module):
    def __init__(self, model_name: str, num_labels: int, dropout: float = 0.3):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        if pooled is None:
            pooled = out.last_hidden_state[:, 0]  # CLS
        x = self.dropout(pooled)
        return self.classifier(x)


def tokenize_reports(tokenizer: AutoTokenizer, reports: List[str], max_len: int = 512) -> Dict[str, torch.Tensor]:
    # Tokenize a list of report texts for BERT
    return tokenizer(
        reports,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
    )


def dice_iou_acc(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float, float]:
    # Compute Dice, IoU and pixel accuracy for segmentation masks
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)
    inter = int((y_true * y_pred).sum())
    dice = (2.0 * inter) / (float(y_true.sum() + y_pred.sum()) + 1e-8)
    try:
        iou = float(jaccard_score(y_true, y_pred, average="binary"))
    except Exception:
        union = float(y_true.sum() + y_pred.sum() - inter) + 1e-8
        iou = float(inter / union)
    acc = float((y_true == y_pred).mean())
    return float(dice), float(iou), float(acc)


# ------------------------
# CLIP-like model: aligns text embedding with image embedding
# and also predicts segmentation mask with UNet
# ------------------------
class ClipModel(nn.Module):
    def __init__(
        self,
        bert_backbone: AutoModel,
        tokenizer: AutoTokenizer,
        unet_model: nn.Module,
        embed_dim: int = 512,
        max_len: int = 512,
    ):
        super().__init__()
        self.bert_backbone = bert_backbone
        self.tokenizer = tokenizer
        self.unet_model = unet_model
        self.max_len = max_len

        # Projection layers to a shared embedding space
        self.text_projection = nn.Linear(self.bert_backbone.config.hidden_size, embed_dim)
        self.image_projection = nn.Linear(512, embed_dim)

        # CLIP temperature parameter and uncertainty weights for multitask loss
        self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / 0.07), dtype=torch.float32))
        self.log_sigma_clip = nn.Parameter(torch.tensor(0.0))
        self.log_sigma_seg = nn.Parameter(torch.tensor(0.0))

    def encode_text(self, texts: List[str]) -> torch.Tensor:
        # Encode reports -> normalized text embeddings
        enc = tokenize_reports(self.tokenizer, texts, max_len=self.max_len)
        input_ids = enc["input_ids"].to(DEVICE, non_blocking=True)
        attention_mask = enc["attention_mask"].to(DEVICE, non_blocking=True)

        out = self.bert_backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        if pooled is None:
            pooled = out.last_hidden_state[:, 0]

        t = self.text_projection(pooled)
        return F.normalize(t, dim=-1)

    def encode_image(self, x: torch.Tensor) -> torch.Tensor:
        # Encode image -> normalized image embeddings
        feats = self.unet_model.encoder(x)[-1]
        v = feats.mean(dim=(2, 3))
        v = self.image_projection(v)
        return F.normalize(v, dim=-1)

    def forward(
        self, texts: List[str], mri: torch.Tensor, regions: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Build 2-channel input (MRI + atlas regions)
        x = torch.cat([mri, regions], dim=1).to(DEVICE, non_blocking=True)

        # CLIP embeddings + segmentation logits
        t = self.encode_text(texts)
        v = self.encode_image(x)
        seg_logits = self.unet_model(x)
        return t, v, seg_logits

    def clip_contrastive_loss(self, t: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        # Standard CLIP-style symmetric cross-entropy loss over batch similarities
        logit_scale = self.logit_scale.exp().clamp(1e-3, 100.0)
        logits = (t @ v.t()) * logit_scale
        labels = torch.arange(t.size(0), device=logits.device)
        return 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels))

    def combined_loss(
        self,
        t: torch.Tensor,
        v: torch.Tensor,
        seg_logits: torch.Tensor,
        seg_target: torch.Tensor,
        seg_criterion: nn.Module,
    ) -> torch.Tensor:
        # Multi-task loss with learned uncertainty weighting
        clip_loss = self.clip_contrastive_loss(t, v)
        seg_loss = seg_criterion(seg_logits, seg_target.to(seg_logits.device, non_blocking=True).float())
        return (
            (1.0 / (2.0 * torch.exp(self.log_sigma_clip) ** 2)) * clip_loss
            + (1.0 / (2.0 * torch.exp(self.log_sigma_seg) ** 2)) * seg_loss
            + self.log_sigma_clip
            + self.log_sigma_seg
        )


def load_biobert_backbone_only(bert_wrapper: BioBERTMultiLabelClassifier, ckpt_path: str) -> None:
    # Load only BERT weights (skip classifier head)
    sd = torch.load(ckpt_path, map_location="cpu")
    sd = {k: v for k, v in sd.items() if not k.startswith("classifier.")}
    bert_wrapper.load_state_dict(sd, strict=False)


def build_unet(in_channels: int = 2, encoder_name: str = "resnet34", encoder_weights: Optional[str] = None) -> nn.Module:
    # UNet used both for segmentation and as image feature encoder
    return smp.Unet(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=1,
        activation=None,
    )


def train_one_epoch_clip(
    model: ClipModel,
    loader: DataLoader,
    optimizer: optim.Optimizer,
    seg_criterion: nn.Module,
    scaler: torch.amp.GradScaler,
) -> Dict[str, float]:
    # One epoch of training (contrastive + segmentation)
    model.train()
    tot_loss = tot_dice = tot_iou = tot_acc = 0.0
    n = 0

    for batch in loader:
        texts = batch["report"]
        mri = batch["mri"].to(DEVICE, non_blocking=True)
        regions = batch["regions"].to(DEVICE, non_blocking=True)
        tumor = batch["tumor"].to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
            t, v, seg_logits = model(texts, mri, regions)
            loss = model.combined_loss(t, v, seg_logits, tumor, seg_criterion)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        with torch.no_grad():
            preds = (torch.sigmoid(seg_logits).detach().cpu().numpy() > 0.5).astype(np.uint8)
            y_np = (tumor.detach().cpu().numpy() > 0.5).astype(np.uint8)
            d, i, a = dice_iou_acc(y_np, preds)

        tot_loss += float(loss.item())
        tot_dice += d
        tot_iou += i
        tot_acc += a
        n += 1

    return {
        "loss": tot_loss / max(1, n),
        "dice": tot_dice / max(1, n),
        "iou": tot_iou / max(1, n),
        "acc": tot_acc / max(1, n),
    }


@torch.no_grad()
def eval_one_epoch_clip(model: ClipModel, loader: DataLoader, seg_criterion: nn.Module) -> Dict[str, float]:
    # Validation epoch
    model.eval()
    tot_loss = tot_dice = tot_iou = tot_acc = 0.0
    n = 0

    for batch in loader:
        texts = batch["report"]
        mri = batch["mri"].to(DEVICE, non_blocking=True)
        regions = batch["regions"].to(DEVICE, non_blocking=True)
        tumor = batch["tumor"].to(DEVICE, non_blocking=True)

        t, v, seg_logits = model(texts, mri, regions)
        loss = model.combined_loss(t, v, seg_logits, tumor, seg_criterion)

        preds = (torch.sigmoid(seg_logits).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (tumor.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = dice_iou_acc(y_np, preds)

        tot_loss += float(loss.item())
        tot_dice += d
        tot_iou += i
        tot_acc += a
        n += 1

    return {
        "loss": tot_loss / max(1, n),
        "dice": tot_dice / max(1, n),
        "iou": tot_iou / max(1, n),
        "acc": tot_acc / max(1, n),
    }


def main():
    candidates_metadata = [
        "/content/cleaned_df.pkl",
        "/content/drive/MyDrive/PKG - MU-Glioma-Post/cleaned_df.pkl",
        "/content/cleaned_df.pkl",
    ]
    candidates_labels = [
        "/content/labels_list.pkl",
        "/content/drive/MyDrive/PKG - MU-Glioma-Post/labels_list.pkl",
        "/content/labels_list.pkl",
    ]
    candidates_data_root = [
        "/content/Preprocessed-Data",
        "/content/drive/MyDrive/PKG - MU-Glioma-Post/Preprocessed-Data",
        "/content/Preprocessed-Data",
    ]

    METADATA_DF_PATH = find_first_existing(candidates_metadata)
    LABELS_PATH = find_first_existing(candidates_labels)
    DATA_ROOT = find_first_existing(candidates_data_root)

    BERT_MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

    # Optional pretrained checkpoints
    BIOBERT_BEST_CAND = [
        "/content/AITDM/Models/BioClinicalBert/best_biobert_client0.pt",
        "/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/BioClinicalBert/best.pt",
        "/content/AITDM/Models/BioClinicalBert/best_biobert_client0.pt",
    ]
    UNET_BEST_CAND = [
        "/content/AITDM/Models/UNet_ImageOnly/best_unet_client0.pth",
        "/content/AITDM/Models/UNet_ImageOnly/best_unet_client0.pth",
        "/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/UNet/best_unet_model.pth",
        "/content/AITDM/Models/UNet_ImageOnly/best_unet_client0.pth",
        "/content/AITDM/Models/UNet_ImageOnly/best_unet_client0.pth",
    ]

    BIOBERT_BEST = next((p for p in BIOBERT_BEST_CAND if os.path.isfile(p)), "")
    UNET_BEST = next((p for p in UNET_BEST_CAND if os.path.isfile(p)), "")

    # Output checkpoint for the CLIP-like model
    CLIP_SAVE = "/content/drive/MyDrive/PKG - MU-Glioma-Post/Models/CLIP_Based/best_clip_model.pth"
    ensure_dir(os.path.dirname(CLIP_SAVE))

    # Training hyperparams
    BATCH_SIZE = 4
    EPOCHS = 50
    LR = 1e-4
    MAX_LEN = 512

    print("Resolved paths:")
    print("  metadata:", METADATA_DF_PATH)
    print("  labels  :", LABELS_PATH)
    print("  data_root:", DATA_ROOT)
    print("  biobert ckpt:", BIOBERT_BEST if BIOBERT_BEST else "(not found, will use base)")
    print("  unet ckpt  :", UNET_BEST if UNET_BEST else "(not found, will use random init)")

    # Build dataset and split train/test
    dataset = GliomaDataset(
        metadata_df_path=METADATA_DF_PATH,
        labels_path=LABELS_PATH,
        data_root=DATA_ROOT,
        exclude_ids=["PatientID_0191"],
        remove_background=False,
    )

    test_size = int(0.2 * len(dataset))
    train_size = len(dataset) - test_size
    train_dataset, test_dataset = random_split(
        dataset, [train_size, test_size], generator=torch.Generator().manual_seed(SEED)
    )

    # Dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
        collate_fn=glioma_collate_fn,
        worker_init_fn=worker_init_fn,
        drop_last=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        collate_fn=glioma_collate_fn,
        worker_init_fn=worker_init_fn,
        drop_last=False,
    )

    # Tokenizer + BERT backbone
    tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)
    bert_wrapper = BioBERTMultiLabelClassifier(BERT_MODEL_NAME, num_labels=len(dataset.labels), dropout=0.3)
    if BIOBERT_BEST:
        try:
            bert_wrapper.load_state_dict(torch.load(BIOBERT_BEST, map_location="cpu"))
        except RuntimeError:
            load_biobert_backbone_only(bert_wrapper, BIOBERT_BEST)
    bert_backbone = bert_wrapper.bert.to(DEVICE)

    # UNet
    unet_model = build_unet(in_channels=2, encoder_name="resnet34", encoder_weights=None)
    if UNET_BEST:
        unet_model.load_state_dict(torch.load(UNET_BEST, map_location="cpu"))
    unet_model = unet_model.to(DEVICE)

    # Build CLIP-like multimodal model
    clip_model = ClipModel(
        bert_backbone=bert_backbone,
        tokenizer=tokenizer,
        unet_model=unet_model,
        embed_dim=512,
        max_len=MAX_LEN,
    ).to(DEVICE)

    # Loss/optimizer/scheduler
    seg_criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    optimizer = optim.AdamW(clip_model.parameters(), lr=LR)

    total_steps = len(train_loader) * EPOCHS
    warmup_steps = int(0.1 * total_steps)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )

    scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

    # Track best checkpoint by (Dice up) and (loss down)
    best_val_dice = -1.0
    best_val_loss = float("inf")

    for epoch in range(1, EPOCHS + 1):
        tr = train_one_epoch_clip(clip_model, train_loader, optimizer, seg_criterion, scaler)
        va = eval_one_epoch_clip(clip_model, test_loader, seg_criterion)

        # Step LR scheduler once per training step (done here in a loop)
        for _ in range(len(train_loader)):
            scheduler.step()

        print(
            f"[Epoch {epoch:03d}/{EPOCHS}] "
            f"Train loss {tr['loss']:.4f} dice {tr['dice']:.4f} iou {tr['iou']:.4f} acc {tr['acc']:.4f} || "
            f"Val loss {va['loss']:.4f} dice {va['dice']:.4f} iou {va['iou']:.4f} acc {va['acc']:.4f}"
        )

        # Save best model
        if (va["dice"] > best_val_dice) and (va["loss"] < best_val_loss):
            best_val_dice = va["dice"]
            best_val_loss = va["loss"]
            torch.save(clip_model.state_dict(), CLIP_SAVE)
            print(f"Saved best model -> {CLIP_SAVE} (val dice {best_val_dice:.4f}, val loss {best_val_loss:.4f})")


if __name__ == "__main__":
    main()

# **M2 - Ensemble on Images only**

In [None]:
import segmentation_models_pytorch as smp
print(list(smp.encoders.get_encoder_names()))

**<h2>UNet - "resnet50"<h2>**

In [None]:
import os, json, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from typing import List
from sklearn.metrics import jaccard_score
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

# Paths and I/O config
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "/content/client")

# Experiment config
USE_ATLAS = True
CLIENT_ID = 0
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 50
NUM_WORKERS = 2
SEED = 42

# Model / encoder config
ENCODER_NAME = "resnet50"
ENCODER_WEIGHTS = "imagenet"

# Output directories
OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "UNet_ImageOnly")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

def sanitize(s: str) -> str:
    return str(s).replace("/", "-").replace(" ", "_")

encoder_tag = sanitize(ENCODER_NAME)
weights_tag = sanitize(ENCODER_WEIGHTS) if ENCODER_WEIGHTS is not None else "none"
atlas_tag = "atlas" if USE_ATLAS else "img"

run_tag = f"unet_{encoder_tag}_{weights_tag}_{atlas_tag}"

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True, warn_only=True)

def worker_init_fn(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask, and optional atlas regions per patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        import pickle, os, numpy as np, pandas as pd

        # Load metadata dataframe
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude some patient IDs
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)

        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")

            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Simple min-max normalization to [0, 1]."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load and return one sample (dict) for a patient."""
        import os, numpy as np

        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load regions/atlas
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """Custom collate: stack MRI (+ optional regions) and tumor into tensors."""
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        # Concatenate MRI and regions as channels
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

class SubsetByPIDs(Dataset):
    """Wrap a dataset but keep only a subset of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        # Map patient IDs to indices
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Load client-specific train/val patient IDs
cdir = os.path.join(CLIENT_DIR, f"client_{CLIENT_ID}")
with open(os.path.join(cdir, "train_pids.json"), "r") as f:
    train_pids = json.load(f)
with open(os.path.join(cdir, "val_pids.json"), "r") as f:
    val_pids = json.load(f)

# Build full dataset and then client subsets
full_ds = ImageOnlyGliomaDataset(
    METADATA_DF_PATH,
    DATA_ROOT,
    use_atlas=USE_ATLAS,
    exclude_ids=["PatientID_0191"],
)
train_dataset = SubsetByPIDs(full_ds, train_pids)
val_dataset = SubsetByPIDs(full_ds, val_pids)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)

print(f"Loaded client_{CLIENT_ID}: train patients={len(train_dataset)}, val patients={len(val_dataset)}")

# UNet model (from segmentation_models_pytorch)
in_channels = 2 if USE_ATLAS else 1
model = smp.Unet(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=in_channels,
    classes=1,
).to(device)

# Loss, optimizer, scheduler, and AMP scaler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=EPOCHS
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

def calc_metrics(y_true, y_pred):
    """Compute Dice, IoU, and pixel accuracy."""
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    # IoU with sklearn, fallback if it fails
    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)

def plot_prediction(sample_imgs, save_path):
    """Plot MRI (+ optional regions), GT, and prediction for one batch."""
    if "regions" in sample_imgs:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # MRI
    axs[0].imshow(sample_imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    col = 1

    # Regions (if available)
    if "regions" in sample_imgs:
        axs[col].imshow(sample_imgs["regions"][0, 0], cmap="gray")
        axs[col].set_title("Regions")
        axs[col].axis("off")
        col += 1

    # Ground truth
    axs[col].imshow(sample_imgs["y_true"][0, 0], cmap="gray")
    axs[col].set_title("Ground Truth")
    axs[col].axis("off")
    col += 1

    # Prediction
    axs[col].imshow(sample_imgs["y_pred"][0, 0], cmap="gray")
    axs[col].set_title("Predicted")
    axs[col].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_metrics(history, save_path):
    """Plot training and validation curves for loss and metrics."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(16, 12))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss")

    # Dice
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history["train_dice"], label="Train Dice")
    plt.plot(epochs, history["val_dice"], label="Val Dice")
    plt.legend()
    plt.title("Dice")

    # IoU
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history["train_iou"], label="Train IoU")
    plt.plot(epochs, history["val_iou"], label="Val IoU")
    plt.legend()
    plt.title("IoU")

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def train_one_epoch(model, dataloader, optimizer, criterion, scaler):
    """One training epoch over the dataloader."""
    model.train()
    total_loss = total_dice = total_iou = total_acc = 0.0

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        # Mixed precision forward pass
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(x)
            loss = criterion(preds, y)

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item())

        # Convert predictions to binary and compute metrics
        preds_prob = (torch.sigmoid(preds).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

@torch.no_grad()
def eval_one_epoch(model, dataloader, criterion, keep_last_batch=True):
    """Evaluation loop over the validation dataloader."""
    model.eval()
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        preds = model(x)
        loss = criterion(preds, y)
        total_loss += float(loss.item())

        preds_prob = (torch.sigmoid(preds).cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

        # Optionally keep last batch for visualization
        if keep_last_batch:
            ch0 = x[:, 0:1].cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.cpu().numpy(), "y_pred": preds_prob}
            if x.shape[1] == 2:
                ch1 = x[:, 1:2].cpu().numpy()
                imgs["regions"] = ch1
            last_batch_imgs = imgs

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def infer_and_visualize_best(
    model,
    val_dataset,
    use_atlas: bool,
    out_dir: str,
    client_id: int,
    best_ckpt_path: str,
    k_samples: int = 3,
    threshold: float = 0.5,
):
    """Load best checkpoint, run inference on a few validation samples, and save visualizations."""
    import random

    # Make sampling deterministic
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    os.makedirs(out_dir, exist_ok=True)

    if os.path.isfile(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        model.eval()
        print(f"[Best Model] Loaded: {best_ckpt_path}\n")
    else:
        print(f"[Best Model] Missing checkpoint: {best_ckpt_path}\n")
        return

    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("[Best Model] Empty val dataset.\n")
        return

    # Deterministic random subset of validation indices
    idxs = random.sample(range(len(val_dataset)), k)

    def _predict_one(sample):
        """Run model on one sample and return inputs and binarized prediction."""
        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float().to(device)

        if use_atlas and ("regions" in sample):
            regs = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float().to(device)
            x = torch.cat([mri, regs], dim=1)
        else:
            x = mri

        with torch.no_grad():
            prob = torch.sigmoid(model(x)).cpu().numpy()
            pred_bin = (prob > threshold).astype(np.uint8)

        return x.cpu().numpy(), pred_bin

    saved_paths = []

    # Save individual sample figures
    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")
        x_np, pred_bin = _predict_one(sample)

        imgs = {
            "mri": x_np[:, 0:1],
            "y_true": np.expand_dims(
                np.expand_dims(sample["tumor"], 0), 0
            ).astype(np.float32),
            "y_pred": pred_bin.astype(np.float32),
        }

        if use_atlas and ("regions" in sample):
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(
            out_dir, f"best_val_sample_{i}_client{client_id}_{pid}.png"
        )
        plot_prediction(imgs, out_path)
        saved_paths.append(out_path)

    print("[Best Model] Saved individual figures:")
    for p in saved_paths:
        print(" -", p)
    print()

    # Save grid figure
    cols = 4 if use_atlas else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))

    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]
        x_np, pred_bin = _predict_one(sample)

        mri = x_np[0, 0]
        gt = sample["tumor"]
        col = 0

        axs[row, col].imshow(mri, cmap="gray")
        axs[row, col].set_title("MRI")
        axs[row, col].axis("off")
        col += 1

        if use_atlas and ("regions" in sample):
            regs = x_np[0, 1]
            axs[row, col].imshow(regs, cmap="gray")
            axs[row, col].set_title("Regions")
            axs[row, col].axis("off")
            col += 1

        axs[row, col].imshow(gt, cmap="gray")
        axs[row, col].set_title("Ground Truth")
        axs[row, col].axis("off")
        col += 1

        axs[row, col].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, col].set_title("Predicted (τ=0.5)")
        axs[row, col].axis("off")

    plt.tight_layout()
    grid_path = os.path.join(out_dir, f"best_model_val_grid_{run_tag}_client{client_id}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Best Model] Saved grid -> {grid_path}\n")

# History containers for training curves
history = {
    "train_loss": [],
    "train_dice": [],
    "train_iou": [],
    "train_acc": [],
    "val_loss": [],
    "val_dice": [],
    "val_iou": [],
    "val_acc": [],
}

# Tracking best validation metrics
best_val_iou = 0.0
best_val_dice = -float("inf")
best_val_loss = float("inf")
best_path = os.path.join(OUT_MODELS_DIR, f"best_{run_tag}_client{CLIENT_ID}.pth")

log_rows = []

# Main training loop
for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    trL, trD, trI, trA = train_one_epoch(
        model, train_loader, optimizer, criterion, scaler
    )
    # Validate
    vaL, vaD, vaI, vaA, _ = eval_one_epoch(
        model, val_loader, criterion, keep_last_batch=True
    )

    # Update scheduler
    scheduler.step()

    # Save metrics for plots
    history["train_loss"].append(trL)
    history["train_dice"].append(trD)
    history["train_iou"].append(trI)
    history["train_acc"].append(trA)

    history["val_loss"].append(vaL)
    history["val_dice"].append(vaD)
    history["val_iou"].append(vaI)
    history["val_acc"].append(vaA)

    print(
        f"[Epoch {epoch:03d}/{EPOCHS}] "
        f"Train — Loss {trL:.4f} | Dice {trD:.4f} | IoU {trI:.4f} | Accuracy {trA:.4f} || "
        f"Val — Loss {vaL:.4f} | Dice {vaD:.4f} | IoU {vaI:.4f} | Accuracy {vaA:.4f}\n"
    )

    # Save checkpoint if Dice improves and loss decreases
    saved_ckpt = False
    if (vaD > best_val_dice) and (vaL < best_val_loss):
        best_val_dice = vaD
        best_val_loss = vaL
        torch.save(model.state_dict(), best_path)
        saved_ckpt = True
        print(
            f"Saved best model (Val Dice↑ {best_val_dice:.4f} & Val Loss↓ {best_val_loss:.4f}) -> {best_path}\n"
        )

    # Log row for CSV
    log_rows.append(
        {
            "epoch": epoch,
            "train_loss": trL,
            "train_dice": trD,
            "train_iou": trI,
            "train_acc": trA,
            "val_loss": vaL,
            "val_dice": vaD,
            "val_iou": vaI,
            "val_acc": vaA,
            "saved_ckpt": saved_ckpt,
            "marker": "X" if saved_ckpt else "",
        }
    )

# Save metrics as CSV
metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"metrics_{run_tag}_client{CLIENT_ID}.csv")
pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
print(f"[Log] Wrote per-epoch metrics CSV -> {metrics_csv}\n")

# Plot training curves (tagged)
curves_path = os.path.join(OUT_GRAPHS_DIR, f"training_curves_{run_tag}_client{CLIENT_ID}.png")
plot_metrics(history, curves_path)

# Run inference with best model and visualize a few validation samples
infer_and_visualize_best(
    model=model,
    val_dataset=val_dataset,
    use_atlas=USE_ATLAS,
    out_dir=OUT_GRAPHS_DIR,
    client_id=CLIENT_ID,
    best_ckpt_path=best_path,
    k_samples=3,
    threshold=0.5,
)

**<h2>UNet - "mit_b3"<h2>**

In [None]:
import os, json, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from typing import List
from sklearn.metrics import jaccard_score
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

# Paths and I/O config
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "/content/client")

# Experiment config
USE_ATLAS = True
CLIENT_ID = 0
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 50
NUM_WORKERS = 2
SEED = 42

# Model / encoder config
ENCODER_NAME = "mit_b3"
ENCODER_WEIGHTS = "imagenet"

# Output directories
OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "UNet_ImageOnly")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

def sanitize(s: str) -> str:
    return str(s).replace("/", "-").replace(" ", "_")

encoder_tag = sanitize(ENCODER_NAME)
weights_tag = sanitize(ENCODER_WEIGHTS) if ENCODER_WEIGHTS is not None else "none"
atlas_tag = "atlas" if USE_ATLAS else "img"

run_tag = f"unet_{encoder_tag}_{weights_tag}_{atlas_tag}"

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(False)

def worker_init_fn(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask, and optional atlas regions per patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        import pickle, os, numpy as np, pandas as pd

        # Load metadata dataframe
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude some patient IDs
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)

        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")

            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Simple min-max normalization to [0, 1]."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load and return one sample (dict) for a patient."""
        import os, numpy as np

        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load regions/atlas
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """Custom collate: stack MRI (+ optional regions) and tumor into tensors."""
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        # Concatenate MRI and regions as channels
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

class SubsetByPIDs(Dataset):
    """Wrap a dataset but keep only a subset of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        # Map patient IDs to indices
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Load client-specific train/val patient IDs
cdir = os.path.join(CLIENT_DIR, f"client_{CLIENT_ID}")
with open(os.path.join(cdir, "train_pids.json"), "r") as f:
    train_pids = json.load(f)
with open(os.path.join(cdir, "val_pids.json"), "r") as f:
    val_pids = json.load(f)

# Build full dataset and then client subsets
full_ds = ImageOnlyGliomaDataset(
    METADATA_DF_PATH,
    DATA_ROOT,
    use_atlas=USE_ATLAS,
    exclude_ids=["PatientID_0191"],
)
train_dataset = SubsetByPIDs(full_ds, train_pids)
val_dataset = SubsetByPIDs(full_ds, val_pids)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)

print(f"Loaded client_{CLIENT_ID}: train patients={len(train_dataset)}, val patients={len(val_dataset)}")

# UNet model (from segmentation_models_pytorch)
in_channels = 2 if USE_ATLAS else 1
model = smp.Unet(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=in_channels,
    classes=1,
).to(device)

# Loss, optimizer, scheduler, and AMP scaler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=EPOCHS
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

def calc_metrics(y_true, y_pred):
    """Compute Dice, IoU, and pixel accuracy."""
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    # IoU with sklearn, fallback if it fails
    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)

def plot_prediction(sample_imgs, save_path):
    """Plot MRI (+ optional regions), GT, and prediction for one batch."""
    if "regions" in sample_imgs:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # MRI
    axs[0].imshow(sample_imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    col = 1

    # Regions (if available)
    if "regions" in sample_imgs:
        axs[col].imshow(sample_imgs["regions"][0, 0], cmap="gray")
        axs[col].set_title("Regions")
        axs[col].axis("off")
        col += 1

    # Ground truth
    axs[col].imshow(sample_imgs["y_true"][0, 0], cmap="gray")
    axs[col].set_title("Ground Truth")
    axs[col].axis("off")
    col += 1

    # Prediction
    axs[col].imshow(sample_imgs["y_pred"][0, 0], cmap="gray")
    axs[col].set_title("Predicted")
    axs[col].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_metrics(history, save_path):
    """Plot training and validation curves for loss and metrics."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(16, 12))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss")

    # Dice
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history["train_dice"], label="Train Dice")
    plt.plot(epochs, history["val_dice"], label="Val Dice")
    plt.legend()
    plt.title("Dice")

    # IoU
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history["train_iou"], label="Train IoU")
    plt.plot(epochs, history["val_iou"], label="Val IoU")
    plt.legend()
    plt.title("IoU")

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def train_one_epoch(model, dataloader, optimizer, criterion, scaler):
    """One training epoch over the dataloader."""
    model.train()
    total_loss = total_dice = total_iou = total_acc = 0.0

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        # Mixed precision forward pass
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(x)
            loss = criterion(preds, y)

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item())

        # Convert predictions to binary and compute metrics
        preds_prob = (torch.sigmoid(preds).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

@torch.no_grad()
def eval_one_epoch(model, dataloader, criterion, keep_last_batch=True):
    """Evaluation loop over the validation dataloader."""
    model.eval()
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        preds = model(x)
        loss = criterion(preds, y)
        total_loss += float(loss.item())

        preds_prob = (torch.sigmoid(preds).cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

        # Optionally keep last batch for visualization
        if keep_last_batch:
            ch0 = x[:, 0:1].cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.cpu().numpy(), "y_pred": preds_prob}
            if x.shape[1] == 2:
                ch1 = x[:, 1:2].cpu().numpy()
                imgs["regions"] = ch1
            last_batch_imgs = imgs

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def infer_and_visualize_best(
    model,
    val_dataset,
    use_atlas: bool,
    out_dir: str,
    client_id: int,
    best_ckpt_path: str,
    k_samples: int = 3,
    threshold: float = 0.5,
):
    """Load best checkpoint, run inference on a few validation samples, and save visualizations."""
    import random

    # Make sampling deterministic
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    os.makedirs(out_dir, exist_ok=True)

    if os.path.isfile(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        model.eval()
        print(f"[Best Model] Loaded: {best_ckpt_path}\n")
    else:
        print(f"[Best Model] Missing checkpoint: {best_ckpt_path}\n")
        return

    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("[Best Model] Empty val dataset.\n")
        return

    # Deterministic random subset of validation indices
    idxs = random.sample(range(len(val_dataset)), k)

    def _predict_one(sample):
        """Run model on one sample and return inputs and binarized prediction."""
        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float().to(device)

        if use_atlas and ("regions" in sample):
            regs = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float().to(device)
            x = torch.cat([mri, regs], dim=1)
        else:
            x = mri

        with torch.no_grad():
            prob = torch.sigmoid(model(x)).cpu().numpy()
            pred_bin = (prob > threshold).astype(np.uint8)

        return x.cpu().numpy(), pred_bin

    saved_paths = []

    # Save individual sample figures
    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")
        x_np, pred_bin = _predict_one(sample)

        imgs = {
            "mri": x_np[:, 0:1],
            "y_true": np.expand_dims(
                np.expand_dims(sample["tumor"], 0), 0
            ).astype(np.float32),
            "y_pred": pred_bin.astype(np.float32),
        }

        if use_atlas and ("regions" in sample):
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(
            out_dir, f"best_val_sample_{i}_client{client_id}_{pid}.png"
        )
        plot_prediction(imgs, out_path)
        saved_paths.append(out_path)

    print("[Best Model] Saved individual figures:")
    for p in saved_paths:
        print(" -", p)
    print()

    # Save grid figure
    cols = 4 if use_atlas else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))

    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]
        x_np, pred_bin = _predict_one(sample)

        mri = x_np[0, 0]
        gt = sample["tumor"]
        col = 0

        axs[row, col].imshow(mri, cmap="gray")
        axs[row, col].set_title("MRI")
        axs[row, col].axis("off")
        col += 1

        if use_atlas and ("regions" in sample):
            regs = x_np[0, 1]
            axs[row, col].imshow(regs, cmap="gray")
            axs[row, col].set_title("Regions")
            axs[row, col].axis("off")
            col += 1

        axs[row, col].imshow(gt, cmap="gray")
        axs[row, col].set_title("Ground Truth")
        axs[row, col].axis("off")
        col += 1

        axs[row, col].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, col].set_title("Predicted (τ=0.5)")
        axs[row, col].axis("off")

    plt.tight_layout()
    grid_path = os.path.join(out_dir, f"best_model_val_grid_{run_tag}_client{client_id}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Best Model] Saved grid -> {grid_path}\n")

# History containers for training curves
history = {
    "train_loss": [],
    "train_dice": [],
    "train_iou": [],
    "train_acc": [],
    "val_loss": [],
    "val_dice": [],
    "val_iou": [],
    "val_acc": [],
}

# Tracking best validation metrics
best_val_iou = 0.0
best_val_dice = -float("inf")
best_val_loss = float("inf")
best_path = os.path.join(OUT_MODELS_DIR, f"best_{run_tag}_client{CLIENT_ID}.pth")

log_rows = []

# Main training loop
for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    trL, trD, trI, trA = train_one_epoch(
        model, train_loader, optimizer, criterion, scaler
    )
    # Validate
    vaL, vaD, vaI, vaA, _ = eval_one_epoch(
        model, val_loader, criterion, keep_last_batch=True
    )

    # Update scheduler
    scheduler.step()

    # Save metrics for plots
    history["train_loss"].append(trL)
    history["train_dice"].append(trD)
    history["train_iou"].append(trI)
    history["train_acc"].append(trA)

    history["val_loss"].append(vaL)
    history["val_dice"].append(vaD)
    history["val_iou"].append(vaI)
    history["val_acc"].append(vaA)

    print(
        f"[Epoch {epoch:03d}/{EPOCHS}] "
        f"Train — Loss {trL:.4f} | Dice {trD:.4f} | IoU {trI:.4f} | Accuracy {trA:.4f} || "
        f"Val — Loss {vaL:.4f} | Dice {vaD:.4f} | IoU {vaI:.4f} | Accuracy {vaA:.4f}\n"
    )

    # Save checkpoint if Dice improves and loss decreases
    saved_ckpt = False
    if (vaD > best_val_dice) and (vaL < best_val_loss):
        best_val_dice = vaD
        best_val_loss = vaL
        torch.save(model.state_dict(), best_path)
        saved_ckpt = True
        print(
            f"Saved best model (Val Dice↑ {best_val_dice:.4f} & Val Loss↓ {best_val_loss:.4f}) -> {best_path}\n"
        )

    # Log row for CSV
    log_rows.append(
        {
            "epoch": epoch,
            "train_loss": trL,
            "train_dice": trD,
            "train_iou": trI,
            "train_acc": trA,
            "val_loss": vaL,
            "val_dice": vaD,
            "val_iou": vaI,
            "val_acc": vaA,
            "saved_ckpt": saved_ckpt,
            "marker": "X" if saved_ckpt else "",
        }
    )

# Save metrics as CSV
metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"metrics_{run_tag}_client{CLIENT_ID}.csv")
pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
print(f"[Log] Wrote per-epoch metrics CSV -> {metrics_csv}\n")

# Plot training curves (tagged)
curves_path = os.path.join(OUT_GRAPHS_DIR, f"training_curves_{run_tag}_client{CLIENT_ID}.png")
plot_metrics(history, curves_path)

# Run inference with best model and visualize a few validation samples
infer_and_visualize_best(
    model=model,
    val_dataset=val_dataset,
    use_atlas=USE_ATLAS,
    out_dir=OUT_GRAPHS_DIR,
    client_id=CLIENT_ID,
    best_ckpt_path=best_path,
    k_samples=3,
    threshold=0.5,
)

**<h2>DeepLabV3Plus - "timm-mobilenetv3_small_100"<h2>**

In [None]:
import os, json, pickle, random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from typing import List
from sklearn.metrics import jaccard_score
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

# Paths and I/O config
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
OUT_BASE = "AITDM"
CLIENT_DIR = os.path.join(OUT_BASE, "/content/client")

# Experiment config
USE_ATLAS = True
CLIENT_ID = 0
BATCH_SIZE = 8
LR = 1e-3
EPOCHS = 50
NUM_WORKERS = 2
SEED = 42

# Model / encoder config
ENCODER_NAME = "timm-mobilenetv3_small_100"
ENCODER_WEIGHTS = "imagenet"

# Output directories
OUT_MODELS_DIR = os.path.join(OUT_BASE, "Models", "DeepLabV3Plus_ImageOnly")
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_MODELS_DIR, exist_ok=True)
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

def sanitize(s: str) -> str:
    return str(s).replace("/", "-").replace(" ", "_")

encoder_tag = sanitize(ENCODER_NAME)
weights_tag = sanitize(ENCODER_WEIGHTS) if ENCODER_WEIGHTS is not None else "none"
atlas_tag = "atlas" if USE_ATLAS else "img"

run_tag = f"deeplabv3plus_{encoder_tag}_{weights_tag}_{atlas_tag}"

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(False)

def worker_init_fn(worker_id):
    worker_seed = SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageOnlyGliomaDataset(Dataset):
    """Dataset that loads MRI, tumor mask, and optional atlas regions per patient."""

    def __init__(self, metadata_df_path, data_root, use_atlas=True, exclude_ids=None, transform=None):
        import pickle, os, numpy as np, pandas as pd

        # Load metadata dataframe
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        # Optionally exclude some patient IDs
        if exclude_ids is None:
            exclude_ids = ["PatientID_0191"]
        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)

        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        # Collect patient IDs that have all required files
        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            mri_p = os.path.join(self.data_root, pid, f"{pid}_mri.npy")
            tumor_p = os.path.join(self.data_root, pid, f"{pid}_tumor.npy")

            if self.use_atlas:
                reg_p = os.path.join(self.data_root, pid, f"{pid}_regions.npy")
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p) and os.path.isfile(reg_p):
                    self.patient_ids.append(pid)
            else:
                if os.path.isfile(mri_p) and os.path.isfile(tumor_p):
                    self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        """Simple min-max normalization to [0, 1]."""
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        return (x - mn) / (mx - mn) if mx > mn else np.zeros_like(x, dtype=np.float32)

    def __getitem__(self, idx):
        """Load and return one sample (dict) for a patient."""
        import os, numpy as np

        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        # Load MRI and tumor mask
        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)

        # Normalize MRI and binarize tumor mask
        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        # Optionally load regions/atlas
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)
            regions = self._minmax(regions)
            sample["regions"] = regions

        return sample

def image_only_collate_fn(batch, use_atlas=True):
    """Custom collate: stack MRI (+ optional regions) and tumor into tensors."""
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        # Concatenate MRI and regions as channels
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

class SubsetByPIDs(Dataset):
    """Wrap a dataset but keep only a subset of patient IDs."""

    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        # Map patient IDs to indices
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Load client-specific train/val patient IDs
cdir = os.path.join(CLIENT_DIR, f"client_{CLIENT_ID}")
with open(os.path.join(cdir, "train_pids.json"), "r") as f:
    train_pids = json.load(f)
with open(os.path.join(cdir, "val_pids.json"), "r") as f:
    val_pids = json.load(f)

# Build full dataset and then client subsets
full_ds = ImageOnlyGliomaDataset(
    METADATA_DF_PATH,
    DATA_ROOT,
    use_atlas=USE_ATLAS,
    exclude_ids=["PatientID_0191"],
)
train_dataset = SubsetByPIDs(full_ds, train_pids)
val_dataset = SubsetByPIDs(full_ds, val_pids)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
    generator=torch.Generator().manual_seed(SEED),
    worker_init_fn=worker_init_fn,
)

print(f"Loaded client_{CLIENT_ID}: train patients={len(train_dataset)}, val patients={len(val_dataset)}")

# UNet model (from segmentation_models_pytorch)
in_channels = 2 if USE_ATLAS else 1
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=in_channels,
    classes=1,
).to(device)

# Loss, optimizer, scheduler, and AMP scaler
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
optimizer = optim.AdamW(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(
    optimizer, start_factor=1.0, end_factor=0.0, total_iters=EPOCHS
)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda"))

def calc_metrics(y_true, y_pred):
    """Compute Dice, IoU, and pixel accuracy."""
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    # IoU with sklearn, fallback if it fails
    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)

def plot_prediction(sample_imgs, save_path):
    """Plot MRI (+ optional regions), GT, and prediction for one batch."""
    if "regions" in sample_imgs:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    # MRI
    axs[0].imshow(sample_imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    col = 1

    # Regions (if available)
    if "regions" in sample_imgs:
        axs[col].imshow(sample_imgs["regions"][0, 0], cmap="gray")
        axs[col].set_title("Regions")
        axs[col].axis("off")
        col += 1

    # Ground truth
    axs[col].imshow(sample_imgs["y_true"][0, 0], cmap="gray")
    axs[col].set_title("Ground Truth")
    axs[col].axis("off")
    col += 1

    # Prediction
    axs[col].imshow(sample_imgs["y_pred"][0, 0], cmap="gray")
    axs[col].set_title("Predicted")
    axs[col].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def plot_metrics(history, save_path):
    """Plot training and validation curves for loss and metrics."""
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(16, 12))

    # Loss
    plt.subplot(2, 2, 1)
    plt.plot(epochs, history["train_loss"], label="Train Loss")
    plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.legend()
    plt.title("Loss")

    # Dice
    plt.subplot(2, 2, 2)
    plt.plot(epochs, history["train_dice"], label="Train Dice")
    plt.plot(epochs, history["val_dice"], label="Val Dice")
    plt.legend()
    plt.title("Dice")

    # IoU
    plt.subplot(2, 2, 3)
    plt.plot(epochs, history["train_iou"], label="Train IoU")
    plt.plot(epochs, history["val_iou"], label="Val IoU")
    plt.legend()
    plt.title("IoU")

    # Accuracy
    plt.subplot(2, 2, 4)
    plt.plot(epochs, history["train_acc"], label="Train Acc")
    plt.plot(epochs, history["val_acc"], label="Val Acc")
    plt.legend()
    plt.title("Accuracy")

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

def train_one_epoch(model, dataloader, optimizer, criterion, scaler):
    """One training epoch over the dataloader."""
    model.train()
    total_loss = total_dice = total_iou = total_acc = 0.0

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        optimizer.zero_grad(set_to_none=True)

        # Mixed precision forward pass
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(x)
            loss = criterion(preds, y)

        # Backward with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.item())

        # Convert predictions to binary and compute metrics
        preds_prob = (torch.sigmoid(preds).detach().cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n

@torch.no_grad()
def eval_one_epoch(model, dataloader, criterion, keep_last_batch=True):
    """Evaluation loop over the validation dataloader."""
    model.eval()
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None

    for batch in dataloader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)

        preds = model(x)
        loss = criterion(preds, y)
        total_loss += float(loss.item())

        preds_prob = (torch.sigmoid(preds).cpu().numpy() > 0.5).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
        d, i, a = calc_metrics(y_np, preds_prob)

        total_dice += d
        total_iou += i
        total_acc += a

        # Optionally keep last batch for visualization
        if keep_last_batch:
            ch0 = x[:, 0:1].cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.cpu().numpy(), "y_pred": preds_prob}
            if x.shape[1] == 2:
                ch1 = x[:, 1:2].cpu().numpy()
                imgs["regions"] = ch1
            last_batch_imgs = imgs

    n = len(dataloader)
    return total_loss / n, total_dice / n, total_iou / n, total_acc / n, last_batch_imgs

def infer_and_visualize_best(
    model,
    val_dataset,
    use_atlas: bool,
    out_dir: str,
    client_id: int,
    best_ckpt_path: str,
    k_samples: int = 3,
    threshold: float = 0.5,
):
    """Load best checkpoint, run inference on a few validation samples, and save visualizations."""
    import random

    # Make sampling deterministic
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    os.makedirs(out_dir, exist_ok=True)

    if os.path.isfile(best_ckpt_path):
        model.load_state_dict(torch.load(best_ckpt_path, map_location=device))
        model.eval()
        print(f"[Best Model] Loaded: {best_ckpt_path}\n")
    else:
        print(f"[Best Model] Missing checkpoint: {best_ckpt_path}\n")
        return

    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("[Best Model] Empty val dataset.\n")
        return

    # Deterministic random subset of validation indices
    idxs = random.sample(range(len(val_dataset)), k)

    def _predict_one(sample):
        """Run model on one sample and return inputs and binarized prediction."""
        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float().to(device)

        if use_atlas and ("regions" in sample):
            regs = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float().to(device)
            x = torch.cat([mri, regs], dim=1)
        else:
            x = mri

        with torch.no_grad():
            prob = torch.sigmoid(model(x)).cpu().numpy()
            pred_bin = (prob > threshold).astype(np.uint8)

        return x.cpu().numpy(), pred_bin

    saved_paths = []

    # Save individual sample figures
    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")
        x_np, pred_bin = _predict_one(sample)

        imgs = {
            "mri": x_np[:, 0:1],
            "y_true": np.expand_dims(
                np.expand_dims(sample["tumor"], 0), 0
            ).astype(np.float32),
            "y_pred": pred_bin.astype(np.float32),
        }

        if use_atlas and ("regions" in sample):
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(
            out_dir, f"best_val_sample_{i}_client{client_id}_{pid}.png"
        )
        plot_prediction(imgs, out_path)
        saved_paths.append(out_path)

    print("[Best Model] Saved individual figures:")
    for p in saved_paths:
        print(" -", p)
    print()

    # Save grid figure
    cols = 4 if use_atlas else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))

    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]
        x_np, pred_bin = _predict_one(sample)

        mri = x_np[0, 0]
        gt = sample["tumor"]
        col = 0

        axs[row, col].imshow(mri, cmap="gray")
        axs[row, col].set_title("MRI")
        axs[row, col].axis("off")
        col += 1

        if use_atlas and ("regions" in sample):
            regs = x_np[0, 1]
            axs[row, col].imshow(regs, cmap="gray")
            axs[row, col].set_title("Regions")
            axs[row, col].axis("off")
            col += 1

        axs[row, col].imshow(gt, cmap="gray")
        axs[row, col].set_title("Ground Truth")
        axs[row, col].axis("off")
        col += 1

        axs[row, col].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, col].set_title("Predicted (τ=0.5)")
        axs[row, col].axis("off")

    plt.tight_layout()
    grid_path = os.path.join(out_dir, f"best_model_val_grid_{run_tag}_client{client_id}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Best Model] Saved grid -> {grid_path}\n")

# History containers for training curves
history = {
    "train_loss": [],
    "train_dice": [],
    "train_iou": [],
    "train_acc": [],
    "val_loss": [],
    "val_dice": [],
    "val_iou": [],
    "val_acc": [],
}

# Tracking best validation metrics
best_val_iou = 0.0
best_val_dice = -float("inf")
best_val_loss = float("inf")
best_path = os.path.join(OUT_MODELS_DIR, f"best_{run_tag}_client{CLIENT_ID}.pth")

log_rows = []

# Main training loop
for epoch in range(1, EPOCHS + 1):
    # Train one epoch
    trL, trD, trI, trA = train_one_epoch(
        model, train_loader, optimizer, criterion, scaler
    )
    # Validate
    vaL, vaD, vaI, vaA, _ = eval_one_epoch(
        model, val_loader, criterion, keep_last_batch=True
    )

    # Update scheduler
    scheduler.step()

    # Save metrics for plots
    history["train_loss"].append(trL)
    history["train_dice"].append(trD)
    history["train_iou"].append(trI)
    history["train_acc"].append(trA)

    history["val_loss"].append(vaL)
    history["val_dice"].append(vaD)
    history["val_iou"].append(vaI)
    history["val_acc"].append(vaA)

    print(
        f"[Epoch {epoch:03d}/{EPOCHS}] "
        f"Train — Loss {trL:.4f} | Dice {trD:.4f} | IoU {trI:.4f} | Accuracy {trA:.4f} || "
        f"Val — Loss {vaL:.4f} | Dice {vaD:.4f} | IoU {vaI:.4f} | Accuracy {vaA:.4f}\n"
    )

    # Save checkpoint if Dice improves and loss decreases
    saved_ckpt = False
    if (vaD > best_val_dice) and (vaL < best_val_loss):
        best_val_dice = vaD
        best_val_loss = vaL
        torch.save(model.state_dict(), best_path)
        saved_ckpt = True
        print(
            f"Saved best model (Val Dice↑ {best_val_dice:.4f} & Val Loss↓ {best_val_loss:.4f}) -> {best_path}\n"
        )

    # Log row for CSV
    log_rows.append(
        {
            "epoch": epoch,
            "train_loss": trL,
            "train_dice": trD,
            "train_iou": trI,
            "train_acc": trA,
            "val_loss": vaL,
            "val_dice": vaD,
            "val_iou": vaI,
            "val_acc": vaA,
            "saved_ckpt": saved_ckpt,
            "marker": "X" if saved_ckpt else "",
        }
    )

# Save metrics as CSV
metrics_csv = os.path.join(OUT_GRAPHS_DIR, f"metrics_{run_tag}_client{CLIENT_ID}.csv")
pd.DataFrame(log_rows).to_csv(metrics_csv, index=False)
print(f"[Log] Wrote per-epoch metrics CSV -> {metrics_csv}\n")

# Plot training curves (tagged)
curves_path = os.path.join(OUT_GRAPHS_DIR, f"training_curves_{run_tag}_client{CLIENT_ID}.png")
plot_metrics(history, curves_path)

# Run inference with best model and visualize a few validation samples
infer_and_visualize_best(
    model=model,
    val_dataset=val_dataset,
    use_atlas=USE_ATLAS,
    out_dir=OUT_GRAPHS_DIR,
    client_id=CLIENT_ID,
    best_ckpt_path=best_path,
    k_samples=3,
    threshold=0.5,
)

**<h2>Ensemble<h2>**

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from sklearn.metrics import jaccard_score

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
USE_ATLAS = True
THRESHOLD = 0.5
K_SAMPLES = 3

OUT_BASE = "AITDM"
OUT_GRAPHS_DIR = os.path.join(OUT_BASE, "Graphs")
os.makedirs(OUT_GRAPHS_DIR, exist_ok=True)

OUT_MODELS_UNET_DIR = os.path.join(OUT_BASE, "Models", "UNet_ImageOnly")
CKPT_RESNET50 = os.path.join(OUT_MODELS_UNET_DIR, "best_unet_resnet50_imagenet_atlas_client0.pth")
CKPT_MITB3 = os.path.join(OUT_MODELS_UNET_DIR, "best_unet_mit_b3_imagenet_atlas_client0.pth")

OUT_MODELS_DLV3P_DIR = os.path.join(OUT_BASE, "Models", "DeepLabV3Plus_ImageOnly")
CKPT_DLV3P = os.path.join(
    OUT_MODELS_DLV3P_DIR,
    "best_deeplabv3plus_timm-mobilenetv3_small_100_imagenet_atlas_client0.pth",
)

ENS_WEIGHTS = [2.5 / 10, 2.5 / 10, 5 / 10]


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)


def calc_metrics(y_true, y_pred):
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true * y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)

    try:
        iou = jaccard_score(y_true, y_pred, average="binary")
    except Exception:
        union = y_true.sum() + y_pred.sum() - inter + 1e-8
        iou = inter / union

    acc = (y_true == y_pred).mean()
    return float(dice), float(iou), float(acc)


def plot_prediction(imgs, save_path, title=None):
    has_regions = "regions" in imgs
    cols = 4 if has_regions else 3
    fig, axs = plt.subplots(1, cols, figsize=(5 * cols, 5))

    axs[0].imshow(imgs["mri"][0, 0], cmap="gray")
    axs[0].set_title("MRI")
    axs[0].axis("off")

    c = 1
    if has_regions:
        axs[c].imshow(imgs["regions"][0, 0], cmap="gray")
        axs[c].set_title("Regions")
        axs[c].axis("off")
        c += 1

    axs[c].imshow(imgs["y_true"][0, 0], cmap="gray")
    axs[c].set_title("Ground Truth")
    axs[c].axis("off")
    c += 1

    axs[c].imshow(imgs["y_pred"][0, 0], cmap="gray")
    axs[c].set_title(f"Ensemble (τ={THRESHOLD})")
    axs[c].axis("off")

    if title:
        fig.suptitle(title, y=1.02)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()


def build_unet(encoder_name, encoder_weights="imagenet", in_channels=2, classes=1):
    return smp.Unet(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=classes,
    )


def build_deeplabv3p(encoder_name, encoder_weights="imagenet", in_channels=2, classes=1):
    return smp.DeepLabV3Plus(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=classes,
    )


def load_model_unet(encoder_name, ckpt_path, in_channels):
    assert os.path.isfile(ckpt_path), f"Missing checkpoint: {ckpt_path}"
    m = build_unet(encoder_name, "imagenet", in_channels=in_channels, classes=1).to(DEVICE)
    m.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    m.eval()
    return m


def load_model_dlv3p(encoder_name, ckpt_path, in_channels):
    assert os.path.isfile(ckpt_path), f"Missing checkpoint: {ckpt_path}"
    m = build_deeplabv3p(encoder_name, "imagenet", in_channels=in_channels, classes=1).to(DEVICE)
    m.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    m.eval()
    return m


in_channels = 2 if USE_ATLAS else 1
model_r50 = load_model_unet("resnet50", CKPT_RESNET50, in_channels=in_channels)
model_mit3 = load_model_unet("mit_b3", CKPT_MITB3, in_channels=in_channels)
model_dlv3p = load_model_dlv3p("timm-mobilenetv3_small_100", CKPT_DLV3P, in_channels=in_channels)

MODELS = [model_r50, model_mit3, model_dlv3p]
criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


@torch.no_grad()
def ensemble_forward_logits_multi(x, models, weights, target_hw=None):
    # Weighted sum of logits; resize to target_hw if needed
    w = np.array(weights, dtype=np.float32)
    w = w / (w.sum() + 1e-8)

    logits_sum = None
    for mi, wi in zip(models, w):
        li = mi(x)
        if target_hw is not None and li.shape[-2:] != target_hw:
            li = F.interpolate(li, size=target_hw, mode="bilinear", align_corners=False)
        logits_sum = li * float(wi) if logits_sum is None else logits_sum + li * float(wi)

    return logits_sum


@torch.no_grad()
def eval_one_epoch_ensemble(dataloader, models, weights, threshold=0.5, criterion=None, keep_last_batch=True):
    # Metrics are computed per-batch, then averaged across batches
    total_loss = total_dice = total_iou = total_acc = 0.0
    last_batch_imgs = None
    n = 0

    for batch in dataloader:
        x = batch["x"].to(DEVICE)
        y = batch["y"].to(DEVICE)

        logits = ensemble_forward_logits_multi(x, models=models, weights=weights, target_hw=y.shape[-2:])

        if criterion is not None:
            loss = criterion(logits, y)
            total_loss += float(loss.item())

        preds_bin = (torch.sigmoid(logits).detach().cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)

        d, i, a = calc_metrics(y_np, preds_bin)
        total_dice += d
        total_iou += i
        total_acc += a
        n += 1

        if keep_last_batch:
            ch0 = x[:, 0:1].detach().cpu().numpy()
            imgs = {"mri": ch0, "y_true": y.detach().cpu().numpy(), "y_pred": preds_bin}
            if x.shape[1] == 2:
                imgs["regions"] = x[:, 1:2].detach().cpu().numpy()
            last_batch_imgs = imgs

    mean_loss = (total_loss / n) if (criterion is not None and n > 0) else None
    mean_dice = total_dice / max(n, 1)
    mean_iou = total_iou / max(n, 1)
    mean_acc = total_acc / max(n, 1)

    return mean_loss, mean_dice, mean_iou, mean_acc, last_batch_imgs


@torch.no_grad()
def ensemble_infer_and_visualize(val_dataset, out_dir, models, weights, k_samples=3, threshold=0.5):
    # Visualization only (random K patients)
    os.makedirs(out_dir, exist_ok=True)
    k = min(k_samples, len(val_dataset))
    if k == 0:
        print("Empty val_dataset.")
        return

    set_seed(SEED)
    idxs = random.sample(range(len(val_dataset)), k)
    per_sample = []

    for i, idx in enumerate(idxs, 1):
        sample = val_dataset[idx]
        pid = sample.get("patient_id", f"val_{idx}")

        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float()
        y = torch.tensor(sample["tumor"]).unsqueeze(0).unsqueeze(0).float()

        if USE_ATLAS and ("regions" in sample):
            reg = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float()
            x = torch.cat([mri, reg], dim=1)
        else:
            x = mri

        x = x.to(DEVICE)
        y = y.to(DEVICE)

        logits = ensemble_forward_logits_multi(x, models=models, weights=weights, target_hw=y.shape[-2:])
        pred_bin = (torch.sigmoid(logits).cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)

        d, iou, acc = calc_metrics(y_np, pred_bin)

        x_np = x.detach().cpu().numpy()
        imgs = {"mri": x_np[:, 0:1], "y_true": y.detach().cpu().numpy(), "y_pred": pred_bin.astype(np.float32)}
        if USE_ATLAS and x_np.shape[1] == 2:
            imgs["regions"] = x_np[:, 1:2]

        out_path = os.path.join(out_dir, f"ensemble3_val_sample_{i}_{pid}.png")
        plot_prediction(imgs, out_path, title=f"{pid} | Dice={d:.4f} IoU={iou:.4f} Acc={acc:.4f}")
        per_sample.append({"pid": pid, "dice": d, "iou": iou, "acc": acc, "path": out_path})

    print("[Ensemble-3] Saved individual figures:")
    for r in per_sample:
        print(f" - {r['pid']}: Dice={r['dice']:.4f} IoU={r['iou']:.4f} Acc={r['acc']:.4f} -> {r['path']}")

    cols = 4 if USE_ATLAS else 3
    fig, axs = plt.subplots(k, cols, figsize=(5 * cols, 4 * k))
    if k == 1:
        axs = np.expand_dims(axs, 0)

    for row, idx in enumerate(idxs):
        sample = val_dataset[idx]

        mri = torch.tensor(sample["mri"]).unsqueeze(0).unsqueeze(0).float()
        y = torch.tensor(sample["tumor"]).unsqueeze(0).unsqueeze(0).float()

        if USE_ATLAS and ("regions" in sample):
            reg = torch.tensor(sample["regions"]).unsqueeze(0).unsqueeze(0).float()
            x = torch.cat([mri, reg], dim=1)
        else:
            x = mri

        x = x.to(DEVICE)
        y = y.to(DEVICE)

        logits = ensemble_forward_logits_multi(x, models=models, weights=weights, target_hw=y.shape[-2:])
        pred_bin = (torch.sigmoid(logits).cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)

        d, _, _ = calc_metrics(y_np, pred_bin)
        x_np = x.detach().cpu().numpy()

        c = 0
        axs[row, c].imshow(x_np[0, 0], cmap="gray")
        axs[row, c].set_title("MRI")
        axs[row, c].axis("off")
        c += 1

        if USE_ATLAS and x_np.shape[1] == 2:
            axs[row, c].imshow(x_np[0, 1], cmap="gray")
            axs[row, c].set_title("Regions")
            axs[row, c].axis("off")
            c += 1

        axs[row, c].imshow(y_np[0, 0], cmap="gray")
        axs[row, c].set_title("GT")
        axs[row, c].axis("off")
        c += 1

        axs[row, c].imshow(pred_bin[0, 0], cmap="gray")
        axs[row, c].set_title(f"Ens τ={threshold}\nDice={d:.3f}")
        axs[row, c].axis("off")

    plt.tight_layout()
    wtag = "-".join([f"{w:.2f}" for w in (np.array(weights) / (np.sum(weights) + 1e-8))])
    grid_path = os.path.join(out_dir, f"ensemble3_val_grid_tau{threshold}_w{wtag}.png")
    plt.savefig(grid_path, bbox_inches="tight")
    plt.show()
    plt.close()
    print(f"[Ensemble-3] Saved grid -> {grid_path}")


# Assumes val_loader and val_dataset already exist in your notebook
ens_loss, ens_d, ens_i, ens_a, last_imgs = eval_one_epoch_ensemble(
    dataloader=val_loader,
    models=MODELS,
    weights=ENS_WEIGHTS,
    threshold=THRESHOLD,
    criterion=criterion,
    keep_last_batch=True,
)

print(f"[Ensemble-3 Val] Loss {ens_loss:.4f} | Dice {ens_d:.4f} | IoU {ens_i:.4f} | Acc {ens_a:.4f}")

if last_imgs is not None:
    out_path = os.path.join(OUT_GRAPHS_DIR, f"ensemble3_last_batch_tau{THRESHOLD}.png")
    plot_prediction(last_imgs, out_path, title="Ensemble-3 - last val batch")

ensemble_infer_and_visualize(
    val_dataset=val_dataset,
    out_dir=OUT_GRAPHS_DIR,
    models=MODELS,
    weights=ENS_WEIGHTS,
    k_samples=K_SAMPLES,
    threshold=THRESHOLD,
)

# **M2 - FL with ensemble**

In [None]:
%%writefile seg_data.py
import os, pickle, numpy as np, torch
from typing import List
from torch.utils.data import Dataset, DataLoader
import torchio as tio

# Global paths and configuration
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
USE_ATLAS = True
EXCLUDE_IDS = ["PatientID_0191"]

# Dataset that loads MRI, tumor mask and optional atlas for each patient
class ImageOnlyGliomaDataset(Dataset):
    def __init__(
        self,
        metadata_df_path,
        data_root,
        use_atlas=False,
        exclude_ids=None,
        transform=None,
    ):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        if exclude_ids is None:
            exclude_ids = []

        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            base = os.path.join(self.data_root, pid)
            mri_p = os.path.join(base, f"{pid}_mri.npy")
            tumor_p = os.path.join(base, f"{pid}_tumor.npy")

            is_valid = os.path.isfile(mri_p) and os.path.isfile(tumor_p)
            if self.use_atlas:
                reg_p = os.path.join(base, f"{pid}_regions.npy")
                is_valid = is_valid and os.path.isfile(reg_p)

            if is_valid:
                self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        if mx > mn:
            return (x - mn) / (mx - mn)
        return np.zeros_like(x, dtype=np.float32)

    def _to_torchio_format(self, arr):
        """
        Convertește (H, W) -> (1, H, W, 1) pentru procesare internă TorchIO.
        """
        if arr.ndim == 2:
            return arr[np.newaxis, ..., np.newaxis]
        return arr

    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)


        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)


        regions = None
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)

        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        if regions is not None:
            regions = self._minmax(regions)


        if self.transform:
            subject_dict = {
                'mri': tio.ScalarImage(tensor=self._to_torchio_format(mri)),
                'tumor': tio.LabelMap(tensor=self._to_torchio_format(tumor)),
            }
            if regions is not None:
                subject_dict['regions'] = tio.ScalarImage(tensor=self._to_torchio_format(regions))

            subject = tio.Subject(subject_dict)

            subject = self.transform(subject)

            out_mri = subject['mri'].data[0, ..., 0].numpy()
            out_tumor = subject['tumor'].data[0, ..., 0].numpy()
            sample = {
                "patient_id": pid,
                "mri": out_mri,       # Shape: (240, 240)
                "tumor": out_tumor    # Shape: (240, 240)
            }

            if regions is not None:
                out_regions = subject['regions'].data[0, ..., 0].numpy()
                sample["regions"] = out_regions # Shape: (240, 240)

            return sample


        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        if self.use_atlas:
            sample["regions"] = regions

        return sample

# Collate function to build batched tensors and patient ID list
def image_only_collate_fn(batch, use_atlas=USE_ATLAS):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

# Dataset wrapper that restricts to a subset of patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Compute Dice, IoU and accuracy for binary masks
def calc_metrics(y_true, y_pred):
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true & y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)
    union = y_true.sum() + y_pred.sum() - inter + 1e-8
    iou = inter / union
    acc = (y_true == y_pred).mean()

    return float(dice), float(iou), float(acc)

Overwriting seg_data.py


In [None]:
%%writefile fl_client.py
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import argparse
import json
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import flwr as fl
import copy
import segmentation_models_pytorch as smp
import random
import torchio as tio
import sys
import time

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    calc_metrics,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
)

CLIENT_DIR = "/content/client"

BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DEFAULT_MODEL_NAME = "unet"
DEFAULT_ENCODER_NAME = "timm-mobilenetv3_small_100"
DEFAULT_ENCODER_WEIGHTS = "imagenet"
TRANSFORMS = None


def _run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def seed_everything(seed: int, deterministic: bool = True) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)


def seed_worker(worker_id: int) -> None:
    worker_seed = (torch.initial_seed() + worker_id) % (2**32)
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def log_transfer_metrics(file_path, cid, rnd, incoming, outgoing, overhead):
    """Functie helper pentru a salva metricile de transfer in fisier."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    msg = (f"Client {cid} Round {rnd}: "
           f"Incoming {incoming/1024:.2f} KB | "
           f"Outgoing {outgoing/1024:.2f} KB | "
           f"Overhead: {overhead:.6f}s\n")


    with open(file_path, "a") as f:
        f.write(msg)


def get_model(
    model_name=DEFAULT_MODEL_NAME,
    encoder_name=DEFAULT_ENCODER_NAME,
    encoder_weights=DEFAULT_ENCODER_WEIGHTS,
):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}. Use 'unet' or 'deeplabv3plus'.")

    return model.to(DEVICE)


def get_loaders(cid: int, base_seed: int, transforms):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH,
        DATA_ROOT,
        use_atlas=USE_ATLAS,
        exclude_ids=["PatientID_0191"],
        transform=transforms,
    )

    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "train_pids.json")) as f:
        tr_p = json.load(f)
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_tr = SubsetByPIDs(full, tr_p)
    ds_va = SubsetByPIDs(full, va_p)

    g_tr = torch.Generator().manual_seed(base_seed + 12345)
    g_va = torch.Generator().manual_seed(base_seed + 67890)

    ld_tr = DataLoader(
        ds_tr,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_tr,
        persistent_workers=(NUM_WORKERS > 0),
    )

    ld_va = DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_va,
        persistent_workers=(NUM_WORKERS > 0),
    )

    return ld_tr, ld_va, len(ds_tr), len(ds_va)


def get_parameters(model):
    return [p.detach().cpu().numpy() for _, p in model.state_dict().items()]


def set_parameters(model, params):
    sd = model.state_dict()
    for k, v in zip(sd.keys(), params):
        sd[k] = torch.tensor(v)
    model.load_state_dict(sd, strict=True)


bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(pred, y):
    return 0.5 * bce(pred, y) + 0.5 * dice_loss(pred, y)


def maybe_save_best(run_dir, cid, val_loss, val_dice, best_epoch, rnd, model):
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    best_json = os.path.join(ckpt_dir, f"client_{cid}_best.json")
    best_pt = os.path.join(ckpt_dir, f"client_{cid}_best.pt")

    prev = {"val_loss": float("inf"), "val_dice": -1.0}
    if os.path.isfile(best_json):
        try:
            with open(best_json, "r") as f:
                prev = json.load(f)
        except Exception:
            pass

    improved = (val_loss < prev.get("val_loss", float("inf"))) and (val_dice > prev.get("val_dice", -1.0))
    if improved:
        torch.save(model.state_dict(), best_pt)
        with open(best_json, "w") as f:
            json.dump(
                {
                    "round": int(rnd),
                    "epoch": int(best_epoch),
                    "val_loss": float(val_loss),
                    "val_dice": float(val_dice),
                },
                f,
            )


class SegClient(fl.client.NumPyClient):
    def __init__(
        self,
        cid: int,
        model_name=DEFAULT_MODEL_NAME,
        encoder_name=DEFAULT_ENCODER_NAME,
        encoder_weights=DEFAULT_ENCODER_WEIGHTS,
    ):
        self.cid = int(cid)
        self.model_name = model_name
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights

        self.base_seed = SEED + self.cid
        seed_everything(self.base_seed, deterministic=True)

        self.run_name = _run_name(model_name, encoder_name)
        self.run_dir = os.path.join("AITDM", self.run_name)

        self.model = get_model(model_name, encoder_name, encoder_weights)
        if self.cid != 2:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)
        else:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)


    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):
        # --- MEASURE INCOMING ---
        incoming_size = sum([p.nbytes for p in parameters])

        set_parameters(self.model, parameters)

        epochs = int(config.get("local_epochs", 1))
        lr = float(config.get("lr", 1e-3))
        rnd = int(config.get("round", 0))

        opt = optim.AdamW(self.model.parameters(), lr=lr)
        scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

        best_state = None
        best_val_loss = float("inf")
        best_val_dice = -1.0
        best_epoch_idx = -1
        epoch_logs = []

        for epoch_idx in range(1, epochs + 1):
            self.model.train()
            tot_tr_loss = tot_tr_d = tot_tr_i = tot_tr_a = 0.0
            nb_tr = 0

            for batch in self.train_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                opt.zero_grad(set_to_none=True)

                with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
                    pred = self.model(x)
                    loss = criterion(pred, y)

                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

                with torch.no_grad():
                    y_hat = (torch.sigmoid(pred).detach().cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                tot_tr_loss += float(loss.item())
                tot_tr_d += d
                tot_tr_i += i
                tot_tr_a += a
                nb_tr += 1

            nb_tr = max(nb_tr, 1)
            epoch_tr_loss = tot_tr_loss / nb_tr
            epoch_tr_dice = tot_tr_d / nb_tr
            epoch_tr_iou = tot_tr_i / nb_tr
            epoch_tr_acc = tot_tr_a / nb_tr

            self.model.eval()
            tot_val_loss = tot_val_d = tot_val_i = tot_val_a = 0.0
            nb_val = 0

            with torch.no_grad():
                for batch in self.val_loader:
                    x = batch["x"].to(DEVICE)
                    y = batch["y"].to(DEVICE)
                    pred = self.model(x)

                    v_loss = float(criterion(pred, y).item())
                    y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                    tot_val_loss += v_loss
                    tot_val_d += d
                    tot_val_i += i
                    tot_val_a += a
                    nb_val += 1

            nb_val = max(nb_val, 1)
            epoch_val_loss = tot_val_loss / nb_val
            epoch_val_dice = tot_val_d / nb_val
            epoch_val_iou = tot_val_i / nb_val
            epoch_val_acc = tot_val_a / nb_val

            epoch_logs.append(
                {
                    "epoch": int(epoch_idx),
                    "train_loss": float(epoch_tr_loss),
                    "train_dice": float(epoch_tr_dice),
                    "train_iou": float(epoch_tr_iou),
                    "train_acc": float(epoch_tr_acc),
                    "val_loss": float(epoch_val_loss),
                    "val_dice": float(epoch_val_dice),
                    "val_iou": float(epoch_val_iou),
                    "val_acc": float(epoch_val_acc),
                }
            )

            if (epoch_val_loss < best_val_loss) and (epoch_val_dice > best_val_dice):
                best_val_loss = epoch_val_loss
                best_val_dice = epoch_val_dice
                best_state = copy.deepcopy(self.model.state_dict())
                best_epoch_idx = epoch_idx

        if best_state is not None:
            self.model.load_state_dict(best_state)

        for ep in epoch_logs:
            ep["best_epoch"] = (ep["epoch"] == best_epoch_idx)

        train_metrics = {
            "cid": int(self.cid),
            "best_epoch": int(best_epoch_idx),
            "best_val_loss": float(best_val_loss),
            "best_val_dice": float(best_val_dice),
            "per_epoch": json.dumps(epoch_logs),
            "run_name": self.run_name,
            "model_name": self.model_name,
            "encoder_name": self.encoder_name,
        }

        maybe_save_best(self.run_dir, self.cid, best_val_loss, best_val_dice, best_epoch_idx, rnd, self.model)

        # --- PREPARE & MEASURE OUTGOING ---
        out_params = get_parameters(self.model)


        start_overhead = time.time()
        final_params_to_send = out_params
        end_overhead = time.time()

        outgoing_size = sum([p.nbytes for p in final_params_to_send])

        print(f"Client {self.cid} Round {rnd}: Incoming {incoming_size/1024:.2f} KB | Outgoing {outgoing_size/1024:.2f} KB | Overhead: {end_overhead - start_overhead:.6f}s")
        log_file_path = os.path.join(self.run_dir, f"client_{self.cid}_transfer_log.txt")
        log_transfer_metrics(log_file_path, self.cid, rnd, incoming_size, outgoing_size, end_overhead - start_overhead)

        return final_params_to_send, self.ntr, train_metrics

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.eval()

        tot_loss = tot_d = tot_i = tot_a = 0.0
        nb = 0

        with torch.no_grad():
            for batch in self.val_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                pred = self.model(x)

                loss = float(criterion(pred, y).item())
                y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                d, i, a = calc_metrics(y_np, y_hat)

                tot_loss += loss
                tot_d += d
                tot_i += i
                tot_a += a
                nb += 1

        nb = max(nb, 1)
        metrics = {
            "loss": tot_loss / nb,
            "dice": tot_d / nb,
            "iou": tot_i / nb,
            "acc": tot_a / nb,
            "cid": int(self.cid),
            "run_name": self.run_name,
        }

        return metrics["loss"], self.nva, metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cid", type=int, required=True)
    parser.add_argument("--server", default="0.0.0.0:8080")
    args = parser.parse_args()

    fl.client.start_numpy_client(
        server_address=args.server,
        client=SegClient(args.cid),
    )

Overwriting fl_client.py


In [None]:
%%writefile fl_sim_colab.py
import os
import csv
import json
import torch
import flwr as fl
import time
from flwr.common import FitIns
import torchio as tio
from fl_client import SegClient

import logging
import warnings
logging.getLogger("flwr").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

def ensure_csv(path: str, header: list[str]):
    if not os.path.isfile(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(header)


def append_row(path: str, row: list):
    with open(path, "a", newline="") as f:
        csv.writer(f).writerow(row)


def log_server_metrics(file_path, rnd, duration):
    """Functie helper pentru logare timp agregare server in fisier."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    msg = f"Round {rnd} Aggregation Time: {duration:.4f} seconds\n"

    with open(file_path, "a") as f:
        f.write(msg)


class PerClientLoggingFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, metrics_dir: str, **kwargs):
        super().__init__(**kwargs)
        self.metrics_dir = metrics_dir
        self.header = [
            "round","epoch","train_loss","train_dice","train_iou","train_acc",
            "val_loss","val_dice","val_iou","val_acc","best_epoch",
        ]

    def configure_fit(self, server_round, parameters, client_manager):
        items = super().configure_fit(server_round, parameters, client_manager)
        out = []
        for it in items:
            if isinstance(it, tuple):
                client, fitins = it
            else:
                client, fitins = None, it

            cfg = dict(fitins.config)
            cfg["round"] = server_round
            new_fitins = FitIns(fitins.parameters, cfg)
            out.append((client, new_fitins) if client is not None else new_fitins)
        return out

    def aggregate_fit(self, rnd, results, failures):
        # --- START TIMER ---
        start_time = time.time()

        agg = super().aggregate_fit(rnd, results, failures)

        # --- END TIMER ---
        end_time = time.time()

        print(f"Round {rnd} Aggregation Time: {end_time - start_time:.4f} seconds")
        log_file_path = os.path.join(self.metrics_dir, "server_aggregation_log.txt")
        log_server_metrics(log_file_path, rnd, end_time - start_time)

        for client_proxy, fit_res in results:
            m = fit_res.metrics or {}
            cid = str(m.get("cid", client_proxy.cid))

            client_csv = os.path.join(self.metrics_dir, f"metrics_client_{cid}.csv")
            ensure_csv(client_csv, self.header)

            best_epoch = int(m.get("best_epoch", -1))
            per_epoch_raw = m.get("per_epoch", "[]")

            try:
                per_epoch = json.loads(per_epoch_raw)
            except Exception:
                per_epoch = []

            for ep in per_epoch:
                epoch = ep.get("epoch", "")
                row = [
                    rnd,
                    epoch,
                    ep.get("train_loss", ""),
                    ep.get("train_dice", ""),
                    ep.get("train_iou", ""),
                    ep.get("train_acc", ""),
                    ep.get("val_loss", ""),
                    ep.get("val_dice", ""),
                    ep.get("val_iou", ""),
                    ep.get("val_acc", ""),
                    "x" if int(epoch) == best_epoch else "",
                ]
                append_row(client_csv, row)

        return agg


def run_one_experiment(model_name: str, encoder_name: str, num_rounds=5, local_epochs=5, lr=1e-3):
    run_name = f"{model_name}__{encoder_name}".replace("/", "-")
    base_dir = os.path.join("AITDM", run_name)
    metrics_dir = os.path.join(base_dir, "metrics")
    os.makedirs(metrics_dir, exist_ok=True)

    def client_fn(cid: str):
        return SegClient(int(cid), model_name=model_name, encoder_name=encoder_name).to_client()

    strategy = PerClientLoggingFedAvg(
        metrics_dir=metrics_dir,
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=3,
        on_fit_config_fn=lambda rnd: {"local_epochs": local_epochs, "lr": lr},
    )

    use_gpu = torch.cuda.is_available()
    client_resources = {"num_cpus": 1, "num_gpus": 1.0 if use_gpu else 0.0}

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources=client_resources,
        ray_init_args={"include_dashboard": False},
    )


if __name__ == "__main__":
    experiments = [
        ("unet", "resnet50"),
        ("unet", "mit_b3"),
        ("deeplabv3plus", "timm-mobilenetv3_small_100"),
    ]

    for model_name, encoder_name in experiments:
        print(f"\n=== Running: {model_name} + {encoder_name} ===")
        start_time = time.time()
        run_one_experiment(model_name, encoder_name, num_rounds=15, local_epochs=10, lr=1e-3)
        run_time = time.time() - start_time
        print(f"Experiment completed in {run_time:.2f} seconds.\n")

Overwriting fl_sim_colab.py


In [None]:
!rm -rf /content/AITDM/client_0/*
!rm -rf /content/AITDM/client_1/*
!rm -rf /content/AITDM/client_2/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/checkpoints/*
!rm -rf /content/AITDM/unet__resnet50/checkpoints/*
!rm -rf /content/AITDM/unet__mit_b3/checkpoints/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/metrics/*
!rm -rf /content/AITDM/unet__resnet50/metrics/*
!rm -rf /content/AITDM/unet__mit_b3/metrics/*
!python fl_sim_colab.py


=== Running: unet + resnet50 ===
2026-01-16 15:38:33.168056: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-16 15:38:33.185939: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768577913.207091    5124 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768577913.213619    5124 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768577913.230077    5124 computation_placer.cc:177] computation placer already registered. Pleas

In [None]:
import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import logging
from fl_client import TRANSFORMS

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
    calc_metrics,
)

logging.getLogger("timm.models._builder").setLevel(logging.ERROR)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
THRESHOLD = 0.5

CLIENT_DIR = "/content/client"
BATCH_SIZE = 8
NUM_WORKERS = 2

ENSEMBLE_CFGS = [
    ("unet", "resnet50"),
    ("unet", "mit_b3"),
    ("deeplabv3plus", "timm-mobilenetv3_small_100"),
]

WEIGHT_MODE = "power"
WEIGHT_POWER = 14.0
WEIGHT_EPS = 1e-6


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)

bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(logits, y):
    return 0.5 * bce(logits, y) + 0.5 * dice_loss(logits, y)


def get_val_loader(cid: int):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"]
    )
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_va = SubsetByPIDs(full, va_p)
    g = torch.Generator().manual_seed(SEED)

    ld_va = torch.utils.data.DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=g,
    )
    return ld_va, len(ds_va)


def run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def build_model(model_name: str, encoder_name: str, encoder_weights="imagenet"):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        m = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        m = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}")

    return m.to(DEVICE)


def ckpt_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.pt")


def best_json_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.json")


def load_model(model_name: str, encoder_name: str, cid: int):
    path = ckpt_path(model_name, encoder_name, cid)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Missing checkpoint: {path}")
    m = build_model(model_name, encoder_name).to(DEVICE)
    sd = torch.load(path, map_location="cpu")
    m.load_state_dict(sd, strict=True)
    m.eval()
    return m


def load_best_val_dice(model_name: str, encoder_name: str, cid: int) -> float:
    p = best_json_path(model_name, encoder_name, cid)
    if not os.path.isfile(p):
        raise FileNotFoundError(f"Missing best json: {p}")
    with open(p, "r") as f:
        j = json.load(f)
    return float(j.get("val_dice", 0.0))


def get_client_weights(cid: int, cfgs, mode="power", power=2.0, eps=1e-6):
    dices = [load_best_val_dice(mn, enc, cid) for (mn, enc) in cfgs]
    d = np.array(dices, dtype=np.float32)

    if mode == "linear":
        raw = np.clip(d, 0.0, None)
    elif mode == "power":
        raw = np.power(np.clip(d, 0.0, None), power)
    else:
        raise ValueError("mode must be 'linear' or 'power'")

    raw = raw + eps
    w = raw / raw.sum()
    return w.tolist(), dices


@torch.no_grad()
def ensemble_forward_logits(x, models, weights, target_hw=None):
    w = np.array(weights, dtype=np.float32)
    w = w / (w.sum() + 1e-8)

    logits_sum = None
    for mi, wi in zip(models, w):
        li = mi(x)
        if target_hw is not None and li.shape[-2:] != target_hw:
            li = F.interpolate(li, size=target_hw, mode="bilinear", align_corners=False)
        logits_sum = li * float(wi) if logits_sum is None else logits_sum + li * float(wi)
    return logits_sum


@torch.no_grad()
def eval_ensemble_on_client(cid: int, threshold=0.5):
    val_loader, nva = get_val_loader(cid)
    models = [load_model(mn, enc, cid) for (mn, enc) in ENSEMBLE_CFGS]

    weights, best_dices = get_client_weights(
        cid, ENSEMBLE_CFGS, mode=WEIGHT_MODE, power=WEIGHT_POWER, eps=WEIGHT_EPS
    )

    tot_loss = tot_d = tot_i = tot_a = 0.0
    nb = 0

    for batch in val_loader:
        x = batch["x"].to(DEVICE)
        y = batch["y"].to(DEVICE)

        logits = ensemble_forward_logits(x, models=models, weights=weights, target_hw=y.shape[-2:])
        loss = float(criterion(logits, y).item())

        preds_bin = (torch.sigmoid(logits).cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)

        d, i, a = calc_metrics(y_np, preds_bin)

        tot_loss += loss
        tot_d += d
        tot_i += i
        tot_a += a
        nb += 1

    nb = max(nb, 1)
    return {
        "cid": int(cid),
        "nva": int(nva),
        "loss": tot_loss / nb,
        "dice": tot_d / nb,
        "iou": tot_i / nb,
        "acc": tot_a / nb,
        "weights": weights,
        "best_dices": best_dices,
    }


def log_client_metrics(file_path, r, bd, ww):

    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    log_message = (
        f"Client {r['cid']} (n={r['nva']}): "
        f"loss={r['loss']:.4f} dice={r['dice']:.4f} iou={r['iou']:.4f} acc={r['acc']:.4f}\n"
        f"  best_dice={['%.4f' % x for x in bd]} -> weights={['%.3f' % x for x in ww]}\n"
    )

    print(log_message, end='')

    with open(file_path, "a") as f:
        f.write(log_message)

def main():
    print(f"[Ensemble-3 | threshold={THRESHOLD} | weight_mode={WEIGHT_MODE} | power={WEIGHT_POWER}]\n")
    for cid in [0, 1, 2]:
        r = eval_ensemble_on_client(cid, threshold=THRESHOLD)
        bd = r["best_dices"]
        ww = r["weights"]
        log_client_metrics(os.path.join('/content/AITDM', f"client_{cid}", "val_metrics_ensemble.txt"), r, bd, ww)


if __name__ == "__main__":
    main()

In [None]:
!mkdir -p /content/drive/MyDrive/AITDM
!cp -rf /content/AITDM /content/drive/MyDrive/AITDM/NORMAL_3

In [None]:
from google.colab import runtime
runtime.unassign()

# **M2 - FL with ensemble and quantization**

In [None]:
%%writefile seg_data.py
import os, pickle, numpy as np, torch
from typing import List
from torch.utils.data import Dataset, DataLoader
import torchio as tio

# Global paths and configuration
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
USE_ATLAS = True
EXCLUDE_IDS = ["PatientID_0191"]

# Dataset that loads MRI, tumor mask and optional atlas for each patient
class ImageOnlyGliomaDataset(Dataset):
    def __init__(
        self,
        metadata_df_path,
        data_root,
        use_atlas=False,
        exclude_ids=None,
        transform=None,
    ):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        if exclude_ids is None:
            exclude_ids = []

        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            base = os.path.join(self.data_root, pid)
            mri_p = os.path.join(base, f"{pid}_mri.npy")
            tumor_p = os.path.join(base, f"{pid}_tumor.npy")

            is_valid = os.path.isfile(mri_p) and os.path.isfile(tumor_p)
            if self.use_atlas:
                reg_p = os.path.join(base, f"{pid}_regions.npy")
                is_valid = is_valid and os.path.isfile(reg_p)

            if is_valid:
                self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        if mx > mn:
            return (x - mn) / (mx - mn)
        return np.zeros_like(x, dtype=np.float32)

    def _to_torchio_format(self, arr):
        """
        Convertește (H, W) -> (1, H, W, 1) pentru procesare internă TorchIO.
        """
        if arr.ndim == 2:
            return arr[np.newaxis, ..., np.newaxis]
        return arr

    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)


        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)


        regions = None
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)

        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        if regions is not None:
            regions = self._minmax(regions)


        if self.transform:
            subject_dict = {
                'mri': tio.ScalarImage(tensor=self._to_torchio_format(mri)),
                'tumor': tio.LabelMap(tensor=self._to_torchio_format(tumor)),
            }
            if regions is not None:
                subject_dict['regions'] = tio.ScalarImage(tensor=self._to_torchio_format(regions))

            subject = tio.Subject(subject_dict)

            subject = self.transform(subject)

            out_mri = subject['mri'].data[0, ..., 0].numpy()
            out_tumor = subject['tumor'].data[0, ..., 0].numpy()
            sample = {
                "patient_id": pid,
                "mri": out_mri,       # Shape: (240, 240)
                "tumor": out_tumor    # Shape: (240, 240)
            }

            if regions is not None:
                out_regions = subject['regions'].data[0, ..., 0].numpy()
                sample["regions"] = out_regions # Shape: (240, 240)

            return sample


        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        if self.use_atlas:
            sample["regions"] = regions

        return sample

# Collate function to build batched tensors and patient ID list
def image_only_collate_fn(batch, use_atlas=USE_ATLAS):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

# Dataset wrapper that restricts to a subset of patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Compute Dice, IoU and accuracy for binary masks
def calc_metrics(y_true, y_pred):
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true & y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)
    union = y_true.sum() + y_pred.sum() - inter + 1e-8
    iou = inter / union
    acc = (y_true == y_pred).mean()

    return float(dice), float(iou), float(acc)

In [None]:
%%writefile fl_client.py
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import argparse
import json
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import flwr as fl
import copy
import segmentation_models_pytorch as smp
import random
import torchio as tio
import sys
import time

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    calc_metrics,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
)

CLIENT_DIR = "/content/client"

BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DEFAULT_MODEL_NAME = "unet"
DEFAULT_ENCODER_NAME = "timm-mobilenetv3_small_100"
DEFAULT_ENCODER_WEIGHTS = "imagenet"
TRANSFORMS = None


def _run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def seed_everything(seed: int, deterministic: bool = True) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)


def seed_worker(worker_id: int) -> None:
    worker_seed = (torch.initial_seed() + worker_id) % (2**32)
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def log_transfer_metrics(file_path, cid, rnd, incoming, outgoing, overhead):
    """Functie helper pentru a salva metricile de transfer in fisier."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    msg = (f"Client {cid} Round {rnd}: "
           f"Incoming {incoming/1024:.2f} KB | "
           f"Outgoing {outgoing/1024:.2f} KB | "
           f"Overhead: {overhead:.6f}s\n")

    with open(file_path, "a") as f:
        f.write(msg)


def get_model(
    model_name=DEFAULT_MODEL_NAME,
    encoder_name=DEFAULT_ENCODER_NAME,
    encoder_weights=DEFAULT_ENCODER_WEIGHTS,
):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}. Use 'unet' or 'deeplabv3plus'.")

    return model.to(DEVICE)


def get_loaders(cid: int, base_seed: int, transforms):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH,
        DATA_ROOT,
        use_atlas=USE_ATLAS,
        exclude_ids=["PatientID_0191"],
        transform=transforms,
    )

    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "train_pids.json")) as f:
        tr_p = json.load(f)
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_tr = SubsetByPIDs(full, tr_p)
    ds_va = SubsetByPIDs(full, va_p)

    g_tr = torch.Generator().manual_seed(base_seed + 12345)
    g_va = torch.Generator().manual_seed(base_seed + 67890)

    ld_tr = DataLoader(
        ds_tr,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_tr,
        persistent_workers=(NUM_WORKERS > 0),
    )

    ld_va = DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_va,
        persistent_workers=(NUM_WORKERS > 0),
    )

    return ld_tr, ld_va, len(ds_tr), len(ds_va)


def get_parameters(model):
    return [p.detach().cpu().numpy() for _, p in model.state_dict().items()]


def set_parameters(model, params):
    sd = model.state_dict()
    for k, v in zip(sd.keys(), params):
        sd[k] = torch.tensor(v).float()
    model.load_state_dict(sd, strict=True)


bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(pred, y):
    return 0.5 * bce(pred, y) + 0.5 * dice_loss(pred, y)


def maybe_save_best(run_dir, cid, val_loss, val_dice, best_epoch, rnd, model):
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    best_json = os.path.join(ckpt_dir, f"client_{cid}_best.json")
    best_pt = os.path.join(ckpt_dir, f"client_{cid}_best.pt")

    prev = {"val_loss": float("inf"), "val_dice": -1.0}
    if os.path.isfile(best_json):
        try:
            with open(best_json, "r") as f:
                prev = json.load(f)
        except Exception:
            pass

    improved = (val_loss < prev.get("val_loss", float("inf"))) and (val_dice > prev.get("val_dice", -1.0))
    if improved:
        torch.save(model.state_dict(), best_pt)
        with open(best_json, "w") as f:
            json.dump(
                {
                    "round": int(rnd),
                    "epoch": int(best_epoch),
                    "val_loss": float(val_loss),
                    "val_dice": float(val_dice),
                },
                f,
            )


class SegClient(fl.client.NumPyClient):
    def __init__(
        self,
        cid: int,
        model_name=DEFAULT_MODEL_NAME,
        encoder_name=DEFAULT_ENCODER_NAME,
        encoder_weights=DEFAULT_ENCODER_WEIGHTS,
    ):
        self.cid = int(cid)
        self.model_name = model_name
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights

        self.base_seed = SEED + self.cid
        seed_everything(self.base_seed, deterministic=True)

        self.run_name = _run_name(model_name, encoder_name)
        self.run_dir = os.path.join("AITDM", self.run_name)

        self.model = get_model(model_name, encoder_name, encoder_weights)
        if self.cid != 2:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)
        else:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)


    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):

        incoming_size = sum([p.nbytes for p in parameters])

        set_parameters(self.model, parameters)

        epochs = int(config.get("local_epochs", 1))
        lr = float(config.get("lr", 1e-3))
        rnd = int(config.get("round", 0))

        opt = optim.AdamW(self.model.parameters(), lr=lr)
        scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

        best_state = None
        best_val_loss = float("inf")
        best_val_dice = -1.0
        best_epoch_idx = -1
        epoch_logs = []

        for epoch_idx in range(1, epochs + 1):
            self.model.train()
            tot_tr_loss = tot_tr_d = tot_tr_i = tot_tr_a = 0.0
            nb_tr = 0

            for batch in self.train_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                opt.zero_grad(set_to_none=True)

                with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
                    pred = self.model(x)
                    loss = criterion(pred, y)

                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

                with torch.no_grad():
                    y_hat = (torch.sigmoid(pred).detach().cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                tot_tr_loss += float(loss.item())
                tot_tr_d += d
                tot_tr_i += i
                tot_tr_a += a
                nb_tr += 1

            nb_tr = max(nb_tr, 1)
            epoch_tr_loss = tot_tr_loss / nb_tr
            epoch_tr_dice = tot_tr_d / nb_tr
            epoch_tr_iou = tot_tr_i / nb_tr
            epoch_tr_acc = tot_tr_a / nb_tr

            self.model.eval()
            tot_val_loss = tot_val_d = tot_val_i = tot_val_a = 0.0
            nb_val = 0

            with torch.no_grad():
                for batch in self.val_loader:
                    x = batch["x"].to(DEVICE)
                    y = batch["y"].to(DEVICE)
                    pred = self.model(x)

                    v_loss = float(criterion(pred, y).item())
                    y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                    tot_val_loss += v_loss
                    tot_val_d += d
                    tot_val_i += i
                    tot_val_a += a
                    nb_val += 1

            nb_val = max(nb_val, 1)
            epoch_val_loss = tot_val_loss / nb_val
            epoch_val_dice = tot_val_d / nb_val
            epoch_val_iou = tot_val_i / nb_val
            epoch_val_acc = tot_val_a / nb_val

            epoch_logs.append(
                {
                    "epoch": int(epoch_idx),
                    "train_loss": float(epoch_tr_loss),
                    "train_dice": float(epoch_tr_dice),
                    "train_iou": float(epoch_tr_iou),
                    "train_acc": float(epoch_tr_acc),
                    "val_loss": float(epoch_val_loss),
                    "val_dice": float(epoch_val_dice),
                    "val_iou": float(epoch_val_iou),
                    "val_acc": float(epoch_val_acc),
                }
            )

            if (epoch_val_loss < best_val_loss) and (epoch_val_dice > best_val_dice):
                best_val_loss = epoch_val_loss
                best_val_dice = epoch_val_dice
                best_state = copy.deepcopy(self.model.state_dict())
                best_epoch_idx = epoch_idx

        if best_state is not None:
            self.model.load_state_dict(best_state)

        for ep in epoch_logs:
            ep["best_epoch"] = (ep["epoch"] == best_epoch_idx)

        train_metrics = {
            "cid": int(self.cid),
            "best_epoch": int(best_epoch_idx),
            "best_val_loss": float(best_val_loss),
            "best_val_dice": float(best_val_dice),
            "per_epoch": json.dumps(epoch_logs),
            "run_name": self.run_name,
            "model_name": self.model_name,
            "encoder_name": self.encoder_name,
        }

        maybe_save_best(self.run_dir, self.cid, best_val_loss, best_val_dice, best_epoch_idx, rnd, self.model)

        out_params = get_parameters(self.model)

        start_overhead = time.time()

        final_params_to_send = [p.astype(np.float16) for p in out_params]

        end_overhead = time.time()

        outgoing_size = sum([p.nbytes for p in final_params_to_send])

        print(f"Client {self.cid} Round {rnd}: Incoming {incoming_size/1024:.2f} KB | Outgoing {outgoing_size/1024:.2f} KB | Overhead: {end_overhead - start_overhead:.6f}s")
        log_file_path = os.path.join(self.run_dir, f"client_{self.cid}_transfer_log.txt")
        log_transfer_metrics(log_file_path, self.cid, rnd, incoming_size, outgoing_size, end_overhead - start_overhead)

        return final_params_to_send, self.ntr, train_metrics

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.eval()

        tot_loss = tot_d = tot_i = tot_a = 0.0
        nb = 0

        with torch.no_grad():
            for batch in self.val_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                pred = self.model(x)

                loss = float(criterion(pred, y).item())
                y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                d, i, a = calc_metrics(y_np, y_hat)

                tot_loss += loss
                tot_d += d
                tot_i += i
                tot_a += a
                nb += 1

        nb = max(nb, 1)
        metrics = {
            "loss": tot_loss / nb,
            "dice": tot_d / nb,
            "iou": tot_i / nb,
            "acc": tot_a / nb,
            "cid": int(self.cid),
            "run_name": self.run_name,
        }

        return metrics["loss"], self.nva, metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cid", type=int, required=True)
    parser.add_argument("--server", default="0.0.0.0:8080")
    args = parser.parse_args()

    fl.client.start_numpy_client(
        server_address=args.server,
        client=SegClient(args.cid),
    )

In [None]:
%%writefile fl_sim_colab.py
import os
import csv
import json
import torch
import flwr as fl
import time
import numpy as np
from flwr.common import FitIns, parameters_to_ndarrays, ndarrays_to_parameters
from fl_client import SegClient

import logging
import warnings
logging.getLogger("flwr").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

def ensure_csv(path: str, header: list[str]):
    if not os.path.isfile(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(header)


def append_row(path: str, row: list):
    with open(path, "a", newline="") as f:
        csv.writer(f).writerow(row)


def log_server_metrics(file_path, rnd, duration, result_size):
    """Logheaza timpul si marimea pachetului trimis de server (Downlink)."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    size_kb = result_size / 1024.0 if result_size > 0 else 0
    msg = f"Round {rnd}: Aggregation Time: {duration:.4f}s | Downlink Payload: {size_kb:.2f} KB\n"

    with open(file_path, "a") as f:
        f.write(msg)


class PerClientLoggingFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, metrics_dir: str, **kwargs):
        super().__init__(**kwargs)
        self.metrics_dir = metrics_dir
        self.header = [
            "round","epoch","train_loss","train_dice","train_iou","train_acc",
            "val_loss","val_dice","val_iou","val_acc","best_epoch",
        ]

    def configure_fit(self, server_round, parameters, client_manager):
        items = super().configure_fit(server_round, parameters, client_manager)
        out = []
        for it in items:
            if isinstance(it, tuple):
                client, fitins = it
            else:
                client, fitins = None, it

            cfg = dict(fitins.config)
            cfg["round"] = server_round
            new_fitins = FitIns(fitins.parameters, cfg)
            out.append((client, new_fitins) if client is not None else new_fitins)
        return out

    def aggregate_fit(self, rnd, results, failures):
        start_time = time.time()

        agg_result = super().aggregate_fit(rnd, results, failures)

        final_params = None
        agg_metrics = {}
        downlink_size_bytes = 0

        if agg_result is not None:
            agg_params, agg_metrics = agg_result

            ndarrays = parameters_to_ndarrays(agg_params)

            ndarrays_fp16 = [val.astype(np.float16) for val in ndarrays]

            final_params = ndarrays_to_parameters(ndarrays_fp16)

            downlink_size_bytes = sum([len(t) for t in final_params.tensors])

            agg_result = (final_params, agg_metrics)

        end_time = time.time()

        print(f"Round {rnd} Server: Aggregation {end_time - start_time:.4f}s | Downlink size: {downlink_size_bytes/1024:.2f} KB")

        log_file_path = os.path.join(self.metrics_dir, "server_aggregation_log.txt")
        log_server_metrics(log_file_path, rnd, end_time - start_time, downlink_size_bytes)

        for client_proxy, fit_res in results:
            m = fit_res.metrics or {}
            cid = str(m.get("cid", client_proxy.cid))

            client_csv = os.path.join(self.metrics_dir, f"metrics_client_{cid}.csv")
            ensure_csv(client_csv, self.header)

            best_epoch = int(m.get("best_epoch", -1))
            per_epoch_raw = m.get("per_epoch", "[]")

            try:
                per_epoch = json.loads(per_epoch_raw)
            except Exception:
                per_epoch = []

            for ep in per_epoch:
                epoch = ep.get("epoch", "")
                row = [
                    rnd,
                    epoch,
                    ep.get("train_loss", ""),
                    ep.get("train_dice", ""),
                    ep.get("train_iou", ""),
                    ep.get("train_acc", ""),
                    ep.get("val_loss", ""),
                    ep.get("val_dice", ""),
                    ep.get("val_iou", ""),
                    ep.get("val_acc", ""),
                    "x" if int(epoch) == best_epoch else "",
                ]
                append_row(client_csv, row)

        return agg_result


def run_one_experiment(model_name: str, encoder_name: str, num_rounds=5, local_epochs=5, lr=1e-3):
    run_name = f"{model_name}__{encoder_name}".replace("/", "-")
    base_dir = os.path.join("AITDM", run_name)
    metrics_dir = os.path.join(base_dir, "metrics")
    os.makedirs(metrics_dir, exist_ok=True)

    def client_fn(cid: str):
        return SegClient(int(cid), model_name=model_name, encoder_name=encoder_name).to_client()

    strategy = PerClientLoggingFedAvg(
        metrics_dir=metrics_dir,
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=3,
        on_fit_config_fn=lambda rnd: {"local_epochs": local_epochs, "lr": lr},
    )

    use_gpu = torch.cuda.is_available()
    client_resources = {"num_cpus": 1, "num_gpus": 1.0 if use_gpu else 0.0}

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources=client_resources,
        ray_init_args={"include_dashboard": False},
    )


if __name__ == "__main__":
    experiments = [
        ("unet", "resnet50"),
        ("unet", "mit_b3"),
        ("deeplabv3plus", "timm-mobilenetv3_small_100"),
    ]

    for model_name, encoder_name in experiments:
        print(f"\n=== Running: {model_name} + {encoder_name} ===")
        start_time = time.time()
        run_one_experiment(model_name, encoder_name, num_rounds=15, local_epochs=10, lr=1e-3)
        run_time = time.time() - start_time
        print(f"Experiment completed in {run_time:.2f} seconds.\n")

In [None]:
!rm -rf /content/AITDM/client_0/*
!rm -rf /content/AITDM/client_1/*
!rm -rf /content/AITDM/client_2/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/checkpoints/*
!rm -rf /content/AITDM/unet__resnet50/checkpoints/*
!rm -rf /content/AITDM/unet__mit_b3/checkpoints/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/metrics/*
!rm -rf /content/AITDM/unet__resnet50/metrics/*
!rm -rf /content/AITDM/unet__mit_b3/metrics/*
!python fl_sim_colab.py


=== Running: unet + resnet50 ===
2026-01-16 16:21:39.174276: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-16 16:21:39.192008: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768580499.213809    3890 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768580499.220357    3890 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768580499.237091    3890 computation_placer.cc:177] computation placer already registered. Pleas

In [None]:
import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import logging
from fl_client import TRANSFORMS

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
    calc_metrics,
)

logging.getLogger("timm.models._builder").setLevel(logging.ERROR)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
THRESHOLD = 0.5

CLIENT_DIR = "/content/client"
BATCH_SIZE = 8
NUM_WORKERS = 2

ENSEMBLE_CFGS = [
    ("unet", "resnet50"),
    ("unet", "mit_b3"),
    ("deeplabv3plus", "timm-mobilenetv3_small_100"),
]

WEIGHT_MODE = "power"
WEIGHT_POWER = 14.0
WEIGHT_EPS = 1e-6


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)

bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(logits, y):
    return 0.5 * bce(logits, y) + 0.5 * dice_loss(logits, y)


def get_val_loader(cid: int):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"]
    )
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_va = SubsetByPIDs(full, va_p)
    g = torch.Generator().manual_seed(SEED)

    ld_va = torch.utils.data.DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=g,
    )
    return ld_va, len(ds_va)


def run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def build_model(model_name: str, encoder_name: str, encoder_weights="imagenet"):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        m = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        m = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}")

    return m.to(DEVICE)


def ckpt_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.pt")


def best_json_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.json")


def load_model(model_name: str, encoder_name: str, cid: int):
    path = ckpt_path(model_name, encoder_name, cid)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Missing checkpoint: {path}")
    m = build_model(model_name, encoder_name).to(DEVICE)
    sd = torch.load(path, map_location="cpu")
    m.load_state_dict(sd, strict=True)
    m.eval()
    return m


def load_best_val_dice(model_name: str, encoder_name: str, cid: int) -> float:
    p = best_json_path(model_name, encoder_name, cid)
    if not os.path.isfile(p):
        raise FileNotFoundError(f"Missing best json: {p}")
    with open(p, "r") as f:
        j = json.load(f)
    return float(j.get("val_dice", 0.0))


def get_client_weights(cid: int, cfgs, mode="power", power=2.0, eps=1e-6):
    dices = [load_best_val_dice(mn, enc, cid) for (mn, enc) in cfgs]
    d = np.array(dices, dtype=np.float32)

    if mode == "linear":
        raw = np.clip(d, 0.0, None)
    elif mode == "power":
        raw = np.power(np.clip(d, 0.0, None), power)
    else:
        raise ValueError("mode must be 'linear' or 'power'")

    raw = raw + eps
    w = raw / raw.sum()
    return w.tolist(), dices


@torch.no_grad()
def ensemble_forward_logits(x, models, weights, target_hw=None):
    w = np.array(weights, dtype=np.float32)
    w = w / (w.sum() + 1e-8)

    logits_sum = None
    for mi, wi in zip(models, w):
        li = mi(x)
        if target_hw is not None and li.shape[-2:] != target_hw:
            li = F.interpolate(li, size=target_hw, mode="bilinear", align_corners=False)
        logits_sum = li * float(wi) if logits_sum is None else logits_sum + li * float(wi)
    return logits_sum


@torch.no_grad()
def eval_ensemble_on_client(cid: int, threshold=0.5):
    val_loader, nva = get_val_loader(cid)
    models = [load_model(mn, enc, cid) for (mn, enc) in ENSEMBLE_CFGS]

    weights, best_dices = get_client_weights(
        cid, ENSEMBLE_CFGS, mode=WEIGHT_MODE, power=WEIGHT_POWER, eps=WEIGHT_EPS
    )

    tot_loss = tot_d = tot_i = tot_a = 0.0
    nb = 0

    for batch in val_loader:
        x = batch["x"].to(DEVICE)
        y = batch["y"].to(DEVICE)

        logits = ensemble_forward_logits(x, models=models, weights=weights, target_hw=y.shape[-2:])
        loss = float(criterion(logits, y).item())

        preds_bin = (torch.sigmoid(logits).cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)

        d, i, a = calc_metrics(y_np, preds_bin)

        tot_loss += loss
        tot_d += d
        tot_i += i
        tot_a += a
        nb += 1

    nb = max(nb, 1)
    return {
        "cid": int(cid),
        "nva": int(nva),
        "loss": tot_loss / nb,
        "dice": tot_d / nb,
        "iou": tot_i / nb,
        "acc": tot_a / nb,
        "weights": weights,
        "best_dices": best_dices,
    }


def log_client_metrics(file_path, r, bd, ww):

    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    log_message = (
        f"Client {r['cid']} (n={r['nva']}): "
        f"loss={r['loss']:.4f} dice={r['dice']:.4f} iou={r['iou']:.4f} acc={r['acc']:.4f}\n"
        f"  best_dice={['%.4f' % x for x in bd]} -> weights={['%.3f' % x for x in ww]}\n"
    )

    print(log_message, end='')

    with open(file_path, "a") as f:
        f.write(log_message)

def main():
    print(f"[Ensemble-3 | threshold={THRESHOLD} | weight_mode={WEIGHT_MODE} | power={WEIGHT_POWER}]\n")
    for cid in [0, 1, 2]:
        r = eval_ensemble_on_client(cid, threshold=THRESHOLD)
        bd = r["best_dices"]
        ww = r["weights"]
        log_client_metrics(os.path.join('/content/AITDM', f"client_{cid}", "val_metrics_ensemble.txt"), r, bd, ww)


if __name__ == "__main__":
    main()

In [None]:
!mkdir -p /content/drive/MyDrive/AITDM
!cp -rf /content/AITDM /content/drive/MyDrive/AITDM/QUANT_3

In [None]:
from google.colab import runtime
runtime.unassign()

# **M2 - FL with ensemble and Top-K Sparsification**

In [None]:
%%writefile seg_data.py
import os, pickle, numpy as np, torch
from typing import List
from torch.utils.data import Dataset, DataLoader
import torchio as tio

# Global paths and configuration
DATA_ROOT = "/content/Preprocessed-Data"
METADATA_DF_PATH = "cleaned_df.pkl"
USE_ATLAS = True
EXCLUDE_IDS = ["PatientID_0191"]

# Dataset that loads MRI, tumor mask and optional atlas for each patient
class ImageOnlyGliomaDataset(Dataset):
    def __init__(
        self,
        metadata_df_path,
        data_root,
        use_atlas=False,
        exclude_ids=None,
        transform=None,
    ):
        with open(metadata_df_path, "rb") as f:
            df = pickle.load(f)

        if exclude_ids is None:
            exclude_ids = []

        self.df = df[~df["Patient_ID"].isin(exclude_ids)].reset_index(drop=True)
        self.data_root = data_root
        self.use_atlas = use_atlas
        self.transform = transform

        self.patient_ids = []
        for pid in sorted(self.df["Patient_ID"].tolist()):
            base = os.path.join(self.data_root, pid)
            mri_p = os.path.join(base, f"{pid}_mri.npy")
            tumor_p = os.path.join(base, f"{pid}_tumor.npy")

            is_valid = os.path.isfile(mri_p) and os.path.isfile(tumor_p)
            if self.use_atlas:
                reg_p = os.path.join(base, f"{pid}_regions.npy")
                is_valid = is_valid and os.path.isfile(reg_p)

            if is_valid:
                self.patient_ids.append(pid)

    def __len__(self):
        return len(self.patient_ids)

    @staticmethod
    def _minmax(x):
        x = x.astype(np.float32)
        mn, mx = np.min(x), np.max(x)
        if mx > mn:
            return (x - mn) / (mx - mn)
        return np.zeros_like(x, dtype=np.float32)

    def _to_torchio_format(self, arr):
        """
        Convertește (H, W) -> (1, H, W, 1) pentru procesare internă TorchIO.
        """
        if arr.ndim == 2:
            return arr[np.newaxis, ..., np.newaxis]
        return arr

    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        base = os.path.join(self.data_root, pid)

        mri = np.load(os.path.join(base, f"{pid}_mri.npy")).astype(np.float32)
        tumor = np.load(os.path.join(base, f"{pid}_tumor.npy")).astype(np.float32)


        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)


        regions = None
        if self.use_atlas:
            regions = np.load(os.path.join(base, f"{pid}_regions.npy")).astype(np.float32)

        mri = self._minmax(mri)
        tumor = (tumor > 0.5).astype(np.float32)

        if regions is not None:
            regions = self._minmax(regions)


        if self.transform:
            subject_dict = {
                'mri': tio.ScalarImage(tensor=self._to_torchio_format(mri)),
                'tumor': tio.LabelMap(tensor=self._to_torchio_format(tumor)),
            }
            if regions is not None:
                subject_dict['regions'] = tio.ScalarImage(tensor=self._to_torchio_format(regions))

            subject = tio.Subject(subject_dict)

            subject = self.transform(subject)

            out_mri = subject['mri'].data[0, ..., 0].numpy()
            out_tumor = subject['tumor'].data[0, ..., 0].numpy()
            sample = {
                "patient_id": pid,
                "mri": out_mri,       # Shape: (240, 240)
                "tumor": out_tumor    # Shape: (240, 240)
            }

            if regions is not None:
                out_regions = subject['regions'].data[0, ..., 0].numpy()
                sample["regions"] = out_regions # Shape: (240, 240)

            return sample

        sample = {"patient_id": pid, "mri": mri, "tumor": tumor}

        if self.use_atlas:
            sample["regions"] = regions

        return sample

# Collate function to build batched tensors and patient ID list
def image_only_collate_fn(batch, use_atlas=USE_ATLAS):
    mri = torch.stack([torch.tensor(it["mri"]) for it in batch]).unsqueeze(1)
    y = torch.stack([torch.tensor(it["tumor"]) for it in batch]).unsqueeze(1)

    if use_atlas:
        regs = torch.stack([torch.tensor(it["regions"]) for it in batch]).unsqueeze(1)
        x = torch.cat([mri.float(), regs.float()], dim=1)
    else:
        x = mri.float()

    return {"x": x, "y": y.float(), "pid": [it["patient_id"] for it in batch]}

# Dataset wrapper that restricts to a subset of patient IDs
class SubsetByPIDs(Dataset):
    def __init__(self, full_dataset: ImageOnlyGliomaDataset, pid_list: List[str]):
        self.ds = full_dataset
        pid_to_idx = {pid: i for i, pid in enumerate(self.ds.patient_ids)}
        self.indices = [pid_to_idx[p] for p in pid_list if p in pid_to_idx]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        return self.ds[self.indices[i]]

# Compute Dice, IoU and accuracy for binary masks
def calc_metrics(y_true, y_pred):
    y_true = y_true.astype(np.uint8).reshape(-1)
    y_pred = y_pred.astype(np.uint8).reshape(-1)

    inter = (y_true & y_pred).sum()
    dice = (2.0 * inter) / (y_true.sum() + y_pred.sum() + 1e-8)
    union = y_true.sum() + y_pred.sum() - inter + 1e-8
    iou = inter / union
    acc = (y_true == y_pred).mean()

    return float(dice), float(iou), float(acc)

In [None]:
%%writefile fl_client.py
import os
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import argparse
import json
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import flwr as fl
import copy
import segmentation_models_pytorch as smp
import random
import torchio as tio
import sys
import time

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    calc_metrics,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
)

CLIENT_DIR = "/content/client"

BATCH_SIZE = 8
NUM_WORKERS = 2
SEED = 42
SPARSITY_RATE = 0.1

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DEFAULT_MODEL_NAME = "unet"
DEFAULT_ENCODER_NAME = "timm-mobilenetv3_small_100"
DEFAULT_ENCODER_WEIGHTS = "imagenet"
TRANSFORMS = None


def _run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def seed_everything(seed: int, deterministic: bool = True) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        torch.use_deterministic_algorithms(True)


def seed_worker(worker_id: int) -> None:
    worker_seed = (torch.initial_seed() + worker_id) % (2**32)
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def log_transfer_metrics(file_path, cid, rnd, incoming, outgoing, overhead):
    """Functie helper pentru a salva metricile de transfer in fisier."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    msg = (f"Client {cid} Round {rnd}: "
           f"Incoming {incoming/1024:.2f} KB | "
           f"Outgoing (Actual Sparse Packet) {outgoing/1024:.2f} KB | "
           f"Overhead: {overhead:.6f}s\n")

    with open(file_path, "a") as f:
        f.write(msg)


def get_model(
    model_name=DEFAULT_MODEL_NAME,
    encoder_name=DEFAULT_ENCODER_NAME,
    encoder_weights=DEFAULT_ENCODER_WEIGHTS,
):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        model = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}. Use 'unet' or 'deeplabv3plus'.")

    return model.to(DEVICE)


def get_loaders(cid: int, base_seed: int, transforms):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH,
        DATA_ROOT,
        use_atlas=USE_ATLAS,
        exclude_ids=["PatientID_0191"],
        transform=transforms,
    )

    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "train_pids.json")) as f:
        tr_p = json.load(f)
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_tr = SubsetByPIDs(full, tr_p)
    ds_va = SubsetByPIDs(full, va_p)

    g_tr = torch.Generator().manual_seed(base_seed + 12345)
    g_va = torch.Generator().manual_seed(base_seed + 67890)

    ld_tr = DataLoader(
        ds_tr,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_tr,
        persistent_workers=(NUM_WORKERS > 0),
    )

    ld_va = DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        worker_init_fn=seed_worker,
        generator=g_va,
        persistent_workers=(NUM_WORKERS > 0),
    )

    return ld_tr, ld_va, len(ds_tr), len(ds_va)


def get_parameters(model):
    return [p.detach().cpu().numpy() for _, p in model.state_dict().items()]


def set_parameters(model, params):
    sd = model.state_dict()
    for k, v in zip(sd.keys(), params):
        sd[k] = torch.tensor(v)
    model.load_state_dict(sd, strict=True)


bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(pred, y):
    return 0.5 * bce(pred, y) + 0.5 * dice_loss(pred, y)


def maybe_save_best(run_dir, cid, val_loss, val_dice, best_epoch, rnd, model):
    ckpt_dir = os.path.join(run_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    best_json = os.path.join(ckpt_dir, f"client_{cid}_best.json")
    best_pt = os.path.join(ckpt_dir, f"client_{cid}_best.pt")

    prev = {"val_loss": float("inf"), "val_dice": -1.0}
    if os.path.isfile(best_json):
        try:
            with open(best_json, "r") as f:
                prev = json.load(f)
        except Exception:
            pass

    improved = (val_loss < prev.get("val_loss", float("inf"))) and (val_dice > prev.get("val_dice", -1.0))
    if improved:
        torch.save(model.state_dict(), best_pt)
        with open(best_json, "w") as f:
            json.dump(
                {
                    "round": int(rnd),
                    "epoch": int(best_epoch),
                    "val_loss": float(val_loss),
                    "val_dice": float(val_dice),
                },
                f,
            )


class SegClient(fl.client.NumPyClient):
    def __init__(
        self,
        cid: int,
        model_name=DEFAULT_MODEL_NAME,
        encoder_name=DEFAULT_ENCODER_NAME,
        encoder_weights=DEFAULT_ENCODER_WEIGHTS,
    ):
        self.cid = int(cid)
        self.model_name = model_name
        self.encoder_name = encoder_name
        self.encoder_weights = encoder_weights

        self.base_seed = SEED + self.cid
        seed_everything(self.base_seed, deterministic=True)

        self.run_name = _run_name(model_name, encoder_name)
        self.run_dir = os.path.join("AITDM", self.run_name)

        self.model = get_model(model_name, encoder_name, encoder_weights)
        if self.cid != 2:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)
        else:
            self.train_loader, self.val_loader, self.ntr, self.nva = get_loaders(self.cid, self.base_seed, transforms=TRANSFORMS)

        self.residual = [np.zeros_like(p) for p in get_parameters(self.model)]
        self.sparsity_rate = SPARSITY_RATE


    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):
        # --- MEASURE INCOMING ---
        incoming_size = sum([p.nbytes for p in parameters])

        global_parameters = [np.copy(p) for p in parameters]

        set_parameters(self.model, parameters)

        epochs = int(config.get("local_epochs", 1))
        lr = float(config.get("lr", 1e-3))
        rnd = int(config.get("round", 0))

        opt = optim.AdamW(self.model.parameters(), lr=lr)
        scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE.type == "cuda"))

        best_state = None
        best_val_loss = float("inf")
        best_val_dice = -1.0
        best_epoch_idx = -1
        epoch_logs = []

        # --- TRAINING LOOP ---
        for epoch_idx in range(1, epochs + 1):
            self.model.train()
            tot_tr_loss = tot_tr_d = tot_tr_i = tot_tr_a = 0.0
            nb_tr = 0

            for batch in self.train_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                opt.zero_grad(set_to_none=True)

                with torch.amp.autocast("cuda", enabled=(DEVICE.type == "cuda")):
                    pred = self.model(x)
                    loss = criterion(pred, y)

                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

                with torch.no_grad():
                    y_hat = (torch.sigmoid(pred).detach().cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.detach().cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                tot_tr_loss += float(loss.item())
                tot_tr_d += d
                tot_tr_i += i
                tot_tr_a += a
                nb_tr += 1

            nb_tr = max(nb_tr, 1)
            epoch_tr_loss = tot_tr_loss / nb_tr
            epoch_tr_dice = tot_tr_d / nb_tr
            epoch_tr_iou = tot_tr_i / nb_tr
            epoch_tr_acc = tot_tr_a / nb_tr

            self.model.eval()
            tot_val_loss = tot_val_d = tot_val_i = tot_val_a = 0.0
            nb_val = 0

            with torch.no_grad():
                for batch in self.val_loader:
                    x = batch["x"].to(DEVICE)
                    y = batch["y"].to(DEVICE)
                    pred = self.model(x)

                    v_loss = float(criterion(pred, y).item())
                    y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                    y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                    d, i, a = calc_metrics(y_np, y_hat)

                    tot_val_loss += v_loss
                    tot_val_d += d
                    tot_val_i += i
                    tot_val_a += a
                    nb_val += 1

            nb_val = max(nb_val, 1)
            epoch_val_loss = tot_val_loss / nb_val
            epoch_val_dice = tot_val_d / nb_val
            epoch_val_iou = tot_val_i / nb_val
            epoch_val_acc = tot_val_a / nb_val

            epoch_logs.append(
                {
                    "epoch": int(epoch_idx),
                    "train_loss": float(epoch_tr_loss),
                    "train_dice": float(epoch_tr_dice),
                    "train_iou": float(epoch_tr_iou),
                    "train_acc": float(epoch_tr_acc),
                    "val_loss": float(epoch_val_loss),
                    "val_dice": float(epoch_val_dice),
                    "val_iou": float(epoch_val_iou),
                    "val_acc": float(epoch_val_acc),
                }
            )

            if (epoch_val_loss < best_val_loss) and (epoch_val_dice > best_val_dice):
                best_val_loss = epoch_val_loss
                best_val_dice = epoch_val_dice
                best_state = copy.deepcopy(self.model.state_dict())
                best_epoch_idx = epoch_idx

        if best_state is not None:
            self.model.load_state_dict(best_state)

        for ep in epoch_logs:
            ep["best_epoch"] = (ep["epoch"] == best_epoch_idx)

        train_metrics = {
            "cid": int(self.cid),
            "best_epoch": int(best_epoch_idx),
            "best_val_loss": float(best_val_loss),
            "best_val_dice": float(best_val_dice),
            "per_epoch": json.dumps(epoch_logs),
            "run_name": self.run_name,
            "model_name": self.model_name,
            "encoder_name": self.encoder_name,
        }

        maybe_save_best(self.run_dir, self.cid, best_val_loss, best_val_dice, best_epoch_idx, rnd, self.model)

        start_overhead = time.time()

        trained_parameters = get_parameters(self.model)

        parameters_to_send_flower = []
        actual_transmission_size = 0

        for i, (new_w, old_w, res) in enumerate(zip(trained_parameters, global_parameters, self.residual)):
            update = new_w - old_w
            accumulated_update = update + res

            flat = accumulated_update.flatten()
            k = int(flat.size * self.sparsity_rate)

            if k > 0:
                idx = np.argpartition(np.abs(flat), -k)[-k:]
                threshold = np.min(np.abs(flat[idx]))
            else:
                threshold = np.inf

            mask = np.abs(accumulated_update) >= threshold
            sparse_update = accumulated_update * mask
            self.residual[i] = accumulated_update - sparse_update


            non_zero_values = sparse_update[mask]

            non_zero_indices = np.where(mask.flatten())[0]

            actual_transmission_size += non_zero_values.nbytes + non_zero_indices.nbytes


            param_to_send = old_w + sparse_update
            parameters_to_send_flower.append(param_to_send)

        end_overhead = time.time()

        # --- LOGGING ---
        print(f"Client {self.cid} Round {rnd}: Incoming {incoming_size/1024:.2f} KB | Outgoing (Packed Sparse) {actual_transmission_size/1024:.2f} KB | Overhead: {end_overhead - start_overhead:.6f}s")

        log_file_path = os.path.join(self.run_dir, f"client_{self.cid}_transfer_log.txt")
        log_transfer_metrics(log_file_path, self.cid, rnd, incoming_size, actual_transmission_size, end_overhead - start_overhead)

        return parameters_to_send_flower, self.ntr, train_metrics

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.eval()

        tot_loss = tot_d = tot_i = tot_a = 0.0
        nb = 0

        with torch.no_grad():
            for batch in self.val_loader:
                x = batch["x"].to(DEVICE)
                y = batch["y"].to(DEVICE)
                pred = self.model(x)

                loss = float(criterion(pred, y).item())
                y_hat = (torch.sigmoid(pred).cpu().numpy() > 0.5).astype(np.uint8)
                y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)
                d, i, a = calc_metrics(y_np, y_hat)

                tot_loss += loss
                tot_d += d
                tot_i += i
                tot_a += a
                nb += 1

        nb = max(nb, 1)
        metrics = {
            "loss": tot_loss / nb,
            "dice": tot_d / nb,
            "iou": tot_i / nb,
            "acc": tot_a / nb,
            "cid": int(self.cid),
            "run_name": self.run_name,
        }

        return metrics["loss"], self.nva, metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cid", type=int, required=True)
    parser.add_argument("--server", default="0.0.0.0:8080")
    args = parser.parse_args()

    fl.client.start_numpy_client(
        server_address=args.server,
        client=SegClient(args.cid),
    )

In [None]:
%%writefile fl_sim_colab.py
import os
import csv
import json
import torch
import flwr as fl
import time
from flwr.common import FitIns
import torchio as tio
from fl_client import SegClient

import logging
import warnings
logging.getLogger("flwr").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

def ensure_csv(path: str, header: list[str]):
    if not os.path.isfile(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", newline="") as f:
            csv.writer(f).writerow(header)


def append_row(path: str, row: list):
    with open(path, "a", newline="") as f:
        csv.writer(f).writerow(row)


def log_server_metrics(file_path, rnd, duration):
    """Functie helper pentru logare timp agregare server in fisier."""
    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    msg = f"Round {rnd} Aggregation Time: {duration:.4f} seconds\n"

    with open(file_path, "a") as f:
        f.write(msg)


class PerClientLoggingFedAvg(fl.server.strategy.FedAvg):
    def __init__(self, metrics_dir: str, **kwargs):
        super().__init__(**kwargs)
        self.metrics_dir = metrics_dir
        self.header = [
            "round","epoch","train_loss","train_dice","train_iou","train_acc",
            "val_loss","val_dice","val_iou","val_acc","best_epoch",
        ]

    def configure_fit(self, server_round, parameters, client_manager):
        items = super().configure_fit(server_round, parameters, client_manager)
        out = []
        for it in items:
            if isinstance(it, tuple):
                client, fitins = it
            else:
                client, fitins = None, it

            cfg = dict(fitins.config)
            cfg["round"] = server_round
            new_fitins = FitIns(fitins.parameters, cfg)
            out.append((client, new_fitins) if client is not None else new_fitins)
        return out

    def aggregate_fit(self, rnd, results, failures):
        # --- START TIMER ---
        curr_cid = 0
        total_uplink_bytes = 0
        for client_proxy, fit_res in results:
            client_bytes = sum(len(t) for t in fit_res.parameters.tensors)
            total_uplink_bytes += client_bytes
            print(f"Client {curr_cid} has sent {client_bytes} bytes!")
            curr_cid += 1

        start_time = time.time()


        agg = super().aggregate_fit(rnd, results, failures)

        end_time = time.time()

        print(f"Round {rnd} Aggregation Time: {end_time - start_time:.4f} seconds")
        log_file_path = os.path.join(self.metrics_dir, "server_aggregation_log.txt")
        log_server_metrics(log_file_path, rnd, end_time - start_time)

        for client_proxy, fit_res in results:
            m = fit_res.metrics or {}
            cid = str(m.get("cid", client_proxy.cid))

            client_csv = os.path.join(self.metrics_dir, f"metrics_client_{cid}.csv")
            ensure_csv(client_csv, self.header)

            best_epoch = int(m.get("best_epoch", -1))
            per_epoch_raw = m.get("per_epoch", "[]")

            try:
                per_epoch = json.loads(per_epoch_raw)
            except Exception:
                per_epoch = []

            for ep in per_epoch:
                epoch = ep.get("epoch", "")
                row = [
                    rnd,
                    epoch,
                    ep.get("train_loss", ""),
                    ep.get("train_dice", ""),
                    ep.get("train_iou", ""),
                    ep.get("train_acc", ""),
                    ep.get("val_loss", ""),
                    ep.get("val_dice", ""),
                    ep.get("val_iou", ""),
                    ep.get("val_acc", ""),
                    "x" if int(epoch) == best_epoch else "",
                ]
                append_row(client_csv, row)

        return agg


def run_one_experiment(model_name: str, encoder_name: str, num_rounds=5, local_epochs=5, lr=1e-3):
    run_name = f"{model_name}__{encoder_name}".replace("/", "-")
    base_dir = os.path.join("AITDM", run_name)
    metrics_dir = os.path.join(base_dir, "metrics")
    os.makedirs(metrics_dir, exist_ok=True)

    def client_fn(cid: str):
        return SegClient(int(cid), model_name=model_name, encoder_name=encoder_name).to_client()

    strategy = PerClientLoggingFedAvg(
        metrics_dir=metrics_dir,
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=3,
        on_fit_config_fn=lambda rnd: {"local_epochs": local_epochs, "lr": lr},
    )

    use_gpu = torch.cuda.is_available()
    client_resources = {"num_cpus": 1, "num_gpus": 1.0 if use_gpu else 0.0}

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources=client_resources,
        ray_init_args={"include_dashboard": False},
    )


if __name__ == "__main__":
    experiments = [
        ("unet", "resnet50"),
        ("unet", "mit_b3"),
        ("deeplabv3plus", "timm-mobilenetv3_small_100"),
    ]

    for model_name, encoder_name in experiments:
        print(f"\n=== Running: {model_name} + {encoder_name} ===")
        start_time = time.time()
        run_one_experiment(model_name, encoder_name, num_rounds=15, local_epochs=10, lr=1e-3)
        run_time = time.time() - start_time
        print(f"Total run time: {run_time:.2f} seconds\n")

In [None]:
!rm -rf /content/AITDM/client_0/*
!rm -rf /content/AITDM/client_1/*
!rm -rf /content/AITDM/client_2/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/checkpoints/*
!rm -rf /content/AITDM/unet__resnet50/checkpoints/*
!rm -rf /content/AITDM/unet__mit_b3/checkpoints/*
!rm -rf /content/AITDM/deeplabv3plus__timm-mobilenetv3_small_100/metrics/*
!rm -rf /content/AITDM/unet__resnet50/metrics/*
!rm -rf /content/AITDM/unet__mit_b3/metrics/*
!python fl_sim_colab.py

In [None]:
import os, json, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import logging
from fl_client import TRANSFORMS

from seg_data import (
    ImageOnlyGliomaDataset,
    SubsetByPIDs,
    image_only_collate_fn,
    DATA_ROOT,
    METADATA_DF_PATH,
    USE_ATLAS,
    calc_metrics,
)

logging.getLogger("timm.models._builder").setLevel(logging.ERROR)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
THRESHOLD = 0.5

CLIENT_DIR = "/content/client"
BATCH_SIZE = 8
NUM_WORKERS = 2

ENSEMBLE_CFGS = [
    ("unet", "resnet50"),
    ("unet", "mit_b3"),
    ("deeplabv3plus", "timm-mobilenetv3_small_100"),
]

WEIGHT_MODE = "power"
WEIGHT_POWER = 14.0
WEIGHT_EPS = 1e-6


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(SEED)

bce = nn.BCEWithLogitsLoss()
dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)


def criterion(logits, y):
    return 0.5 * bce(logits, y) + 0.5 * dice_loss(logits, y)


def get_val_loader(cid: int):
    full = ImageOnlyGliomaDataset(
        METADATA_DF_PATH, DATA_ROOT, use_atlas=USE_ATLAS, exclude_ids=["PatientID_0191"]
    )
    with open(os.path.join(CLIENT_DIR, f"client_{cid}", "val_pids.json")) as f:
        va_p = json.load(f)

    ds_va = SubsetByPIDs(full, va_p)
    g = torch.Generator().manual_seed(SEED)

    ld_va = torch.utils.data.DataLoader(
        ds_va,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        collate_fn=lambda b: image_only_collate_fn(b, use_atlas=USE_ATLAS),
        generator=g,
    )
    return ld_va, len(ds_va)


def run_name(model_name: str, encoder_name: str) -> str:
    return f"{model_name}__{encoder_name}".replace("/", "-")


def build_model(model_name: str, encoder_name: str, encoder_weights="imagenet"):
    in_ch = 2 if USE_ATLAS else 1
    mn = model_name.lower()

    if mn == "unet":
        m = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    elif mn in ["deeplabv3plus", "deeplabv3+", "dlv3p"]:
        m = smp.DeepLabV3Plus(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_ch,
            classes=1,
        )
    else:
        raise ValueError(f"Unknown model_name={model_name}")

    return m.to(DEVICE)


def ckpt_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.pt")


def best_json_path(model_name: str, encoder_name: str, cid: int) -> str:
    rn = run_name(model_name, encoder_name)
    return os.path.join("AITDM", rn, "checkpoints", f"client_{cid}_best.json")


def load_model(model_name: str, encoder_name: str, cid: int):
    path = ckpt_path(model_name, encoder_name, cid)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Missing checkpoint: {path}")
    m = build_model(model_name, encoder_name).to(DEVICE)
    sd = torch.load(path, map_location="cpu")
    m.load_state_dict(sd, strict=True)
    m.eval()
    return m


def load_best_val_dice(model_name: str, encoder_name: str, cid: int) -> float:
    p = best_json_path(model_name, encoder_name, cid)
    if not os.path.isfile(p):
        raise FileNotFoundError(f"Missing best json: {p}")
    with open(p, "r") as f:
        j = json.load(f)
    return float(j.get("val_dice", 0.0))


def get_client_weights(cid: int, cfgs, mode="power", power=2.0, eps=1e-6):
    dices = [load_best_val_dice(mn, enc, cid) for (mn, enc) in cfgs]
    d = np.array(dices, dtype=np.float32)

    if mode == "linear":
        raw = np.clip(d, 0.0, None)
    elif mode == "power":
        raw = np.power(np.clip(d, 0.0, None), power)
    else:
        raise ValueError("mode must be 'linear' or 'power'")

    raw = raw + eps
    w = raw / raw.sum()
    return w.tolist(), dices


@torch.no_grad()
def ensemble_forward_logits(x, models, weights, target_hw=None):
    w = np.array(weights, dtype=np.float32)
    w = w / (w.sum() + 1e-8)

    logits_sum = None
    for mi, wi in zip(models, w):
        li = mi(x)
        if target_hw is not None and li.shape[-2:] != target_hw:
            li = F.interpolate(li, size=target_hw, mode="bilinear", align_corners=False)
        logits_sum = li * float(wi) if logits_sum is None else logits_sum + li * float(wi)
    return logits_sum


@torch.no_grad()
def eval_ensemble_on_client(cid: int, threshold=0.5):
    val_loader, nva = get_val_loader(cid)
    models = [load_model(mn, enc, cid) for (mn, enc) in ENSEMBLE_CFGS]

    weights, best_dices = get_client_weights(
        cid, ENSEMBLE_CFGS, mode=WEIGHT_MODE, power=WEIGHT_POWER, eps=WEIGHT_EPS
    )

    tot_loss = tot_d = tot_i = tot_a = 0.0
    nb = 0

    for batch in val_loader:
        x = batch["x"].to(DEVICE)
        y = batch["y"].to(DEVICE)

        logits = ensemble_forward_logits(x, models=models, weights=weights, target_hw=y.shape[-2:])
        loss = float(criterion(logits, y).item())

        preds_bin = (torch.sigmoid(logits).cpu().numpy() > threshold).astype(np.uint8)
        y_np = (y.cpu().numpy() > 0.5).astype(np.uint8)

        d, i, a = calc_metrics(y_np, preds_bin)

        tot_loss += loss
        tot_d += d
        tot_i += i
        tot_a += a
        nb += 1

    nb = max(nb, 1)
    return {
        "cid": int(cid),
        "nva": int(nva),
        "loss": tot_loss / nb,
        "dice": tot_d / nb,
        "iou": tot_i / nb,
        "acc": tot_a / nb,
        "weights": weights,
        "best_dices": best_dices,
    }


def log_client_metrics(file_path, r, bd, ww):

    directory = os.path.dirname(file_path)
    if directory and not os.path.exists(directory):
        os.makedirs(directory)

    log_message = (
        f"Client {r['cid']} (n={r['nva']}): "
        f"loss={r['loss']:.4f} dice={r['dice']:.4f} iou={r['iou']:.4f} acc={r['acc']:.4f}\n"
        f"  best_dice={['%.4f' % x for x in bd]} -> weights={['%.3f' % x for x in ww]}\n"
    )

    print(log_message, end='')

    with open(file_path, "a") as f:
        f.write(log_message)

def main():
    print(f"[Ensemble-3 | threshold={THRESHOLD} | weight_mode={WEIGHT_MODE} | power={WEIGHT_POWER}]\n")
    for cid in [0, 1, 2]:
        r = eval_ensemble_on_client(cid, threshold=THRESHOLD)
        bd = r["best_dices"]
        ww = r["weights"]
        log_client_metrics(os.path.join('/content/AITDM', f"client_{cid}", "val_metrics_ensemble.txt"), r, bd, ww)


if __name__ == "__main__":
    main()

In [None]:
!mkdir -p /content/drive/MyDrive/AITDM
!cp -rf /content/AITDM /content/drive/MyDrive/AITDM/TOPK_3

In [None]:
from google.colab import runtime
runtime.unassign()

In [None]:
!zip -r ./experiments.zip /content/drive/MyDrive/AITDM -x "*.pt"