In [None]:
import warnings
# Suppress tqdm warning about IProgress (ipywidgets) - progress bars will still work
warnings.filterwarnings('ignore', category=UserWarning, module='tqdm')

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

# Note: transformers library is imported in the model setup cell
# to keep imports organized by functionality
# 
# Note: If you want enhanced progress bars in Jupyter, install ipywidgets:
#   pip install ipywidgets
#   jupyter nbextension enable --py widgetsnbextension  # for classic notebook
#   or for JupyterLab: jupyter labextension install @jupyter-widgets/jupyterlab-manager


In [None]:
# Convert grayscale to RGB for pretrained models (which expect 3 channels)
# TissueMNIST is grayscale (1 channel), so we repeat it 3 times
# Use ImageNet normalization for pretrained models
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),  # Convert 1 channel to 3
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
])

data_flag = 'tissuemnist'
# data_flag = 'breastmnist'
download = False

# Training hyperparameters - same for both CNN and Swin Transformer
NUM_EPOCHS = 3
BATCH_SIZE = 128
lr = 0.001  # Learning rate for transfer learning
USE_PRETRAINED = True  # Use ImageNet pretrained weights for both models

# Model saving configuration
SAVE_MODELS = True  # Set to True to save models during and after training
MODEL_SAVE_DIR = 'saved_models'  # Directory to save models
SAVE_BEST_ONLY = True  # Save only the best model (based on test accuracy) during training

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

In [None]:
DataClass = getattr(medmnist, info['python_class'])
# Use relative path that works across different machines
import os
# Get project root (parent of notebooks directory)
current_dir = os.getcwd()
if 'notebooks' in current_dir:
    project_root = os.path.dirname(current_dir)
else:
    # If running from project root
    project_root = current_dir
custom_path = os.path.join(project_root, 'mnist_dataset')
# Fallback to absolute path if relative doesn't work
if not os.path.exists(custom_path):
    custom_path = '/Users/shreyasavant/Desktop/comp6721/project_git_speed/project/mnist_dataset'
    if not os.path.exists(custom_path):
        print(f"Warning: Dataset path not found. Using: {custom_path}")
        print("Please update the custom_path variable if needed.")

# load the data
train_dataset = DataClass(split='train', transform=data_transform, download=download, root=custom_path, size=224, mmap_mode='r')
test_dataset = DataClass(split='test', transform=data_transform, download=download, root=custom_path, size=224, mmap_mode='r')

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [None]:
x, y = train_dataset[0]

print(x.shape, y.shape)

In [None]:
train_dataset.montage(length=3)

In [None]:
# MODEL SETUP: CNN (ResNet18) and Swin Transformer
# Both using ImageNet pretrained weights and same training approach

from torchvision.models import resnet18, ResNet18_Weights
from transformers import SwinModel

# Check for CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Number of channels: {n_channels}")
print(f"Number of classes: {n_classes}")
print(f"Using pretrained weights: {USE_PRETRAINED}")

# CNN Model: ResNet18 with ImageNet pretrained weights

if USE_PRETRAINED:
    # New torchvision API - load pretrained model with weights
    weights = ResNet18_Weights.IMAGENET1K_V1
    cnn_model = resnet18(weights=weights).to(device)
    # Replace the final fully connected layer for our number of classes
    num_features = cnn_model.fc.in_features
    cnn_model.fc = nn.Linear(num_features, n_classes)
    cnn_model = cnn_model.to(device)
    print("âœ“ CNN: Loaded ResNet18 with ImageNet pretrained weights (IMAGENET1K_V1)")
    print(f"  Replaced final layer: {num_features} -> {n_classes} classes")
else:
    cnn_model = resnet18(weights=None).to(device)
    num_features = cnn_model.fc.in_features
    cnn_model.fc = nn.Linear(num_features, n_classes)
    cnn_model = cnn_model.to(device)
    print("âœ“ CNN: ResNet18 initialized from scratch")

