In [1]:
!pip install -q kaggle timm h5py torchmetrics tqdm seaborn matplotlib scikit-learn torchinfo

import os
import json
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
    MulticlassAUROC,
    MulticlassMatthewsCorrCoef
)

from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.6/2.6 MB[0m [31m95.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m48.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m49.8 MB/s[0m eta [36m0:00:00[0m
[?25hUsing device: cpu


In [2]:
KAGGLE_USERNAME = "YOUR_USERNAME"
KAGGLE_KEY = "YOUR_KEY"

os.makedirs("/root/.config/kaggle", exist_ok=True)
with open("/root/.config/kaggle/kaggle.json", "w") as f:
    json.dump({"username": KAGGLE_USERNAME, "key": KAGGLE_KEY}, f)
os.chmod("/root/.config/kaggle/kaggle.json", 0o600)

from kaggle.api.kaggle_api_extended import KaggleApi

DATA_DIR = "brain_tumor_dataset"
DATASET = "ashkhagan/figshare-brain-tumor-dataset"

if not os.path.exists(DATA_DIR):
    api = KaggleApi()
    api.authenticate()
    api.dataset_download_files(DATASET, path=DATA_DIR, unzip=True)


Dataset URL: https://www.kaggle.com/datasets/ashkhagan/figshare-brain-tumor-dataset


In [3]:
mat_files = []
for root, _, files in os.walk(DATA_DIR):
    for f in files:
        if f.endswith(".mat") and "cvind" not in f.lower():
            mat_files.append(os.path.join(root, f))

paths, labels = [], []

for f in mat_files:
    with h5py.File(f, "r") as mat:
        if "label" in mat:
            label = int(mat["label"][()][0][0])
        else:
            label = int(mat["cjdata"]["label"][()][0][0])
    paths.append(f)
    labels.append(label - 1)

df = pd.DataFrame({"image_path": paths, "label": labels})
num_classes = df["label"].nunique()

train_df, test_df = train_test_split(df, stratify=df.label, test_size=0.2, random_state=42)
train_df, val_df  = train_test_split(train_df, stratify=train_df.label, test_size=0.2, random_state=42)


In [4]:
class BrainTumorDataset(Dataset):
    def __init__(self, df, transform):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.loc[idx, "image_path"]
        label = self.df.loc[idx, "label"]

        with h5py.File(path, "r") as mat:
            if "image" in mat:
                img = mat["image"][()]
            else:
                img = mat["cjdata"]["image"][()]

        img = np.array(img).T
        img = np.stack([img]*3, axis=-1).astype(np.uint8)
        img = self.transform(img)

        return img, label


In [5]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize((0.485,0.456,0.406),
                         (0.229,0.224,0.225))
])

train_loader = DataLoader(BrainTumorDataset(train_df, transform), batch_size=8, shuffle=True)
val_loader   = DataLoader(BrainTumorDataset(val_df, transform), batch_size=8)
test_loader  = DataLoader(BrainTumorDataset(test_df, transform), batch_size=8)


In [6]:
model = timm.create_model(
    "inception_v3",
    pretrained=True,
    num_classes=num_classes,
    aux_logits=True
)

model = model.to(device)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/95.5M [00:00<?, ?B/s]

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)

EPOCHS = 30
patience = 5

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=EPOCHS
)


In [8]:
def forward_pass(model, x):
    outputs = model(x)

    if isinstance(outputs, tuple):
        main_out = outputs[0]
        aux_out = outputs[1]
        return main_out, aux_out
    else:
        return outputs, None


In [9]:
from torchmetrics.classification import (
    MulticlassAccuracy,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassF1Score,
    MulticlassAUROC,
    MulticlassMatthewsCorrCoef
)


In [10]:
def get_metrics():
    return {
        "acc": MulticlassAccuracy(num_classes=num_classes).to(device),
        "precision": MulticlassPrecision(num_classes=num_classes, average="macro").to(device),
        "recall": MulticlassRecall(num_classes=num_classes, average="macro").to(device),
        "f1": MulticlassF1Score(num_classes=num_classes, average="macro").to(device),
        "auc": MulticlassAUROC(num_classes=num_classes).to(device),
        "mcc": MulticlassMatthewsCorrCoef(num_classes=num_classes).to(device)
    }

def multiclass_specificity_sensitivity(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    spec, sens = [], []

    for i in range(num_classes):
        tp = cm[i,i]
        fn = cm[i,:].sum() - tp
        fp = cm[:,i].sum() - tp
        tn = cm.sum() - (tp+fp+fn)

        spec.append(tn/(tn+fp+1e-8))
        sens.append(tp/(tp+fn+1e-8))

    return float(np.mean(spec)), float(np.mean(sens))


In [None]:
history = {
    "loss": [], "val_loss": [],
    "accuracy": [], "val_accuracy": [],
    "precision": [], "val_precision": [],
    "recall": [], "val_recall": [],
    "f1": [], "val_f1": [],
    "specificity": [], "val_specificity": [],
    "sensitivity": [], "val_sensitivity": [],
    "mcc": [], "val_mcc": [],
    "auc": [], "val_auc": []
}

best_val_loss = float("inf")
counter = 0

for epoch in range(EPOCHS):

    # ================= TRAIN =================
    model.train()
    train_loss = 0
    train_preds, train_targets = [], []
    train_metrics = get_metrics()

    pbar = tqdm(train_loader,
                desc=f"Training Model - Epoch [{epoch+1}/{EPOCHS}]")

    for x,y in pbar:
        x,y = x.to(device), y.to(device)

        optimizer.zero_grad()

        outputs = model(x)

        # Handle tuple (Inception safety)
        if isinstance(outputs, tuple):
            main_out, aux_out = outputs
            loss = criterion(main_out, y) + 0.4 * criterion(aux_out, y)
            out = main_out
        else:
            loss = criterion(outputs, y)
            out = outputs

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        preds = out.argmax(1)
        train_preds.append(preds.cpu())
        train_targets.append(y.cpu())

        for m in train_metrics.values():
            m.update(out, y)

    train_loss /= len(train_loader)

    train_preds = torch.cat(train_preds).numpy()
    train_targets = torch.cat(train_targets).numpy()

    train_spec, train_sens = multiclass_specificity_sensitivity(
        train_targets, train_preds
    )

    train_out = {k:v.compute().item() for k,v in train_metrics.items()}

    print(f"\n[TRAIN] loss: {train_loss:.4f}, "
          f"accuracy: {train_out['acc']:.4f}, "
          f"precision: {train_out['precision']:.4f}, "
          f"recall: {train_out['recall']:.4f}, "
          f"f1_score: {train_out['f1']:.4f}, "
          f"specificity: {train_spec:.4f}, "
          f"sensitivity: {train_sens:.4f}, "
          f"mcc: {train_out['mcc']:.4f}, "
          f"auc: {train_out['auc']:.4f}")

    # ===== STORE TRAIN HISTORY =====
    history["loss"].append(train_loss)
    history["accuracy"].append(train_out["acc"])
    history["precision"].append(train_out["precision"])
    history["recall"].append(train_out["recall"])
    history["f1"].append(train_out["f1"])
    history["specificity"].append(train_spec)
    history["sensitivity"].append(train_sens)
    history["mcc"].append(train_out["mcc"])
    history["auc"].append(train_out["auc"])

    for m in train_metrics.values():
        m.reset()

    # ================= VALID =================
    model.eval()
    val_loss = 0
    val_preds, val_targets = [], []
    val_metrics = get_metrics()

    pbar = tqdm(val_loader, desc="Validating Model")

    with torch.no_grad():
        for x,y in pbar:
            x,y = x.to(device), y.to(device)

            outputs = model(x)

            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, y)
            val_loss += loss.item()

            preds = outputs.argmax(1)
            val_preds.append(preds.cpu())
            val_targets.append(y.cpu())

            for m in val_metrics.values():
                m.update(outputs, y)

    val_loss /= len(val_loader)

    val_preds = torch.cat(val_preds).numpy()
    val_targets = torch.cat(val_targets).numpy()

    val_spec, val_sens = multiclass_specificity_sensitivity(
        val_targets, val_preds
    )

    val_out = {k:v.compute().item() for k,v in val_metrics.items()}

    print(f"[VAL]   loss: {val_loss:.4f}, "
          f"accuracy: {val_out['acc']:.4f}, "
          f"precision: {val_out['precision']:.4f}, "
          f"recall: {val_out['recall']:.4f}, "
          f"f1_score: {val_out['f1']:.4f}, "
          f"specificity: {val_spec:.4f}, "
          f"sensitivity: {val_sens:.4f}, "
          f"mcc: {val_out['mcc']:.4f}, "
          f"auc: {val_out['auc']:.4f}")

    print(f"Current LR: {optimizer.param_groups[0]['lr']:.6e}\n")

    # ===== STORE VAL HISTORY =====
    history["val_loss"].append(val_loss)
    history["val_accuracy"].append(val_out["acc"])
    history["val_precision"].append(val_out["precision"])
    history["val_recall"].append(val_out["recall"])
    history["val_f1"].append(val_out["f1"])
    history["val_specificity"].append(val_spec)
    history["val_sensitivity"].append(val_sens)
    history["val_mcc"].append(val_out["mcc"])
    history["val_auc"].append(val_out["auc"])

    scheduler.step()

    for m in val_metrics.values():
        m.reset()

    # ===== EARLY STOPPING =====
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break


Training Model - Epoch [1/30]: 100%|██████████| 245/245 [09:24<00:00,  2.30s/it]



[TRAIN] loss: 0.7864, accuracy: 0.7701, precision: 0.7705, recall: 0.7701, f1_score: 0.7700, specificity: 0.8925, sensitivity: 0.7701, mcc: 0.6733, auc: 0.9097


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.22it/s]


[VAL]   loss: 0.4984, accuracy: 0.8336, precision: 0.8257, recall: 0.8336, f1_score: 0.8292, specificity: 0.9214, sensitivity: 0.8336, mcc: 0.7563, auc: 0.9463
Current LR: 5.000000e-04



Training Model - Epoch [2/30]: 100%|██████████| 245/245 [09:08<00:00,  2.24s/it]



[TRAIN] loss: 0.4808, accuracy: 0.8559, precision: 0.8518, recall: 0.8559, f1_score: 0.8537, specificity: 0.9333, sensitivity: 0.8559, mcc: 0.7936, auc: 0.9602


Validating Model: 100%|██████████| 62/62 [00:51<00:00,  1.20it/s]


[VAL]   loss: 0.3790, accuracy: 0.8529, precision: 0.8605, recall: 0.8529, f1_score: 0.8484, specificity: 0.9235, sensitivity: 0.8529, mcc: 0.7757, auc: 0.9706
Current LR: 4.986305e-04



Training Model - Epoch [3/30]: 100%|██████████| 245/245 [09:08<00:00,  2.24s/it]



[TRAIN] loss: 0.3839, accuracy: 0.8894, precision: 0.8839, recall: 0.8894, f1_score: 0.8865, specificity: 0.9474, sensitivity: 0.8894, mcc: 0.8377, auc: 0.9752


Validating Model: 100%|██████████| 62/62 [00:51<00:00,  1.20it/s]


[VAL]   loss: 0.3001, accuracy: 0.8881, precision: 0.8768, recall: 0.8881, f1_score: 0.8792, specificity: 0.9407, sensitivity: 0.8881, mcc: 0.8192, auc: 0.9735
Current LR: 4.945369e-04



Training Model - Epoch [4/30]: 100%|██████████| 245/245 [09:03<00:00,  2.22s/it]



[TRAIN] loss: 0.2785, accuracy: 0.9193, precision: 0.9122, recall: 0.9193, f1_score: 0.9155, specificity: 0.9608, sensitivity: 0.9193, mcc: 0.8785, auc: 0.9861


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.23it/s]


