In [None]:
!pip install kagglehub timm --quiet

In [None]:
import kagglehub
from pathlib import Path

path = Path(kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset"))
print("Dataset path:", path)
print("Top-level folders:", list(path.iterdir()))

In [None]:
from PIL import Image
import matplotlib.pyplot as plt

# Load one sample image to verify access
sample_image_path = path / "Training" / "glioma" / "Tr-gl_0010.jpg"
img = Image.open(sample_image_path)
plt.imshow(img)
plt.title("Sample Image")
plt.axis("off")
plt.show()

print("Image mode:", img.mode)  # Expect 'RGB'

In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split

# preprocessing transforms
# these transforms help the model generalize better and avoid overfitting
transform = transforms.Compose([
    transforms.Resize((224, 224)), #resize all images to a fixed size, Vit's typically use 224 by 224
    transforms.RandomHorizontalFlip(), #flips images horizontally randomly
    transforms.RandomRotation(15), # randomly rotates images up to 15 degrees
    transforms.RandomAffine(degrees=10, translate=(0.05, 0.05)), # random affine transformations which include geometric changes like rotating, translating, scaling, etc.
    transforms.ColorJitter(brightness=0.1, contrast=0.1), # randomly tweaks with the brightness
    transforms.ToTensor(), #converts images to a pytorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], #normalizing using ImageNet means and stds (our model is pretrained on ImageNet and then trained further on the dataset)
                         std=[0.229, 0.224, 0.225]),
])

train_dir = path / "Training"
test_dir = path / "Testing"

train_dataset = ImageFolder(train_dir, transform=transform)
test_dataset = ImageFolder(test_dir, transform=transform)

# Split training set into 90% train, 10% validation
train_size = int(0.9 * len(train_dataset))
validation_size = len(train_dataset) - train_size
train_subset, validation_subset = random_split(train_dataset, [train_size, validation_size])

# loaders need to be added so the models can process the data in mini batches, which increases training speed
train_loader = DataLoader(train_subset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
validation_loader = DataLoader(validation_subset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
!pip install pytorch-lightning --quiet


In [None]:
"""Please note this is based on official pytorch lightning guidelines for building Lightning modules:
 https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
and adapted using the timm library's pretrained ViT ("vit_tiny_patch16_224")
https://github.com/rwightman/pytorch-image-models
label smoothing and lagging methods follow best practices outlined in lightning docs and timm examples.
"""

import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViTLightningModel(pl.LightningModule):
    def __init__(self, num_classes=4, lr=3e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = timm.create_model("vit_tiny_patch16_224", pretrained=True, num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", acc)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

model = ViTLightningModel(num_classes=4, lr=3e-4)

early_stop = EarlyStopping(monitor="val_loss", patience=3, mode="min")
logger = CSVLogger("lightning_logs_v2", name="vit_model")

trainer = pl.Trainer(
    max_epochs=30,
    accelerator="auto",
    devices="auto",
    precision="16-mixed",
    callbacks=[early_stop],
    logger=logger
)

In [None]:
trainer.fit(model, train_loader, validation_loader)

def evaluate_model(model, test_loader, class_names, device='cuda'):
    model.eval()
    model.to(device)
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    return all_labels, all_preds

In [None]:
#evaluate on test set

class_names = ["Glioma", "Meningioma", "Pituitary", "No Tumor"]
y_true, y_pred = evaluate_model(model, test_loader, class_names)

In [None]:
from sklearn.metrics import accuracy_score

# Evaluate accuracy
acc = accuracy_score(y_true, y_pred)
print(f"Test Accuracy: {acc:.4f}")

# Extract test loss from the test result (Lightning returns a list of dicts)
test_results = trainer.test(model, dataloaders=test_loader, verbose=False)
test_loss = test_results[0]['test_loss']
print(f"Test Loss: {test_loss:.4f}")

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Greens",
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix – ViT Model")
plt.show()

In [None]:
from sklearn.metrics import precision_recall_fscore_support
import numpy as np

# Get metrics
precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, zero_division=0)

x = np.arange(len(class_names))
width = 0.25

# Plot
plt.figure(figsize=(10, 6))
plt.bar(x - width, precision, width, label='Precision', color='blue')
plt.bar(x, recall, width, label='Recall', color='green')
plt.bar(x + width, f1, width, label='F1-Score', color='red')

plt.xticks(x, class_names)
plt.ylim(0.8, 1.0)
plt.xlabel("Class")
plt.ylabel("Score")
plt.title("Precision, Recall, F1-Score by Class")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
!git clone https://github.com/richapatel93/brain-tumor-classification.git
%cd brain-tumor-classification

In [None]:
%cd /content/brain-tumor-classification

!git config --global user.email "sabel.matt@gmail.com"
!git config --global user.name "matthew-sabel"

In [None]:
!git add brain_tumor_vit_organized.ipynb
!git commit -m "Add cleaned ViT notebook with final plots and evaluation"
!git push https://[REDACTED_TOKEN]IcTRQ14ikNXqY4eyQHj32L0rnzUiZK4POClN@github.com/richapatel93/brain-tumor-classification.git main

In [None]:
!git rm brain_tumor_vit_final.ipynb
!git commit -m "Remove outdated ViT final notebook"
!git push https://[REDACTED_TOKEN]IcTRQ14ikNXqY4eyQHj32L0rnzUiZK4POClN@github.com/richapatel93/brain-tumor-classification.git


In [None]:
import nbformat

# Path to your current notebook (adjust if needed)
notebook_path = "/content/brain_tumor_vit_organized.ipynb"

# Load the notebook
with open(notebook_path, 'r', encoding='utf-8') as f:
    nb = nbformat.read(f, as_version=nbformat.NO_CONVERT)

# Remove problematic metadata if it exists
if 'widgets' in nb.metadata.get('metadata', {}):
    del nb.metadata['metadata']['widgets']

# Optional: remove all cell output to further clean the notebook
for cell in nb.cells:
    if 'outputs' in cell:
        cell['outputs'] = []
    if 'execution_count' in cell:
        cell['execution_count'] = None

# Save cleaned notebook
with open(notebook_path, 'w', encoding='utf-8') as f:
    nbformat.write(nb, f)

print("Notebook cleaned and saved successfully.")