<a href="https://colab.research.google.com/github/vivi-alencar/bachelor_thesis/blob/main/Training_Flickr_30k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This version of the model uses an identity matrix as targets. Before, targets were based on the similarity between images and texts within the batch using softmax, which could weaken the model's ability to distinguish between matching and non-matching pairs.

In [None]:
!pip install albumentations
!pip install timm
!pip install transformers

In [None]:
from PIL import Image
import os
import albumentations as A
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer


In [None]:
from torch.cuda.amp import autocast, GradScaler
# Initialize the scaler for mixed precision
scaler = torch.amp.GradScaler('cuda')

In [None]:
from google.colab import files

In [None]:
!pip install kaggle --upgrade
os.environ['KAGGLE_USERNAME'] = "XXXXX"
os.environ['KAGGLE_KEY'] = "XXXXXXXXXXXXXX"

# Flickr 30k
!kaggle datasets download -d hsankesara/flickr-image-dataset
!unzip flickr-image-dataset.zip
dataset = "30k"

In [None]:
# # Set up logging configuration
# logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## Preprocessing

In [None]:
!ls /content/

### Create pandas dataframe to add ids to images

In [None]:
# Read the captions file
df = pd.read_csv("/content/flickr30k_images/results.csv", delimiter="|")
df.columns = ['image', 'caption_number', 'caption']
df['caption'] = df['caption'].str.lstrip()
df['caption_number'] = df['caption_number'].str.lstrip()
df.loc[19999, 'caption_number'] = "4"
df.loc[19999, 'caption'] = "A dog runs across the grass ."

# Add the 'id' column
ids = [id_ for id_ in range(len(df) // 5) for _ in range(5)]
df['id'] = ids

# Save the updated DataFrame
df.to_csv("captions.csv", index=False)

# Define the paths for images and captions
image_path = "/content/flickr30k_images/flickr30k_images"
captions_path = "/content"

In [None]:
df.head()

In [None]:
# Check if the file exists
file_path = "/content/captions.csv"
if os.path.exists(file_path):
    print(f"File {file_path} exists.")
else:
    print(f"File {file_path} does not exist.")

# Definitions

### CLASS: Store configurations

In [None]:
class CFG:
    debug = False
    #debug = True
    image_path = image_path
    captions_path = captions_path
    batch_size = 32           # Number of samples processed in each batch during training/validation
    num_workers = 2           # Number of subprocesses used for data loading
    head_lr = 1e-3            # Learning rate for the projection heads (which map image and text embeddings to a common space)
    image_encoder_lr = 1e-4   # Learning rate for the image encoder
    text_encoder_lr = 1e-5    # Learning rate for the text encoder
    weight_decay = 1e-3       # Regularization: add a penalty for large weights in the model

    epochs = 20              # Number of epochs to train the model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'vit_base_patch32_224' # Image encoder
    image_embedding = 768     # Embedding size for ViT
    text_encoder_model = "distilbert-base-uncased" # Text encoder
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True        # for both image encoder and text encoder
    trainable = True         # for both image encoder and text encoder
    temperature = 1.2        # Used in the softmax function to control the sharpness of the probability distribution in contrastive learning

    size = 224               # Image size

    # For projection head:
    num_projection_layers = 1 # Layers in the projection head
    projection_dim = 256      # Size of the output embedding produced by the projection head
    dropout = 0.1             # Regularization: randomly drop some neurons during training to prevent overfitting.

    # Early stopping patience for validation loss improvement
    early_stopping_patience = 5 # After N epochs without improvement, stop training

    # ReduceLROnPlateau scheduler settings
    lr_scheduler_patience = 2  # Number of epochs to wait before reducing the LR
    lr_scheduler_factor = 0.8  # Factor to reduce LR

### CLASS: Track and compute the running average of a metric

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self): # This method sets all the internal counters (avg, sum, count) to zero, preparing the meter for a fresh run.
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1): # Updates the sum, count, and recalculates the average (avg) whenever a new value (val) is provided.
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self): #  string representation method that formats and returns the average value as a string. Handy for printing/logging purposes.
        text = f"{self.name}: {self.avg:.4f}"
        return text

