In [5]:
!nvidia-smi

Sun Dec  8 23:47:41 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 561.09                 Driver Version: 561.09         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   51C    P0             22W /   80W |       0MiB /   6144MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)
   ---------------------------------------- 0.0/926.4 kB ? eta -:--:--
   ---------------------- ----------------- 524.3/926.4 kB 4.2 MB/s eta 0:00:01
   ---------------------------------------- 926.4/926.4 kB 4.7 MB/s eta 0:00:00
Downloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.11.9 torchmetrics-1.6.0


# Image captioning Transformer

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import math
from tqdm.notebook import trange, tqdm
import random

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.distributions import Categorical

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

from torchmetrics.text import BLEUScore

from torchsummary import summary

import time

import wandb

from transformers import AutoTokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

torch.backends.cuda.matmul.allow_tf32 = True

## parameters 

In [None]:
image_size = 128

batch_size = 128

# Model

## Positional embedding

In [None]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        DEVICE = x.DEVICE
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, DEVICE=DEVICE) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

## Encoder

In [None]:
def extract_patches(image_tensor, patch_size=16):
    # Get the dimensions of the image tensor
    b, c, h, w = image_tensor.size()

    # Define the Unfold layer with appropriate parameters
    unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)

    # Apply Unfold to the image tensor
    unfolded = unfold(image_tensor)

    # Reshape the unfolded tensor to match the desired output shape
    # Output shape: BxLxH, where L is the number of patches in each dimension
    unfolded = unfolded.transpose(1, 2).reshape(b, -1, c * patch_size * patch_size)

    return unfolded

In [None]:
class Encoder(nn.Module): #base on VIT
    def __init__(self, image_size, channels_in, patch_size=16, hidden_size=128,
                 num_layers=3, num_heads=4):
        super(Encoder, self).__init__()

        self.patch_size = patch_size
        self.fc_in = nn.Linear(channels_in * patch_size * patch_size, hidden_size)

        seq_length = (image_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length,
                                                      hidden_size).normal_(std=0.02))

        # Create multiple transformer blocks as layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward= hidden_size*4, 
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)


    def forward(self, image):
        b = image.shape[0]

        patch_seq = extract_patches(image, patch_size=self.patch_size)
        patch_emb = self.fc_in(patch_seq)

        # Add a unique embedding to each token embedding
        embs = patch_emb + self.pos_embedding

        # Pass the embeddings through each transformer block
        output = self.transformer_encoder(embs)

        return output


In [None]:
from transformers import ViTModel

class ImageEmbedding(nn.Module):
    def __init__(self, vit_model="google/vit-base-patch16-224-in21k"):
        super(ImageEmbedding, self).__init__()
        self.vit = ViTModel.from_pretrained(vit_model)
        
        # Use only the patch embedding layer
        self.patch_embed = self.vit.embeddings.patch_embeddings
        self.pos_embed = self.vit.embeddings.position_embeddings

    def forward(self, images):
        # Extract patch embeddings
        patch_embeds = self.patch_embed(images)
        
        # Add positional embeddings
        embs = patch_embeds + self.pos_embed
        return embs

In [None]:
class Encoder2(nn.Module): #base on VIT
    def __init__(self, hidden_size=128, num_layers=3, num_heads=4):
        super(Encoder2, self).__init__()

        # Create multiple transformer blocks as layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward= hidden_size*4, 
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)


    def forward(self, image):
        b = image.shape[0]

        embs = ImageEmbedding()

        # Pass the embeddings through each transformer block
        output = self.transformer_encoder(embs)

        return output

## Decoder