print(f"  CNN first layer expects: {cnn_model.conv1.in_channels} channels")


In [None]:
# MODEL SAVING UTILITIES
import os
from datetime import datetime

def setup_model_save_dir(base_dir=MODEL_SAVE_DIR):
    """Create directory for saving models if it doesn't exist"""
    if SAVE_MODELS:
        os.makedirs(base_dir, exist_ok=True)
        print(f"âœ“ Model save directory: {os.path.abspath(base_dir)}")
    return base_dir

def save_model_checkpoint(model, model_name, epoch, test_accuracy, save_dir, is_best=False):
    """Save model checkpoint (best and latest)"""
    if not SAVE_MODELS:
        return None
    
    # Create model name without spaces for filename
    model_name_clean = model_name.replace(' ', '_').replace('(', '').replace(')', '').lower()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'test_accuracy': test_accuracy,
        'model_name': model_name
    }
    
    # Always save latest checkpoint
    latest_path = os.path.join(save_dir, f'{model_name_clean}_latest.pt')
    torch.save(checkpoint, latest_path)
    
    # Save best checkpoint if this is the best model
    if is_best:
        best_path = os.path.join(save_dir, f'{model_name_clean}_best.pt')
        torch.save(checkpoint, best_path)
        print(f"  ðŸ’¾ Saved best model checkpoint: {best_path} (Test Acc: {test_accuracy:.2f}%)")
        return best_path
    
    return latest_path

