In [None]:
# Part 1, Option A: Resnet Model Training / Fine tuning for better feature extraction with extended evaluation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
from PIL import Image, ImageFile
import os
import json
import csv
import logging
import sys
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
import random

ImageFile.LOAD_TRUNCATED_IMAGES = True

# Directory and Data Settings
DATA_PARAMS = {
    'data_folders': {
        'Boston': '../data/ma-boston/buildings',
        'Charlotte': '../data/nc-charlotte/buildings',
        'Manhattan': '../data/ny-manhattan/buildings',
        'Pittsburgh': '../data/pa-pittsburgh/buildings'
    },
    'output_dir': 'softmax-output',
    'model_subdir': 'models',
    'log_subdir': 'logs',
    'viz_subdir': 'visualizations',
    'metrics_subdir': 'metrics'
}

# Training Parameters
TRAINING_PARAMS = {
    # Data parameters
    'batch_size': 32,
    'gradient_accumulation_steps': 4,
    'train_val_split': 0.8,
    'num_workers': 0,  # Set to 0 to avoid multiprocessing issues
    'max_images_per_class': None,  # Set to None for all images, or a number for limit
    
    # Model parameters
    'model_type': 'ResNet50',  # Options: 'ResNet18', 'ResNet50'
    'hidden_dim': 512,
    'dropout_rate': 0.3,
    
    # Training parameters
    'num_epochs': 50,
    'learning_rate': 1e-3,
    'weight_decay': 1e-5,
    'focal_loss_gamma': 2.0,
    
    # Scheduler parameters
    'scheduler_factor': 0.5,
    'scheduler_patience': 5,
    
    # Monitoring parameters
    'visualization_interval': 5,  # Export plots every N epochs
    'checkpoint_interval': 10,    # Save checkpoint every N epochs
    'max_checkpoints': 3,        # Maximum number of checkpoints to keep
}

# Image Transform Parameters
TRANSFORM_PARAMS = {
    'image_size': (224, 224),
    'rotation_degrees': 15,
    'color_jitter': {
        'brightness': 0.2,
        'contrast': 0.2,
        'saturation': 0.2,
        'hue': 0.1
    },
    'normalize_mean': [0.485, 0.456, 0.406],  # ImageNet normalization
    'normalize_std': [0.229, 0.224, 0.225]
}

# Calculate effective batch size
TRAINING_PARAMS['effective_batch_size'] = TRAINING_PARAMS['batch_size'] * TRAINING_PARAMS['gradient_accumulation_steps']

# Create transforms
train_transform = transforms.Compose([
    transforms.Resize(TRANSFORM_PARAMS['image_size']),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(TRANSFORM_PARAMS['rotation_degrees']),
    transforms.ColorJitter(
        brightness=TRANSFORM_PARAMS['color_jitter']['brightness'],
        contrast=TRANSFORM_PARAMS['color_jitter']['contrast'],
        saturation=TRANSFORM_PARAMS['color_jitter']['saturation'],
        hue=TRANSFORM_PARAMS['color_jitter']['hue']
    ),
    transforms.RandomAffine(degrees=TRANSFORM_PARAMS['rotation_degrees'], 
                          translate=(0.1, 0.1), 
                          scale=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=TRANSFORM_PARAMS['normalize_mean'],
        std=TRANSFORM_PARAMS['normalize_std']
    )
])

def manage_checkpoints(output_dir, epoch, model, optimizer, scheduler, metrics, device):
    """Save checkpoint and maintain maximum number of checkpoints"""
    checkpoint_dir = output_dir / 'checkpoints'
    checkpoint_dir.mkdir(exist_ok=True)
    
    # Handle MPS device when saving model
    if device.type == "mps":
        model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
        optimizer_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 
                             for k, v in optimizer.state_dict().items()}
    else:
        model_state_dict = model.state_dict()
        optimizer_state_dict = optimizer.state_dict()
    
    # Save checkpoint
    checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pth'
    torch.save({
        'epoch': epoch,
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer_state_dict,
        'scheduler_state_dict': scheduler.scheduler.state_dict(),
        'metrics': metrics
    }, checkpoint_path)
    
    # Manage number of checkpoints
    checkpoints = sorted(checkpoint_dir.glob('checkpoint_epoch_*.pth'))
    if len(checkpoints) > TRAINING_PARAMS['max_checkpoints']:
        oldest_checkpoint = checkpoints[0]
        oldest_checkpoint.unlink()  # Delete oldest checkpoint