[VAL]   loss: 0.5839, accuracy: 0.8391, precision: 0.8460, recall: 0.8391, f1_score: 0.8182, specificity: 0.9209, sensitivity: 0.8391, mcc: 0.7542, auc: 0.9623
Current LR: 4.877641e-04



Training Model - Epoch [5/30]: 100%|██████████| 245/245 [09:01<00:00,  2.21s/it]



[TRAIN] loss: 0.2528, accuracy: 0.9291, precision: 0.9257, recall: 0.9291, f1_score: 0.9273, specificity: 0.9660, sensitivity: 0.9291, mcc: 0.8955, auc: 0.9892


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.23it/s]


[VAL]   loss: 0.2326, accuracy: 0.8889, precision: 0.8983, recall: 0.8889, f1_score: 0.8930, specificity: 0.9470, sensitivity: 0.8889, mcc: 0.8462, auc: 0.9829
Current LR: 4.783864e-04



Training Model - Epoch [6/30]: 100%|██████████| 245/245 [09:04<00:00,  2.22s/it]



[TRAIN] loss: 0.2048, accuracy: 0.9420, precision: 0.9409, recall: 0.9420, f1_score: 0.9414, specificity: 0.9738, sensitivity: 0.9420, mcc: 0.9185, auc: 0.9929


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.22it/s]