In [None]:
class Decoder(nn.Module): #base on BERT
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Decoder, self).__init__()

        # Create an embedding layer for tokens
        self.embedding = nn.Embedding(num_emb, hidden_size)
        # Initialize the embedding weights
        self.embedding.weight.data = 0.001 * self.embedding.weight.data

        # Initialize sinusoidal positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)

        # Create multiple transformer blocks as layers
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=num_heads,dim_feedforward= hidden_size*4, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Define a linear layer for output prediction
        self.fc_out = nn.Linear(hidden_size, num_emb)

        self.softmax = nn.Softmax()

    def forward(self, input_seq, encoder_output, input_padding_mask=None,
                encoder_padding_mask=None):
        # Embed the input sequence
        input_embs = self.embedding(input_seq)
        b, l, h = input_embs.shape

        # Add positional embeddings to the input embeddings
        seq_indx = torch.arange(l, DEVICE=input_seq.DEVICE)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(b, l, h)
        embs = input_embs + pos_emb

        # Pass the embeddings through each transformer block
        output = self.transformer_decoder(tgt = embs, memory=encoder_output, memory_mask=None, 
                                          tgt_key_padding_mask=input_padding_mask, memory_key_padding_mask=encoder_padding_mask,
                                          tgt_is_causal=True, memory_is_causal=False)

        output = self.softmax(self.fc_out(output))
        
        return output

In [None]:
from transformers import BertModel

class TextEmbedding(nn.Module):
    def __init__(self, bert_model="bert-base-uncased"):
        super(TextEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model)
        
        # Extract word and positional embeddings
        self.embedding = self.bert.embeddings.word_embeddings

    def forward(self, tokens):

        tokens = tokens.to(self.embedding.weight.device)
        
        text_embeds = self.embedding(tokens)
        return text_embeds


In [None]:
class Decoder2(nn.Module): #base on BERT
    def __init__(self, num_emb, hidden_size=128, num_layers=3, num_heads=4):
        super(Decoder2, self).__init__()

        self.embedding = TextEmbedding

        # Initialize sinusoidal positional embeddings
        self.pos_emb = SinusoidalPosEmb(hidden_size)

        # Create multiple transformer blocks as layers
        decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_size, nhead=num_heads,dim_feedforward= hidden_size*4, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Define a linear layer for output prediction
        self.fc_out = nn.Linear(hidden_size, num_emb)

        self.softmax = nn.Softmax()

    def forward(self, input_seq, encoder_output, input_padding_mask=None,
                encoder_padding_mask=None):
        # Embed the input sequence
        input_embs = self.embedding(input_seq)
        b, l, h = input_embs.shape

        # Add positional embeddings to the input embeddings
        seq_indx = torch.arange(l, DEVICE=input_seq.DEVICE)
        pos_emb = self.pos_emb(seq_indx).reshape(1, l, h).expand(b, l, h)
        embs = input_embs + pos_emb

        # Pass the embeddings through each transformer block
        output = self.transformer_decoder(tgt = embs, memory=encoder_output, memory_mask=None, 
                                          tgt_key_padding_mask=input_padding_mask, memory_key_padding_mask=encoder_padding_mask,
                                          tgt_is_causal=True, memory_is_causal=False)

        output = self.softmax(self.fc_out(output))
        
        return output

## Encode-Decoder

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, image_size, channels_in, num_emb, patch_size=16,
                 hidden_size=128, num_layers=(3, 3), num_heads=4):
        super(EncoderDecoder, self).__init__()

        # Create an encoder and decoder with specified parameters
        self.encoder = Encoder(image_size=image_size, channels_in=channels_in,
                                     patch_size=patch_size, hidden_size=hidden_size,
                                     num_layers=num_layers[0], num_heads=num_heads)

        self.decoder = Decoder(num_emb=num_emb, hidden_size=hidden_size,
                               num_layers=num_layers[1], num_heads=num_heads)

    def forward(self, input_image, target_seq, padding_mask):
        # Generate padding masks for the target sequence
        bool_padding_mask = padding_mask == 0

        # Encode the input sequence
        encoded_seq = self.encoder(image=input_image)

        # Decode the target sequence using the encoded sequence
        decoded_seq = self.decoder(input_seq=target_seq,
                                   encoder_output=encoded_seq,
                                   input_padding_mask=bool_padding_mask)
        return decoded_seq

