# Set Up Environment and Imports
Import torch, torchvision (if needed), numpy, and other required libraries. Detect and print available device (CPU/GPU) for training.

In [None]:
# Import necessary libraries
import torch
import numpy as np

# Check for available device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configure Paths and Hyperparameters
Define configuration variables: path to weights.pth, dataset root paths, batch size, number of epochs, learning rate, weight decay, and output directories. Use Python variables so they can be easily changed in one place.

In [None]:
# Configure paths and hyperparameters

# Path to the pre-trained model weights
weights_path = "weights.pth"

# Dataset root paths
train_data_path = "./data/train"
val_data_path = "./data/val"

# Hyperparameters
batch_size = 32
num_epochs = 20
learning_rate = 0.001
weight_decay = 1e-4

# Output directory for saving models and logs
output_dir = "./output"

# Ensure the output directory exists
import os
os.makedirs(output_dir, exist_ok=True)

# Define / Load Model Architecture
Recreate the exact model architecture used when weights.pth was generated. This may involve defining a custom nn.Module class or instantiating a standard torchvision model with the correct number of classes. Print model summary and number of parameters.

In [None]:
# Define / Load Model Architecture

import torch.nn as nn
import torchvision.models as models
from torchsummary import summary

# Define the model architecture (example: ResNet18 with a custom output layer)
class CustomResNet(nn.Module):
    def __init__(self, num_classes):
        super(CustomResNet, self).__init__()
        # Load a pre-trained ResNet18 model
        self.base_model = models.resnet18(pretrained=False)
        # Replace the fully connected layer to match the number of classes
        self.base_model.fc = nn.Linear(self.base_model.fc.in_features, num_classes)

    def forward(self, x):
        return self.base_model(x)

# Number of classes in the dataset (update as needed)
num_classes = 10

# Instantiate the model
model = CustomResNet(num_classes=num_classes)

# Move the model to the selected device
model = model.to(device)

# Print model summary and number of parameters
summary(model, input_size=(3, 224, 224))  # Assuming input images are 3x224x224

# Print total number of trainable parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")

# Load Model Weights from weights.pth
Load the state_dict from weights.pth using torch.load, map to the correct device, and call model.load_state_dict. Handle common issues like missing or unexpected keys and optionally support loading from a checkpoint dict that includes optimizer state and epoch.

In [None]:
# Load the model weights from weights.pth
try:
    # Load the state_dict from the weights file
    checkpoint = torch.load(weights_path, map_location=device)
    
    # Check if the checkpoint contains a full state_dict or a dictionary with additional info
    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        # Load model weights from the checkpoint
        model.load_state_dict(checkpoint["model_state_dict"])
        print("Model weights loaded successfully from checkpoint.")
        
        # Optionally, load optimizer state and epoch if needed
        if "optimizer_state_dict" in checkpoint:
            optimizer_state_dict = checkpoint["optimizer_state_dict"]
            print("Optimizer state_dict found in checkpoint.")
        if "epoch" in checkpoint:
            start_epoch = checkpoint["epoch"]
            print(f"Resuming training from epoch {start_epoch}.")
    else:
        # Load model weights directly
        model.load_state_dict(checkpoint)
        print("Model weights loaded successfully.")
except FileNotFoundError:
    print(f"Error: The weights file '{weights_path}' was not found.")
except KeyError as e:
    print(f"Error: Missing key in state_dict - {e}")
except RuntimeError as e:
    print(f"Error: Runtime error while loading state_dict - {e}")

# Prepare Dataset and Dataloaders
Define dataset objects (e.g., custom torch.utils.data.Dataset or torchvision datasets) along with required transforms/augmentations. Then create DataLoader instances for training and validation sets with appropriate batch_size, shuffle, num_workers, and pin_memory settings.

In [None]:
# Prepare Dataset and Dataloaders

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define data transformations for training and validation datasets
data_transforms = {
    "train": transforms.Compose([
        transforms.RandomResizedCrop(224),  # Randomly crop and resize to 224x224
        transforms.RandomHorizontalFlip(),  # Random horizontal flip
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ]),
    "val": transforms.Compose([
        transforms.Resize(256),  # Resize to 256
        transforms.CenterCrop(224),  # Center crop to 224x224
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])
}

# Create datasets for training and validation
train_dataset = datasets.ImageFolder(root=train_data_path, transform=data_transforms["train"])
val_dataset = datasets.ImageFolder(root=val_data_path, transform=data_transforms["val"])

# Create DataLoader instances for training and validation
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,  # Shuffle training data
    num_workers=4,  # Number of subprocesses for data loading
    pin_memory=True  # Pin memory for faster data transfer to GPU
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,  # No need to shuffle validation data
    num_workers=4,  # Number of subprocesses for data loading
    pin_memory=True  # Pin memory for faster data transfer to GPU
)