class CityDataset(Dataset):
    def __init__(self, folders, transform=None, max_images_per_class=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(folders.keys())}
        
        print("Building dataset...")
        for class_name, folder in tqdm(folders.items(), desc="Loading classes"):
            # Get all image files
            class_images = [
                os.path.join(folder, f) for f in os.listdir(folder)
                if (f.lower().endswith(('.jpg', '.jpeg', '.png')) and
                    not f.startswith('._') and
                    not f.startswith('.DS_Store'))
            ]
            
            # Limit images if specified
            if max_images_per_class is not None and len(class_images) > max_images_per_class:
                class_images = random.sample(class_images, max_images_per_class)
            
            print(f"\nFound {len(class_images)} images for {class_name}")
            
            self.image_paths.extend(class_images)
            self.labels.extend([self.class_to_idx[class_name]] * len(class_images))
        
        print("\nDataset statistics:")
        print(f"Total images: {len(self.image_paths)}")
        for class_name in folders.keys():
            class_count = self.labels.count(self.class_to_idx[class_name])
            print(f"{class_name}: {class_count} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            with Image.open(image_path) as img:
                image = img.convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f"Error loading image {image_path}: {str(e)}")
            raise e

def setup_logging(identifier):
    logger = logging.getLogger(__name__)
    
    # Clear any existing handlers
    if logger.hasHandlers():
        logger.handlers.clear()
    
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    
    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    
    # File handler
    log_dir = Path(DATA_PARAMS['output_dir']) / identifier / DATA_PARAMS['log_subdir']
    log_dir.mkdir(parents=True, exist_ok=True)
    file_handler = logging.FileHandler(log_dir / f'{identifier}.log')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    
    return logger

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss.sum()

class LRSchedulerWrapper:
    def __init__(self, scheduler):
        self.scheduler = scheduler
        
    def step(self, metric=None):
        self.scheduler.step(metric)
        current_lr = self.scheduler.get_last_lr()[0]
        return current_lr
    
    def get_last_lr(self):
        return self.scheduler.get_last_lr()
    
    def state_dict(self):
        return self.scheduler.state_dict()
    
    def load_state_dict(self, state_dict):
        self.scheduler.load_state_dict(state_dict)

