# E‑ConvNeXt Image Classification Pipeline

This notebook presents a unified image classification workflow built around the E‑ConvNeXt architecture. It brings together several pieces of functionality from separate notebooks and scripts into one cohesive pipeline:

* **Automated data splitting:** given a root directory of images, the notebook can automatically split the data into training, validation and test sets (70/15/15 split by default).
* **Exploratory Data Analysis (EDA):** before training, the notebook scans the dataset and visualises key characteristics such as class distribution, image dimensions, aspect ratios and brightness. Example images are also shown.
* **E‑ConvNeXt model definition and training:** the core model, training loop and evaluation functions are adapted from the *E‑ConvNeXt Notebook Guide‑FULLPIPELINE* to provide a configurable classification backbone. Training history is plotted using a consistent blue‑gradient colour palette.

To use this notebook with your own data, set the `dataset_root` variable in the next code cell to point to your image folder. The folder can either contain class sub‑directories directly or be organised into `train`, `val`/`validation` and `test` folders. Once the dataset is split and loaded, the notebook will run EDA, train the E‑ConvNeXt model and provide evaluation metrics and visualisations.

In [None]:
# Install required libraries (uncomment if running locally)
# !pip install torch torchvision timm numpy pandas matplotlib seaborn scikit-learn

import os
import shutil
import random
import math
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision import datasets, transforms

from sklearn.metrics import classification_report, confusion_matrix

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Use a blue gradient palette for all plots
sns.set_palette('Blues')

## Automated dataset splitting

The following function inspects a dataset directory and organises it into `train`, `val` and `test` sub‑folders on disk.

* If the root contains only class sub‑folders, a 70/15/15 split is created.
* If only `train` and `val` directories exist, their contents are combined and re‑split into `train`, `val` and `test`.
* If `train` and `test` directories exist, they are left untouched (no validation split).
* If `train`, `val` and `test` are present, no splitting is performed.

You can adjust the `train_ratio`, `val_ratio` and `test_ratio` arguments to change the default 70/15/15 proportions.

In [None]:
def split_dataset(base_dir: str, train_ratio: float = 0.70, val_ratio: float = 0.15, test_ratio: float = 0.15, seed: int = 42) -> None:
    """
    Split a dataset of images on disk. The function operates in four modes based on folder structure:

    1. When the root contains only class sub‑folders (no `train`/`val`/`test`), a new `train`/`val`/`test` split is created.
    2. When `train` and `val` folders exist, their contents are combined and re‑split into `train`/`val`/`test`.
    3. When `train` and `test` exist, they are left untouched and no validation set is created.
    4. When `train`, `val` and `test` exist, the directory structure is assumed correct and nothing is changed.

    Parameters
    ----------
    base_dir : str
        Path to the dataset root.
    train_ratio : float
        Proportion of samples to allocate to the training set when performing a new split.
    val_ratio : float
        Proportion of samples to allocate to the validation set when performing a new split.
    test_ratio : float
        Proportion of samples to allocate to the test set when performing a new split.
    seed : int
        Random seed for reproducibility.
    """
    random.seed(seed)
    root = os.path.abspath(base_dir)
    train_dir = os.path.join(root, 'train')
    val_dir = os.path.join(root, 'val')
    test_dir = os.path.join(root, 'test')

    def _ensure_empty(dir_path: str) -> None:
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        else:
            for item in os.listdir(dir_path):
                shutil.rmtree(os.path.join(dir_path, item))

    has_train = os.path.isdir(train_dir)
    has_val = os.path.isdir(val_dir)
    has_test = os.path.isdir(test_dir)

    # Case 1: single folder of class sub‑directories
    if not has_train and not has_val and not has_test:
        print('No train/val/test directories detected. Creating a 70/15/15 split.')
        classes = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))]
        _ensure_empty(train_dir)
        _ensure_empty(val_dir)
        _ensure_empty(test_dir)
        for cls in classes:
            src_dir = os.path.join(root, cls)
            images = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
            random.shuffle(images)
            n_total = len(images)
            n_train = int(n_total * train_ratio)
            n_val = int(n_total * val_ratio)
            splits = {
                'train': images[:n_train],
                'val': images[n_train:n_train + n_val],
                'test': images[n_train + n_val:]
            }
            for split_name, files in splits.items():
                target_dir = os.path.join(root, split_name, cls)
                os.makedirs(target_dir, exist_ok=True)
                for fname in files:
                    shutil.copy2(os.path.join(src_dir, fname), os.path.join(target_dir, fname))
    # Case 2: train and val only → combine and re‑split
    elif has_train and has_val and not has_test:
        print('Found train and val directories only. Combining and re‑splitting into train/val/test.')
        combined = {}
        for split_dir in [train_dir, val_dir]:
            for cls in os.listdir(split_dir):
                cls_dir = os.path.join(split_dir, cls)
                combined.setdefault(cls, []).extend([
                    os.path.join(cls_dir, f) for f in os.listdir(cls_dir)
                    if os.path.isfile(os.path.join(cls_dir, f))
                ])
        _ensure_empty(train_dir)
        _ensure_empty(val_dir)
        _ensure_empty(test_dir)
        for cls, files in combined.items():
            random.shuffle(files)
            n_total = len(files)
            n_train = int(n_total * train_ratio)
            n_val = int(n_total * val_ratio)
            splits = {
                'train': files[:n_train],
                'val': files[n_train:n_train + n_val],
                'test': files[n_train + n_val:]
            }
            for split_name, filepaths in splits.items():
                target_dir = os.path.join(root, split_name, cls)
                os.makedirs(target_dir, exist_ok=True)
                for src_path in filepaths:
                    shutil.copy2(src_path, os.path.join(target_dir, os.path.basename(src_path)))
    # Case 3: train and test only → do nothing
    elif has_train and not has_val and has_test:
        print('Found train and test directories only. Leaving as‑is. No validation set will be created.')
    # Case 4: train, val and test exist → do nothing
    elif has_train and has_val and has_test:
        print('Found train, val, and test directories. No splitting needed.')
    else:
        raise RuntimeError(f'Unrecognised layout under {root}. Expected combinations of train/val/test.')


