### <font style="color:blue">Project 2: Kaggle Competition - Classification</font>

#### Maximum Points: 100

<div>
    <table>
        <tr><td><h3>Sr. no.</h3></td> <td><h3>Section</h3></td> <td><h3>Points</h3></td> </tr>
        <tr><td><h3>1</h3></td> <td><h3>Data Loader</h3></td> <td><h3>10</h3></td> </tr>
        <tr><td><h3>2</h3></td> <td><h3>Configuration</h3></td> <td><h3>5</h3></td> </tr>
        <tr><td><h3>3</h3></td> <td><h3>Evaluation Metric</h3></td> <td><h3>10</h3></td> </tr>
        <tr><td><h3>4</h3></td> <td><h3>Train and Validation</h3></td> <td><h3>5</h3></td> </tr>
        <tr><td><h3>5</h3></td> <td><h3>Model</h3></td> <td><h3>5</h3></td> </tr>
        <tr><td><h3>6</h3></td> <td><h3>Utils</h3></td> <td><h3>5</h3></td> </tr>
        <tr><td><h3>7</h3></td> <td><h3>Experiment</h3></td><td><h3>5</h3></td> </tr>
        <tr><td><h3>8</h3></td> <td><h3>TensorBoard Dev Scalars Log Link</h3></td> <td><h3>5</h3></td> </tr>
        <tr><td><h3>9</h3></td> <td><h3>Kaggle Profile Link</h3></td> <td><h3>50</h3></td> </tr>
    </table>
</div>


## <font style="color:green">1. Data Loader [10 Points]</font>

In this section, you have to write a class or methods, which will be used to get training and validation data loader.

You need to write a custom dataset class to load data.

**Note; There is   no separate validation data. , You will thus have to create your own validation set, by dividing the train data into train and validation data. Usually, we do 80:20 ratio for train and validation, respectively.**


For example:

```python
class KenyanFood13Dataset(Dataset):
    """
    
    """
    
    def __init__(self, *args):
    ....
    ...
    
    def __getitem__(self, idx):
    ...
    ...
    

```


```python
def get_data(args1, *args):
    ....
    ....
    return train_loader, test_loader
```

In [1]:
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split

In [2]:
class KenyanFood13Dataset(Dataset):
    """
    Dataset class for the Kenyan Food 13 dataset.
    Loads images and corresponding labels from CSV file.
    """
    
    def __init__(self, csv_file, img_dir, transform=None, train=True, val_split=0.2, is_test=False, seed=42):
        """
        Args:
            csv_file (str): Path to the CSV file with annotations.
            img_dir (str): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            train (bool): If True, create training set, else create validation set.
            val_split (float): Validation split ratio (default: 0.2).
            is_test (bool): If True, this is the test set (no labels).
            seed (int): Random seed for reproducibility.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        
        # Read the CSV file
        self.data_info = pd.read_csv(csv_file)
        
        # If this is the test set, we don't need to split or get labels
        if is_test:
            self.img_ids = self.data_info['ID'].values
            self.targets = None
        else:
            # Get image IDs and labels
            self.img_ids = self.data_info['ID'].values
            
            # Convert class names to numerical labels
            self.class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted(self.data_info['CLASS'].unique()))}
            self.idx_to_class = {i: cls_name for cls_name, i in self.class_to_idx.items()}
            self.targets = [self.class_to_idx[cls] for cls in self.data_info['CLASS'].values]
            
            # Split into train and validation sets
            train_indices, val_indices = train_test_split(
                np.arange(len(self.img_ids)),
                test_size=val_split,
                random_state=seed,
                stratify=self.targets
            )
            
            # Select either training or validation indices
            if train:
                self.indices = train_indices
            else:
                self.indices = val_indices
                
            # Filter image IDs and targets
            self.img_ids = self.img_ids[self.indices]
            self.targets = [self.targets[i] for i in self.indices]
    
    def __len__(self):
        """Return the total number of samples"""
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        """Get a sample from the dataset"""
        img_id = self.img_ids[idx]
        img_path = os.path.join(self.img_dir, f"{img_id}.jpg")
        
        # Load the image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image and its label if the image can't be loaded
            image = Image.new('RGB', (224, 224))
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        # Return image and label for training/val, just image for test
        if self.is_test:
            return image, img_id
        else:
            return image, self.targets[idx]

In [3]:
def get_data(data_dir, batch_size=32, num_workers=4):
    """
    Create train, validation and test data loaders.
    
    Args:
        data_dir (str): Path to the data directory.
        batch_size (int): Batch size for the data loaders.
        num_workers (int): Number of workers for the data loaders.
        
    Returns:
        train_loader, val_loader, test_loader: DataLoader objects for training, validation and testing.
    """
    # Define paths
    train_csv = os.path.join(data_dir, 'train.csv')
    test_csv = os.path.join(data_dir, 'test.csv')
    img_dir = os.path.join(data_dir, 'images')
    
    # Define transformations
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = KenyanFood13Dataset(train_csv, img_dir, transform=train_transform, train=True)
    val_dataset = KenyanFood13Dataset(train_csv, img_dir, transform=val_test_transform, train=False)
    test_dataset = KenyanFood13Dataset(test_csv, img_dir, transform=val_test_transform, is_test=True)
    
    # Get class weights for handling imbalanced classes
    class_counts = np.bincount(train_dataset.targets)
    class_weights = 1.0 / class_counts
    weights = class_weights[train_dataset.targets]
    sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=sampler,
        num_workers=num_workers, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

## <font style="color:green">2. Configuration [5 Points]</font>

**Define your configuration here.**

For example:


```python
@dataclass
class TrainingConfiguration:
    '''
    Describes configuration of the training process
    '''
    batch_size: int = 10 
    epochs_count: int = 50  
    init_learning_rate: float = 0.1  # initial learning rate for lr scheduler
    log_interval: int = 5  
    test_interval: int = 1  
    data_root: str = "/kaggle/input/opencv-pytorch-project-2-classification-round-3" 
    num_workers: int = 2  
    device: str = 'cuda'  
    
```

In [5]:
from dataclasses import dataclass, field
import torch
import os
from typing import List, Optional, Tuple

@dataclass
class TrainingConfiguration:
    """Configuration for training the Kenyan Food 13 classifier."""
    
    # Data parameters
    data_dir: str = "/kaggle/input/opencv-pytorch-project-2-classification-round-3"
    img_dir: str = "images"
    train_csv: str = "train.csv"
    test_csv: str = "test.csv"
    val_split: float = 0.2
    num_classes: int = 13
    
    # Image parameters
    img_size: int = 224
    crop_size: int = 224
    resize_size: int = 256
    mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)  # ImageNet means
    std: Tuple[float, float, float] = (0.229, 0.224, 0.225)   # ImageNet stds
    
    # Augmentation parameters
    use_augmentation: bool = True
    rotation_degrees: int = 15
    color_jitter_factor: float = 0.1
    
    # Model parameters
    model_name: str = "resnet50"  # Options: resnet18, resnet50, efficientnet_b0, etc.
    pretrained: bool = True
    dropout_rate: float = 0.2
    
    # Training parameters
    batch_size: int = 32
    num_epochs: int = 30
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    lr_scheduler: str = "cosine"  # Options: cosine, step, plateau
    lr_step_size: int = 7
    lr_gamma: float = 0.1
    lr_min: float = 1e-6
    early_stopping_patience: int = 5
    
    # Optimizer parameters
    optimizer: str = "adamw"  # Options: adam, adamw, sgd
    momentum: float = 0.9  # For SGD
    
    # Loss parameters
    loss_fn: str = "cross_entropy"  # Options: cross_entropy, focal
    focal_alpha: float = 0.25
    focal_gamma: float = 2.0
    
    # Device parameters
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    num_workers: int = 4
    pin_memory: bool = True
    
    # Logging parameters
    checkpoint_dir: str = "./checkpoints"
    tensorboard_dir: str = "./runs"
    log_interval: int = 10  # Log every N batches
    save_best_only: bool = True
    
    # Miscellaneous
    seed: int = 42
    verbose: bool = True
    mixed_precision: bool = True  # Use mixed precision training
    
    def __post_init__(self):
        """Create directories if they don't exist."""
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.tensorboard_dir, exist_ok=True)
    
    def get_full_img_dir(self) -> str:
        """Get the full path to the image directory."""
        return os.path.join(self.data_dir, self.img_dir)
    
    def get_full_train_csv(self) -> str:
        """Get the full path to the training CSV file."""
        return os.path.join(self.data_dir, self.train_csv)
    
    def get_full_test_csv(self) -> str:
        """Get the full path to the test CSV file."""
        return os.path.join(self.data_dir, self.test_csv)
    
    def model_checkpoint_path(self, epoch: Optional[int] = None) -> str:
        """Get the path to save/load model checkpoints."""
        if epoch is not None:
            return os.path.join(self.checkpoint_dir, f"{self.model_name}_epoch_{epoch}.pth")
        return os.path.join(self.checkpoint_dir, f"{self.model_name}_best.pth")

