# 🚀 ThermoSight: Training Pipeline Experiment

This notebook implements and monitors the complete training pipeline for ThermoSight:
- **Configuration Setup**: Define all hyperparameters and paths.
- **Advanced Data Loading**: Includes augmentations and efficient batching.
- **Model Architecture**: Define and initialize the Vision Transformer (ViT).
- **Training Loop**: With detailed progress, loss, and learning rate tracking.
- **Evaluation**: Calculate accuracy, confusion matrix, and per-class metrics.
- **TensorBoard Logging**: For real-time monitoring of metrics and visualizations.
- **Model Checkpointing**: Save the best performing model.
- **Results Visualization**: Plot training curves.

---

In [None]:
# Imports and Global Setup
import os
import sys
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix as sk_confusion_matrix

# Add src to path for custom modules
sys.path.append(os.path.abspath(os.path.join('..'))) # Go up one level to ThermoSight root
from src.models.vit_model import ViT
from src.utils.visualize import plot_metrics as tb_plot_metrics # Renamed to avoid conflict
from src.utils.visualize import plot_confusion_matrix as tb_plot_confusion_matrix # Renamed
from src.evaluate.metrics import per_class_accuracy

# Notebook specific settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")
%matplotlib inline
%load_ext tensorboard

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration Class
class TrainingConfig:
    def __init__(self):
        # Data paths
        self.base_dir = os.path.join('..') # Project root
        self.data_dir = os.path.join(self.base_dir, 'data', 'processed')
        self.train_dir = os.path.join(self.data_dir, 'train')
        self.test_dir = os.path.join(self.data_dir, 'test')

        # Model parameters
        self.img_size = 460
        self.patch_size = 8
        self.in_channels = 3
        self.num_classes = 4 # Adjust if your dataset has different number of classes
        self.embed_dim = 768
        self.depth = 12
        self.num_heads = 12
        self.mlp_ratio = 4.0
        self.dropout = 0.1 # Dropout for training

        # Training hyperparameters
        self.batch_size = 16 # Adjust based on GPU memory
        self.learning_rate = 3e-5 # ViTs often benefit from smaller LRs
        self.weight_decay = 0.05
        self.num_epochs = 15 # Increase for full training, keep low for demo
        self.warmup_epochs = 2
        self.label_smoothing = 0.1

        # Logging and saving
        self.experiment_name = f"ViT_exp_{time.strftime('%Y%m%d_%H%M%S')}"
        self.log_dir = os.path.join(self.base_dir, 'outputs', 'logs', self.experiment_name)
        self.model_save_dir = os.path.join(self.base_dir, 'models')
        self.best_model_name = f"{self.experiment_name}_best_model.pth"
        
        # Device
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Ensure directories exist
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.model_save_dir, exist_ok=True)

config = TrainingConfig()

print("Configuration Loaded:")
for key, value in config.__dict__.items():
    print(f"  {key}: {value}")

# Check if data directories exist
if not os.path.exists(config.train_dir) or not os.path.exists(config.test_dir):
    print(f"🚨 WARNING: Data directories not found at {config.train_dir} or {config.test_dir}")
    print("💡 Please run the data preparation script (e.g., src/data/make_dataset.py) first.")

In [None]:
# Data Transforms and Loaders
def get_data_transforms(img_size):
    # Normalization parameters (commonly ImageNet)
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        normalize
    ])
    
    test_transforms = transforms.Compose([
        transforms.Resize(img_size + 32), # Resize to a bit larger then center crop
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        normalize
    ])
    return train_transforms, test_transforms

train_tf, test_tf = get_data_transforms(config.img_size)