### FUNCTION: Retrieve the current learning rate from the optimizer

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

### CLASS: Dataset class to load image-caption pairs

In [None]:
class CLIPDataset(torch.utils.data.Dataset): #class inherits from torch.utils.data.Dataset
    def __init__(self, image_filenames, captions, tokenizer, transforms): # initialization method for the class
        # image_filenames: list of image file names (i.e., paths to the images you want to load).
        # captions: list of corresponding text captions for the images.
        # tokenizer: tokenizer used to convert the captions into a format suitable for the model.
        # transforms: set of image transformations applied to the images before feeding them into the model.

        # Stores the list of image filenames
        self.image_filenames = image_filenames

        # Convert the captions into a list and stores them
        self.captions = list(captions)

        # Use tokenizer to convert captions into a dictionary of tokenized representations,
        # with padding and truncation applied to match the desired maximum sequence length (CFG.max_length).
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=CFG.max_length
        )

        # Store the image transformation functions for later use when loading the images
        self.transforms = transforms

    def __getitem__(self, idx): # Define how to retrieve an individual sample (image and caption) from the dataset based on the provided index
        # Loop through each tokenized item in encoded_captions (e.g., input_ids, attention_mask)
        # and convert into PyTorch tensors, grabbing the specific tokenized caption at index idx.
        # This creates the item dictionary where each key is a tokenized caption component, and each value is a tensor.
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        # Load image corresponding to the index idx using PIL:
        img_path = f"{CFG.image_path}/{self.image_filenames[idx]}"
        image = Image.open(img_path).convert("RGB")  # Open image and convert to RGB format

        # Convert the PIL image to a NumPy array (some transformations expect NumPy arrays)
        image = np.array(image)

        # Apply transformations (assuming the transforms expect a NumPy array)
        image = self.transforms(image=image)['image']

        # Convert the transformed image into a PyTorch tensor.
        # Permute dimensions from (H, W, C) (height, width, channels) to (C, H, W) - PyTorch format for image tensors.
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()

        # Add original (non-tokenized) caption to the item dictionary
        item['caption'] = self.captions[idx]

        # Return item dictionary, which now contains tokenized caption + transformed image data + original caption
        return item

    def __len__(self): # Method to return total number of samples in the dataset (length of the list of captions)
        return len(self.captions)

### FUNCTION: Return a set of image transformations

In [None]:
def get_transforms(mode="train"):
    if mode == "train": # returns a set of transformations specifically tailored for training images
        return A.Compose( #sequential transformations
            [
                A.Resize(CFG.size, CFG.size, always_apply=True), # Resize to CFG.size x CFG.size
                A.Normalize(max_pixel_value=255.0, always_apply=True), # Scale pixel values from [0, 255] to [0, 1]
            ]
        )
    else: # Non-train mode
        return A.Compose( # Same as above. Validation and testing require the same resizing/normalization
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ]
        )

### CLASS: Image encoder

In [None]:
class ImageEncoder(nn.Module): # Class inherits from nn.Module from PyTorch

    # Constructor
    def __init__(
        self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable
    ):
        #model_name: model architecture for image encoding
        #pretrained: whether to use pre-trained version of the model
        #trainable: determine whether the model's parameters should be trainable (True: updated during training, False:frozen)

        # Call constructor of the parent class nn.Module (required when overriding the __init__ method in a subclass)
        super().__init__()

        # Create model with timm library
        # With num_classes = 0, output classification layer is removed (model outputs a fixed-size feature vector (embedding) instead of class predictions.
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )

        # Loop iterates through all model parameters and sets the requires_grad attribute based on the value of trainable
        # requires_grad=True parameters are trainable/will be updated during backpropagation
        # requires_grad=False: parameters are frozen/will not be updated during training.
        for p in self.model.parameters():
            p.requires_grad = trainable

    def forward(self, x): # Method for Forward pass of the model (how the input data flows through the network)
        # x: input tensor (image or a batch of images)
        return self.model(x) # pass x through the pretrained model

### CLASS: Text encoder