## <font style="color:green">3. Evaluation Metric [10 Points]</font>

**Define methods or classes that will be used in model evaluation. For example, accuracy, f1-score etc.**

In [6]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import classification_report
import pandas as pd
from typing import Dict, List, Tuple, Optional, Union
import torch.nn.functional as F

class EvaluationMetrics:
    """
    Class for evaluating model performance on the Kenyan Food classification task.
    Includes methods for computing accuracy, precision, recall, F1-score, and 
    generating confusion matrices and other visualizations.
    """
    
    @staticmethod
    def accuracy(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """
        Calculate classification accuracy.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            
        Returns:
            Accuracy score
        """
        return accuracy_score(y_true, y_pred)
    
    @staticmethod
    def precision(y_true: np.ndarray, y_pred: np.ndarray, average: str = 'weighted') -> float:
        """
        Calculate precision score.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            average: Averaging method ('micro', 'macro', 'weighted', or None for per-class)
            
        Returns:
            Precision score
        """
        return precision_score(y_true, y_pred, average=average, zero_division=0)
    
    @staticmethod
    def recall(y_true: np.ndarray, y_pred: np.ndarray, average: str = 'weighted') -> float:
        """
        Calculate recall score.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            average: Averaging method ('micro', 'macro', 'weighted', or None for per-class)
            
        Returns:
            Recall score
        """
        return recall_score(y_true, y_pred, average=average, zero_division=0)
    
    @staticmethod
    def f1(y_true: np.ndarray, y_pred: np.ndarray, average: str = 'weighted') -> float:
        """
        Calculate F1 score.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            average: Averaging method ('micro', 'macro', 'weighted', or None for per-class)
            
        Returns:
            F1 score
        """
        return f1_score(y_true, y_pred, average=average, zero_division=0)
    
    @staticmethod
    def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
        """
        Generate confusion matrix.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            
        Returns:
            Confusion matrix as a numpy array
        """
        return confusion_matrix(y_true, y_pred)
    
    @staticmethod
    def plot_confusion_matrix(
        y_true: np.ndarray, 
        y_pred: np.ndarray, 
        class_names: List[str],
        figsize: Tuple[int, int] = (12, 10),
        normalize: bool = True
    ) -> plt.Figure:
        """
        Plot confusion matrix with class names.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            class_names: List of class names
            figsize: Figure size
            normalize: Whether to normalize the confusion matrix
            
        Returns:
            Matplotlib figure object
        """
        cm = confusion_matrix(y_true, y_pred)
        
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
        fig, ax = plt.subplots(figsize=figsize)
        sns.heatmap(
            cm, annot=True, fmt='.2f' if normalize else 'd',
            cmap='Blues', xticklabels=class_names, yticklabels=class_names
        )
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.title('Confusion Matrix')
        plt.tight_layout()
        
        return fig
    
    @staticmethod
    def plot_metrics_history(metrics_history: Dict[str, List[float]]) -> plt.Figure:
        """
        Plot the history of metrics during training.
        
        Args:
            metrics_history: Dictionary of metric name to list of values
            
        Returns:
            Matplotlib figure object
        """
        fig, ax = plt.subplots(figsize=(12, 6))
        
        for metric_name, values in metrics_history.items():
            ax.plot(values, label=metric_name)
            
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Value')
        ax.set_title('Training Metrics History')
        ax.legend()
        ax.grid(True)
        
        return fig
    
    @staticmethod
    def classification_report_df(
        y_true: np.ndarray, 
        y_pred: np.ndarray, 
        class_names: List[str]
    ) -> pd.DataFrame:
        """
        Generate a classification report as a pandas DataFrame.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            class_names: List of class names
            
        Returns:
            DataFrame with precision, recall, and f1-score for each class
        """
        report = classification_report(
            y_true, y_pred, target_names=class_names, output_dict=True
        )
        return pd.DataFrame(report).transpose()
    
    @staticmethod
    def top_k_accuracy(
        outputs: torch.Tensor, 
        targets: torch.Tensor, 
        k: int = 5
    ) -> float:
        """
        Calculate top-k accuracy.
        
        Args:
            outputs: Model outputs (logits) of shape [batch_size, num_classes]
            targets: Ground truth labels of shape [batch_size]
            k: k value for top-k accuracy
            
        Returns:
            Top-k accuracy
        """
        batch_size = targets.size(0)
        _, pred = outputs.topk(k, 1, True, True)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        return correct_k.item() * (100.0 / batch_size)
    
    @staticmethod
    def per_class_accuracy(
        y_true: np.ndarray, 
        y_pred: np.ndarray, 
        num_classes: int
    ) -> np.ndarray:
        """
        Calculate per-class accuracy.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            num_classes: Number of classes
            
        Returns:
            Array of per-class accuracy values
        """
        per_class_acc = np.zeros(num_classes)
        
        for i in range(num_classes):
            idx = y_true == i
            if np.sum(idx) > 0:
                per_class_acc[i] = accuracy_score(y_true[idx], y_pred[idx])
                
        return per_class_acc
    
    @staticmethod
    def plot_per_class_metrics(
        y_true: np.ndarray, 
        y_pred: np.ndarray, 
        class_names: List[str]
    ) -> plt.Figure:
        """
        Plot per-class precision, recall, and F1 score.
        
        Args:
            y_true: Ground truth labels
            y_pred: Predicted labels
            class_names: List of class names
            
        Returns:
            Matplotlib figure object
        """
        precision = precision_score(y_true, y_pred, average=None, zero_division=0)
        recall = recall_score(y_true, y_pred, average=None, zero_division=0)
        f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
        
        fig, ax = plt.subplots(figsize=(15, 8))
        x = np.arange(len(class_names))
        width = 0.2
        
        ax.bar(x - width, precision, width, label='Precision')
        ax.bar(x, recall, width, label='Recall')
        ax.bar(x + width, f1, width, label='F1-score')
        
        ax.set_ylabel('Score')
        ax.set_title('Per-class Performance Metrics')
        ax.set_xticks(x)
        ax.set_xticklabels(class_names, rotation=45, ha='right')
        ax.legend()
        plt.tight_layout()
        
        return fig


In [7]:
class MetricsTracker:
    """
    Class for tracking and logging metrics during training.
    """
    
    def __init__(self, num_classes: int, class_names: Optional[List[str]] = None):
        """
        Initialize metrics tracker.
        
        Args:
            num_classes: Number of classes
            class_names: List of class names (optional)
        """
        self.num_classes = num_classes
        self.class_names = class_names if class_names else [f"Class {i}" for i in range(num_classes)]
        self.reset()
        
    def reset(self):
        """Reset metrics for a new epoch."""
        self.true_labels = []
        self.pred_labels = []
        self.logits = []
        self.loss_values = []
        self.metrics_history = {
            'train_loss': [],
            'val_loss': [],
            'train_acc': [],
            'val_acc': [],
            'train_f1': [],
            'val_f1': []
        }
        
    def update(
        self, 
        outputs: torch.Tensor, 
        targets: torch.Tensor, 
        loss: Optional[torch.Tensor] = None
    ):
        """
        Update metrics with batch results.
        
        Args:
            outputs: Model outputs (logits)
            targets: Ground truth labels
            loss: Loss value (optional)
        """
        # Convert to numpy for metric calculation
        _, preds = torch.max(outputs, 1)
        self.true_labels.extend(targets.cpu().numpy())
        self.pred_labels.extend(preds.cpu().numpy())
        self.logits.append(outputs.detach().cpu())
        
        if loss is not None:
            self.loss_values.append(loss.item())
            
    def compute_metrics(self) -> Dict[str, float]:
        """
        Compute metrics from accumulated batch results.
        
        Returns:
            Dictionary of metric names and values
        """
        y_true = np.array(self.true_labels)
        y_pred = np.array(self.pred_labels)
        
        metrics = {
            'accuracy': EvaluationMetrics.accuracy(y_true, y_pred),
            'precision': EvaluationMetrics.precision(y_true, y_pred),
            'recall': EvaluationMetrics.recall(y_true, y_pred),
            'f1_score': EvaluationMetrics.f1(y_true, y_pred),
            'loss': np.mean(self.loss_values) if self.loss_values else None
        }
        
        return metrics
    
    def update_history(self, phase: str, metrics: Dict[str, float]):
        """
        Update metrics history with current epoch results.
        
        Args:
            phase: 'train' or 'val'
            metrics: Dictionary of metric names and values
        """
        self.metrics_history[f'{phase}_loss'].append(metrics['loss'])
        self.metrics_history[f'{phase}_acc'].append(metrics['accuracy'])
        self.metrics_history[f'{phase}_f1'].append(metrics['f1_score'])
        
    def log_metrics(self, epoch: int, phase: str, metrics: Dict[str, float]):
        """
        Log metrics for current epoch.
        
        Args:
            epoch: Current epoch number
            phase: 'train' or 'val'
            metrics: Dictionary of metric names and values
        """
        log_str = f"Epoch {epoch} - {phase.capitalize()} - "
        log_str += " | ".join([f"{k}: {v:.4f}" for k, v in metrics.items() if v is not None])
        print(log_str)
        
    def generate_report(self) -> pd.DataFrame:
        """
        Generate a classification report.
        
        Returns:
            DataFrame with precision, recall, and f1-score for each class
        """
        y_true = np.array(self.true_labels)
        y_pred = np.array(self.pred_labels)
        
        return EvaluationMetrics.classification_report_df(y_true, y_pred, self.class_names)
    
    def plot_confusion_matrix(self, normalize: bool = True) -> plt.Figure:
        """
        Plot confusion matrix.
        
        Args:
            normalize: Whether to normalize the confusion matrix
            
        Returns:
            Matplotlib figure object
        """
        y_true = np.array(self.true_labels)
        y_pred = np.array(self.pred_labels)
        
        return EvaluationMetrics.plot_confusion_matrix(
            y_true, y_pred, self.class_names, normalize=normalize
        )
    
    def plot_metrics_history(self) -> plt.Figure:
        """
        Plot metrics history.
        
        Returns:
            Matplotlib figure object
        """
        return EvaluationMetrics.plot_metrics_history(self.metrics_history)

In [8]:
def get_predictions(model, dataloader, device):
    """
    Get model predictions on a dataset.
    
    Args:
        model: PyTorch model
        dataloader: DataLoader for the dataset
        device: Device to run inference on
        
    Returns:
        Tuple of (true labels, predicted labels, raw outputs)
    """
    model.eval()
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            all_outputs.append(outputs)
            all_labels.append(labels)
    
    all_outputs = torch.cat(all_outputs, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    _, predicted = torch.max(all_outputs, 1)
    
    return all_labels.cpu().numpy(), predicted.cpu().numpy(), all_outputs.cpu()

## <font style="color:green">4. Train and Validation [5 Points]</font>


**Write the methods or classes to be used for training and validation.**

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, StepLR
import time
import os
import copy
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from typing import Dict, List, Optional, Tuple, Union, Callable


class Trainer:
    """
    Class for training and validating models for the Kenyan Food 13 classification task.
    """
    
    def __init__(
        self,
        model: nn.Module,
        config: 'TrainingConfiguration',
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        metrics_tracker: 'MetricsTracker',
        criterion: Optional[nn.Module] = None,
        optimizer: Optional[torch.optim.Optimizer] = None,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
        device: Optional[torch.device] = None
    ):
        """
        Initialize the trainer.
        
        Args:
            model: PyTorch model to train
            config: Training configuration
            train_loader: DataLoader for training data
            val_loader: DataLoader for validation data
            metrics_tracker: Metrics tracker for logging
            criterion: Loss function (if None, CrossEntropyLoss is used)
            optimizer: Optimizer (if None, AdamW is used)
            scheduler: Learning rate scheduler (if None, created based on config)
            device: Device to train on (if None, use config.device)
        """
        self.model = model
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.metrics_tracker = metrics_tracker
        self.device = device if device is not None else torch.device(config.device)
        self.model = self.model.to(self.device)
        
        # Set up criterion
        self.criterion = criterion if criterion is not None else nn.CrossEntropyLoss()
        
        # Set up optimizer
        if optimizer is not None:
            self.optimizer = optimizer
        else:
            if config.optimizer.lower() == 'adam':
                self.optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
            elif config.optimizer.lower() == 'adamw':
                self.optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
            elif config.optimizer.lower() == 'sgd':
                self.optimizer = optim.SGD(
                    model.parameters(), 
                    lr=config.learning_rate,
                    momentum=config.momentum,
                    weight_decay=config.weight_decay
                )
            else:
                raise ValueError(f"Unsupported optimizer: {config.optimizer}")
        
        # Set up scheduler
        if scheduler is not None:
            self.scheduler = scheduler
        else:
            if config.lr_scheduler.lower() == 'plateau':
                self.scheduler = ReduceLROnPlateau(
                    self.optimizer, mode='max', factor=config.lr_gamma, 
                    patience=3, verbose=True
                )
            elif config.lr_scheduler.lower() == 'cosine':
                self.scheduler = CosineAnnealingLR(
                    self.optimizer, T_max=config.num_epochs, eta_min=config.lr_min
                )
            elif config.lr_scheduler.lower() == 'step':
                self.scheduler = StepLR(
                    self.optimizer, step_size=config.lr_step_size, gamma=config.lr_gamma
                )
            else:
                self.scheduler = None
        
        # Set up tensorboard writer
        self.writer = SummaryWriter(config.tensorboard_dir)
        
        # Initialize best model metrics
        self.best_val_metric = 0.0
        self.best_model_wts = copy.deepcopy(model.state_dict())
        self.early_stopping_counter = 0
        
    def train_one_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Train for one epoch.
        
        Args:
            epoch: Current epoch number
            
        Returns:
            Dictionary of metric values for the epoch
        """
        self.model.train()
        self.metrics_tracker.reset()
        running_loss = 0.0
        
        # Use tqdm for progress bar
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Train]")
        
        # Enable mixed precision if configured
        scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision else None
        
        for i, (inputs, targets) in enumerate(pbar):
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)
            
            # Zero the parameter gradients
            self.optimizer.zero_grad()
            
            # Forward
            with torch.set_grad_enabled(True):
                with torch.cuda.amp.autocast() if self.config.mixed_precision else torch.no_grad():
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)
                
                # Backward + optimize
                if scaler is not None:
                    scaler.scale(loss).backward()
                    scaler.step(self.optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    self.optimizer.step()
            
            # Update metrics
            self.metrics_tracker.update(outputs, targets, loss)
            running_loss += loss.item() * inputs.size(0)
            
            # Update progress bar
            if i % self.config.log_interval == 0:
                pbar.set_postfix({
                    'loss': loss.item(),
                    'lr': self.optimizer.param_groups[0]['lr']
                })
        
        # Compute metrics for the entire epoch
        metrics = self.metrics_tracker.compute_metrics()
        
        # Log metrics to tensorboard
        self.writer.add_scalar('Loss/train', metrics['loss'], epoch)
        self.writer.add_scalar('Accuracy/train', metrics['accuracy'], epoch)
        self.writer.add_scalar('F1/train', metrics['f1_score'], epoch)
        
        # Update metrics history
        self.metrics_tracker.update_history('train', metrics)
        
        # Log metrics
        self.metrics_tracker.log_metrics(epoch + 1, 'train', metrics)
        
        return metrics
    
    def validate(self, epoch: int) -> Dict[str, float]:
        """
        Validate the model.
        
        Args:
            epoch: Current epoch number
            
        Returns:
            Dictionary of metric values for validation
        """
        self.model.eval()
        self.metrics_tracker.reset()
        running_loss = 0.0
        
        # Use tqdm for progress bar
        pbar = tqdm(self.val_loader, desc=f"Epoch {epoch+1}/{self.config.num_epochs} [Val]")
        
        with torch.no_grad():
            for inputs, targets in pbar:
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
                
                # Forward
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)
                
                # Update metrics
                self.metrics_tracker.update(outputs, targets, loss)
                running_loss += loss.item() * inputs.size(0)
        
        # Compute metrics for validation
        metrics = self.metrics_tracker.compute_metrics()
        
        # Log metrics to tensorboard
        self.writer.add_scalar('Loss/val', metrics['loss'], epoch)
        self.writer.add_scalar('Accuracy/val', metrics['accuracy'], epoch)
        self.writer.add_scalar('F1/val', metrics['f1_score'], epoch)
        
        # Update metrics history
        self.metrics_tracker.update_history('val', metrics)
        
        # Log metrics
        self.metrics_tracker.log_metrics(epoch + 1, 'val', metrics)
        
        # Update learning rate scheduler if using ReduceLROnPlateau
        if isinstance(self.scheduler, ReduceLROnPlateau):
            self.scheduler.step(metrics['accuracy'])
        
        return metrics
    
    def train(self, monitor_metric: str = 'accuracy') -> nn.Module:
        """
        Train the model for the specified number of epochs.
        
        Args:
            monitor_metric: Metric to monitor for early stopping and best model ('accuracy', 'f1_score')
            
        Returns:
            Best model based on validation metric
        """
        print(f"Starting training for {self.config.num_epochs} epochs...")
        start_time = time.time()
        
        # Train for each epoch
        for epoch in range(self.config.num_epochs):
            # Train for one epoch
            train_metrics = self.train_one_epoch(epoch)
            
            # Validate
            val_metrics = self.validate(epoch)
            
            # Update scheduler if not ReduceLROnPlateau
            if self.scheduler is not None and not isinstance(self.scheduler, ReduceLROnPlateau):
                self.scheduler.step()
            
            # Check if this is the best model
            current_val_metric = val_metrics[monitor_metric]
            if current_val_metric > self.best_val_metric:
                self.best_val_metric = current_val_metric
                self.best_model_wts = copy.deepcopy(self.model.state_dict())
                self.early_stopping_counter = 0
                
                # Save the best model
                if self.config.save_best_only:
                    self.save_checkpoint(epoch, monitor_metric=monitor_metric, is_best=True)
                    print(f"Saved best model with {monitor_metric}: {current_val_metric:.4f}")
            else:
                self.early_stopping_counter += 1
                if self.early_stopping_counter >= self.config.early_stopping_patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    break
            
            # Save checkpoint if not saving only the best model
            if not self.config.save_best_only:
                self.save_checkpoint(epoch, monitor_metric=monitor_metric)
        
        # Load the best model
        self.model.load_state_dict(self.best_model_wts)
        
        # Calculate elapsed time
        time_elapsed = time.time() - start_time
        print(f"Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
        print(f"Best validation {monitor_metric}: {self.best_val_metric:.4f}")
        
        # Close tensorboard writer
        self.writer.close()
        
        return self.model
    
    def save_checkpoint(self, epoch: int, monitor_metric: str = 'accuracy', is_best: bool = False) -> str:
        """
        Save model checkpoint.
        
        Args:
            epoch: Current epoch number
            monitor_metric: Metric being monitored
            is_best: Whether this is the best model so far
            
        Returns:
            Path to saved checkpoint
        """
        if is_best:
            checkpoint_path = self.config.model_checkpoint_path()
        else:
            checkpoint_path = self.config.model_checkpoint_path(epoch=epoch)
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            f'best_{monitor_metric}': self.best_val_metric,
            'config': self.config
        }
        
        torch.save(checkpoint, checkpoint_path)
        return checkpoint_path
    
    def load_checkpoint(self, checkpoint_path: str) -> int:
        """
        Load model checkpoint.
        
        Args:
            checkpoint_path: Path to checkpoint
            
        Returns:
            Epoch number of the checkpoint
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if self.scheduler and checkpoint['scheduler_state_dict']:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        return checkpoint['epoch']

In [10]:
def train_and_validate(
    model: nn.Module, 
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    config: 'TrainingConfiguration',
    class_names: List[str]
) -> Tuple[nn.Module, Dict[str, List[float]], float]:
    """
    Train and validate a model.
    
    Args:
        model: PyTorch model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        config: Training configuration
        class_names: List of class names
        
    Returns:
        Tuple of (best model, metrics history, best validation accuracy)
    """
    # Set random seeds for reproducibility
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.seed)
    
    # Initialize metrics tracker
    metrics_tracker = MetricsTracker(num_classes=config.num_classes, class_names=class_names)
    
    # Initialize trainer
    trainer = Trainer(
        model=model,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
        metrics_tracker=metrics_tracker
    )
    
    # Train the model
    best_model = trainer.train(monitor_metric='accuracy')
    
    # Get metrics history
    metrics_history = metrics_tracker.metrics_history
    
    return best_model, metrics_history, trainer.best_val_metric

## <font style="color:green">5. Model [5 Points]</font>

**Define your model in this section.**

**You are allowed to use any pre-trained model.**

In [ ]:
import torch
import torch.nn as nn
import torchvision.models as models
from efficientnet_pytorch import EfficientNet  # You might need to install this: pip install efficientnet-pytorch

class FoodClassifier(nn.Module):
    """
    Model for Kenyan Food 13 classification using a pre-trained backbone.
    """
    
    def __init__(
        self, 
        model_name: str = 'resnet50', 
        num_classes: int = 13, 
        pretrained: bool = True,
        dropout_rate: float = 0.2
    ):
        """
        Initialize the model.
        
        Args:
            model_name: Name of the pre-trained model to use as backbone
                        (resnet18, resnet50, efficientnet_b0, etc.)
            num_classes: Number of classes
            pretrained: Whether to use pre-trained weights
            dropout_rate: Dropout rate for the final layer
        """
        super(FoodClassifier, self).__init__()
        self.model_name = model_name.lower()
        
        # Initialize the backbone model
        if 'resnet' in self.model_name:
            if self.model_name == 'resnet18':
                self.backbone = models.resnet18(pretrained=pretrained)
                num_features = self.backbone.fc.in_features
            elif self.model_name == 'resnet34':
                self.backbone = models.resnet34(pretrained=pretrained)
                num_features = self.backbone.fc.in_features
            elif self.model_name == 'resnet50':
                self.backbone = models.resnet50(pretrained=pretrained)
                num_features = self.backbone.fc.in_features
            elif self.model_name == 'resnet101':
                self.backbone = models.resnet101(pretrained=pretrained)
                num_features = self.backbone.fc.in_features
            else:
                raise ValueError(f"Unsupported ResNet model: {model_name}")
            
            # Replace the final fully connected layer
            self.backbone.fc = nn.Identity()
            
            # Create a new classifier head
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, num_classes)
            )
            
        elif 'efficientnet' in self.model_name:
            # Format should be 'efficientnet_b0', 'efficientnet_b1', etc.
            version = self.model_name.split('_')[1]
            self.backbone = EfficientNet.from_pretrained(
                f'efficientnet-{version}', 
                num_classes=num_classes,
                include_top=False
            )
            num_features = self.backbone._fc.in_features
            
            # Replace the final classifier
            self.backbone._fc = nn.Identity()
            
            # Create a new classifier head
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, num_classes)
            )
            
        elif 'mobilenet' in self.model_name:
            self.backbone = models.mobilenet_v2(pretrained=pretrained)
            num_features = self.backbone.classifier[1].in_features
            
            # Replace the classifier
            self.backbone.classifier = nn.Identity()
            
            # Create a new classifier head
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, num_classes)
            )
            
        elif 'densenet' in self.model_name:
            if self.model_name == 'densenet121':
                self.backbone = models.densenet121(pretrained=pretrained)
            elif self.model_name == 'densenet169':
                self.backbone = models.densenet169(pretrained=pretrained)
            elif self.model_name == 'densenet201':
                self.backbone = models.densenet201(pretrained=pretrained)
            else:
                raise ValueError(f"Unsupported DenseNet model: {model_name}")
                
            num_features = self.backbone.classifier.in_features
            
            # Replace the classifier
            self.backbone.classifier = nn.Identity()
            
            # Create a new classifier head
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(num_features, num_classes)
            )
            
        else:
            raise ValueError(f"Unsupported model: {model_name}")
        
    def forward(self, x):
        """Forward pass through the model."""
        features = self.backbone(x)
        output = self.classifier(features)
        return output

