In [2]:
!unzip -q '/content/drive/MyDrive/split_5-fold.zip' -d '/content/drive/MyDrive/braintumor'

In [4]:
# ===============================
# ✅ SETUP: Install Dependencies
# ===============================
!pip install -q timm albumentations scikit-learn

# ===============================
# ✅ Mount Google Drive
# ===============================
from google.colab import drive
drive.mount('/content/drive')

# ===============================
# ✅ Imports
# ===============================
import os
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, datasets
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.utils.class_weight import compute_class_weight
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
from timm import create_model
import matplotlib.pyplot as plt
import seaborn as sns

# ===============================
# ✅ Cross-Attention Module
# ===============================
class CrossAttention(nn.Module):
    def __init__(self, dim_q, dim_kv, dim_out):
        super().__init__()
        self.query_proj = nn.Linear(dim_q, dim_out)
        self.key_proj = nn.Linear(dim_kv, dim_out)
        self.value_proj = nn.Linear(dim_kv, dim_out)
        self.scale = dim_out ** -0.5
        self.dropout = nn.Dropout(0.2)

    def forward(self, x1, x2):
        if x2.dim() == 3:
            x2 = x2.mean(dim=1)
        q = self.query_proj(x1).unsqueeze(1)
        k = self.key_proj(x2).unsqueeze(1)
        v = self.value_proj(x2).unsqueeze(1)
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        return x1 + self.dropout(torch.matmul(attn, v).squeeze(1))

# ===============================
# ✅ Feature Extractors
# ===============================
def create_resnet():
    resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    for param in list(resnet.parameters())[:6]:
        param.requires_grad = False
    return nn.Sequential(*list(resnet.children())[:-1], nn.Flatten())

def create_swin():
    swin = create_model('swin_tiny_patch4_window7_224', pretrained=True)
    swin.head = nn.Identity()
    class SwinWrapper(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            return self.model.forward_features(x).mean(dim=1)
    return SwinWrapper(swin)

# ===============================
# ✅ Main Hybrid Model
# ===============================
class TumorClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = create_resnet()
        self.swin = create_swin()
        self.norm_r = nn.LayerNorm(512)
        self.norm_s = nn.LayerNorm(768)
        self.cross_attn = CrossAttention(512, 768, 512)
        self.classifier = nn.Sequential(
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 3)
        )

    def forward(self, x):
        res_feat = self.norm_r(self.resnet(x))
        swin_feat = self.norm_s(self.swin(x))
        fused = self.cross_attn(res_feat, swin_feat)
        return self.classifier(fused)

