# Training pipeline for the ResNet

In [1]:
from resnet import ResNet18, KeypointDataset, train_model
from torch.utils.data import DataLoader
from torchvision import transforms

In [None]:
TRAIN_IMG_DIR = "data/train/images"
TRAIN_LABEL_DIR = "data/train/labels"
VAL_IMG_DIR = "data/val/images"
VAL_LABEL_DIR = "data/val/labels"

BATCH_SIZE = 16
NUM_WORKERS = 4
PIN_MEMORY = True
NUM_EPOCHS = 50
SAVE_INTERVAL = 5

TRAIN_LOSS_PATH = "train_loss.npy"
VAL_LOSS_PATH = "val_loss.npy"
LOSS_PLOT_PATH = "loss_plot.png"

In [None]:
# Create the model
model = ResNet18(num_keypoints=9, pretrained=True)

In [None]:
# Define data transformations

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)), # Random rotation and translation
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ToTensor()
])

# Keep validation transforms simple (no augmentation) to evaluate true performance
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ToTensor()
])
    

In [None]:
# Create datasets and dataloaders

train_dataset = KeypointDataset(TRAIN_IMG_DIR, TRAIN_LABEL_DIR, transform=train_transform)  # TODO: Add transforms
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

val_dataset = KeypointDataset(VAL_IMG_DIR, VAL_LABEL_DIR, transform=val_transform)  # TODO: Add transforms
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)

In [None]:
# Train the model

train_loss, val_loss = train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS, save_interval=SAVE_INTERVAL)

In [None]:
# Save the lists of losses, accuracy and dice scores

import numpy as np
np.save(TRAIN_LOSS_PATH, np.array(train_loss))
np.save(VAL_LOSS_PATH, np.array(val_loss))

# To load the saved lists, use:
# train_loss = np.load("train_loss.npy")

In [None]:
# Plot the results

import matplotlib.pyplot as plt

# Plot the training and validation loss
plt.figure()
plt.title("Training and validation losses over epochs")
plt.plot(train_loss, label="Train loss")
plt.plot(val_loss, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig(LOSS_PLOT_PATH)
plt.show()