In [ ]:
def get_model(config):
    """
    Create model instance based on configuration.
    
    Args:
        config: Configuration object with model parameters
        
    Returns:
        Initialized model
    """
    model = FoodClassifier(
        model_name=config.model_name,
        num_classes=config.num_classes,
        pretrained=config.pretrained,
        dropout_rate=config.dropout_rate
    )
    
    return model

In [ ]:
class FocalLoss(nn.Module):
    """
    Focal Loss for dealing with class imbalance.
    
    Reference:
    Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017).
    Focal loss for dense object detection.
    """
    
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        """
        Initialize Focal Loss.
        
        Args:
            alpha: Weighting factor for the rare class
            gamma: Focusing parameter
            reduction: 'mean', 'sum', or 'none'
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.cross_entropy = nn.CrossEntropyLoss(reduction='none')
        
    def forward(self, inputs, targets):
        """
        Forward pass.
        
        Args:
            inputs: Model predictions (logits), shape [B, C]
            targets: Ground truth labels, shape [B]
            
        Returns:
            Loss value
        """
        # Standard cross entropy
        ce_loss = self.cross_entropy(inputs, targets)
        
        # Get probabilities
        probs = torch.exp(-ce_loss)
        
        # Apply focal weighting
        focal_loss = self.alpha * (1 - probs)**self.gamma * ce_loss
        
        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:  # 'none'
            return focal_loss

## <font style="color:green">6. Utils [5 Points]</font>

**Define those methods or classes, which have  not been covered in the above sections.**

In [ ]:
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import os
import pandas as pd
from typing import Optional, List, Tuple, Dict, Any
from torchvision.utils import make_grid
import torch.nn.functional as F

def set_seed(seed: int = 42):
    """
    Set random seed for reproducibility.
    
    Args:
        seed: Random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def visualize_batch(batch, class_names=None, max_images=16, normalize=True):
    """
    Visualize a batch of images.
    
    Args:
        batch: Batch of images, shape [B, C, H, W]
        class_names: List of class names
        max_images: Maximum number of images to display
        normalize: Whether to normalize the images
        
    Returns:
        Matplotlib figure object
    """
    images, labels = batch
    
    # Select a subset of images
    batch_size = min(images.size(0), max_images)
    images = images[:batch_size]
    labels = labels[:batch_size]
    
    # Create grid of images
    img_grid = make_grid(images, nrow=4, normalize=normalize)
    
    # Convert to numpy for display
    img_grid = img_grid.cpu().numpy().transpose((1, 2, 0))
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(img_grid)
    
    # Add labels if class names are provided
    if class_names is not None:
        title = [class_names[label] for label in labels.cpu().numpy()]
        ax.set_title('\n'.join([', '.join(title[i:i+4]) for i in range(0, len(title), 4)]))
    
    ax.axis('off')
    
    return fig

