In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator


In [None]:
# Convert grayscale to RGB for ResNet (which expects 3 channels)
# TissueMNIST is grayscale (1 channel), so we repeat it 3 times
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),  # Convert 1 channel to 3
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 3 channels for RGB
])

data_flag = 'tissuemnist'
# data_flag = 'breastmnist'
download = False

NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

In [None]:
DataClass = getattr(medmnist, info['python_class'])
# Use relative path that works across different machines
import os
# Get project root (parent of notebooks directory)
current_dir = os.getcwd()
if 'notebooks' in current_dir:
    project_root = os.path.dirname(current_dir)
else:
    # If running from project root
    project_root = current_dir
custom_path = os.path.join(project_root, 'mnist_dataset')
# Fallback to absolute path if relative doesn't work
if not os.path.exists(custom_path):
    custom_path = '/Users/shreyasavant/Desktop/comp6721/project_git_speed/project/mnist_dataset'
    if not os.path.exists(custom_path):
        print(f"Warning: Dataset path not found. Using: {custom_path}")
        print("Please update the custom_path variable if needed.")

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download, root=custom_path, size=224, mmap_mode='r')
test_dataset = DataClass(split='test', transform=data_transform, download=download, root=custom_path, size=224, mmap_mode='r')

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
x, y = train_dataset[0]

print(x.shape, y.shape)

In [None]:
train_dataset.montage(length=3)

In [None]:
from torchvision.models import resnet18

# Check for CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Number of channels: {n_channels}")

# Create model - ResNet18 expects 3 channels (RGB)
# Data transform converts grayscale (1 channel) to RGB (3 channels)
model = resnet18(num_classes=n_classes).to(device)

# Verify input shape matches model expectations
print(f"Model first layer expects: {model.conv1.in_channels} channels")

criterion = nn.CrossEntropyLoss()
    
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# train
# Make sure device is defined (from previous cell)
if 'device' not in locals():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

# Training history
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    train_loss = 0.0
    
    model.train()
    for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Train]"):
        # Move inputs and targets to device (CPU or GPU)
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
            # For multi-label, use sigmoid and threshold
            pred = (torch.sigmoid(outputs) > 0.5).int()
            train_correct += (pred == targets.int()).all(dim=1).sum().item()
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)
            # Calculate accuracy
            _, pred = torch.max(outputs, 1)
            train_correct += (pred == targets).sum().item()
        
        train_loss += loss.item()
        train_total += targets.size(0)
        
        loss.backward()
        optimizer.step()
    
    # Calculate training metrics
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100.0 * train_correct / train_total
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)
    
    # Evaluate on test set
    test_correct = 0
    test_total = 0
    test_loss = 0.0
    
    model.eval()
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Test]"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                loss = criterion(outputs, targets)
                pred = (torch.sigmoid(outputs) > 0.5).int()
                test_correct += (pred == targets.int()).all(dim=1).sum().item()
            else:
                targets = targets.squeeze().long()
                loss = criterion(outputs, targets)
                _, pred = torch.max(outputs, 1)
                test_correct += (pred == targets).sum().item()
            
            test_loss += loss.item()
            test_total += targets.size(0)
    
    # Calculate test metrics
    avg_test_loss = test_loss / len(test_loader)
    test_accuracy = 100.0 * test_correct / test_total
    test_losses.append(avg_test_loss)
    test_accuracies.append(test_accuracy)
    
    # Print epoch summary
    print(f'\nEpoch {epoch+1}/{NUM_EPOCHS}:')
    print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')
    print(f'  Test Loss: {avg_test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')
    print('-' * 50)

In [None]:
split = 'test'

model.eval()
y_true = torch.tensor([])
y_score = torch.tensor([])

# Use the correct data loader (train_loader_at_eval was undefined)
data_loader = train_loader if split == 'train' else test_loader

# Make sure device is defined
if 'device' not in locals():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with torch.no_grad():
    for inputs, targets in data_loader:
        inputs = inputs.to(device)  # Use .to(device) instead of .cuda() for portability
        targets = targets.to(device)
        
        outputs = model(inputs)
        outputs = outputs.softmax(dim=-1)
        
        # Collect both y_true and y_score
        y_score = torch.cat((y_score, outputs.cpu()), 0)
        # Handle targets shape (might be (batch_size, 1) or (batch_size,))
        if targets.dim() > 1:
            targets = targets.squeeze()
        y_true = torch.cat((y_true, targets.cpu()), 0)

    y_score = y_score.detach().numpy()
    y_true = y_true.detach().numpy()
    
    evaluator = Evaluator(data_flag, split, size=224)
    # medmnist Evaluator.evaluate() typically only needs y_score (gets y_true from dataset)
    # but we collect y_true for potential future use
    try:
        metrics = evaluator.evaluate(y_score, y_true)
    except TypeError:
        # If evaluator doesn't accept y_true, try with just y_score
        metrics = evaluator.evaluate(y_score)

    print('%s  auc: %.3f  acc: %.3f' % (split, *metrics))


In [None]:
# Plot training curves
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, 'b-', label='Train Loss')
plt.plot(range(1, NUM_EPOCHS + 1), test_losses, 'r-', label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, NUM_EPOCHS + 1), train_accuracies, 'b-', label='Train Accuracy')
plt.plot(range(1, NUM_EPOCHS + 1), test_accuracies, 'r-', label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Test Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Print final results
print(f'\nFinal Results:')
print(f'  Best Train Accuracy: {max(train_accuracies):.2f}%')
print(f'  Best Test Accuracy: {max(test_accuracies):.2f}%')
print(f'  Final Train Loss: {train_losses[-1]:.4f}')
print(f'  Final Test Loss: {test_losses[-1]:.4f}')