In [None]:
class TextEncoder(nn.Module): # Inherits from nn.Module
    def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable):
        super().__init__()
        # model_name: name of the DistilBERT model to use.
        # pretrained: whether to use a pre-trained version of DistilBERT. True: model is loaded with pre-trained weights. False: model is initialized with random weights.
        # trainable: whether the DistilBERT model's parameters should be trainable. True: model will be fine-tuned during training. False: parameters are frozen.

        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        # Set requires_grad attribute for all parameters in the DistilBERT model based on the trainable flag.
        for p in self.model.parameters():
            p.requires_grad = trainable

        # In BERT-based models (including DistilBERT), the [CLS] token is a special token that is added at the beginning of each input sequence.
        # The hidden representation of this token is often used as the embedding for the entire sentence or sequence,
        # as it is designed to represent the full meaning of the input.
        self.target_token_idx = 0 # embedding for the [CLS] token (at position 0 in the sequence) will be used as the output embedding for the text sequence.

    def forward(self, input_ids, attention_mask): # How the model processes input data.
        # input_ids: tokenized input text sequences, represented as integers (tokens) that correspond to words or subwords.
        # Each sequence starts with the [CLS] token.
        # attention_mask: indicates which tokens are actual tokens and which are padding
        # (in cases where sequences have different lengths). It allows the model to ignore the padding tokens during processing.

        # Pass input ids and attention masks to DistilBERT, which returns object with various hidden states
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)

        # Extract last hidden state, which contains final hidden layer representations for all tokens in the sequence
        last_hidden_state = output.last_hidden_state # Tensor of shape (batch_size, sequence_length, hidden_size).

        # Extract the hidden state corresponding to the [CLS] token (which is at index 0)
        return last_hidden_state[:, self.target_token_idx, :] # the fixed-size embedding representing the entire input sentence or sequence (tensor of shape (batch_size, hidden_size)

### CLASS: Projection head

In [None]:
# This class defines a projection head, which is responsible for mapping high-dimensional input embeddings into a lower-dimensional space
# (often used before applying a loss function, like contrastive loss).
class ProjectionHead(nn.Module): #inherits from nn.Module
    def __init__( # Constructor
        self,
        embedding_dim, # the size of the input embeddings (dimensionality of the embeddings produced by the image encoder or text encoder)
        projection_dim=CFG.projection_dim, # dimensionality of the space to which the embeddings will be projected
        dropout=CFG.dropout # prevent overfitting
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim) # Fully connected layer, projects from embedding_dim to projection_dim
        self.gelu = nn.GELU() # nonlinear activation function. Allows network to model complex relationships between the input and output.
        self.fc = nn.Linear(projection_dim, projection_dim) # Fully connected layer, projects intermediate space back to the same dimensionality
        self.dropout = nn.Dropout(dropout) # Dropout randomly sets some elements of the tensor to zero during training, helping the network generalize better by preventing overfitting.
        self.layer_norm = nn.LayerNorm(projection_dim) # helps smooth out the values and avoid exploding/vanishing gradients during training.

    def forward(self, x): #  defines how the input embeddings x flow through the layers of the projection head.
        projected = self.projection(x) #  input x is projected to a lower-dimensional space (projection_dim) using the first fully connected layer.
        x = self.gelu(projected) # projected output is passed through the GELU activation function to introduce non-linearity.
        x = self.fc(x) #  fully connected layer is applied to the output of GELU, maintaining the same dimensionality (projection_dim).
        x = self.dropout(x) # Dropout randomly sets some elements to zero. Helps with regularization and reduces overfitting.
        x = x + projected # Residual connection is added: original projection (projected) is added to the output of the fully connected layer.
                          # Stabilizes training and prevents gradient issues. Allows model to learn differences from the original input rather than completely transforming it.
        x = self.layer_norm(x) #  output is normalized across the feature dimension using layer normalization to ensure smooth training
        return x # returns the final transformed embedding, which is now in the projection_dim space and ready for downstream tasks (e.g., contrastive loss).

### CLASS: Clip Model