# Create datasets
try:
    train_dataset = datasets.ImageFolder(config.train_dir, transform=train_tf)
    test_dataset = datasets.ImageFolder(config.test_dir, transform=test_tf)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    print(f"🖼️ Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
    print(f"🏷️ Classes: {train_dataset.classes}")
    config.num_classes = len(train_dataset.classes) # Update num_classes from dataset
    print(f"Updated num_classes to: {config.num_classes}")
except FileNotFoundError as e:
    print(f"❌ Error: Data directory not found. {e}")
    print("💡 Make sure 'data/processed/train' and 'data/processed/test' exist and contain class subfolders.")
    train_loader, test_loader = None, None # Set to None to prevent further execution

In [None]:
# Model, Loss, Optimizer, Scheduler
if train_loader and test_loader: # Proceed only if data loaded
    model = ViT(
        img_size=config.img_size,
        patch_size=config.patch_size,
        in_channels=config.in_channels,
        num_classes=config.num_classes,
        embed_dim=config.embed_dim,
        depth=config.depth,
        num_heads=config.num_heads,
        mlp_ratio=config.mlp_ratio,
        dropout=config.dropout
    ).to(config.device)

    print(f"🤖 Model Initialized: {model.__class__.__name__}")
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"   Trainable Parameters: {num_params/1e6:.2f}M")

    criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    
    # Scheduler: OneCycleLR
    total_steps = len(train_loader) * config.num_epochs
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=config.learning_rate, 
        total_steps=total_steps,
        pct_start=config.warmup_epochs/config.num_epochs if config.num_epochs > 0 else 0.1
    )

    # TensorBoard Writer
    writer = SummaryWriter(log_dir=config.log_dir)
    print(f"📊 TensorBoard logs will be saved to: {config.log_dir}")
    # Log hyperparameters
    writer.add_hparams(
        {k: v for k, v in config.__dict__.items() if isinstance(v, (int, float, str, bool))},
        {} # No metrics to log at hparam definition time
    )
else:
    print("⚠️ Skipping model and optimizer setup due to data loading issues.")
    model, writer = None, None # Ensure these are None if data loading failed

In [None]:
# Training and Evaluation Loop
def train_one_epoch(model, loader, criterion, optimizer, scheduler, device, epoch, writer):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{config.num_epochs} [Training]", leave=False)
    for i, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step() # Step scheduler at each batch for OneCycleLR
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_predictions += torch.sum(preds == labels.data)
        total_samples += labels.size(0)
        
        # Log batch loss and LR to TensorBoard
        current_iter = epoch * len(loader) + i
        writer.add_scalar('Loss/train_batch', loss.item(), current_iter)
        writer.add_scalar('LR/batch', scheduler.get_last_lr()[0], current_iter)
        
        progress_bar.set_postfix(loss=loss.item(), lr=f"{scheduler.get_last_lr()[0]:.1e}")
        
    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    return epoch_loss, epoch_acc.item()

def evaluate(model, loader, criterion, device, epoch, writer, class_names):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_labels = []
    all_preds = []
    
    progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{config.num_epochs} [Evaluating]", leave=False)
    with torch.no_grad():
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
            total_samples += labels.size(0)
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            
    epoch_loss = running_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    
    # Confusion Matrix and Per-Class Accuracy
    cm = sk_confusion_matrix(all_labels, all_preds, labels=list(range(config.num_classes)))
    pca = per_class_accuracy(cm)
    
    # Log to TensorBoard
    tb_plot_metrics(writer, epoch_loss, epoch_acc.item(), epoch) # Using renamed util
    tb_plot_confusion_matrix(writer, cm, epoch, class_names=class_names) # Using renamed util
    for i, acc_val in enumerate(pca):
        writer.add_scalar(f'Accuracy/class_{class_names[i]}', acc_val, epoch)
        
    return epoch_loss, epoch_acc.item(), cm, pca