def get_class_names(data_loader):
    """
    Get class names from the dataset.
    
    Args:
        data_loader: DataLoader with dataset
        
    Returns:
        List of class names
    """
    if hasattr(data_loader.dataset, 'idx_to_class'):
        # Get indices and sort by value
        indices = list(data_loader.dataset.idx_to_class.keys())
        indices.sort()
        
        # Get class names in order
        class_names = [data_loader.dataset.idx_to_class[i] for i in indices]
        return class_names
    else:
        return None

def visualize_model_predictions(model, data_loader, class_names, device, num_images=16):
    """
    Visualize model predictions on a batch of images.
    
    Args:
        model: Trained model
        data_loader: DataLoader with validation or test data
        class_names: List of class names
        device: Device to run inference on
        num_images: Number of images to display
        
    Returns:
        Matplotlib figure object
    """
    model.eval()
    images_so_far = 0
    fig = plt.figure(figsize=(16, 16))
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loader):
            if images_so_far >= num_images:
                break
                
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size(0)):
                if images_so_far >= num_images:
                    break
                    
                images_so_far += 1
                ax = plt.subplot(4, 4, images_so_far)
                ax.axis('off')
                
                # Show image
                img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
                # Denormalize
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = std * img + mean
                img = np.clip(img, 0, 1)
                
                ax.imshow(img)
                
                # Show prediction
                title = f'True: {class_names[labels[j]]}\nPred: {class_names[preds[j]]}'
                if preds[j] == labels[j]:
                    title = f'{title}\n(Correct)'
                else:
                    title = f'{title}\n(Wrong)'
                ax.set_title(title, color='green' if preds[j] == labels[j] else 'red')
    
    plt.tight_layout()
    return fig

