##README

In this notebook, we will be testing our pre-trained model performances with the bench and GS transformed versions of the data and compare the results in terms of accuracy loss.

For the sake of this notebook we will be using the pretrained models under the pretrianed_model repository. If you would prefer to re-trained the model with the transformed data, please contact with the corresponding author for the FL simulation framework.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Imports and Functions

In [2]:
import os
import torch
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from tqdm import tqdm
import torch.nn as nn

# === CONFIGURATION ===
SAVE_PATH = "/content/drive/MyDrive/Spring 25/github_brainfl/results/accuracy"  # <-- Replace with your save directory
TEST_IMAGE_PATH = "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_images.pickle"  # <-- Replace
TEST_LABEL_PATH = "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_labels.pickle"  # <-- Replace

VARIANT_CONFIGS = {
    "bench": {
        "global_model_path": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/global_model.pth",# <-- Replace
        "client_model_paths": {
            "client_1": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_0_model.pth",# <-- Replace
            "client_2": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_1_model.pth",# <-- Replace
            "client_3": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_2_model.pth",# <-- Replace
        },
        "test_image_path": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_images.pickle",# <-- Replace
        "test_label_path": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_labels.pickle"# <-- Replace
    },
    "mask_20": {
        "global_model_path": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/global_model_gs20p.pth",# <-- Replace
        "client_model_paths": {
            "client_1": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_0_model_gs20p.pth",# <-- Replace
            "client_2": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_1_model_gs20p.pth",# <-- Replace
            "client_3": "/content/drive/MyDrive/Spring 25/github_brainfl/pretrained_models/client_2_model_gs20p.pth",# <-- Replace
        },
        "test_image_path": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/test_images_gs20p.pickle",# <-- Replace
        "test_label_path": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_labels.pickle"# <-- Replace
    }
    ## You can add more variants here to broaden your model evaluations
    ## To Do, you will need to GS transform the original data with desired settings, train a model,
    ## and save the trained model paths to your directory
}