# Print dataset sizes
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

# Define Loss Function and Optimizer
Instantiate the loss function (e.g., CrossEntropyLoss, MSELoss) and an optimizer (e.g., Adam, SGD) using the model parameters and configured hyperparameters. Optionally configure a learning rate scheduler.

In [None]:
# Define the loss function
criterion = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=learning_rate, 
    weight_decay=weight_decay
)

# Optionally, define a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, 
    step_size=7,  # Decay learning rate every 7 epochs
    gamma=0.1     # Multiply learning rate by 0.1
)

# Print confirmation
print("Loss function, optimizer, and scheduler defined successfully.")

# Single-Epoch Training Step
Implement a train_one_epoch function that loops over the training DataLoader, moves batches to the device, performs forward pass, computes loss, backpropagates, applies optimizer.step, and tracks running loss and accuracy. Include gradient zeroing and optional gradient clipping.

In [None]:
# Define the train_one_epoch function
def train_one_epoch(model, dataloader, criterion, optimizer, device, clip_grad=None):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # Move inputs and labels to the selected device
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()

        # Optional gradient clipping
        if clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

        # Update weights
        optimizer.step()

        # Update running loss
        running_loss += loss.item() * inputs.size(0)

        # Compute accuracy
        _, preds = torch.max(outputs, 1)
        correct_predictions += torch.sum(preds == labels).item()
        total_samples += labels.size(0)

    # Compute average loss and accuracy
    epoch_loss = running_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples

    print(f"Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.4f}")

    return epoch_loss, epoch_accuracy

# Example usage
train_one_epoch(model, train_loader, criterion, optimizer, device, clip_grad=1.0)

# Full Training Loop with Checkpointing
Implement a train function that calls train_one_epoch for multiple epochs, logs metrics, optionally steps the LR scheduler, and saves checkpoints (including model.state_dict, optimizer.state_dict, current epoch, and best metric) to disk so training can be resumed from weights.pth-like files.

In [None]:
# Define the full training loop with checkpointing
def train(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs, output_dir, start_epoch=0, best_metric=None):
    best_metric = best_metric or float('-inf')  # Initialize best metric if not provided

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Train for one epoch
        train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device)

        # Validate the model
        val_loss, val_accuracy = validate(model, val_loader, criterion, device)

        # Log metrics
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        # Step the learning rate scheduler if defined
        if scheduler:
            scheduler.step()

        # Save the best model based on validation accuracy
        is_best = val_accuracy > best_metric
        if is_best:
            best_metric = val_accuracy
            print(f"New best metric: {best_metric:.4f}")

        # Save checkpoint
        checkpoint = {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "best_metric": best_metric
        }
        checkpoint_path = os.path.join(output_dir, f"checkpoint_epoch_{epoch + 1}.pth")
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

# Define the validation function
def validate(model, dataloader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            # Move inputs and labels to the selected device
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update running loss
            running_loss += loss.item() * inputs.size(0)

            # Compute accuracy
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels).item()
            total_samples += labels.size(0)

    # Compute average loss and accuracy
    epoch_loss = running_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples

    return epoch_loss, epoch_accuracy

# Example usage of the training loop
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=num_epochs,
    output_dir=output_dir,
    start_epoch=checkpoint.get("epoch", 0) if "epoch" in checkpoint else 0,
    best_metric=checkpoint.get("best_metric", None) if "best_metric" in checkpoint else None
)

# Evaluation Loop and Metrics
Implement an evaluate function that sets the model to eval mode, iterates over the validation/test DataLoader without gradient computation, computes loss and metrics (e.g., accuracy, F1), and prints or returns them. Use this after each epoch in the training loop.

In [None]:
def evaluate(model, dataloader, criterion, device):
    """
    Evaluate the model on the validation/test dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        dataloader (torch.utils.data.DataLoader): DataLoader for the validation/test dataset.
        criterion (torch.nn.Module): Loss function.
        device (torch.device): Device to run the evaluation on.

    Returns:
        tuple: Average loss and metrics (accuracy, F1 score).
    """
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Initialize variables for F1 score calculation
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            # Move inputs and labels to the selected device
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Update running loss
            running_loss += loss.item() * inputs.size(0)

            # Compute accuracy
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels).item()
            total_samples += labels.size(0)

            # Collect labels and predictions for F1 score
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    # Compute average loss and accuracy
    avg_loss = running_loss / total_samples
    accuracy = correct_predictions / total_samples

    # Compute F1 score
    from sklearn.metrics import f1_score
    f1 = f1_score(all_labels, all_preds, average="weighted")

    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

    return avg_loss, accuracy, f1

# Example usage of the evaluate function
val_loss, val_accuracy, val_f1 = evaluate(model, val_loader, criterion, device)