In [None]:
class EncoderDecoder2(nn.Module):
    def __init__(self, num_emb,
                 hidden_size=128, num_layers=(3, 3), num_heads=4):
        super(EncoderDecoder2, self).__init__()

        # Create an encoder and decoder with specified parameters
        self.encoder = Encoder2(hidden_size=hidden_size,
                                     num_layers=num_layers[0], num_heads=num_heads)

        self.decoder = Decoder2(num_emb=num_emb, hidden_size=hidden_size,
                               num_layers=num_layers[1], num_heads=num_heads)

    def forward(self, input_image, target_seq, padding_mask):
        # Generate padding masks for the target sequence
        bool_padding_mask = padding_mask == 0

        # Encode the input sequence
        encoded_seq = self.encoder(image=input_image)

        # Decode the target sequence using the encoded sequence
        decoded_seq = self.decoder(input_seq=target_seq,
                                   encoder_output=encoded_seq,
                                   input_padding_mask=bool_padding_mask)
        return decoded_seq

# Pretrain Model 

## Encoder

In [None]:
from transformers import ViTModel, ViTConfig

class ImageEncoder(nn.Module):
    def __init__(self, pretrained_model="google/vit-base-patch16-224-in21k"):
        super(ImageEncoder, self).__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model)
        
    def forward(self, images):
        vit_output = self.vit(pixel_values=images)
        return vit_output.last_hidden_state


## Decoder

In [None]:
from transformers import BertModel, BertTokenizer

class CaptionDecoder(nn.Module):
    def __init__(self, pretrained_model="bert-base-uncased"):
        super(CaptionDecoder, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)

    def forward(self, captions, encoder_outputs):
        # Tokenize captions
        tokens = self.tokenizer(captions, padding=True, return_tensors="pt").input_ids
        tokens = tokens.to(encoder_outputs.device)

        # Pass through BERT, using image features as encoder outputs
        outputs = self.bert(
            input_ids=tokens,
            encoder_hidden_states=encoder_outputs,
            encoder_attention_mask=torch.ones_like(tokens)
        )
        return outputs.last_hidden_state


## Encoder-Decoder

In [None]:
class ImageCaptioningModel(nn.Module):
    def __init__(self, vit_model="google/vit-base-patch16-224-in21k",
                 bert_model="bert-base-uncased"):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = ImageEncoder(pretrained_model=vit_model)
        self.decoder = CaptionDecoder(pretrained_model=bert_model)

    def forward(self, images, captions):
        encoder_outputs = self.encoder(images)
        decoder_outputs = self.decoder(captions, encoder_outputs)
        return decoder_outputs

# Hyper Parameters

In [None]:
DEVICE = torch.device(1 if torch.cuda.is_available() else 'cpu')
DEVICE

In [None]:
learning_rate = 1e-4

epochs = 200

model_path = "/model/Transformer_model.pth"

In [None]:
hidden_size = 192

# Number of Transformer blocks for the (Encoder, Decoder)
num_layers = (6, 6)

# MultiheadAttention Heads
num_heads = 8

# Size of the patches
patch_size = 8

In [None]:
model = EncoderDecoder(image_size=image_size, channels_in=test_images.shape[1],
                                     num_emb=tokenizer.vocab_size, patch_size=patch_size,
                                     num_layers=num_layers,hidden_size=hidden_size,
                                     num_heads=num_heads).to(DEVICE)

model = model.to(DEVICE)
#summary(model, (3, 256, 256))

# Training

## Wandb para

In [None]:
PROJECT = "Image_Captioning"
RESUME = "allow"
WANDB_KEY = "d9d14819dddd8a35a353b5c0b087e0f60d717140"

## Set up

In [None]:
class TokenDrop(nn.Module):
    """For a batch of tokens indices, randomly replace a non-specical token.

    Args:
        prob (float): probability of dropping a token
        blank_token (int): index for the blank token
        num_special (int): Number of special tokens, assumed to be at the start of the vocab
    """

    def __init__(self, prob=0.1, blank_token=1, eos_token=102):
        self.prob = prob
        self.eos_token = eos_token
        self.blank_token = blank_token

    def __call__(self, sample):
        # Randomly sample a bernoulli distribution with p=prob
        # to create a mask where 1 means we will replace that token
        mask = torch.bernoulli(self.prob * torch.ones_like(sample)).long()

        # only replace if the token is not the eos token
        can_drop = (~(sample == self.eos_token)).long()
        mask = mask * can_drop

        # Do not replace the sos tokens
        mask[:, 0] = torch.zeros_like(mask[:, 0]).long()

        replace_with = (self.blank_token * torch.ones_like(sample)).long()

        sample_out = (1 - mask) * sample + mask * replace_with

        return sample_out

