In [None]:
# !pip install flwr
# !pip install monai
# !pip install kornia
# !pip install nibabel
# !pip install scikit-image

##README

This notebook is for training FL model with bench and GS transformed data for model performance comparison. We will train and save the models in this notebook to later use for model performance comparison and some of the model stealing attack that are targetting the trained model parameters or gradients.

If you don't want to run this, we have provided pre-trained models in the github repository under ```pretrained_models``` folder.

If you want to run this notebook please make sure you have done following:
1. You have run the ```1_Load_and_preprocess_Data``` and ```2_GS_Transform_data``` notebooks and save the original and transformed data in your repository.
2. Run the pip install commands at the above cell.
3. Modify the load and save paths according to your folder hierarchy in your workspace.

In [None]:
1from google.colab import drive
drive.mount('/content/drive')

## Imports and Functions

In [None]:
import os
import sys
import gc
import pickle
import logging
from typing import Tuple, List, Dict, Union, Optional
from collections import defaultdict
from dataclasses import dataclass
import threading
import time

# Environment configuration for GRPC (used by Flower)
os.environ["GRPC_MAX_RECEIVE_MESSAGE_LENGTH"] = "4000000000"
os.environ["GRPC_MAX_SEND_MESSAGE_LENGTH"] = "4000000000"
os.environ["GRPC_DEFAULT_COMPRESSION_ALGORITHM"] = "gzip"

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s'
)
logger = logging.getLogger(__name__)

# ----------------------------
# PyTorch and Federated Learning
# ----------------------------
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms as T
from torch.amp import autocast, GradScaler # For mixed precision training
# Flower for federated learning coordination
import flwr as fl

# ----------------------------
# Medical Imaging and Inversion Attack Utilities
# ----------------------------
# MONAI: for medical image processing and transformations
import monai

# Kornia: for differentiable computer vision operations (e.g., edge detection)
import kornia
import kornia.enhance
import kornia.losses
# nibabel: for reading/writing medical image formats (e.g., NIfTI)
import nibabel as nib

# ----------------------------
# Data Processing, Evaluation, and Visualization
# ----------------------------
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# scikit-image: for image quality metrics (e.g., SSIM)
from skimage.metrics import structural_similarity as ssim


In [None]:
## Medical Classifier
class BrainMRIClassifier(nn.Module):
    def __init__(self):
        super(BrainMRIClassifier, self).__init__()
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(32),
            # Second block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(64),
            # Third block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(128),
            # Fourth block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256),
            # Fifth block
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256),
        )
        # With input size 299x299 and 5 pooling layers, spatial dimensions ~9x9
        self.flat_features = 256 * 9 * 9
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(self.flat_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 4)  # 4 classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


In [None]:
## Training Loop Definition

def train_classifier_model(model, train_loader, device, epochs=1):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    total_samples = len(train_loader.dataset)

    for epoch in range(epochs):
        epoch_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            batch_size = images.size(0)
            epoch_loss += loss.item() * batch_size
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        avg_loss = epoch_loss / total
        accuracy = correct / total
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Accuracy={accuracy:.4f}")

    return avg_loss, accuracy

def evaluate_model(model, test_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            batch_size = images.size(0)
            total_loss += loss.item() * batch_size
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


In [None]:
## Federated Setup
import flwr as fl

class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, device, epochs=1):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.epochs = epochs

    # In FLClient class
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        # Load parameters into model.
        params_dict = dict(zip(self.model.state_dict().keys(), parameters))
        state_dict = {k: torch.tensor(v) for k, v in params_dict.items()}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        print("Starting local training...")
        train_loss, train_acc = train_classifier_model(self.model, self.train_loader, self.device, epochs=self.epochs)
        return self.get_parameters(), len(self.train_loader.dataset), {"loss": train_loss, "accuracy": train_acc}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        eval_loss, eval_acc = evaluate_model(self.model, self.test_loader, self.device)
        return float(eval_loss), len(self.test_loader.dataset), {"accuracy": eval_acc}