In [None]:
# Encode both images and text into a shared embedding space, then calculate the contrastive loss between the two.
class CLIPModel(nn.Module):
    def __init__( # Constructor
        self,
        temperature=CFG.temperature, # scalar value used to scale the logits (similarities) before applying softmax.
                                     #It controls the sharpness of the output distribution.
                                     #A lower temperature sharpens the distribution, making high similarities more dominant.
        image_embedding=CFG.image_embedding, # Dimensionality of the image embeddings (i.e., the size of the vector produced by the ImageEncoder).
        text_embedding=CFG.text_embedding, # Dimensionality of the text embeddings (i.e., the size of the vector produced by the TextEncoder).
    ):
        super().__init__()
        self.image_encoder = ImageEncoder() # generates feature embeddings for the input images.
        self.text_encoder = TextEncoder() # generates feature embeddings for the input text.
        self.image_projection = ProjectionHead(embedding_dim=image_embedding) # takes output from img encoder and projects into low-dim space for contrastive learning
        self.text_projection = ProjectionHead(embedding_dim=text_embedding) # similarly for text
        self.temperature = temperature

    def forward(self, batch):
        # Input to model: batch, a dictionary containint image, input ids and attention mask
        image_features = self.image_encoder(batch["image"]) # batch of input images.
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] # tokenized text sequences and corresponding attention masks.
        )

        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features) # project image features.
        text_embeddings = self.text_projection(text_features) # project text features.

        # Calculating Constrastive Loss:
        # 1. Compute similarity scores between the projected text embeddings and the image embeddings using matrix multiplication (@ operator).
        # The dot product between every text embedding and every image embedding is computed to create the logits matrix.
        logits = (text_embeddings @ image_embeddings.T) / self.temperature # similarity matrix: row = text, column = image.

        # 2. Identity matrix as target for cross-entropy loss
        batch_size = logits.shape[0]  # This should be the same for both text and image logits
        targets = torch.eye(batch_size, device=logits.device)  # Ensure batch size consistency


        # 3. Calculate loss using cross-entropy
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')

        # 4. Final loss as the average of the image and text losses, ensuring symmetry between the two modalities.
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)

        return loss.mean() # The mean loss is returned for backpropagation.

### FUNCTION: Custom cross entropy loss

In [None]:
def cross_entropy(preds, targets, reduction='none'):
    # preds: Predicted values (logits) from the model.
    #       These are unnormalized scores (typically before applying softmax) that represent the model's confidence in each class.
    # targets: Identity matrix representing the ground truth (matching pairs).
    # reduction='none': Specifies how to reduce (aggregate) the loss across the batch.
    #                   none: No reduction, the per-sample loss is returned.
    #                   mean: The loss is averaged across all samples in the batch.

    # Create a LogSoftmax layer that applies the log of the softmax function along the last dimension (dim=-1)
    # Softmax converts logits (unnormalized model predictions) into probabilities.
    # LogSoftmax gives the natural logarithm of these probabilities, which is useful for computing log-likelihood-based losses.
    # It is numerically more stable than computing softmax followed by a logarithm
    log_softmax = nn.LogSoftmax(dim=-1)

    # Calculate the loss
    # Since the targets will be an identity matrix, this computes the negative log likelihood
    # for the matching pairs (diagonal entries) and considers non-matching pairs (off-diagonal
    loss = (-targets * log_softmax(preds)).sum(1)

    # Handling Reduction
    if reduction == "none": # function returns the individual loss values for each sample in the batch (i.e., no aggregation is applied).
        return loss
    elif reduction == "mean": # function returns the mean of the individual loss values across the batch.
        return loss.mean()

### FUNCTION: Create Splits (80% Training, 20% Validation)