[VAL]   loss: 0.4176, accuracy: 0.8888, precision: 0.8807, recall: 0.8888, f1_score: 0.8667, specificity: 0.9444, sensitivity: 0.8888, mcc: 0.8211, auc: 0.9860
Current LR: 4.665064e-04



Training Model - Epoch [7/30]: 100%|██████████| 245/245 [08:53<00:00,  2.18s/it]



[TRAIN] loss: 0.1254, accuracy: 0.9656, precision: 0.9661, recall: 0.9656, f1_score: 0.9659, specificity: 0.9842, sensitivity: 0.9656, mcc: 0.9520, auc: 0.9966


Validating Model: 100%|██████████| 62/62 [00:51<00:00,  1.20it/s]


[VAL]   loss: 0.1894, accuracy: 0.9110, precision: 0.9122, recall: 0.9110, f1_score: 0.9116, specificity: 0.9570, sensitivity: 0.9110, mcc: 0.8721, auc: 0.9875
Current LR: 4.522542e-04



Training Model - Epoch [8/30]: 100%|██████████| 245/245 [09:03<00:00,  2.22s/it]



[TRAIN] loss: 0.1334, accuracy: 0.9678, precision: 0.9671, recall: 0.9678, f1_score: 0.9674, specificity: 0.9852, sensitivity: 0.9678, mcc: 0.9544, auc: 0.9966


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.22it/s]