In [None]:
class FederatedLearningSimulator:
    def __init__(self, model, client_loaders, test_loader, device, num_rounds=5, epochs_per_round=2):
        self.model = model
        self.client_loaders = client_loaders
        self.test_loader = test_loader
        self.device = device
        self.num_rounds = num_rounds
        self.epochs_per_round = epochs_per_round
        self.global_parameters = None
        self.client_models = []  # Store client models for inversion attacks

    def initialize_global_model(self):
        global_model = self.model().to(self.device)
        self.global_parameters = [val.cpu().numpy() for _, val in global_model.state_dict().items()]

    def run_federated_learning(self):
        self.initialize_global_model()
        self.client_models = []  # Reset client models

        for round_num in range(self.num_rounds):
            print(f"\n===== Round {round_num+1}/{self.num_rounds} =====")
            client_parameters = []
            client_weights = []
            round_client_models = []  # Store client models for this round

            # Train on each client
            for client_idx, loader in enumerate(self.client_loaders):
                print(f"\nTraining client {client_idx+1}:")

                # Create new model instance for this client
                client_model = self.model().to(self.device)

                # Load global parameters
                params_dict = dict(zip(client_model.state_dict().keys(), self.global_parameters))
                state_dict = {k: torch.tensor(v) for k, v in params_dict.items()}
                client_model.load_state_dict(state_dict, strict=True)

                # Train the client model with per-epoch validation
                for epoch in range(self.epochs_per_round):
                    train_loss, train_acc = train_classifier_model(client_model, loader, self.device, epochs=1)
                    val_loss, val_acc = evaluate_model(client_model, self.test_loader, self.device)
                    print(f"  Epoch {epoch+1}/{self.epochs_per_round}: Train acc={train_acc:.4f}, Val acc={val_acc:.4f}")

                # Save client model (only from final round)
                if round_num == self.num_rounds - 1:
                    # Create a copy to avoid reference issues
                    client_copy = self.model().to(self.device)
                    client_copy.load_state_dict(client_model.state_dict())
                    round_client_models.append((client_idx, client_copy))

                # Collect client's updated parameters
                params = [val.cpu().numpy() for _, val in client_model.state_dict().items()]
                client_parameters.append(params)
                client_weights.append(len(loader.dataset))

            # Save client models from final round
            if round_num == self.num_rounds - 1:
                self.client_models = round_client_models

            # Perform weighted averaging to update global parameters
            self.global_parameters = self.federated_averaging(client_parameters, client_weights)

            # Evaluate global model after aggregation
            global_model = self.model().to(self.device)
            params_dict = dict(zip(global_model.state_dict().keys(), self.global_parameters))
            state_dict = {k: torch.tensor(v) for k, v in params_dict.items()}
            global_model.load_state_dict(state_dict, strict=True)

            eval_loss, eval_acc = evaluate_model(global_model, self.test_loader, self.device)
            print(f"\n>> Round {round_num+1} complete: Global model accuracy = {eval_acc:.4f}")

        return self.global_parameters

    def federated_averaging(self, client_parameters, client_weights):
        total_weight = sum(client_weights)
        weighted_params = []

        for i in range(len(client_parameters[0])):
            weighted_sum = sum(client_params[i] * weight for client_params, weight
                              in zip(client_parameters, client_weights))
            weighted_params.append(weighted_sum / total_weight)

        return weighted_params


In [None]:
def client_fn(cid: str) -> fl.client.NumPyClient:
    # Convert client id (string) to integer index
    index = int(cid)
    # Create a new instance of the classifier for this client
    model = BrainMRIClassifier().to(device)
    # Use the corresponding client DataLoader and common test_loader as validation
    return FLClient(model, client_loaders[index], test_loader, device, epochs=2)

## Data Loading Functions
def load_data(images_path, labels_path):
    with open(images_path, 'rb') as f:
        images = pickle.load(f)
    with open(labels_path, 'rb') as f:
        labels = pickle.load(f)
    return images, labels

def prepare_dataset(images, labels):
    # Convert images to NumPy array and then to a float tensor.
    images = torch.tensor(np.array(images)).float()
    # If images are grayscale and shaped (N, H, W), add channel dimension to get (N, 1, H, W)
    if images.ndim == 3:
        images = images.unsqueeze(1)
    # Convert labels to tensor (assuming integer encoding for CrossEntropyLoss)
    labels = torch.tensor(np.array(labels)).long()
    return torch.utils.data.TensorDataset(images, labels)

def save_model_states(global_model, client_models, key, save_dir="saved_models"):
    """
    Saves the global model state and each client model state to disk.

    Parameters:
      - global_model: the global model instance.
      - client_models: a list of tuples (client_idx, client_model) from fl_simulator.client_models.
      - key: the current root key (e.g., 'bench', 'mask_0', etc.).
      - save_dir: directory where models will be saved.

    Returns:
      A dictionary with paths to the saved global model and client models.
    """
    # Ensure the save directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Save global model state
    global_model_state = global_model.state_dict()
    global_model_path = os.path.join(save_dir, f"{key}_global_model.pt")
    torch.save(global_model_state, global_model_path)

    # Save each client model state and record paths in a dictionary
    client_paths = {}
    for client_idx, client_model in client_models:
        client_model_path = os.path.join(save_dir, f"{key}_client_{client_idx+1}_model.pt")
        torch.save(client_model.state_dict(), client_model_path)
        client_paths[f"client_{client_idx+1}"] = client_model_path

    return {
        "global": global_model_path,
        "clients": client_paths
    }