class TrainingMonitor:
    def __init__(self, model, train_loader, val_loader, device, logger, output_dir, class_names):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.logger = logger
        self.output_dir = Path(output_dir)
        self.viz_dir = self.output_dir / DATA_PARAMS['viz_subdir']
        self.metrics_dir = self.output_dir / DATA_PARAMS['metrics_subdir']
        self.viz_dir.mkdir(exist_ok=True)
        self.metrics_dir.mkdir(exist_ok=True)
        self.class_names = class_names
        
        self.train_losses = []
        self.val_losses = []
        self.accuracies = []
        self.learning_rates = []
        self.start_time = datetime.now()
    
    def log_metrics(self, epoch, train_loss, val_loss, accuracy, lr):
        metrics = {
            'epoch': epoch,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'accuracy': accuracy,
            'learning_rate': lr
        }
        
        self.train_losses.append(train_loss)
        self.val_losses.append(val_loss)
        self.accuracies.append(accuracy)
        self.learning_rates.append(lr)
        
        self.logger.info(
            f"Epoch {epoch + 1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
            f"Accuracy: {accuracy:.4f}, LR: {lr:.6f}"
        )
        
        return metrics
    
    def export_metrics_to_csv(self, epoch):
        metrics_file = self.metrics_dir / 'training_metrics.csv'
        metrics = {
            'epoch': epoch,
            'train_loss': self.train_losses[-1],
            'val_loss': self.val_losses[-1],
            'accuracy': self.accuracies[-1],
            'learning_rate': self.learning_rates[-1]
        }
        
        mode = 'a' if metrics_file.exists() else 'w'
        write_header = not metrics_file.exists()
        
        with open(metrics_file, mode, newline='') as f:
            writer = csv.DictWriter(f, fieldnames=metrics.keys())
            if write_header:
                writer.writeheader()
            writer.writerow(metrics)
    
    def plot_learning_curves(self, epoch):
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Val Loss')
        plt.title('Loss Curves')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 3, 2)
        plt.plot(self.accuracies, label='Validation Accuracy')
        plt.title('Accuracy Curve')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.subplot(1, 3, 3)
        plt.plot(self.learning_rates, label='Learning Rate')
        plt.title('Learning Rate Schedule')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.yscale('log')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig(self.viz_dir / f'learning_curves_epoch_{epoch}.png')
        plt.close()
    
    def plot_feature_space(self, epoch):
        features = []
        labels = []
        self.model.eval()
        
        with torch.no_grad():
            for images, target in tqdm(self.val_loader, desc="Extracting features", 
                                     position=2, leave=False):
                images = images.to(self.device)
                # Get features from the global average pooling layer
                features_batch = self.model.avgpool(
                    self.model.layer4(
                        self.model.layer3(
                            self.model.layer2(
                                self.model.layer1(
                                    self.model.maxpool(
                                        self.model.relu(
                                            self.model.bn1(
                                                self.model.conv1(images)
                                            )
                                        )
                                    )
                                )
                            )
                        )
                    )
                )
                features_batch = torch.flatten(features_batch, 1)
                features.append(features_batch.cpu().numpy())
                labels.extend(target.numpy())
        
        features = np.vstack(features)
        labels = np.array(labels)
        
        # Reduce dimensionality for visualization
        tsne = TSNE(n_components=2, random_state=42)
        features_2d = tsne.fit_transform(features)
        
        plt.figure(figsize=(10, 8))
        scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], 
                            c=labels, cmap='tab10')
        plt.colorbar(scatter, label='Classes', ticks=range(len(self.class_names)))
        plt.title(f'Feature Space Visualization (t-SNE) - Epoch {epoch}')
        plt.xlabel('t-SNE dimension 1')
        plt.ylabel('t-SNE dimension 2')
        
        # Add legend
        handles = [plt.Line2D([0], [0], marker='o', color='w', 
                            markerfacecolor=plt.cm.tab10(i / len(self.class_names)), 
                            label=name, markersize=8) 
                  for i, name in enumerate(self.class_names)]
        plt.legend(handles=handles, title='Classes', bbox_to_anchor=(1.15, 1.0))
        
        plt.tight_layout()
        plt.savefig(self.viz_dir / f'feature_space_epoch_{epoch}.png', 
                    bbox_inches='tight', dpi=300)
        plt.close()
    
    def export_confusion_matrix(self, epoch):
        self.model.eval()
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in tqdm(self.val_loader, desc="Computing confusion matrix",
                                     position=2, leave=False):
                images = images.to(self.device)
                outputs = self.model(images)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.numpy())
        
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=self.class_names,
                   yticklabels=self.class_names)
        plt.title(f'Confusion Matrix - Epoch {epoch}')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.tight_layout()
        plt.savefig(self.viz_dir / f'confusion_matrix_epoch_{epoch}.png')
        plt.close()
    
    def export_class_metrics(self, epoch):
        self.model.eval()
        predictions = []
        labels = []
        
        with torch.no_grad():
            for images, target in tqdm(self.val_loader, desc="Computing class metrics",
                                     position=2, leave=False):
                images = images.to(self.device)
                outputs = self.model(images)
                _, preds = torch.max(outputs, 1)
                predictions.extend(preds.cpu().numpy())
                labels.extend(target.numpy())
        
        report = classification_report(labels, predictions, 
                                    target_names=self.class_names, 
                                    output_dict=True)
        
        with open(self.metrics_dir / f'class_metrics_epoch_{epoch}.json', 'w') as f:
            json.dump(report, f, indent=4)
    
    def export_training_summary(self):
        summary = {
            'total_epochs': len(self.train_losses),
            'best_accuracy': max(self.accuracies),
            'final_accuracy': self.accuracies[-1],
            'best_val_loss': min(self.val_losses),
            'final_val_loss': self.val_losses[-1],
            'training_duration': str(datetime.now() - self.start_time),
            'learning_rate_progression': self.learning_rates,
            'accuracy_progression': self.accuracies,
            'val_loss_progression': self.val_losses,
            'class_names': self.class_names
        }
        
        with open(self.output_dir / 'training_summary.json', 'w') as f:
            json.dump(summary, f, indent=4)