def create_submission_file(model, test_loader, device, output_path='submission.csv'):
    """
    Create submission file with predicted classes.
    
    Args:
        model: Trained model
        test_loader: DataLoader with test data
        device: Device to run inference on
        output_path: Path to save the submission file
        
    Returns:
        Path to the saved submission file
    """
    model.eval()
    all_ids = []
    all_preds = []
    
    with torch.no_grad():
        for inputs, ids in test_loader:
            inputs = inputs.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            # Store predictions
            all_ids.extend(ids)
            all_preds.extend(preds.cpu().numpy())
    
    # Create DataFrame with predictions
    submission_df = pd.DataFrame({
        'ID': all_ids,
        'CLASS': all_preds
    })
    
    # Convert numerical labels to class names
    if hasattr(test_loader.dataset, 'idx_to_class'):
        submission_df['CLASS'] = submission_df['CLASS'].apply(
            lambda x: test_loader.dataset.idx_to_class.get(x, x)
        )
    
    # Save to CSV
    submission_df.to_csv(output_path, index=False)
    print(f"Submission file saved to {output_path}")
    
    return output_path

In [ ]:
def get_optimizer(model, config):
    """
    Create optimizer based on configuration.
    
    Args:
        model: PyTorch model
        config: Configuration object with optimizer parameters
        
    Returns:
        PyTorch optimizer
    """
    if config.optimizer.lower() == 'adam':
        return optim.Adam(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
    elif config.optimizer.lower() == 'adamw':
        return optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
    elif config.optimizer.lower() == 'sgd':
        return optim.SGD(
            model.parameters(),
            lr=config.learning_rate,
            momentum=config.momentum,
            weight_decay=config.weight_decay
        )
    else:
        raise ValueError(f"Unsupported optimizer: {config.optimizer}")

def get_scheduler(optimizer, config):
    """
    Create learning rate scheduler based on configuration.
    
    Args:
        optimizer: PyTorch optimizer
        config: Configuration object with scheduler parameters
        
    Returns:
        PyTorch learning rate scheduler
    """
    if config.lr_scheduler.lower() == 'plateau':
        return ReduceLROnPlateau(
            optimizer, mode='max', factor=config.lr_gamma,
            patience=3, verbose=True
        )
    elif config.lr_scheduler.lower() == 'cosine':
        return CosineAnnealingLR(
            optimizer, T_max=config.num_epochs, eta_min=config.lr_min
        )
    elif config.lr_scheduler.lower() == 'step':
        return StepLR(
            optimizer, step_size=config.lr_step_size, gamma=config.lr_gamma
        )
    else:
        return None

def get_loss_function(config):
    """
    Create loss function based on configuration.
    
    Args:
        config: Configuration object with loss parameters
        
    Returns:
        PyTorch loss function
    """
    if config.loss_fn.lower() == 'cross_entropy':
        return nn.CrossEntropyLoss()
    elif config.loss_fn.lower() == 'focal':
        return FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma)
    else:
        raise ValueError(f"Unsupported loss function: {config.loss_fn}")