In [None]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

scaler = torch.cuda.amp.GradScaler()

loss_fn = nn.CrossEntropyLoss(reduction="none")

td = TokenDrop(0.5)

# Initialize the training loss logger
training_loss_logger = []

In [None]:
wandb.login(
    key = "d9d14819dddd8a35a353b5c0b087e0f60d717140",
)

In [None]:
wandb.init(
    project=PROJECT,
    resume=RESUME,
    name="init_transformer",
    config={
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
    },
)
wandb.watch(model)

## Train

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=15, factor=0.1, verbose=True)

In [None]:
best_val_loss = float('inf')
bleu_metric = BLEUScore(n_gram=4, smooth=True).to(DEVICE)

def train_epoch(model, dataloader, optimizer, loss_fn, DEVICE):

    model.train()
    running_loss = 0.0
    # Iterate over the training data loader
    for images, captions in tqdm(dataloader, desc="Training", leave=False):

        images = images.to(DEVICE)

        # Tokenize and pre-process the captions
        tokens = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")
        token_ids = tokens['input_ids'].to(DEVICE)
        padding_mask = tokens['attention_mask'].to(DEVICE)
        b = token_ids.shape[0]

        # Shift right the input sequence to create the target sequence
        target_ids = torch.cat((token_ids[:, 1:],
                                torch.zeros(b, 1, DEVICE=DEVICE).long()), 1)

        tokens_in = td(token_ids)
        with torch.cuda.amp.autocast():
            # Forward pass
            pred = model(images, tokens_in, padding_mask=padding_mask)
            loss = (loss_fn(pred.transpose(1, 2), target_ids) * padding_mask).mean()

        # Backpropagation
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)

    return epoch_loss

def validate_epoch(model, dataloader, loss_fn, DEVICE):
    model.eval()
    total_val_loss = 0
    bleu_metric.reset()  # Reset BLEU metric before validation

    with torch.no_grad():
        for images, captions in tqdm(dataloader, desc="Validation", leave=False):
            images = images.to(DEVICE)
            tokens = tokenizer(captions, padding=True, truncation=True, return_tensors="pt")
            token_ids = tokens['input_ids'].to(DEVICE)
            padding_mask = tokens['attention_mask'].to(DEVICE)
            b = token_ids.shape[0]

            target_ids = torch.cat((token_ids[:, 1:], torch.zeros(b, 1, DEVICE=DEVICE).long()), 1)
            tokens_in = td(token_ids)

            # Forward pass
            with torch.cuda.amp.autocast():
                pred = model(images, tokens_in, padding_mask=padding_mask)

            # Compute validation loss
            val_loss = (loss_fn(pred.transpose(1, 2), target_ids) * padding_mask).mean()
            total_val_loss += val_loss.item()

            # Decode predictions and targets
            pred_ids = torch.argmax(pred, dim=2)
            pred_texts = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            target_texts = tokenizer.batch_decode(target_ids, skip_special_tokens=True)

            # Update BLEU score
            bleu_metric.update(pred_texts, target_texts)

    epoch_loss = total_val_loss / len(dataloader)
    avg_bleu_score = bleu_metric.compute()

    return [epoch_loss, avg_bleu_score]


for epoch in trange(0, epochs, leave=False, desc="Epoch"):
    start_time = time.time()

    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, DEVICE)
    val_loss, val_bleu = validate_epoch(model, val_loader, loss_fn, DEVICE)

    scheduler.step(val_loss)

    current_lr = optimizer.param_groups[0]['lr']
    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1}/{epochs}, Time: {epoch_time:.2f}s, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Dice_Coefficient: {val_bleu:.4f}, Learning Rate: {current_lr:.8f}")


    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), model_path)
        print(f"New best checkpoint saved with val_loss: {val_loss:.4f}")

    # Log results to WandB
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "Dice_Coefficient": val_bleu,
        "learing_rate": current_lr,
    })

wandb.finish()