# --- Training Execution ---
if model and train_loader and test_loader and writer: # Check if setup was successful
    best_test_acc = 0.0
    history = defaultdict(list)
    class_names = train_dataset.classes

    print(f"\n🚀 Starting Training for {config.num_epochs} epochs...")
    start_time = time.time()

    for epoch in range(config.num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, config.device, epoch, writer)
        test_loss, test_acc, cm, pca = evaluate(model, test_loader, criterion, config.device, epoch, writer, class_names)
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        history['lr'].append(scheduler.get_last_lr()[0])

        print(f"Epoch {epoch+1}/{config.num_epochs} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f} | "
              f"LR: {scheduler.get_last_lr()[0]:.1e}")

        # Save best model
        if test_acc > best_test_acc:
            best_test_acc = test_acc
            model_save_path = os.path.join(config.model_save_dir, config.best_model_name)
            torch.save(model.state_dict(), model_save_path)
            print(f"✨ New best model saved to {model_save_path} (Test Acc: {best_test_acc:.4f})")

    total_time = time.time() - start_time
    print(f"\n🏁 Training Finished! Total time: {total_time/60:.2f} minutes")
    print(f"🏆 Best Test Accuracy: {best_test_acc:.4f}")
    
    # Log final metrics to hparams
    writer.add_hparams(
        {k: v for k, v in config.__dict__.items() if isinstance(v, (int, float, str, bool))},
        {"hparam/best_test_accuracy": best_test_acc,
         "hparam/final_train_accuracy": train_acc,
         "hparam/final_test_accuracy": test_acc}
    )
    writer.close()
else:
    print("❌ Training aborted due to setup issues (model, data, or writer not initialized).")

In [None]:
# Plot Training History
if model and 'train_loss' in history and len(history['train_loss']) > 0: # Check if training ran
    epochs_range = range(1, config.num_epochs + 1)

    plt.figure(figsize=(15, 5))

    # Plot Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs_range, history['train_loss'], label='Train Loss', marker='o')
    plt.plot(epochs_range, history['test_loss'], label='Test Loss', marker='x')
    plt.title('Loss vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot Accuracy
    plt.subplot(1, 3, 2)
    plt.plot(epochs_range, history['train_acc'], label='Train Accuracy', marker='o')
    plt.plot(epochs_range, history['test_acc'], label='Test Accuracy', marker='x')
    plt.title('Accuracy vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot Learning Rate
    plt.subplot(1, 3, 3)
    plt.plot(epochs_range, history['lr'], label='Learning Rate', color='green', marker='.')
    plt.title('Learning Rate vs. Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Learning Rate')
    plt.legend()
    plt.grid(True)
    plt.yscale('log') # Often better to view LR on log scale

    plt.tight_layout()
    plt.show()
    
    # Display final confusion matrix from the last epoch
    if 'cm' in locals() and cm is not None: # Check if cm was computed
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                    xticklabels=class_names, yticklabels=class_names)
        plt.title(f'Confusion Matrix (Epoch {config.num_epochs})')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.show()
        
        print("\nPer-class Accuracy (Epoch {config.num_epochs}):")
        if 'pca' in locals() and pca is not None:
            for i, acc_val in enumerate(pca):
                print(f"  {class_names[i]}: {acc_val:.4f}")
else:
    print("⚠️ No training history to plot.")

## 💡 Next Steps & Experimentation Ideas

- **Hyperparameter Tuning**:
    - Experiment with different learning rates, weight decay, batch sizes.
    - Try different optimizer (e.g., SGD with momentum) or schedulers.
    - Adjust `mlp_ratio`, `embed_dim`, `depth`, `num_heads` in ViT.
- **Data Augmentation**:
    - Explore more advanced augmentation techniques (e.g., Mixup, CutMix).
    - Analyze the impact of different augmentation strengths.
- **Regularization**:
    - Experiment with different dropout rates or other regularization methods like Stochastic Depth.
- **Transfer Learning**:
    - Initialize ViT with pre-trained weights (e.g., from ImageNet) and fine-tune.
- **Longer Training**:
    - Run for more epochs to see if performance improves further.
- **Advanced Logging**:
    - Log model gradients or activation histograms to TensorBoard for deeper insights.
- **Cross-Validation**:
    - Implement k-fold cross-validation for more robust evaluation if dataset size permits.

---

🚀 **Experimentation is key to achieving optimal performance!**