# === MODEL ===
class BrainMRIClassifier(torch.nn.Module):
    def __init__(self):
        super(BrainMRIClassifier, self).__init__()
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.BatchNorm2d(32),
            torch.nn.Conv2d(32, 64, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.BatchNorm2d(64),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.BatchNorm2d(128),
            torch.nn.Conv2d(128, 256, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.BatchNorm2d(256),
            torch.nn.Conv2d(256, 256, kernel_size=3, padding=1), torch.nn.ReLU(), torch.nn.MaxPool2d(2), torch.nn.BatchNorm2d(256),
        )
        self.flat_features = 256 * 9 * 9
        self.classifier = torch.nn.Sequential(
            torch.nn.Flatten(), torch.nn.Dropout(0.5),
            torch.nn.Linear(self.flat_features, 512), torch.nn.ReLU(), torch.nn.Dropout(0.5),
            torch.nn.Linear(512, 256), torch.nn.ReLU(), torch.nn.Dropout(0.5),
            torch.nn.Linear(256, 4)
        )
    def forward(self, x):
        return self.classifier(self.features(x))

# === UTILS ===
def load_data(images_path, labels_path):
    with open(images_path, 'rb') as f: images = pickle.load(f)
    with open(labels_path, 'rb') as f: labels = pickle.load(f)
    images = torch.tensor(np.array(images)).float()
    if images.ndim == 3: images = images.unsqueeze(1)
    labels = torch.tensor(np.array(labels)).long()
    return torch.utils.data.TensorDataset(images, labels)

def compute_metrics(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, preds = torch.max(out, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average="macro")
    cm = confusion_matrix(all_labels, all_preds)
    return {"accuracy": acc, "f1_score": f1, "confusion_matrix": cm}

def create_heatmap(df, metric_prefix, output_file, title):
    import matplotlib.colors as mcolors
    import matplotlib.patches as patches

    rows = df["Variant"]

    # Construct correct column names
    if metric_prefix in ["F1", "Acc"]:
        cols = [f"{k}_{metric_prefix}" for k in ["Global", "Client 1", "Client 2", "Client 3"]]
    else:
        raise ValueError(f"Unsupported metric_prefix: {metric_prefix}")

    data = df[cols].values
    normed_data = np.zeros_like(data)

    # Normalize per column
    for j in range(data.shape[1]):
        col = data[:, j]
        col_min, col_max = col.min(), col.max()
        if col_max > col_min:
            normed_data[:, j] = (col - col_min) / (col_max - col_min)
        else:
            normed_data[:, j] = 0.5  # neutral

    fig, ax = plt.subplots(figsize=(12, 8))

    cmap = plt.get_cmap("RdYlGn")
    im = ax.imshow(normed_data, cmap=cmap)

    # Set ticks and labels
    ax.set_xticks(np.arange(len(cols)))
    ax.set_yticks(np.arange(len(rows)))
    ax.set_xticklabels(["Global", "Client 1", "Client 2", "Client 3"])
    ax.set_yticklabels(rows)

    # Draw black grid lines
    for i in range(len(rows) + 1):
        ax.axhline(i - 0.5, color='black', linewidth=1)
    for j in range(len(cols) + 1):
        ax.axvline(j - 0.5, color='black', linewidth=1)

    # Add cell text
    for i in range(len(rows)):
        for j in range(len(cols)):
            ax.text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="black", fontsize=11)

    ax.set_title(title, fontsize=16, pad=20)
    plt.tight_layout()
    plt.savefig(output_file, dpi=300)
    plt.close()


# === MAIN EXECUTION ===
def evaluate_and_plot():
    os.makedirs(SAVE_PATH, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rows = []

    print("\n🔍 Starting model evaluations...\n")
    for variant in tqdm(VARIANT_CONFIGS.keys(), desc="Evaluating variants"):
        paths = VARIANT_CONFIGS[variant]
        row = {"Variant": variant}

        tqdm.write(f"📁 Loading test data for variant: {variant}")
        test_dataset = load_data(paths["test_image_path"], paths["test_label_path"])
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        tqdm.write(f"🔄 Evaluating global model for {variant}")
        model = BrainMRIClassifier().to(device)
        model.load_state_dict(torch.load(paths["global_model_path"], map_location=device))
        metrics = compute_metrics(model, test_loader, device)
        row["Global_F1"] = metrics["f1_score"]
        row["Global_Acc"] = metrics["accuracy"]

        for i in range(1, 4):
            key = f"client_{i}"
            tqdm.write(f"🔄 Evaluating {variant} - Client {i}")
            model = BrainMRIClassifier().to(device)
            model.load_state_dict(torch.load(paths["client_model_paths"][key], map_location=device))
            metrics = compute_metrics(model, test_loader, device)
            row[f"Client {i}_F1"] = metrics["f1_score"]
            row[f"Client {i}_Acc"] = metrics["accuracy"]

        rows.append(row)

    df = pd.DataFrame(rows)
    df.to_csv(os.path.join(SAVE_PATH, "multi_variant_metrics.csv"), index=False)
    print(df.head())

    tqdm.write("📊 Generating heatmaps...")
    create_heatmap(df, "F1", os.path.join(SAVE_PATH, "f1_score_comparison.png"), "F1 Score Comparison")
    create_heatmap(df, "Acc", os.path.join(SAVE_PATH, "accuracy_comparison.png"), "Accuracy Score Comparison")

    print("\n✅ All evaluations and plots successfully saved to:", SAVE_PATH)

In [3]:
# Run
if __name__ == "__main__":
    evaluate_and_plot()


🔍 Starting model evaluations...



Evaluating variants:   0%|          | 0/2 [00:00<?, ?it/s]

📁 Loading test data for variant: bench


Evaluating variants:   0%|          | 0/2 [00:05<?, ?it/s]

🔄 Evaluating global model for bench


Evaluating variants:   0%|          | 0/2 [00:10<?, ?it/s]

🔄 Evaluating bench - Client 1


Evaluating variants:   0%|          | 0/2 [00:14<?, ?it/s]

🔄 Evaluating bench - Client 2


Evaluating variants:   0%|          | 0/2 [00:18<?, ?it/s]

🔄 Evaluating bench - Client 3


Evaluating variants:  50%|█████     | 1/2 [00:23<00:23, 23.19s/it]

📁 Loading test data for variant: mask_20


Evaluating variants:  50%|█████     | 1/2 [00:34<00:23, 23.19s/it]

🔄 Evaluating global model for mask_20


Evaluating variants:  50%|█████     | 1/2 [00:38<00:23, 23.19s/it]

🔄 Evaluating mask_20 - Client 1


Evaluating variants:  50%|█████     | 1/2 [00:43<00:23, 23.19s/it]

🔄 Evaluating mask_20 - Client 2


Evaluating variants:  50%|█████     | 1/2 [00:48<00:23, 23.19s/it]

🔄 Evaluating mask_20 - Client 3


Evaluating variants: 100%|██████████| 2/2 [00:52<00:00, 26.31s/it]


   Variant  Global_F1  Global_Acc  Client 1_F1  Client 1_Acc  Client 2_F1  \
0    bench   0.945754       0.948     0.897482         0.908     0.954442   
1  mask_20   0.969170       0.971     0.919536         0.926     0.957561   

   Client 2_Acc  Client 3_F1  Client 3_Acc  
0         0.957     0.948827         0.951  
1         0.959     0.944202         0.948  
📊 Generating heatmaps...

✅ All evaluations and plots successfully saved to: /content/drive/MyDrive/Spring 25/github_brainfl/results/accuracy