In [None]:
#####################################################
# Helper Function to Evaluate a Given Model on Test
#####################################################
def compute_metrics(model, loader, device):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='macro')
    cm = confusion_matrix(all_labels, all_preds)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)
    report = classification_report(all_labels, all_preds, output_dict=True)
    return {
        "accuracy": acc,
        "f1_score": f1,
        "per_class_accuracy": per_class_acc,
        "classification_report": report,
        "confusion_matrix": cm
    }

# Annotate bars with values
def autolabel(rects):
    for rect in rects:
        height = rect.get_height()
        ax.annotate(f'{height:.2f}',
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # offset text by 3 points vertically
                    textcoords="offset points",
                    ha='center', va='bottom')

def compute_confidence_metrics(model, loader, device):
    """
    Computes the mean and standard deviation of the model's confidence (i.e. max softmax probability)
    per true class over the test set.

    Returns:
      mean_conf: dict mapping class label (int) to mean confidence.
      std_conf: dict mapping class label (int) to standard deviation.
    """
    model.eval()
    confidences = {}
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)
            probs = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probs, dim=1)
            for i, label in enumerate(labels):
                lbl = int(label)
                if lbl not in confidences:
                    confidences[lbl] = []
                confidences[lbl].append(max_probs[i].item())
    mean_conf = {lbl: np.mean(vals) for lbl, vals in confidences.items()}
    std_conf = {lbl: np.std(vals) for lbl, vals in confidences.items()}
    return mean_conf, std_conf


## Main Execution Block

In [None]:
#@title Modify DATA_PATHS here!!
MASKED_DATA_PATHS = {
    "bench": {
        "clients": [
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client1_images.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client1_labels.pickle",  ## <-Replace
            },
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client2_images.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client2_labels.pickle",  ## <-Replace
            },
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client3_images.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client3_labels.pickle",  ## <-Replace
            }
        ],
        "test": {
            "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client_test_images.pickle",  ## <-Replace
            "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client_test_labels.pickle"  ## <-Replace
        }
    },
    "mask_20": {
        "clients": [
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client1_images_gs20p.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client1_labels.pickle",  ## <-Replace
            },
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client2_images_gs20p.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client2_labels.pickle",  ## <-Replace
            },
            {
                "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client3_images_gs20p.pickle",  ## <-Replace
                "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client3_labels.pickle",  ## <-Replace
            }
        ],
        "test": {
            "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client_test_images_gs20p.pickle",  ## <-Replace
            "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client_test_labels.pickle"  ## <-Replace
        }
    }#,
    ## You can add more data variant here if you want to test different masking settings with GS
    ## i.e.:
    # "mask_50": {
    #     "clients": [
    #         {
    #             "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client1_images_gs50p.pickle",  ## <-Replace
    #             "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client1_labels.pickle",  ## <-Replace
    #         },
    #         {
    #             "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client2_images_gs50p.pickle",  ## <-Replace
    #             "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client2_labels.pickle",  ## <-Replace
    #         },
    #         {
    #             "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client3_images_gs50p.pickle",  ## <-Replace
    #             "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client3_labels.pickle",  ## <-Replace
    #         }
    #     ],
    #     "test": {
    #         "images": "/content/drive/MyDrive/Spring 25/github_brainfl/data/gs/client_test_images_gs50p.pickle",  ## <-Replace
    #         "labels": "/content/drive/MyDrive/Spring 25/github_brainfl/data/bench/client_test_labels.pickle"  ## <-Replace
    #     }
    # }
}



In [None]:
#####################################################
# Loop Over Masking Settings and Run Simulation
#####################################################
# Dictionary to store evaluation metrics for each setting

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
results = {}