def train_final_model(dataset, class_names, identifier):
    # Setup
    logger = setup_logging(identifier)
    output_dir = Path(DATA_PARAMS['output_dir']) / identifier
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Split dataset and create data loaders
    train_size = int(TRAINING_PARAMS['train_val_split'] * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # DataLoader setup with single-process loading
    train_loader = DataLoader(
        train_dataset, 
        batch_size=TRAINING_PARAMS['batch_size'],
        shuffle=True,
        num_workers=0,  # Force single process
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=TRAINING_PARAMS['batch_size'],
        shuffle=False,
        num_workers=0,  # Force single process
        pin_memory=True if torch.cuda.is_available() else False,
        persistent_workers=False
    )
    
    # Device setup with MPS support
    if torch.cuda.is_available():
        device = torch.device("cuda")
        logger.info(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
        use_mixed_precision = True
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        logger.info("Using Apple Silicon (MPS) device")
        use_mixed_precision = False  # MPS doesn't support mixed precision yet
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    else:
        device = torch.device("cpu")
        logger.info("Using CPU device")
        use_mixed_precision = False
    
    # Set up mixed precision training based on device
    if use_mixed_precision:
        scaler = GradScaler()
        logger.info("Using mixed precision training")
    else:
        scaler = None
        logger.info("Mixed precision training not available for this device")
    
    # Model setup
    if TRAINING_PARAMS['model_type'] == 'ResNet18':
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    else:
        model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, TRAINING_PARAMS['hidden_dim']),
        nn.BatchNorm1d(TRAINING_PARAMS['hidden_dim']),
        nn.ReLU(),
        nn.Dropout(TRAINING_PARAMS['dropout_rate']),
        nn.Linear(TRAINING_PARAMS['hidden_dim'], len(class_names))
    )
    model = model.to(device)
    
    # Training setup
    criterion = FocalLoss(gamma=TRAINING_PARAMS['focal_loss_gamma'])
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=TRAINING_PARAMS['learning_rate'],
        weight_decay=TRAINING_PARAMS['weight_decay']
    )
    
    scheduler = LRSchedulerWrapper(
        torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min',
            factor=TRAINING_PARAMS['scheduler_factor'],
            patience=TRAINING_PARAMS['scheduler_patience']
        )
    )
    
    # Initialize monitor
    monitor = TrainingMonitor(model, train_loader, val_loader, device, 
                          logger, output_dir, class_names)
    
    # Training loop
    best_val_loss = float('inf')
    steps_without_improvement = 0
    total_epochs = TRAINING_PARAMS['num_epochs']
    
    try:
        # Add overall progress bar
        epoch_pbar = tqdm(range(total_epochs), 
                         desc="Overall Progress",
                         position=0,
                         leave=True,
                         dynamic_ncols=True)
        
        for epoch in epoch_pbar:
            model.train()
            running_loss = 0.0
            optimizer.zero_grad()
            
            # Training phase with nested progress bar
            batch_pbar = tqdm(train_loader, 
                            desc=f"Epoch {epoch+1}/{total_epochs}",
                            position=1,
                            leave=False,
                            dynamic_ncols=True)
            
            for i, (images, labels) in enumerate(batch_pbar):
                images, labels = images.to(device), labels.to(device)
                
                # Handle training step based on mixed precision availability
                if use_mixed_precision:
                    with autocast(device_type=device.type):
                        outputs = model(images)
                        loss = criterion(outputs, labels)
                    loss = loss / TRAINING_PARAMS['gradient_accumulation_steps']
                    scaler.scale(loss).backward()
                    
                    if (i + 1) % TRAINING_PARAMS['gradient_accumulation_steps'] == 0:
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    loss = loss / TRAINING_PARAMS['gradient_accumulation_steps']
                    loss.backward()
                    
                    if (i + 1) % TRAINING_PARAMS['gradient_accumulation_steps'] == 0:
                        optimizer.step()
                        optimizer.zero_grad()
                
                running_loss += loss.item() * TRAINING_PARAMS['gradient_accumulation_steps']
                batch_pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                
                # Empty cache periodically for MPS
                if device.type == "mps" and (i + 1) % 50 == 0:
                    if hasattr(torch.mps, 'empty_cache'):
                        torch.mps.empty_cache()
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            
            with torch.no_grad():
                val_pbar = tqdm(val_loader, 
                              desc="Validation",
                              position=1,
                              leave=False,
                              dynamic_ncols=True)
                
                for images, labels in val_pbar:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                    val_pbar.set_postfix({'val_loss': f'{loss.item():.4f}'})
            
            # Calculate metrics
            train_loss = running_loss / len(train_loader)
            val_loss = val_loss / len(val_loader)
            accuracy = correct / total
            current_lr = scheduler.step(val_loss)
            
            # Update progress bar with current metrics
            epoch_pbar.set_postfix({
                'Train Loss': f'{train_loss:.4f}',
                'Val Loss': f'{val_loss:.4f}', 
                'Accuracy': f'{accuracy:.4f}',
                'LR': f'{current_lr:.6f}'
            })
            
            # Log and export metrics
            metrics = monitor.log_metrics(epoch, train_loss, val_loss, accuracy, current_lr)
            monitor.export_metrics_to_csv(epoch)
            
            # Generate visualizations on interval
            if (epoch + 1) % TRAINING_PARAMS['visualization_interval'] == 0:
                monitor.plot_learning_curves(epoch)
                monitor.plot_feature_space(epoch)
                monitor.export_confusion_matrix(epoch)
                monitor.export_class_metrics(epoch)
            
            # Save checkpoint at interval
            if (epoch + 1) % TRAINING_PARAMS['checkpoint_interval'] == 0:
                manage_checkpoints(output_dir, epoch, model, optimizer, scheduler, metrics, device)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                steps_without_improvement = 0
                
                # Handle MPS device when saving model
                if device.type == "mps":
                    model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
                    optimizer_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v 
                                         for k, v in optimizer.state_dict().items()}
                else:
                    model_state_dict = model.state_dict()
                    optimizer_state_dict = optimizer.state_dict()
                
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model_state_dict,
                    'optimizer_state_dict': optimizer_state_dict,
                    'scheduler_state_dict': scheduler.scheduler.state_dict(),
                    'val_loss': val_loss,
                    'accuracy': accuracy,
                    'metrics': metrics
                }, output_dir / 'best_model.pth')
            else:
                steps_without_improvement += 1
            
            # Empty MPS cache after each epoch
            if device.type == "mps" and hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
            
            # Early stopping check
            if steps_without_improvement >= TRAINING_PARAMS['scheduler_patience'] * 2:
                logger.info(f"Early stopping triggered after {epoch + 1} epochs")
                break
    
    except KeyboardInterrupt:
        logger.info("Training interrupted by user")
    except Exception as e:
        logger.error(f"Training error: {str(e)}")
        raise e
    finally:
        # Export final training summary before exiting
        monitor.export_training_summary()
    
    return model, monitor

