# Deep Learning Assignment 2 - Part A
## CNN for Image Classification

This notebook implements and analyzes a CNN model for image classification using PyTorch Lightning.

In [None]:
# Import required libraries
import torch
import torch.nn as nn
import pytorch_lightning as pl
import wandb
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import datasets, transforms
import os
from pathlib import Path
from model import CNN
import numpy as np
from torchvision.utils import make_grid

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Data Loading and Exploration

Let's first load and visualize some samples from our dataset.

In [None]:
# Set up data transforms
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# Load datasets
train_dataset = datasets.ImageFolder(root='data/train_split', transform=transform)
val_dataset = datasets.ImageFolder(root='data/val_split', transform=transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"\nClasses: {train_dataset.classes}")

In [None]:
# Visualize some sample images
def show_samples(dataset, num_samples=10):
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.ravel()
    
    for idx in range(num_samples):
        img, label = dataset[idx]
        axes[idx].imshow(img.permute(1, 2, 0))
        axes[idx].set_title(dataset.classes[label])
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(train_dataset)

## 2. Model Architecture

Let's create and visualize our CNN model architecture.

In [None]:
# Initialize model with default parameters
model = CNN(
    num_conv_layers=5,
    num_filters=32,
    filter_size=3,
    activation='ReLU',
    dense_layer_neurons=64,
    batch_norm=True,
    dropout_rate=0.3
)

# Print model summary
print(model)

# Calculate model complexity
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 3. Training Configuration

Set up the training configuration and initialize W&B logging.

In [None]:
from train import LitCNN, train_model

# Configure training parameters
config = {
    'num_conv_layers': 5,
    'num_filters': 32,
    'filter_size': 3,
    'activation': 'ReLU',
    'dense_layer_neurons': 64,
    'learning_rate': 1e-3,
    'batch_norm': True,
    'dropout_rate': 0.3,
    'data_augmentation': True,
    'batch_size': 32,
    'max_epochs': 10
}

# Initialize wandb
wandb.init(project='da6401_assignment2_partA', config=config)

In [None]:
# Train the model
trained_model = train_model(**config)

# Close wandb run
wandb.finish()

## 4. Training Results Analysis

In [None]:
# Plot training history from wandb
api = wandb.Api()
run = api.run(f"{wandb.run.entity}/{wandb.run.project}/{wandb.run.id}")
history = run.history()

plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Loss over time')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['train_acc'], label='Train')
plt.plot(history['val_acc'], label='Validation')
plt.title('Accuracy over time')
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

## 5. Model Complexity Analysis

In [None]:
# Display results
print("Results from Part A:")
print(results)

# Create a simple visualization
plt.figure(figsize=(10, 6))
plt.plot(results)
plt.title('Part A Results')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.show()

## 6. Hyperparameter Sweep

Now let's perform a hyperparameter sweep to find the best model configuration. We'll use Weights & Biases (wandb) for tracking the experiments.

In [None]:
from partA.train import get_sweep_config
import wandb

# Get the sweep configuration
sweep_config = get_sweep_config()
print("Sweep configuration:")
print(sweep_config)

In [None]:
# Initialize and run the sweep
sweep_id = wandb.sweep(sweep_config, project='da6401_assignment2_partA')

# Run the agent
wandb.agent(sweep_id, function=train_sweep_model, count=20)  # Run 20 different configurations

## 7. Sweep Analysis

Let's analyze the results of our hyperparameter sweep to understand which configurations performed best.

In [None]:
from partA.sweep_analysis import analyze_sweep
import os

# Create output directory for analysis plots
os.makedirs('sweep_analysis', exist_ok=True)

# Run the analysis
analyze_sweep(
    entity=wandb.run.entity,
    project='da6401_assignment2_partA',
    sweep_id=sweep_id,
    output_dir='sweep_analysis'
)

In [None]:
# Display the analysis plots
import matplotlib.pyplot as plt
from IPython.display import Image, display

print("Accuracy vs. Runs:")
display(Image(filename='sweep_analysis/accuracy_vs_runs.png'))

print("\nCorrelation Heatmap:")
display(Image(filename='sweep_analysis/correlation_heatmap.png'))

print("\nParallel Coordinates Plot:")
display(Image(filename='sweep_analysis/parallel_coordinates.png'))