# ===============================
# ✅ Albumentations Dataset
# ===============================
class AlbumentationsDataset(Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = np.array(img)
        img = self.transform(image=img)['image']
        return img, label
    def __len__(self):
        return len(self.dataset)

def get_transform(train=True):
    return A.Compose([
        A.Resize(224, 224),
        A.HorizontalFlip(p=0.5) if train else A.NoOp(),
        A.RandomBrightnessContrast(p=0.2) if train else A.NoOp(),
        A.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
        ToTensorV2()
    ])

# ===============================
# ✅ Train and Evaluate (5-Fold + EarlyStopping)
# ===============================
def train_and_evaluate(base_path, device, epochs=40, batch_size=16):
    class_names = ['1', '2', '3']
    patience = 10

    for fold in range(1, 6):
        print(f"\n==== Fold {fold} ====")
        train_dir = os.path.join(base_path, f"fold_{fold}", "train")
        test_dir = os.path.join(base_path, f"fold_{fold}", "test")

        train_data = datasets.ImageFolder(train_dir)
        test_data = datasets.ImageFolder(test_dir)

        train_loader = DataLoader(
            AlbumentationsDataset(train_data, get_transform(train=True)),
            batch_size=batch_size, shuffle=True
        )
        test_loader = DataLoader(
            AlbumentationsDataset(test_data, get_transform(train=False)),
            batch_size=batch_size, shuffle=False
        )

        model = TumorClassifier().to(device)

        labels = [s[1] for s in train_data.samples]
        class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
        criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float).to(device))

        optimizer = optim.Adam([
            {'params': model.resnet.parameters(), 'lr': 1e-5},
            {'params': model.swin.parameters(), 'lr': 1e-5},
            {'params': model.cross_attn.parameters(), 'lr': 1e-4},
            {'params': model.classifier.parameters(), 'lr': 1e-4},
        ])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

        best_acc = 0.0
        best_model_state = None
        counter = 0

        for epoch in range(epochs):
            model.train()
            total_loss, correct, total = 0, 0, 0
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                out = model(x)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item() * x.size(0)
                correct += (out.argmax(1) == y).sum().item()
                total += y.size(0)

            acc = correct / total
            print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss/total:.4f} | Accuracy: {acc:.4f}")
            scheduler.step()

            if acc > best_acc:
                best_acc = acc
                best_model_state = copy.deepcopy(model.state_dict())
                counter = 0
            else:
                counter += 1
                print(f"Patience counter: {counter}/10")
                if counter >= patience:
                    print("Early stopping triggered.")
                    break

        # Save best model
        torch.save(best_model_state, f"/content/drive/MyDrive/best_model_fold_{fold}.pt")

        # === Evaluation
        model.load_state_dict(torch.load(f"/content/drive/MyDrive/best_model_fold_{fold}.pt"))
        model.eval()
        y_true, y_pred, y_score = [], [], []

        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                probs = nn.functional.softmax(out, dim=1)
                y_true.extend(y.cpu().numpy())
                y_pred.extend(probs.argmax(1).cpu().numpy())
                y_score.extend(probs.cpu().numpy())

        print("\nClassification Report:")
        print(classification_report(y_true, y_pred, digits=4, target_names=class_names))

        cm = confusion_matrix(y_true, y_pred, normalize='true')
        plt.figure(figsize=(6, 5))
        sns.heatmap(cm, annot=True, cmap='Blues', fmt='.2f',
                    xticklabels=class_names, yticklabels=class_names)
        plt.title(f'Confusion Matrix - Fold {fold}')
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.tight_layout()
        plt.savefig(f"/content/drive/MyDrive/conf_matrix_fold_{fold}.png")
        plt.close()

        y_bin = label_binarize(y_true, classes=[0, 1, 2])
        y_score = np.array(y_score)
        plt.figure(figsize=(6, 5))
        for i in range(3):
            fpr, tpr, _ = roc_curve(y_bin[:, i], y_score[:, i])
            auc_score = auc(fpr, tpr)
            print(f"AUC Class {i}: {auc_score:.4f}")
            plt.plot(fpr, tpr, label=f"Class {i} (AUC = {auc_score:.2f})")
        plt.plot([0, 1], [0, 1], 'k--')
        plt.title(f"ROC Curve - Fold {fold}")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"/content/drive/MyDrive/roc_curve_fold_{fold}.png")
        plt.close()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_path = "/content/drive/MyDrive/braintumor/split_5-fold"
    train_and_evaluate(base_path, device)


==== Fold 1 ====


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:02<00:00, 22.4MB/s]
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Epoch 1/40 | Loss: 0.5531 | Accuracy: 0.7870
Epoch 2/40 | Loss: 0.2623 | Accuracy: 0.9068
Epoch 3/40 | Loss: 0.1819 | Accuracy: 0.9337
Epoch 4/40 | Loss: 0.1412 | Accuracy: 0.9497
Epoch 5/40 | Loss: 0.0957 | Accuracy: 0.9699
Epoch 6/40 | Loss: 0.0768 | Accuracy: 0.9757
Epoch 7/40 | Loss: 0.0545 | Accuracy: 0.9860
Epoch 8/40 | Loss: 0.0474 | Accuracy: 0.9883
Epoch 9/40 | Loss: 0.0431 | Accuracy: 0.9886
Epoch 10/40 | Loss: 0.0438 | Accuracy: 0.9921
Epoch 11/40 | Loss: 0.0612 | Accuracy: 0.9801
Patience counter: 1/10
Epoch 12/40 | Loss: 0.0459 | Accuracy: 0.9842
Patience counter: 2/10
Epoch 13/40 | Loss: 0.0302 | Accuracy: 0.9918
Patience counter: 3/10
Epoch 14/40 | Loss: 0.0237 | Accuracy: 0.9944
Epoch 15/40 | Loss: 0.0152 | Accuracy: 0.9971
Epoch 16/40 | Loss: 0.0130 | Accuracy: 0.9965
Patience counter: 1/10
Epoch 17/40 | Loss: 0.0168 | Accuracy: 0.9944
Patience counter: 2/10
Epoch 18/40 | Loss: 0.0074 | Accuracy: 0.9988
Epoch 19/40 | Loss: 0.0049 | Accuracy: 0.9988
Patience counter: 1/