if __name__ == "__main__":
    # Create dataset with optional image limit
    dataset = CityDataset(
        folders=DATA_PARAMS['data_folders'], 
        transform=train_transform,
        max_images_per_class=TRAINING_PARAMS['max_images_per_class']
    )
    class_names = list(DATA_PARAMS['data_folders'].keys())

    # Create identifier for this run
    current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
    identifier = (f"softmax-{TRAINING_PARAMS['model_type']}_"
                 f"{TRAINING_PARAMS['num_epochs']}-ep_"
                 f"{TRAINING_PARAMS['effective_batch_size']}-bs_"
                 f"{current_time}")
    
    # Train model
    model, monitor = train_final_model(dataset, class_names, identifier)

In [None]:
# Part 1 Option B: Resnet Model Training / Fine tuning for better feature extraction without extended evaluation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from PIL import Image, ImageFile
import os
import json
import matplotlib.pyplot as plt
from datetime import datetime
import numpy as np
from tqdm.auto import tqdm
import random

# Parameters
batch_size = 32
learning_rate = 1e-3
num_epochs = 50
checkpoint_interval = 25
max_images_per_class = 25000
resnet_model = 'ResNet50'

# Setup directories and paths
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M")
identifier = f"softmax-{resnet_model}_{num_epochs}-ep_{batch_size}-bs_{max_images_per_class}-images_{current_time}"
class_names = ['Boston', 'Charlotte', 'Manhattan', 'Pittsburgh']
folders = {
    'Boston': '../data/ma-boston/buildings',
    'Charlotte': '../data/nc-charlotte/buildings',
    'Manhattan': '../data/ny-manhattan/buildings',
    'Pittsburgh': '../data/pa-pittsburgh/buildings'
}
output_folder = os.path.join('softmax-output', identifier)
checkpoint_dir = os.path.join(output_folder, 'checkpoints')
model_save_path = os.path.join(output_folder, f'trained-model_{identifier}.pth')

# Create output directories
os.makedirs(output_folder, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)

# Dataset and model setup
normalize_mean = [0.485, 0.456, 0.406]
normalize_std = [0.229, 0.224, 0.225]
num_classes = len(class_names)
weight_decay = 1e-5

ImageFile.LOAD_TRUNCATED_IMAGES = True