In [None]:
def make_train_valid_dfs():
    dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") # read CSV file containing captions and image IDs into a Pandas DataFrame.

    # Find maximum value in the id column (which represents the highest image ID) and add 1 to it (helps establish the number of unique images)
    # If CFG.debug is set to True, the max_id is capped at 100 (indicating only 100 samples will be used for debugging).
    max_id = dataframe["id"].max() + 1 if not CFG.debug else 100

    # Create array of image IDs ranging from 0 to max_id-1. These IDs represent all the unique image IDs in the dataset.
    image_ids = np.arange(0, max_id)

    # Set random seed
    np.random.seed(42) # ensures that the random splitting of data is reproducible

    # Select Validation IDs:
    # a) np.random.choice() selects a random subset of image IDs for the validation set.
    # b) size=int(0.2 * len(image_ids)): The size of the validation set is 20% of the total image IDs (calculated as 0.2 * len(image_ids)).
    # c) replace=False: This ensures that the same image ID cannot be selected more than once.
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )

    # Create a list of image IDs for the training set, containing all the IDs that are not in the validation set.
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] # uses a list comprehension to filter out IDs that are in the validation set.

    # Create training and validation DataFrames:
    # a) dataframe["id"].isin(train_ids): filters the rows of the original DataFrame where the image id is in the train_ids list.
    # b) (valid_dataframe) is created by filtering the rows where the id is in the valid_ids list.
    # c) reset_index(drop=True): resets the index of the resulting DataFrames,
                                 # ensuring that the new DataFrames have a clean index starting from 0,
                                 # and drop=True ensures that the old index is not added as a column.
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) # create the training DataFrame.
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)

    # # Log sizes of train and validation sets
    # logging.info(f'Training set size: {len(train_dataframe)}')
    # logging.info(f'Validation set size: {len(valid_dataframe)}')

    return train_dataframe, valid_dataframe # returns two DataFrames: one for training and one for validation.

### FUNCTION: Create data loaders for training and validation

In [None]:
def build_loaders(dataframe, tokenizer, mode):
    # dataframe: Pandas DataFrame containing image paths and captions
    # tokenizer: tokenizer
    # mode: whether the DataLoader is being created for training or validation. Determines the behavior for data augmentation and shuffling.

    # Call function that returns different image transformations based on whether the model is in:
    #     training mode
    #     validation mode
    transforms = get_transforms(mode=mode)

    # Create dataset
    dataset = CLIPDataset(
        dataframe["image"].values, # pass the array of image file paths from the DataFrame.
        dataframe["caption"].values, # pass the array of captions from the DataFrame.
        tokenizer=tokenizer, # to tokenize captions
        transforms=transforms, # image transformations, specific to the mode (train/validation), are applied to the images in the dataset.
    )

    # Create a PyTorch DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset, # dataset object created above
        batch_size=CFG.batch_size, # batch size for loading the data
        num_workers=CFG.num_workers, # number of worker processes used to load the data in parallel. More workers = faster data loading (depends on the system’s hardware).
        shuffle=True if mode == "train" else False, # Shuffling is important during training to ensure that the model doesn't learn the order of the data
    )

    # # Log some information about the dataloader
    # logging.info(f'{mode.capitalize()} DataLoader created with {len(dataset)} samples, batch size {CFG.batch_size}, and {CFG.num_workers} workers.')

    return dataloader # returns the DataLoader, which can then be used in the training/validation loop to load batches of images and captions.

### FUNCTION: Train the model

In [None]:
# This function processes a single epoch of training,
# updating the model weights using the optimizer and adjusting the learning rate using the learning rate scheduler.
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    # model: The neural network model being trained
    # train_loader: The DataLoader for the training dataset, which yields batches of data during training.
    # optimizer: The optimizer responsible for updating the model's weights.
    # lr_scheduler: A learning rate scheduler that adjusts the learning rate during training.
    # step: Determines when the learning rate scheduler should be updated, either at each batch or at the end of each epoch.

    # Initialize Loss Tracking
    loss_meter = AvgMeter()

    # Initialize Progress Bar (tqdm)
    tqdm_object = tqdm(train_loader, total=len(train_loader))

    # Main training loop
    for batch in tqdm_object: # iterate over each batch of data in the train_loader.

        # 1. Move data to device
        # The if k != "caption" condition excludes the caption (text data), which is likely not needed by the model directly in this form.
        batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}

        # 2. Zero out the gradients of the previous batch before updating model weights.
        # This ensures that gradients from earlier batches don’t accumulate.
        optimizer.zero_grad()

        # 3. Forward pass and loss calculation
        #    Use autocast for mixed precision training (if enabled)
        with autocast():
            loss = model(batch)

        # 4. Backward Pass (Compute Gradients) with Scaled Gradients
        # This step calculates how much each parameter should be adjusted to minimize the loss.
        scaler.scale(loss).backward() # compute the gradients of the loss with respect to the model's parameters using backpropagation.

        # 5. Optimizer Step (Update Weights) with Scaled Gradients
        #  This step updates the model’s parameters using the gradients that were computed in the previous step.
        # The optimizer adjusts the parameters in the direction that minimizes the loss.
        scaler.step(optimizer)
        scaler.update()

        # 6. Update Loss Tracking
        count = batch["image"].size(0) # gets the number of samples in the current batch
        loss_meter.update(loss.item(), count) #  update the running average of the loss using the loss value for this batch (loss.item()) and the batch size (count).

        # 7. Update Progress Bar
        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))

        # # 8. Logging loss and learning rate per batch
        # logging.info(f"Batch loss: {loss_meter.avg:.4f}, Learning rate: {get_lr(optimizer):.6f}")

    # # Log final loss at the end of the epoch
    # logging.info(f"End of epoch loss: {loss_meter.avg:.4f}")

    return loss_meter # contains the running average of the training loss over the entire epoch.