[VAL]   loss: 0.4659, accuracy: 0.7891, precision: 0.8072, recall: 0.7891, f1_score: 0.7949, specificity: 0.8978, sensitivity: 0.7891, mcc: 0.6974, auc: 0.9450
Current LR: 4.357862e-04



Training Model - Epoch [9/30]: 100%|██████████| 245/245 [09:02<00:00,  2.21s/it]



[TRAIN] loss: 0.1326, accuracy: 0.9648, precision: 0.9607, recall: 0.9648, f1_score: 0.9627, specificity: 0.9835, sensitivity: 0.9648, mcc: 0.9474, auc: 0.9966


Validating Model: 100%|██████████| 62/62 [00:50<00:00,  1.22it/s]


[VAL]   loss: 0.2311, accuracy: 0.8962, precision: 0.9029, recall: 0.8962, f1_score: 0.8992, specificity: 0.9511, sensitivity: 0.8962, mcc: 0.8558, auc: 0.9830
Current LR: 4.172827e-04



Training Model - Epoch [10/30]:  60%|██████    | 148/245 [05:24<03:32,  2.19s/it]

In [None]:
# Load best model
model.load_state_dict(torch.load("best_model.pth"))
model.eval()


In [1]:
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

all_preds, all_targets = [], []
test_loss = 0

with torch.no_grad():
    for x,y in tqdm(test_loader, desc="Testing"):
        x,y = x.to(device), y.to(device)

        outputs = model(x)

        if isinstance(outputs, tuple):
            outputs = outputs[0]

        loss = criterion(outputs, y)
        test_loss += loss.item()

        preds = outputs.argmax(1)

        all_preds.append(preds.cpu())
        all_targets.append(y.cpu())

test_loss /= len(test_loader)

print("\nTest Loss:", test_loss)


NameError: name 'model' is not defined

In [None]:
all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

print(classification_report(all_targets, all_preds, digits=4))


In [None]:
test_metrics = get_metrics()
test_loss = 0
all_preds, all_targets = [], []

with torch.no_grad():
    for x, y in tqdm(test_loader, desc="Testing"):
        x, y = x.to(device), y.to(device)

        out = model(x)
        loss = criterion(out, y)
        test_loss += loss.item()

        preds = out.argmax(1)

        all_preds.append(preds.cpu())
        all_targets.append(y.cpu())

        for m in test_metrics.values():
            m.update(out, y)

test_loss /= len(test_loader)

print("\n===== TEST RESULTS =====")
print(f"Test Loss: {test_loss:.4f}")

for k, v in test_metrics.items():
    print(f"{k.upper()}: {v.compute().item():.4f}")


In [None]:
all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

print("\n===== CLASSIFICATION REPORT =====\n")
print(classification_report(all_targets, all_preds, digits=4))


In [None]:
cm = confusion_matrix(all_targets, all_preds)

plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()


In [None]:
def plot_metric(train, val, title, ylabel):
    epochs = range(1, len(train) + 1)
    plt.figure(figsize=(6,4))
    plt.plot(epochs, train, label="Training")
    plt.plot(epochs, val, label="Validation")
    plt.title(title)
    plt.xlabel("Epochs")
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.show()


In [None]:
plot_metric(history["loss"], history["val_loss"], "Loss vs Epochs", "Loss")
plot_metric(history["accuracy"], history["val_accuracy"], "Accuracy vs Epochs", "Accuracy")
plot_metric(history["precision"], history["val_precision"], "Precision vs Epochs", "Precision")
plot_metric(history["recall"], history["val_recall"], "Recall vs Epochs", "Recall")
plot_metric(history["f1"], history["val_f1"], "F1 Score vs Epochs", "F1-score")
plot_metric(history["specificity"], history["val_specificity"], "Specificity vs Epochs", "Specificity")
plot_metric(history["sensitivity"], history["val_sensitivity"], "Sensitivity vs Epochs", "Sensitivity")
plot_metric(history["mcc"], history["val_mcc"], "MCC vs Epochs", "MCC")
plot_metric(history["auc"], history["val_auc"], "AUC vs Epochs", "AUC")