class CityDataset(Dataset):
    def __init__(self, folders, transform=None, max_images_per_class=max_images_per_class):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(folders.keys())}

        print("Building dataset...")
        for class_name, folder in tqdm(folders.items(), desc="Loading classes"):
            # Filter out macOS system files and get only image files
            class_images = [
                os.path.join(folder, f) for f in os.listdir(folder) 
                if (f.lower().endswith(('.jpg', '.jpeg', '.png')) and 
                    not f.startswith('._') and 
                    not f.startswith('.DS_Store'))
            ]
            
            print(f"\nFound {len(class_images)} images for {class_name}")
            
            if len(class_images) > max_images_per_class:
                class_images = random.sample(class_images, max_images_per_class)
            
            self.image_paths.extend(class_images)
            self.labels.extend([self.class_to_idx[class_name]] * len(class_images))
        
        print("\nDataset statistics:")
        print(f"Total images: {len(self.image_paths)}")
        for class_name in folders.keys():
            class_count = self.labels.count(self.class_to_idx[class_name])
            print(f"{class_name}: {class_count} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        try:
            with Image.open(image_path) as img:
                image = img.convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            
            return image, label
        except Exception as e:
            print(f"Error loading image {image_path}: {str(e)}")
            raise e

# Enhanced transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=normalize_mean, std=normalize_std),
])

# Create dataset
print("\nInitializing dataset...")
dataset = CityDataset(folders, transform=transform)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else 
                     "mps" if torch.backends.mps.is_available() else 
                     "cpu")
print(f"\nUsing device: {device}")

class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma * ce_loss)
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss.sum()

def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def train_final_model(dataset):
    print("\nSplitting dataset into train/val sets...")
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    print(f"Training set size: {len(train_dataset)}")
    print(f"Validation set size: {len(val_dataset)}")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    print(f"\nInitializing {resnet_model}...")
    if resnet_model == 'ResNet18':
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
    elif resnet_model == 'ResNet50':
        weights = models.ResNet50_Weights.DEFAULT
        model = models.resnet50(weights=weights)
    
    model.fc = nn.Sequential(
        nn.Linear(model.fc.in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    model.to(device)
    
    criterion = FocalLoss(gamma=2.0)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    scaler = GradScaler('cuda' if torch.cuda.is_available() else 'cpu')
    
    best_val_loss = float('inf')
    patience = 10
    epochs_without_improvement = 0
    
    epoch_pbar = tqdm(range(num_epochs), desc="Training Progress", position=0)
    
    for epoch in epoch_pbar:
        model.train()
        running_loss = 0.0
        per_class_correct = torch.zeros(num_classes)
        per_class_total = torch.zeros(num_classes)
        
        batch_pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}", 
                         leave=False, position=1)
        
        for images, labels in batch_pbar:
            images, labels = images.to(device), labels.to(device)
            
            images, targets_a, targets_b, lam = mixup_data(images, labels)
            
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            running_loss += loss.item()
            current_loss = loss.item()
            batch_pbar.set_postfix({'loss': f'{current_loss:.4f}'})
            
            with torch.no_grad():
                _, predicted = torch.max(model(images), 1)
                for label, pred in zip(labels, predicted):
                    per_class_correct[label] += (label == pred).item()
                    per_class_total[label] += 1
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        val_pbar = tqdm(val_loader, desc="Validation", 
                       leave=False, position=1)
        
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                current_val_loss = loss.item()
                val_pbar.set_postfix({'val_loss': f'{current_val_loss:.4f}'})
        
        train_loss = running_loss / len(train_loader)
        val_loss = val_loss / len(val_loader)
        accuracy = correct / total
        
        scheduler.step(val_loss)
        
        epoch_pbar.set_postfix({
            'train_loss': f'{train_loss:.4f}',
            'val_loss': f'{val_loss:.4f}',
            'accuracy': f'{accuracy:.4f}'
        })
        
        print(f"\nEpoch {epoch + 1} Complete:")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Accuracy: {accuracy:.4f}")
        
        print("\nPer-class accuracies:")
        for i in range(num_classes):
            if per_class_total[i] > 0:
                class_acc = per_class_correct[i] / per_class_total[i]
                print(f"{class_names[i]}: {class_acc:.4f}")
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            print(f"\nSaving best model with val_loss: {val_loss:.4f}")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_loss,
                'accuracy': accuracy
            }, model_save_path)
        else:
            epochs_without_improvement += 1
            
        if epochs_without_improvement >= patience:
            print("\nEarly stopping triggered!")
            break
        
        if (epoch + 1) % checkpoint_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch + 1}.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print(f"\nCheckpoint saved: {checkpoint_path}")
    
    return model

if __name__ == "__main__":
    print("\nStarting training...")
    final_model = train_final_model(dataset)
    print(f"\nTraining complete! Model saved to {model_save_path}")