### FUNCTION: Handle the validation phase

In [None]:
def valid_epoch(model, valid_loader):
    # model: The neural network model being validated.
    # valid_loader: A DataLoader object that yields batches of data from the validation set.
    # The validation DataLoader is created in a similar way as the training DataLoader,
    # but typically without data augmentation and without shuffling.

    # Initialize Loss Tracking
    loss_meter = AvgMeter()

    # Initialize Progress Bar (tqdm)
    tqdm_object = tqdm(valid_loader, total=len(valid_loader))

    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for validation

        # Main Validation Loop
        for batch in tqdm_object: #  iterate through each batch of validation data yielded by the valid_loader.

            # 1. Move data to device
            batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"}

            # 2. Forward Pass and Loss Calculation
            # No backward pass or gradient updates are required here since this is validation.
            loss = model(batch)

            # 3. Update Loss Tracking
            count = batch["image"].size(0) # the number of images in the current batch;  used to properly weight the average loss calculation.
            loss_meter.update(loss.item(), count) # The loss for this batch (loss.item()) and the batch size (count) are used to update the running average of the loss (AvgMeter).

            # 4. Update Progress Bar
            tqdm_object.set_postfix(valid_loss=loss_meter.avg)

    return loss_meter # contains the final average validation loss for the entire epoch

### Function to plot loss curves