for mask_name, paths in MASKED_DATA_PATHS.items():
    print(f"\n=== Evaluating for Masking Setting: {mask_name} ===\n")

    # 1. Load Test Data
    test_images, test_labels = load_data(paths["test"]["images"], paths["test"]["labels"])
    test_dataset = prepare_dataset(test_images, test_labels)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 1. Load Client Data for 3 Clients
    client_datasets = []
    for client_info in paths["clients"]:
        imgs, labs = load_data(client_info["images"], client_info["labels"])
        ds = prepare_dataset(imgs, labs)
        client_datasets.append(ds)
    client_loaders = [DataLoader(ds, batch_size=32, shuffle=True) for ds in client_datasets]

    # 2. Create and run the federated learning simulation
    fl_simulator = FederatedLearningSimulator(
        model=BrainMRIClassifier,
        client_loaders=client_loaders,
        test_loader=test_loader,
        device=device,
        num_rounds=10,       # Adjustable number of rounds
        epochs_per_round=5   # Adjustable epochs per round
    )

    print("Starting federated learning simulation...")
    global_parameters = fl_simulator.run_federated_learning()
    print("Federated training simulation completed.")

    # Evaluate the final global model
    global_model = BrainMRIClassifier().to(device)
    params_dict = dict(zip(global_model.state_dict().keys(), global_parameters))
    state_dict = {k: torch.tensor(v) for k, v in params_dict.items()}
    global_model.load_state_dict(state_dict, strict=True)
    global_metrics = compute_metrics(global_model, test_loader, device)

    # Compute model confidence metrics (mean and std per class)
    global_mean_conf, global_std_conf = compute_confidence_metrics(global_model, test_loader, device)

    # Evaluate each client model saved from the final round
    client_metrics = {}
    for client_idx, client_model in fl_simulator.client_models:
        metrics = compute_metrics(client_model, test_loader, device)
        client_metrics[f"client_{client_idx+1}"] = metrics

    # Save metrics for this masking setting
    results[mask_name] = {
        "global": global_metrics,
        "clients": client_metrics,
        "confidence": {"mean": global_mean_conf, "std": global_std_conf}
    }

    # Save the models for this masking setting and record the file paths in the results
    model_save_info = save_model_states(global_model, fl_simulator.client_models, mask_name)
    results[mask_name]["model_paths"] = model_save_info

    # 4. Clear memory: delete models, loaders, and force garbage collection
    del global_model, fl_simulator, client_datasets, client_loaders, test_dataset, test_loader
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    time.sleep(1)  # optional pause between iterations

# Optionally, save the results dictionary (with evaluation metrics and model paths) to disk
with open("federated_results.pkl", "wb") as f:
    pickle.dump(results, f)


In [None]:
#@title Visualize Model Performance w/ Heatmaps

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Figure 1: F1 Score Heatmap
# ---------------------------------
# Get dataset variants and class labels (assumed to be numeric strings)
dataset_variants = list(results.keys())
example_report = results[dataset_variants[0]]["global"]["classification_report"]
# Filter keys that are digits (i.e. class labels) and sort them
class_labels = sorted([k for k in example_report.keys() if k.isdigit()], key=lambda x: int(x))
class_labels_int = [int(x) for x in class_labels]

# Build a 2D array: rows: dataset variant, columns: per-class F1 scores
f1_matrix = np.zeros((len(dataset_variants), len(class_labels_int)))
for i, variant in enumerate(dataset_variants):
    report = results[variant]["global"]["classification_report"]
    for j, cls in enumerate(class_labels):
        f1_matrix[i, j] = report[cls]["f1-score"]

plt.figure(figsize=(10, 6))
sns.heatmap(f1_matrix, annot=True, fmt=".2f", cmap="YlGnBu",
            xticklabels=class_labels_int, yticklabels=dataset_variants)
plt.xlabel("Class Label")
plt.ylabel("Dataset Variant")
plt.title("Global Model Per-Class F1 Scores")
plt.tight_layout()
plt.show()

# Figure 2: Model Confidence per Class (Facet Grid)
# ---------------------------------
# Prepare data: For each dataset variant, we already have the confidence metrics in results.
# For each class, we want to collect the mean and std for each variant.
n_classes = len(class_labels_int)
# For each class, create lists for means and stds across dataset variants.
confidence_data = {cls: {"mean": [], "std": []} for cls in class_labels_int}
for variant in dataset_variants:
    conf = results[variant]["confidence"]["mean"]
    conf_std = results[variant]["confidence"]["std"]
    for cls in class_labels_int:
        # Append mean confidence and std (use np.nan if not present)
        confidence_data[cls]["mean"].append(conf.get(cls, np.nan))
        confidence_data[cls]["std"].append(conf_std.get(cls, np.nan))

# Create a subplot for each class
fig, axs = plt.subplots(1, n_classes, figsize=(5 * n_classes, 5), sharey=True)
if n_classes == 1:
    axs = [axs]  # Ensure axs is iterable when only one subplot

x = np.arange(len(dataset_variants))
for i, cls in enumerate(class_labels_int):
    ax = axs[i]
    means = np.array(confidence_data[cls]["mean"])
    stds = np.array(confidence_data[cls]["std"])
    ax.errorbar(x, means, yerr=stds, fmt='o-', capsize=5)
    ax.set_xticks(x)
    ax.set_xticklabels(dataset_variants, rotation=45, ha="right")
    ax.set_title(f"Class {cls}")
    ax.set_xlabel("Dataset Variant")
    if i == 0:
        ax.set_ylabel("Mean Confidence (Softmax Probability)")
    ax.grid(True, linestyle='--', alpha=0.5)

fig.suptitle("Global Model Confidence per Class Across Dataset Variants", fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