def load_model_checkpoint(model, checkpoint_path, device):
    """Load model from checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"âœ“ Loaded model from {checkpoint_path}")
    print(f"  Epoch: {checkpoint['epoch']}, Test Accuracy: {checkpoint.get('test_accuracy', 'N/A')}")
    return checkpoint


In [None]:
# Swin Transformer Model: Swin Transformer with ImageNet pretrained weights

class SwinTransformerClassifier(nn.Module):
    def __init__(self, num_classes, model_name="microsoft/swin-base-patch4-window7-224"):
        super().__init__()
        if USE_PRETRAINED:
            self.swin = SwinModel.from_pretrained(model_name)
            print(f"âœ“ Swin Transformer: Loaded {model_name} with ImageNet pretrained weights")
        else:
            # Initialize from config without pretrained weights
            from transformers import SwinConfig
            config = SwinConfig.from_pretrained(model_name)
            self.swin = SwinModel(config)
            print(f"âœ“ Swin Transformer: Initialized {model_name} from scratch")
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(self.swin.config.hidden_size, num_classes)
        )
    
    def forward(self, x):
        outputs = self.swin(pixel_values=x)
        # Get pooled output (Swin Transformer uses pooler_output)
        if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
            pooled_output = outputs.pooler_output
        else:
            # Fallback: use mean pooling of last hidden state
            pooled_output = outputs.last_hidden_state.mean(dim=1)
        logits = self.classifier(pooled_output)
        return logits

swin_model = SwinTransformerClassifier(num_classes=n_classes).to(device)
print(f"  Swin Transformer hidden size: {swin_model.swin.config.hidden_size}")

# Loss function and optimizer - same for both models

criterion = nn.CrossEntropyLoss()

# Use same optimizer settings for both models (SGD with momentum)
# This ensures fair comparison between CNN and Swin Transformer
cnn_optimizer = optim.SGD(cnn_model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
swin_optimizer = optim.SGD(swin_model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)

print(f"\nâœ“ Both models use SGD optimizer with lr={lr}, momentum=0.9, weight_decay=1e-4")

In [None]:
# UNIFIED TRAINING FUNCTION - Same approach for both CNN and Swin Transformer

def train_model(model, optimizer, model_name, train_loader, test_loader, num_epochs, device, task, criterion):
    """
    Unified training function that works for both CNN and Swin Transformer models.
    Uses the same training approach for fair comparison.
    """
    train_losses = []
    train_accuracies = []
    test_losses = []
    test_accuracies = []
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}")
    
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        train_loss = 0.0
        
        model.train()
        for inputs, targets in tqdm(train_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} [Train]"):
            # Move inputs and targets to device
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(inputs)
            
            # Handle different task types
            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                loss = criterion(outputs, targets)
                pred = (torch.sigmoid(outputs) > 0.5).int()
                train_correct += (pred == targets.int()).all(dim=1).sum().item()
            else:
                targets = targets.squeeze().long()
                loss = criterion(outputs, targets)
                _, pred = torch.max(outputs, 1)
                train_correct += (pred == targets).sum().item()
            
            # Backward pass
            train_loss += loss.item()
            train_total += targets.size(0)
            loss.backward()
            optimizer.step()
        
        # Calculate training metrics
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = 100.0 * train_correct / train_total
        train_losses.append(avg_train_loss)
        train_accuracies.append(train_accuracy)
        
        # Evaluate on test set
        test_correct = 0
        test_total = 0
        test_loss = 0.0
        
        model.eval()
        with torch.no_grad():
            for inputs, targets in tqdm(test_loader, desc=f"{model_name} Epoch {epoch+1}/{num_epochs} [Test]"):
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                outputs = model(inputs)
                
                if task == 'multi-label, binary-class':
                    targets = targets.to(torch.float32)
                    loss = criterion(outputs, targets)
                    pred = (torch.sigmoid(outputs) > 0.5).int()
                    test_correct += (pred == targets.int()).all(dim=1).sum().item()
                else:
                    targets = targets.squeeze().long()
                    loss = criterion(outputs, targets)
                    _, pred = torch.max(outputs, 1)
                    test_correct += (pred == targets).sum().item()
                
                test_loss += loss.item()
                test_total += targets.size(0)
        
        # Calculate test metrics
        avg_test_loss = test_loss / len(test_loader)
        test_accuracy = 100.0 * test_correct / test_total
        test_losses.append(avg_test_loss)
        test_accuracies.append(test_accuracy)
        
        # Print epoch summary
        print(f'\n{model_name} - Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')
        print(f'  Test Loss: {avg_test_loss:.4f}, Test Acc: {test_accuracy:.2f}%')
        print('-' * 50)
    
    return {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'test_losses': test_losses,
        'test_accuracies': test_accuracies
    }


# TRAIN CNN MODEL

if 'device' not in locals():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

cnn_history = train_model(
    model=cnn_model,
    optimizer=cnn_optimizer,
    model_name="CNN (ResNet18)",
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=NUM_EPOCHS,
    device=device,
    task=task,
    criterion=criterion
)

In [None]:
# SAVE FINAL MODELS (Optional - models are already saved as checkpoints)

if SAVE_MODELS:
    model_save_dir = setup_model_save_dir()
    print(f"\n{'='*60}")
    print("Saving Final Models")
    print(f"{'='*60}")
    
    # Save final CNN model with full training history
    cnn_final_path = os.path.join(model_save_dir, 'cnn_resnet18_final.pt')
    torch.save({
        'model_state_dict': cnn_model.state_dict(),
        'model_name': 'CNN (ResNet18)',
        'num_classes': n_classes,
        'final_test_accuracy': cnn_history.get('best_test_accuracy', cnn_history['test_accuracies'][-1]),
        'best_epoch': cnn_history.get('best_epoch', NUM_EPOCHS),
        'training_history': cnn_history,
        'hyperparameters': {
            'num_epochs': NUM_EPOCHS,
            'batch_size': BATCH_SIZE,
            'learning_rate': lr,
            'use_pretrained': USE_PRETRAINED
        }
    }, cnn_final_path)
    print(f"âœ“ Saved final CNN model: {cnn_final_path}")
    
    # Save final Swin Transformer model with full training history
    swin_final_path = os.path.join(model_save_dir, 'swin_transformer_final.pt')
    torch.save({
        'model_state_dict': swin_model.state_dict(),
        'model_name': 'Swin Transformer',
        'num_classes': n_classes,
        'final_test_accuracy': swin_history.get('best_test_accuracy', swin_history['test_accuracies'][-1]),
        'best_epoch': swin_history.get('best_epoch', NUM_EPOCHS),
        'training_history': swin_history,
        'hyperparameters': {
            'num_epochs': NUM_EPOCHS,
            'batch_size': BATCH_SIZE,
            'learning_rate': lr,
            'use_pretrained': USE_PRETRAINED
        }
    }, swin_final_path)
    print(f"âœ“ Saved final Swin Transformer model: {swin_final_path}")
    
    print(f"\n{'='*60}")
    print(f"All models saved to: {os.path.abspath(model_save_dir)}")
    print(f"{'='*60}")
    print("\nSaved files:")
    print("  CNN:")
    print("    - cnn_resnet18_best.pt (best checkpoint during training)")
    print("    - cnn_resnet18_latest.pt (latest checkpoint)")
    print("    - cnn_resnet18_final.pt (final model with full history)")
    print("  Swin Transformer:")
    print("    - swin_transformer_best.pt (best checkpoint during training)")
    print("    - swin_transformer_latest.pt (latest checkpoint)")
    print("    - swin_transformer_final.pt (final model with full history)")
    print(f"\n{'='*60}")
    print("To load a saved model later, use:")
    print("  checkpoint = torch.load('saved_models/cnn_resnet18_best.pt', map_location=device)")
    print("  model.load_state_dict(checkpoint['model_state_dict'])")
    print(f"{'='*60}")


In [None]:
# QUICK SAVE: Save any model currently in memory
# Use this cell to save a model that's already trained and in memory

def quick_save_model(model, model_name, save_path=None, additional_info=None):
    """
    Quickly save a model that's currently in memory.
    
    Args:
        model: The PyTorch model to save
        model_name: Name/description of the model
        save_path: Full path to save the model (optional, defaults to saved_models/)
        additional_info: Dictionary with any additional info to save (optional)
    """
    import os
    
    # Setup save directory
    if save_path is None:
        os.makedirs('saved_models', exist_ok=True)
        model_name_clean = model_name.replace(' ', '_').replace('(', '').replace(')', '').lower()
        save_path = os.path.join('saved_models', f'{model_name_clean}_manual_save.pt')
    
    # Prepare checkpoint
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'model_name': model_name,
        'num_classes': n_classes if 'n_classes' in globals() else None,
    }
    
    # Add any additional info
    if additional_info:
        checkpoint.update(additional_info)
    
    # Save the model
    torch.save(checkpoint, save_path)
    print(f"âœ“ Model saved successfully!")
    print(f"  Path: {os.path.abspath(save_path)}")
    print(f"  Model: {model_name}")
    
    return save_path

# Example usage (uncomment and modify as needed):
# 
# # Save CNN model:
# quick_save_model(
#     model=cnn_model,
#     model_name="CNN (ResNet18) - Manual Save",
#     additional_info={'test_accuracy': 85.5}  # optional
# )
#
# # Save Swin Transformer model:
# quick_save_model(
#     model=swin_model,
#     model_name="Swin Transformer - Manual Save",
#     additional_info={'test_accuracy': 87.2}  # optional
# )
#
# # Or save to a specific path:
# quick_save_model(
#     model=cnn_model,
#     model_name="My Custom Model",
#     save_path='my_custom_path/model.pt'
# )

In [None]:
# TRAIN SWIN TRANSFORMER MODEL

swin_history = train_model(
    model=swin_model,
    optimizer=swin_optimizer,
    model_name="Swin Transformer",
    train_loader=train_loader,
    test_loader=test_loader,
    num_epochs=NUM_EPOCHS,
    device=device,
    task=task,
    criterion=criterion
)


In [None]:
# EVALUATE BOTH MODELS USING MEDMNIST EVALUATOR

def evaluate_model_medmnist(model, model_name, data_loader, split, device, task):
    """Evaluate model and return metrics using medmnist Evaluator"""
    model.eval()
    y_true = torch.tensor([])
    y_score = torch.tensor([])
    
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            outputs = outputs.softmax(dim=-1)
            
            # Collect both y_true and y_score
            y_score = torch.cat((y_score, outputs.cpu()), 0)
            # Handle targets shape (might be (batch_size, 1) or (batch_size,))
            if targets.dim() > 1:
                targets = targets.squeeze()
            y_true = torch.cat((y_true, targets.cpu()), 0)
    
    y_score = y_score.detach().numpy()
    y_true = y_true.detach().numpy()
    
    evaluator = Evaluator(data_flag, split, size=224)
    try:
        metrics = evaluator.evaluate(y_score, y_true)
    except TypeError:
        metrics = evaluator.evaluate(y_score)
    
    print(f'\n{model_name} - {split.upper()} Results:')
    print(f'  AUC: {metrics[0]:.3f}, Accuracy: {metrics[1]:.3f}')
    return metrics

split = 'test'
data_loader = train_loader if split == 'train' else test_loader

# Evaluate CNN
cnn_metrics = evaluate_model_medmnist(cnn_model, "CNN (ResNet18)", data_loader, split, device, task)

# Evaluate Swin Transformer
swin_metrics = evaluate_model_medmnist(swin_model, "Swin Transformer", data_loader, split, device, task)


In [None]:
# PLOT COMPARISON: CNN vs Swin Transformer Training Curves

import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss plots
axes[0, 0].plot(range(1, NUM_EPOCHS + 1), cnn_history['train_losses'], 'b-', label='CNN Train', linewidth=2)
axes[0, 0].plot(range(1, NUM_EPOCHS + 1), cnn_history['test_losses'], 'b--', label='CNN Test', linewidth=2)
axes[0, 0].plot(range(1, NUM_EPOCHS + 1), swin_history['train_losses'], 'r-', label='Swin Train', linewidth=2)
axes[0, 0].plot(range(1, NUM_EPOCHS + 1), swin_history['test_losses'], 'r--', label='Swin Test', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Test Loss Comparison')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Accuracy plots
axes[0, 1].plot(range(1, NUM_EPOCHS + 1), cnn_history['train_accuracies'], 'b-', label='CNN Train', linewidth=2)
axes[0, 1].plot(range(1, NUM_EPOCHS + 1), cnn_history['test_accuracies'], 'b--', label='CNN Test', linewidth=2)
axes[0, 1].plot(range(1, NUM_EPOCHS + 1), swin_history['train_accuracies'], 'r-', label='Swin Train', linewidth=2)
axes[0, 1].plot(range(1, NUM_EPOCHS + 1), swin_history['test_accuracies'], 'r--', label='Swin Test', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Training and Test Accuracy Comparison')
axes[0, 1].legend()
axes[0, 1].grid(True)

# CNN only
axes[1, 0].plot(range(1, NUM_EPOCHS + 1), cnn_history['train_losses'], 'b-', label='Train Loss', linewidth=2)
axes[1, 0].plot(range(1, NUM_EPOCHS + 1), cnn_history['test_losses'], 'r-', label='Test Loss', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('CNN (ResNet18) - Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Swin Transformer only
axes[1, 1].plot(range(1, NUM_EPOCHS + 1), swin_history['train_losses'], 'b-', label='Train Loss', linewidth=2)
axes[1, 1].plot(range(1, NUM_EPOCHS + 1), swin_history['test_losses'], 'r-', label='Test Loss', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].set_title('Swin Transformer - Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

# FINAL RESULTS SUMMARY

print(f'\n{'='*60}')
print('FINAL RESULTS SUMMARY')
print(f'{'='*60}')

print(f'\nCNN (ResNet18) Results:')
print(f'  Best Train Accuracy: {max(cnn_history["train_accuracies"]):.2f}%')
print(f'  Best Test Accuracy: {max(cnn_history["test_accuracies"]):.2f}%')
print(f'  Final Train Loss: {cnn_history["train_losses"][-1]:.4f}')
print(f'  Final Test Loss: {cnn_history["test_losses"][-1]:.4f}')
print(f'  MedMNIST Test AUC: {cnn_metrics[0]:.3f}, Acc: {cnn_metrics[1]:.3f}')

print(f'\nSwin Transformer Results:')
print(f'  Best Train Accuracy: {max(swin_history["train_accuracies"]):.2f}%')
print(f'  Best Test Accuracy: {max(swin_history["test_accuracies"]):.2f}%')
print(f'  Final Train Loss: {swin_history["train_losses"][-1]:.4f}')
print(f'  Final Test Loss: {swin_history["test_losses"][-1]:.4f}')
print(f'  MedMNIST Test AUC: {swin_metrics[0]:.3f}, Acc: {swin_metrics[1]:.3f}')

print(f'\n{'='*60}')
print('TRAINING CONFIGURATION:')
print(f'  Both models use ImageNet pretrained weights: {USE_PRETRAINED}')
print(f'  Optimizer: SGD with lr={lr}, momentum=0.9, weight_decay=1e-4')
print(f'  Batch size: {BATCH_SIZE}')
print(f'  Epochs: {NUM_EPOCHS}')
print(f'  Loss function: CrossEntropyLoss')
print(f'{'='*60}')

In [None]:
# PLOT TRAINING AND TESTING ACCURACY CURVES

import matplotlib.pyplot as plt

# Create a figure with two subplots - one for each model
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# CNN Accuracy Curves
axes[0].plot(range(1, NUM_EPOCHS + 1), cnn_history['train_accuracies'], 'b-o', 
             label='Training Accuracy', linewidth=2, markersize=8)
axes[0].plot(range(1, NUM_EPOCHS + 1), cnn_history['test_accuracies'], 'r-s', 
             label='Testing Accuracy', linewidth=2, markersize=8)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Accuracy (%)', fontsize=12)
axes[0].set_title('CNN (ResNet18) - Training and Testing Accuracy', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].set_xticks(range(1, NUM_EPOCHS + 1))

# Swin Transformer Accuracy Curves
axes[1].plot(range(1, NUM_EPOCHS + 1), swin_history['train_accuracies'], 'b-o', 
             label='Training Accuracy', linewidth=2, markersize=8)
axes[1].plot(range(1, NUM_EPOCHS + 1), swin_history['test_accuracies'], 'r-s', 
             label='Testing Accuracy', linewidth=2, markersize=8)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Swin Transformer - Training and Testing Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)
axes[1].set_xticks(range(1, NUM_EPOCHS + 1))

plt.tight_layout()
plt.show()

# Print accuracy summary
print(f'\n{"="*60}')
print('ACCURACY SUMMARY')
print(f'{"="*60}')
print(f'\nCNN (ResNet18):')
print(f'  Training Accuracy: {cnn_history["train_accuracies"][-1]:.2f}% (Final), {max(cnn_history["train_accuracies"]):.2f}% (Best)')
print(f'  Testing Accuracy:  {cnn_history["test_accuracies"][-1]:.2f}% (Final), {max(cnn_history["test_accuracies"]):.2f}% (Best)')
print(f'\nSwin Transformer:')
print(f'  Training Accuracy: {swin_history["train_accuracies"][-1]:.2f}% (Final), {max(swin_history["train_accuracies"]):.2f}% (Best)')
print(f'  Testing Accuracy:  {swin_history["test_accuracies"][-1]:.2f}% (Final), {max(swin_history["test_accuracies"]):.2f}% (Best)')
print(f'{"="*60}')