In [None]:
def plot_loss_curves(train_losses, valid_losses):
    """
    Plots the training and validation loss curves.
    """
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(valid_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss Curves')
    plt.grid(True)  # Add grid for better readability
    plt.savefig('loss_curves.png')  # Save the plot as a file (optional)
    plt.show()

## FUNCTION: Save checkpoint to resume training

In [None]:
def save_checkpoint(epoch, model, optimizer, lr_scheduler, best_loss, file_path="checkpoint.pth"):
    # Create a dictionary containing all the necessary training information
    state = {
        'epoch': epoch,  # Save the current epoch
        'model_state_dict': model.state_dict(),  # Save the model's parameters
        'optimizer_state_dict': optimizer.state_dict(),  # Save the optimizer's state
        'scheduler_state_dict': lr_scheduler.state_dict(),  # Save the scheduler's state
        'best_loss': best_loss  # Save the best validation loss encountered so far
    }

    # Save the dictionary as a file using torch.save
    torch.save(state, file_path)
    print(f"Checkpoint saved at epoch {epoch + 1}")

## FUNCTION: Load saved checkpoint to resume training

In [None]:
def load_checkpoint(file_path, model, optimizer, lr_scheduler):
    try:
        checkpoint = torch.load(file_path)
    except FileNotFoundError:
        print(f"Checkpoint file not found: {file_path}")
        return None, None
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        return None, None

    # Load the saved model, optimizer, and scheduler states
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # Return the epoch and best validation loss to resume training
    return checkpoint['epoch'], checkpoint['best_loss']

### FUNCTION: Main (training and validation loop)

In [None]:
def main(resume=False): # Set to TRUE if training was interrupted
    #  Prepare Training and Validation DataFrames
    train_df, valid_df = make_train_valid_dfs() #  split the dataset into training and validation sets. It returns two DataFrames

    # Initialize Tokenizer
    tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)

    # Build Data Loaders
    train_loader = build_loaders(train_df, tokenizer, mode="train")
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

    # Initialize the Model
    model = CLIPModel().to(CFG.device)

    # Set Up the Optimizer Parameters
    # Different learning rates for different parts of the model
    params = [
        {"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr}, # specifies the parameters of the image encoder. No decay.
        {"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr}, # sets the learning rate for the text encoder parameters. No decay.
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters() # groups the parameters of both the image and text projection heads. Decay.
        ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
    ]

    # Initialize the Optimizer
    optimizer = torch.optim.AdamW(params)

    # Initialize the Learning Rate Scheduler:
    #    a) torch.optim.lr_scheduler.ReduceLROnPlateau: This learning rate scheduler reduces the learning rate when a metric
    #       (in this case, validation loss) stops improving.
    #    b) mode="min": The scheduler monitors the minimum validation loss, meaning the learning rate will decrease when the loss plateaus or increases.
    #    c) patience=CFG.patience: Number of epochs with no improvement after which the learning rate will be reduced.
    #    d) factor=CFG.factor: The factor by which the learning rate will be reduced (e.g., if factor=0.1, the learning rate is multiplied by 0.1).

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=CFG.lr_scheduler_patience, factor=CFG.lr_scheduler_factor
    )

    # Initialize variables to track progress
    best_loss = float('inf') # initialize best_loss to infinity so that any actual loss from validation can be considered better.
    start_epoch = 0

    # Resume from checkpoint if needed
    if resume:
        # If resuming, load from the checkpoint file
        checkpoint = torch.load("checkpoint.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        print(f"Resuming from epoch {start_epoch+1} with best validation loss {best_loss:.4f}")

    #  Training Loop
    patience_counter = 0
    early_stopping_patience = CFG.early_stopping_patience  # From CFG

    # Initialize lists to store loss values for plotting
    train_losses = []
    valid_losses = []

    for epoch in range(CFG.epochs):
        print(f"Epoch: {epoch + 1}")

        #  Training Phase
        model.train() # ensures that certain layers like dropout and batch normalization behave differently compared to evaluation mode.
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step="epoch")
        train_losses.append(train_loss.avg) # for plotting

        #  Validation Phase
        model.eval() # changes the behavior of certain layers like dropout and batch normalization.
        with torch.no_grad(): # disables gradient computation during validation to reduce memory usage and speed up the process.
            valid_loss = valid_epoch(model, valid_loader)
        valid_losses.append(valid_loss.avg) # for plotting

        ## DEBUG ##
        #   Print the losses and learning rate for debugging purposes
        print(f"Training Loss: {train_loss.avg:.4f}, Validation Loss: {valid_loss.avg:.4f}")
        print(f"Learning Rate: {get_lr(optimizer):.6f}")

        ## REMOVE IF NOT USED
        # #  Log Metrics (Training Loss, Validation Loss)
        # log_metrics(epoch + 1, train_loss.avg, valid_loss.avg)  # Log training and validation loss

        #  Save the Best Model and Checkpoint
        if valid_loss.avg < best_loss: # If the validation loss for this epoch is better (lower) than the previous best loss, save the model’s parameters.
            best_loss = valid_loss.avg # update best_loss to the current epoch’s validation loss.
            patience_counter = 0  # Reset patience if there's an improvement
            torch.save(model.state_dict(), "best30k07.pt") # save the model’s parameters to a file named best.pt when the best model is found.
            print("Saved Best Model!")
        else:
            patience_counter += 1  # Increment patience counter if no improvement

        # Save checkpoint at the end of each epoch
        save_checkpoint(epoch, model, optimizer, lr_scheduler, best_loss)

        #  Early Stopping Check
        if patience_counter >= early_stopping_patience:
            print(f"Early stopping at epoch {epoch + 1} due to no improvement.")
            break

        #  Learning Rate Adjustment:
        #       If the validation loss stops improving, the learning rate is reduced according to the learning rate scheduler settings.
        lr_scheduler.step(valid_loss.avg)

    #  Plot the loss curves after training is completed
    plot_loss_curves(train_losses, valid_losses)

In [None]:
 main()