### Configuration and data loading

The configuration below allows you to customise the dataset location, image resolution, model variant and training hyperparameters. After splitting the dataset, the `load_data` function loads the images into `torchvision.datasets.ImageFolder` datasets and constructs PyTorch dataloaders for the training and validation sets.

In [None]:
# Editable configuration
@dataclass
class Config:
    DATA_DIR: str = 'path/to/your/dataset'  # change this to your dataset root
    NUM_CLASSES: int = 10
    MODEL_VARIANT: str = 'tiny'            # choose from 'mini', 'tiny', 'small'
    IMAGE_SIZE: int = 224
    BATCH_SIZE: int = 32
    EPOCHS: int = 10
    LEARNING_RATE: float = 2.5e-4
    WEIGHT_DECAY: float = 0.05
    # augmentation settings
    RAND_AUGMENT_N: int = 9
    RAND_AUGMENT_M: float = 0.5
    MIXUP_ALPHA: float = None
    CUTMIX_ALPHA: float = None

config = Config()

def get_transforms(cfg: Config):
    train_transforms = transforms.Compose([
        transforms.Resize(cfg.IMAGE_SIZE + 32),
        transforms.RandomCrop(cfg.IMAGE_SIZE),
        transforms.RandAugment(num_ops=cfg.RAND_AUGMENT_N, magnitude=int(cfg.RAND_AUGMENT_M * 10)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    val_transforms = transforms.Compose([
        transforms.Resize(cfg.IMAGE_SIZE + 32),
        transforms.CenterCrop(cfg.IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return train_transforms, val_transforms

def load_data(cfg: Config):
    train_transforms, val_transforms = get_transforms(cfg)
    train_dir = os.path.join(cfg.DATA_DIR, 'train')
    val_dir = os.path.join(cfg.DATA_DIR, 'val')
    train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
    val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)
    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, train_dataset.classes


## Exploratory Data Analysis (EDA)

The following helpers perform a lightweight scan of an image dataset and produce visual summaries. Statistics such as image width, height, aspect ratio, file size, brightness and colourfulness are collected. Plotting functions then visualise class distribution, dimension distributions, aspect ratios and other metrics. All charts use the same blue‑gradient palette for consistency.

In [None]:
from PIL import Image

def _colorfulness_fast(im: Image.Image) -> float:
    arr = np.asarray(im.convert('RGB'), dtype=np.float32)
    r, g, b = arr[..., 0], arr[..., 1], arr[..., 2]
    rg = np.abs(r - g)
    yb = np.abs(0.5 * (r + g) - b)
    return float(math.sqrt(rg.var() + yb.var()) + 0.3 * math.sqrt(rg.mean()**2 + yb.mean()**2))

def _brightness(im: Image.Image) -> float:
    return float(np.asarray(im.convert('L'), dtype=np.float32).mean())

def scan_image_dataset(data_path: str) -> pd.DataFrame:
    """
    Perform a scan over all images under `data_path` and return a DataFrame summarising their properties.

    The scan looks recursively for files with typical image extensions. For each image it records
    the label (sub‑directory name), optional split (train/val/test if present), dimensions, aspect ratio,
    file size (kilobytes), brightness and colourfulness.
    """
    img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff', '.webp'}
    records = []
    for root_dir, _, files in os.walk(data_path):
        for fname in files:
            if not any(fname.lower().endswith(ext) for ext in img_exts):
                continue
            fpath = os.path.join(root_dir, fname)
            try:
                with Image.open(fpath) as im:
                    w, h = im.size
                    mode = im.mode
                    brightness_value = _brightness(im)
                    colorfulness_value = _colorfulness_fast(im)
            except Exception:
                continue
            parts = os.path.normpath(fpath).split(os.sep)
            label = parts[-2] if len(parts) >= 2 else ''
            split = ''
            if label.lower() in {'train', 'val', 'validation', 'test'}:
                split = label.lower()
                label = parts[-3] if len(parts) >= 3 else label
            records.append({
                'path': fpath,
                'split': split,
                'label': label,
                'width': w,
                'height': h,
                'aspect': (w / h) if h > 0 else float('nan'),
                'file_size_kb': os.path.getsize(fpath) / 1024.0,
                'brightness': brightness_value,
                'colorfulness': colorfulness_value,
                'mode': mode
            })
    return pd.DataFrame(records)

def plot_class_distribution(df: pd.DataFrame) -> None:
    counts = df['label'].value_counts().sort_values(ascending=False)
    plt.figure(figsize=(max(6, 0.35 * len(counts) + 4), 4))
    sns.barplot(x=counts.index, y=counts.values)
    plt.xticks(rotation=45, ha='right')
    plt.title('Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Number of images')
    plt.show()

def plot_histogram(series: np.ndarray, bins: int, title: str, xlabel: str) -> None:
    plt.figure(figsize=(6, 4))
    plt.hist(series, bins=bins, color=sns.color_palette('Blues')[3])
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel('Count')
    plt.show()

def plot_scatter(x: np.ndarray, y: np.ndarray, title: str, xlabel: str, ylabel: str) -> None:
    plt.figure(figsize=(5, 5))
    plt.scatter(x, y, alpha=0.6, color=sns.color_palette('Blues')[4])
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

def show_random_images(df: pd.DataFrame, n: int = 9) -> None:
    n = min(n, len(df))
    indices = np.random.choice(len(df), size=n, replace=False)
    sample_paths = df.iloc[indices]['path'].tolist()
    grid_size = int(math.sqrt(n))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(grid_size * 3, grid_size * 3))
    for idx, path in enumerate(sample_paths):
        row = idx // grid_size
        col = idx % grid_size
        ax = axes[row, col]
        try:
            img = Image.open(path).convert('RGB')
            ax.imshow(img)
            ax.set_title(os.path.basename(os.path.dirname(path)), fontsize=8)
        except Exception:
            ax.axis('off')
            continue
        ax.axis('off')
    plt.suptitle('Random Image Samples')
    plt.tight_layout()
    plt.show()

def run_eda(data_path: str) -> pd.DataFrame:
    df = scan_image_dataset(data_path)
    if df.empty:
        print(f'No images found under: {data_path}')
        return df
    plot_class_distribution(df)
    plot_histogram(df['width'].values, bins=40, title='Image Width Distribution', xlabel='Width (pixels)')
    plot_histogram(df['height'].values, bins=40, title='Image Height Distribution', xlabel='Height (pixels)')
    aspects = df['aspect'].dropna().values
    plot_histogram(aspects, bins=40, title='Aspect Ratio Distribution', xlabel='Aspect ratio (W/H)')
    plot_histogram(df['file_size_kb'].values, bins=40, title='File Size Distribution', xlabel='File size (KB)')
    plot_histogram(df['brightness'].values, bins=40, title='Brightness Distribution', xlabel='Average grayscale value (0–255)')
    plot_scatter(df['width'].values, df['height'].values, title='Width vs Height', xlabel='Width (pixels)', ylabel='Height (pixels)')
    show_random_images(df, n=9)
    return df


## E‑ConvNeXt model architecture

The following cells define the E‑ConvNeXt architecture adapted from the reference implementation.
It comprises an Effective Squeeze‑and‑Excitation (ESE) block, a modified ConvNeXt block with Batch Normalisation, a Cross‑Stage Partial (CSP) stage and the overall model class.

In [None]:
# Effective Squeeze‑and‑Excitation block
class ESEBlock(nn.Module):
    def __init__(self, channels: int, reduction: int = 4):
        super().__init__()
        reduced_channels = max(1, channels // reduction)
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excitation = nn.Sequential(
            nn.Conv2d(channels, reduced_channels, kernel_size=1, bias=True),
            nn.GELU(),
            nn.Conv2d(reduced_channels, channels, kernel_size=1, bias=True),
            nn.Sigmoid()
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        s = self.squeeze(x)
        e = self.excitation(s)
        return x * e

# E‑ConvNeXt block
class EConvNeXtBlock(nn.Module):
    def __init__(self, dim: int, expansion_factor: float = 4.0):
        super().__init__()
        expanded_dim = int(dim * expansion_factor)
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.bn1 = nn.BatchNorm2d(dim)
        self.pwconv1 = nn.Conv2d(dim, expanded_dim, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(expanded_dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv2d(expanded_dim, dim, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(dim)
        self.ese = ESEBlock(dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.dwconv(x)
        x = self.bn1(x)
        x = self.pwconv1(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = self.bn3(x)
        x = self.ese(x)
        return x + residual

# CSP stage
class CSPStage(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_blocks: int, ch_mid: int = None):
        super().__init__()
        self.downsample = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )
        if ch_mid is None:
            ch_mid = (in_channels + out_channels) // 2
        self.split_conv0 = nn.Sequential(
            nn.Conv2d(out_channels, ch_mid, kernel_size=1),
            nn.BatchNorm2d(ch_mid),
            nn.GELU()
        )
        self.split_conv1 = nn.Sequential(
            nn.Conv2d(out_channels, ch_mid, kernel_size=1),
            nn.BatchNorm2d(ch_mid),
            nn.GELU()
        )
        self.blocks = nn.Sequential(*[EConvNeXtBlock(ch_mid) for _ in range(num_blocks)])
        self.after_blocks = nn.Sequential(
            nn.Conv2d(ch_mid, ch_mid, kernel_size=1),
            nn.BatchNorm2d(ch_mid),
            nn.GELU()
        )
        self.merge_conv = nn.Sequential(
            nn.Conv2d(ch_mid * 2, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU()
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.downsample(x)
        y1 = self.split_conv0(x)
        y2 = self.split_conv1(x)
        y2 = self.blocks(y2)
        y2 = self.after_blocks(y2)
        merged = torch.cat([y1, y2], dim=1)
        return self.merge_conv(merged)

def get_variant_config(variant: str):
    variant = variant.lower()
    if variant == 'mini':
        dims = [48, 96, 192, 384]
        depths = [3, 3, 9, 3]
    elif variant == 'tiny':
        dims = [64, 128, 256, 512]
        depths = [3, 3, 9, 3]
    elif variant == 'small':
        dims = [80, 160, 320, 640]
        depths = [3, 3, 27, 3]
    else:
        raise ValueError(f'Unknown variant: {variant}')
    return dims, depths

class EConvNeXt(nn.Module):
    def __init__(self, num_classes: int = 1000, variant: str = 'tiny'):
        super().__init__()
        dims, depths = get_variant_config(variant)
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.GELU(),
            nn.Conv2d(32, dims[0], kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(dims[0]),
            nn.GELU(),
        )
        in_channels = dims[0]
        stages = []
        for out_channels, depth in zip(dims, depths):
            stages.append(CSPStage(in_channels, out_channels, num_blocks=depth))
            in_channels = out_channels
        self.stages = nn.Sequential(*stages)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.head = nn.Linear(dims[-1], num_classes)
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.stages(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        return self.head(x)


## Training and evaluation functions

Here we implement the training loop, validation routine, history plotting and evaluation helpers.
Plots of the training history (loss and accuracy) use the same blue palette.

In [None]:
def train_one_epoch(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer, device: torch.device) -> tuple:
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels)
        total += inputs.size(0)
    epoch_loss = running_loss / total
    epoch_acc = running_corrects.double().item() / total
    return epoch_loss, epoch_acc

def validate(model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device) -> tuple:
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels)
            total += inputs.size(0)
    epoch_loss = running_loss / total
    epoch_acc = running_corrects.double().item() / total
    return epoch_loss, epoch_acc

def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: Config, device: torch.device):
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
    scheduler = CosineAnnealingLR(optimizer, T_max=cfg.EPOCHS)
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    for epoch in range(cfg.EPOCHS):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        print(f'Epoch {epoch+1}/{cfg.EPOCHS}  |  Train Loss: {train_loss:.4f}  Acc: {train_acc:.4f}  |  Val Loss: {val_loss:.4f}  Acc: {val_acc:.4f}')
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_econvnext_model.pth')
    print(f'Best validation accuracy: {best_acc:.4f}')
    return history

def plot_training_history(history: dict) -> None:
    epochs_range = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(epochs_range, history['train_loss'], label='Train Loss', color=sns.color_palette('Blues')[4])
    axes[0].plot(epochs_range, history['val_loss'], label='Val Loss', color=sns.color_palette('Blues')[2])
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Loss Curves')
    axes[0].legend()
    axes[1].plot(epochs_range, history['train_acc'], label='Train Acc', color=sns.color_palette('Blues')[4])
    axes[1].plot(epochs_range, history['val_acc'], label='Val Acc', color=sns.color_palette('Blues')[2])
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Accuracy Curves')
    axes[1].legend()
    plt.show()

def evaluate_model(model: nn.Module, dataloader: DataLoader, class_names: list, device: torch.device) -> None:
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    print(classification_report(all_labels, all_preds, target_names=class_names))
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

def load_model_for_inference(weight_path: str, num_classes: int, variant: str, device: torch.device) -> nn.Module:
    model = EConvNeXt(num_classes=num_classes, variant=variant)
    model.load_state_dict(torch.load(weight_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict_image(model: nn.Module, image_path: str, cfg: Config, class_names: list, device: torch.device, topk: int = 5) -> None:
    transform = transforms.Compose([
        transforms.Resize(cfg.IMAGE_SIZE + 32),
        transforms.CenterCrop(cfg.IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = Image.open(image_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(input_tensor)
        probs = torch.softmax(outputs[0], dim=0)
        top_probs, top_idxs = torch.topk(probs, k=min(topk, len(probs)))
    for i in range(len(top_probs)):
        print(f'{class_names[top_idxs[i].item()]}: {top_probs[i].item():.4f}')
    plt.imshow(img)
    plt.axis('off')
    plt.title('Input Image')
    plt.show()


## Putting it all together

In this final section you can run the full pipeline. Provide the path to your dataset, split the data if necessary, perform EDA, train the model and evaluate it. The example below demonstrates the typical order of operations. Replace `'path/to/your/dataset'` with the actual location of your image data.

In [None]:
# Specify your dataset root here. It may contain class sub‑folders directly or train/val/test splits.
dataset_root = config.DATA_DIR  # update this variable before running

# Step 1: split the dataset (if required)
split_dataset(dataset_root, train_ratio=0.70, val_ratio=0.15, test_ratio=0.15, seed=SEED)

# Step 2: run EDA on the training set (or on the entire dataset)
train_path = os.path.join(dataset_root, 'train')
df_summary = run_eda(train_path)

# Step 3: load data for training and validation
train_loader, val_loader, class_names = load_data(config)
config.NUM_CLASSES = len(class_names)

# Step 4: initialise model and train
model = EConvNeXt(num_classes=config.NUM_CLASSES, variant=config.MODEL_VARIANT).to(device)
history = train_model(model, train_loader, val_loader, config, device)
plot_training_history(history)

# Step 5: evaluate the best saved model on the validation set
best_model = load_model_for_inference('best_econvnext_model.pth', num_classes=config.NUM_CLASSES, variant=config.MODEL_VARIANT, device=device)
evaluate_model(best_model, val_loader, class_names, device)

# Step 6: optional inference on a single image
# image_path = '/path/to/your/image.jpg'
# predict_image(best_model, image_path, config, class_names, device, topk=5)