def get_transforms(config):
    """
    Create data transforms based on configuration.
    
    Args:
        config: Configuration object with transform parameters
        
    Returns:
        Tuple of (train_transform, val_transform)
    """
    train_transform = transforms.Compose([
        transforms.Resize((config.resize_size, config.resize_size)),
        transforms.RandomCrop(config.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(config.rotation_degrees) if config.use_augmentation else transforms.Lambda(lambda x: x),
        transforms.ColorJitter(
            brightness=config.color_jitter_factor,
            contrast=config.color_jitter_factor,
            saturation=config.color_jitter_factor,
            hue=config.color_jitter_factor
        ) if config.use_augmentation else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize(mean=config.mean, std=config.std)
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((config.img_size, config.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=config.mean, std=config.std)
    ])
    
    return train_transform, val_transform

def visualize_learning_rate(optimizer, scheduler, num_epochs):
    """
    Visualize learning rate schedule.
    
    Args:
        optimizer: PyTorch optimizer
        scheduler: PyTorch learning rate scheduler
        num_epochs: Number of epochs
        
    Returns:
        Matplotlib figure object
    """
    lr_history = []
    
    # Save initial learning rate
    lr_history.append(optimizer.param_groups[0]['lr'])
    
    # Step scheduler for each epoch
    for _ in range(num_epochs):
        if isinstance(scheduler, ReduceLROnPlateau):
            # For ReduceLROnPlateau, assume validation accuracy improves and then plateaus
            if _ < 5:
                scheduler.step(0.5 + _/10)  # Simulate improving metric
            else:
                scheduler.step(0.5 + 5/10)  # Simulate plateauing metric
        else:
            scheduler.step()
        
        # Save learning rate
        lr_history.append(optimizer.param_groups[0]['lr'])
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(range(num_epochs + 1), lr_history)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Learning Rate')
    ax.set_title('Learning Rate Schedule')
    ax.grid(True)
    
    return fig

In [ ]:
def analyze_model(model, val_loader, device, class_names):
    """
    Analyze model performance on validation set.
    
    Args:
        model: Trained model
        val_loader: Validation data loader
        device: Device to run inference on
        class_names: List of class names
        
    Returns:
        Dictionary with analysis results
    """
    # Get predictions
    true_labels, pred_labels, outputs = get_predictions(model, val_loader, device)
    
    # Convert outputs to probabilities
    probs = F.softmax(outputs, dim=1).numpy()
    
    # Calculate metrics
    metrics = {
        'accuracy': EvaluationMetrics.accuracy(true_labels, pred_labels),
        'precision': EvaluationMetrics.precision(true_labels, pred_labels),
        'recall': EvaluationMetrics.recall(true_labels, pred_labels),
        'f1': EvaluationMetrics.f1(true_labels, pred_labels),
        'per_class_accuracy': EvaluationMetrics.per_class_accuracy(true_labels, pred_labels, len(class_names))
    }
    
    # Generate confusion matrix
    cm = EvaluationMetrics.confusion_matrix(true_labels, pred_labels)
    
    # Identify most confused classes
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    np.fill_diagonal(cm_norm, 0)  # Set diagonal to 0 to ignore correct predictions
    
    most_confused = []
    for i in range(len(class_names)):
        # Find top confusions for this class
        confused_indices = np.argsort(cm_norm[i])[::-1][:3]  # Top 3 confusions
        confusions = [(class_names[i], class_names[j], cm[i, j], cm_norm[i, j]) 
                     for j in confused_indices if cm_norm[i, j] > 0]
        most_confused.extend(confusions)
    
    # Sort by confusion percentage
    most_confused.sort(key=lambda x: x[3], reverse=True)
    
    # Format as DataFrame
    if most_confused:
        confused_df = pd.DataFrame(
            most_confused, 
            columns=['True Class', 'Predicted Class', 'Count', 'Percentage']
        )
    else:
        confused_df = pd.DataFrame(
            columns=['True Class', 'Predicted Class', 'Count', 'Percentage']
        )
    
    # Find most confident correct and incorrect predictions
    class_probs = np.max(probs, axis=1)
    correct_mask = pred_labels == true_labels
    
    # Most confident correct predictions
    confident_correct_idx = np.argsort(class_probs * correct_mask)[::-1][:5]
    confident_correct = [{
        'idx': idx,
        'true_label': class_names[true_labels[idx]],
        'pred_label': class_names[pred_labels[idx]],
        'confidence': class_probs[idx]
    } for idx in confident_correct_idx if correct_mask[idx]]
    
    # Most confident incorrect predictions
    confident_incorrect_idx = np.argsort(class_probs * ~correct_mask)[::-1][:5]
    confident_incorrect = [{
        'idx': idx,
        'true_label': class_names[true_labels[idx]],
        'pred_label': class_names[pred_labels[idx]],
        'confidence': class_probs[idx]
    } for idx in confident_incorrect_idx if not correct_mask[idx]]
    
    # Return analysis results
    return {
        'metrics': metrics,
        'confusion_matrix': cm,
        'most_confused': confused_df,
        'confident_correct': confident_correct,
        'confident_incorrect': confident_incorrect
    }

def plot_class_distribution(train_loader, val_loader, class_names):
    """
    Plot class distribution in training and validation sets.
    
    Args:
        train_loader: Training data loader
        val_loader: Validation data loader
        class_names: List of class names
        
    Returns:
        Matplotlib figure object
    """
    # Get class counts
    train_counts = np.bincount([target for _, target in train_loader.dataset])
    val_counts = np.bincount([target for _, target in val_loader.dataset])
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.35
    
    # Plot bars
    ax.bar(x - width/2, train_counts, width, label='Train')
    ax.bar(x + width/2, val_counts, width, label='Validation')
    
    # Add labels and legend
    ax.set_xlabel('Class')
    ax.set_ylabel('Count')
    ax.set_title('Class Distribution')
    ax.set_xticks(x)
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.legend()
    
    plt.tight_layout()
    return fig

## <font style="color:green">7. Experiment [5 Points]</font>

**Choose your optimizer and LR-scheduler and use the above methods and classes to train your model.**

In [ ]:
# Set random seed for reproducibility
set_seed(42)

# Create configuration
config = TrainingConfiguration()

# Adjust configuration for this experiment
config.data_dir = "/kaggle/input/opencv-pytorch-project-2-classification-round-3"  # Kaggle path
config.model_name = "resnet50"  # Using ResNet50 pre-trained model
config.num_epochs = 20
config.batch_size = 32
config.learning_rate = 3e-4
config.optimizer = "adamw"
config.lr_scheduler = "cosine"
config.weight_decay = 1e-4
config.dropout_rate = 0.3
config.use_augmentation = True
config.mixed_precision = True  # Use mixed precision for faster training
config.early_stopping_patience = 5

In [ ]:
# Load the data
print("Loading data...")
train_loader, val_loader, test_loader = get_data(
    data_dir=config.data_dir,
    batch_size=config.batch_size,
    num_workers=config.num_workers
)

# Get class names from the training dataset
class_names = get_class_names(train_loader)
print(f"Class names: {class_names}")

# Visualize class distribution
print("Visualizing class distribution...")
class_dist_fig = plot_class_distribution(train_loader, val_loader, class_names)
plt.show()

# Create the model
print(f"Creating {config.model_name} model...")
model = get_model(config)

# Create optimizer and scheduler
optimizer = get_optimizer(model, config)
scheduler = get_scheduler(optimizer, config)

# Create loss function
criterion = get_loss_function(config)

# Create metrics tracker
metrics_tracker = MetricsTracker(num_classes=config.num_classes, class_names=class_names)

# Create trainer
trainer = Trainer(
    model=model,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    metrics_tracker=metrics_tracker,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler
)

# Train the model
print(f"Training {config.model_name} for {config.num_epochs} epochs...")
best_model = trainer.train(monitor_metric='accuracy')

# Analyze model performance on validation set
print("Analyzing model performance...")
analysis = analyze_model(best_model, val_loader, torch.device(config.device), class_names)

# Print metrics
print("\nValidation Metrics:")
for metric_name, value in analysis['metrics'].items():
    if isinstance(value, np.ndarray):
        continue
    print(f"{metric_name}: {value:.4f}")

# Print per-class accuracy
print("\nPer-class Accuracy:")
for i, acc in enumerate(analysis['metrics']['per_class_accuracy']):
    print(f"{class_names[i]}: {acc:.4f}")

# Plot confusion matrix
print("\nPlotting confusion matrix...")
cm_fig = EvaluationMetrics.plot_confusion_matrix(
    analysis['confusion_matrix'], 
    class_names=class_names, 
    normalize=True
)
plt.show()

# Print most confused classes
print("\nMost Confused Classes:")
print(analysis['most_confused'].head(10))

# Create submission file
print("\nCreating submission file...")
submission_path = create_submission_file(
    best_model, 
    test_loader, 
    torch.device(config.device),
    output_path='submission.csv'
)

print(f"Submission file saved to {submission_path}")
print("Done!")

In [ ]:
# Ensemble Model Training (for better results)

# Define configurations for ensemble models
ensemble_configs = [
    # ResNet50 with different learning rate and augmentations
    {
        'model_name': 'resnet50',
        'learning_rate': 2e-4,
        'optimizer': 'adamw',
        'weight_decay': 1e-4,
        'dropout_rate': 0.3,
        'use_augmentation': True,
        'rotation_degrees': 20,
        'color_jitter_factor': 0.15
    },
    # ResNet101 with different learning rate and augmentations
    {
        'model_name': 'resnet101',
        'learning_rate': 1e-4,
        'optimizer': 'adamw',
        'weight_decay': 2e-4,
        'dropout_rate': 0.2,
        'use_augmentation': True,
        'rotation_degrees': 15,
        'color_jitter_factor': 0.1
    },
    # EfficientNet-B0 with different learning rate and augmentations
    {
        'model_name': 'efficientnet_b0',
        'learning_rate': 3e-4,
        'optimizer': 'adam',
        'weight_decay': 1e-5,
        'dropout_rate': 0.25,
        'use_augmentation': True,
        'rotation_degrees': 10,
        'color_jitter_factor': 0.08
    }
]

# Function to train ensemble models
def train_ensemble_models(base_config, model_configs, train_loader, val_loader, class_names):
    """
    Train multiple models with different configurations for ensemble.
    
    Args:
        base_config: Base configuration
        model_configs: List of model-specific configurations
        train_loader: Training data loader
        val_loader: Validation data loader
        class_names: List of class names
        
    Returns:
        List of trained models
    """
    models = []
    
    for i, model_config in enumerate(model_configs):
        print(f"\n=== Training Ensemble Model {i+1}/{len(model_configs)} ===")
        
        # Create a copy of the base configuration
        config = copy.deepcopy(base_config)
        
        # Update with model-specific configuration
        for key, value in model_config.items():
            setattr(config, key, value)
        
        # Create the model
        print(f"Creating {config.model_name} model...")
        model = get_model(config)
        
        # Create optimizer and scheduler
        optimizer = get_optimizer(model, config)
        scheduler = get_scheduler(optimizer, config)
        
        # Create loss function
        criterion = get_loss_function(config)
        
        # Create metrics tracker
        metrics_tracker = MetricsTracker(num_classes=config.num_classes, class_names=class_names)
        
        # Create trainer
        trainer = Trainer(
            model=model,
            config=config,
            train_loader=train_loader,
            val_loader=val_loader,
            metrics_tracker=metrics_tracker,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler
        )
        
        # Train the model
        print(f"Training {config.model_name} for {config.num_epochs} epochs...")
        best_model = trainer.train(monitor_metric='accuracy')
        
        # Add to ensemble
        models.append(best_model)
        print(f"Best validation accuracy: {trainer.best_val_metric:.4f}")
    
    return models

# Function to get ensemble predictions
def get_ensemble_predictions(models, dataloader, device):
    """
    Get ensemble predictions by averaging outputs from multiple models.
    
    Args:
        models: List of trained models
        dataloader: DataLoader for the dataset
        device: Device to run inference on
        
    Returns:
        If test set (no labels): (ids, predicted_labels)
        If validation set (with labels): (true_labels, predicted_labels)
    """
    # Check if this is the test set (no labels) or validation set (with labels)
    batch = next(iter(dataloader))
    is_test = len(batch) == 2 and not isinstance(batch[1], torch.Tensor)
    
    # Collect predictions from all models
    all_outputs = []
    
    for model in models:
        model.eval()
        model_outputs = []
        all_ids = []
        all_labels = []
        
        with torch.no_grad():
            for data in dataloader:
                if is_test:
                    inputs, ids = data
                    all_ids.extend(ids)
                else:
                    inputs, labels = data
                    all_labels.extend(labels.cpu().numpy())
                
                inputs = inputs.to(device)
                outputs = model(inputs)
                model_outputs.append(outputs)
        
        # Concatenate all outputs
        model_outputs = torch.cat(model_outputs, dim=0)
        all_outputs.append(model_outputs)
    
    # Average outputs from all models
    ensemble_outputs = torch.mean(torch.stack(all_outputs), dim=0)
    
    # Get predicted labels
    _, predicted = torch.max(ensemble_outputs, 1)
    predicted = predicted.cpu().numpy()
    
    if is_test:
        return all_ids, predicted
    else:
        return np.array(all_labels), predicted

# Function to create ensemble submission file
def create_ensemble_submission(models, test_loader, device, output_path='ensemble_submission.csv'):
    """
    Create submission file with ensemble predictions.
    
    Args:
        models: List of trained models
        test_loader: DataLoader with test data
        device: Device to run inference on
        output_path: Path to save the submission file
        
    Returns:
        Path to the saved submission file
    """
    # Get ensemble predictions
    ids, predictions = get_ensemble_predictions(models, test_loader, device)
    
    # Create DataFrame with predictions
    submission_df = pd.DataFrame({
        'ID': ids,
        'CLASS': predictions
    })
    
    # Convert numerical labels to class names
    if hasattr(test_loader.dataset, 'idx_to_class'):
        submission_df['CLASS'] = submission_df['CLASS'].apply(
            lambda x: test_loader.dataset.idx_to_class.get(x, x)
        )
    
    # Save to CSV
    submission_df.to_csv(output_path, index=False)
    print(f"Ensemble submission file saved to {output_path}")
    
    return output_path

# Uncomment to train ensemble models (this will take a while)
"""
# Train ensemble models
print("Training ensemble models...")
ensemble_models = train_ensemble_models(
    config, ensemble_configs, train_loader, val_loader, class_names
)

# Evaluate ensemble on validation set
print("\nEvaluating ensemble on validation set...")
ensemble_val_labels, ensemble_val_preds = get_ensemble_predictions(
    ensemble_models, val_loader, torch.device(config.device)
)

# Calculate ensemble metrics
ensemble_accuracy = EvaluationMetrics.accuracy(ensemble_val_labels, ensemble_val_preds)
ensemble_f1 = EvaluationMetrics.f1(ensemble_val_labels, ensemble_val_preds)

print(f"Ensemble Validation Accuracy: {ensemble_accuracy:.4f}")
print(f"Ensemble Validation F1 Score: {ensemble_f1:.4f}")

# Create ensemble submission file
print("\nCreating ensemble submission file...")
ensemble_submission_path = create_ensemble_submission(
    ensemble_models, 
    test_loader, 
    torch.device(config.device),
    output_path='ensemble_submission.csv'
)

print(f"Ensemble submission file saved to {ensemble_submission_path}")
"""

## <font style="color:green">8. TensorBoard Log Link [5 Points]</font>

**Share your TensorBoard scalars logs link here You can also share (not mandatory) your GitHub link, if you have pushed this project in GitHub.**


Note: In light of the recent shutdown of tensorboard.dev, we have updated the submission requirements for your project. Instead of sharing a tensorboard.dev link, you are now required to upload your generated TensorBoard event files directly onto the lab. As an alternative, you may also include a screenshot of your TensorBoard output within your Jupyter notebook. This adjustment ensures that your data visualization and model training efforts are thoroughly documented and accessible for evaluation.

You are also welcome (and encouraged) to utilize alternative logging services like wandB or comet. In such instances, you can easily make your project logs publicly accessible and share the link with others.

In [ ]:
# TensorBoard Integration

# TensorBoard scalars log link
# Note: Since tensorboard.dev is no longer available, we'll save logs locally
# and upload screenshots or event files instead.

# Load tensorboard extension
%load_ext tensorboard

# Launch tensorboard
%tensorboard --logdir=./runs

# Alternative logging services:
# 1. Weights & Biases (wandb): https://wandb.ai/
# 2. Comet.ml: https://www.comet.ml/

# Example code for wandb integration:
"""
import wandb

# Initialize wandb project
wandb.init(project="kenyan-food-classification", name="resnet50-experiment")

# Log hyperparameters
wandb.config.update({
    "model": config.model_name,
    "batch_size": config.batch_size,
    "learning_rate": config.learning_rate,
    "optimizer": config.optimizer,
    "num_epochs": config.num_epochs
})

# During training, log metrics
def log_metrics_to_wandb(metrics, epoch, phase):
    """Log metrics to wandb"""
    wandb.log({
        f"{phase}_loss": metrics['loss'],
        f"{phase}_accuracy": metrics['accuracy'],
        f"{phase}_f1_score": metrics['f1_score'],
        "epoch": epoch
    })

# After training, log model
wandb.save(config.model_checkpoint_path())

# Close wandb run
wandb.finish()
"""

# You can also include a screenshot of your TensorBoard here:
# ![TensorBoard Screenshot](tensorboard_screenshot.png)

# In the final submission, provide your TensorBoard event files or a link to your
# wandb/comet dashboard for visualization of training metrics.

## <font style="color:green">9. Kaggle Profile Link [50 Points]</font>

**Share your Kaggle profile link  with us here to score , points in  the competition.**

**For full points, you need a minimum accuracy of `75%` on the test data. If accuracy is less than `70%`, you gain  no points for this section.**


**Submit `submission.csv` (prediction for images in `test.csv`), in the `Submit Predictions` tab in Kaggle, to get evaluated for  this section.**

In [ ]:
# Kaggle Profile Link

# Replace the following with your actual Kaggle profile link
kaggle_profile_link = "https://www.kaggle.com/yourusername"  # Replace with your Kaggle profile link

# Note: To gain full points (50 points) for this section, you need:
# 1. A minimum accuracy of 75% on the test data
# 2. If accuracy is less than 70%, you gain no points for this section

# For submission:
# 1. Submit 'submission.csv' (prediction for images in 'test.csv') in the 'Submit Predictions' tab in Kaggle
# 2. The file format should have two columns: 'ID' and 'CLASS'
# 3. The 'ID' column contains the image IDs from test.csv
# 4. The 'CLASS' column contains the predicted class for each image

# Submission Tips:
# - The ResNet50 model provided in this notebook should achieve >75% accuracy if trained properly
# - If you want to improve results further, try:
#   1. Ensemble methods (multiple models with different architectures)
#   2. Advanced data augmentation techniques
#   3. Test-time augmentation (TTA)
#   4. Hyperparameter tuning

# Competition Strategy:
# 1. First establish a solid baseline (ResNet50 with proper training)
# 2. Then try to improve with ensemble models
# 3. Finally, fine-tune hyperparameters for best performance

print(f"My Kaggle profile link: {kaggle_profile_link}")