In [None]:
!pip install transformers

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
from PIL import Image
import numpy as np
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence

In [None]:
!pip install datasets

from datasets import list_datasets
datasets_list = list_datasets()
len(datasets_list)

In [None]:
from datasets import load_dataset

dataset = load_dataset("nlphuji/flickr30k")

In [None]:
print(dataset.keys())

In [None]:
from datasets import DatasetDict

# Create a DatasetDict with separate splits based on the 'split' column
dataset_dict = DatasetDict({
    'train': dataset['test'].filter(lambda x: x['split'] == 'train'),
    'validation': dataset['test'].filter(lambda x: x['split'] == 'val'),
    'test': dataset['test'].filter(lambda x: x['split'] == 'test')
})

In [None]:
# all code below is from https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch:

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device", device)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        # d_model: Dimensionality of the input.
        # num_heads: The number of attention heads to split the input into.
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value
        
        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)
        
        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        
        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [None]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.patch_projection = nn.Linear(768, d_model)  # Add a linear projection layer
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        batch_size, src_len, _ = src.size()
        tgt_len = tgt.size(1)
        src_mask = torch.ones(batch_size, 1, 1, src_len).bool().to(src.device)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        nopeak_mask = (1 - torch.triu(torch.ones(1, tgt_len, tgt_len), diagonal=1)).bool().to(tgt.device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_projected = self.patch_projection(src)  # Project the image patches to the desired dimension
        src_embedded = self.dropout(self.positional_encoding(src_projected))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [None]:
# DON'T RUN THIS ANYMORE

def old_process_image(image_path, target_size=(256, 256), patch_size=16):
    # A patch size of 16 means that each patch is 16 pixels tall and 16 pixels wide
    # Load the image
    image = Image.open(image_path)
    
    # Resize the image to the target size
    resized_image = image.resize(target_size)
    
    # Convert the image to a numpy array
    image_array = np.array(resized_image)
    # print("image_array.shape", image_array.shape)
    
    # Split the image into patches
    patches = image_array.reshape(target_size[0] // patch_size, patch_size,
                                  target_size[1] // patch_size, patch_size,
                                  image_array.shape[2]).swapaxes(1, 2).reshape(-1, patch_size, patch_size, image_array.shape[2])

    # print("patches.shape", patches.shape)
    
    # Flatten the patches
    flattened_patches = patches.reshape(patches.shape[0], -1)
    
    # Normalize the flattened patches
    normalized_patches = flattened_patches / 255.0

    # Return a tensor instead of a numpy array:
    return torch.tensor(normalized_patches, dtype=torch.float32).to(device)

# # Example usage
# image_path = "./flashcard.jpg"
# processed_patches = old_process_image(image_path)

# print(f"Original image shape: {np.array(Image.open(image_path)).shape}")
# print(f"Processed patches shape: {processed_patches.shape}")
# print(f"Processed patches: {processed_patches}")

In [None]:
def process_image(image, target_size=(256, 256), patch_size=16):
    # A patch size of 16 means that each patch is 16 pixels tall and 16 pixels wide
    # Resize the image to the target size
    resized_image = image.resize(target_size)
    
    # Convert the image to a numpy array
    image_array = np.array(resized_image)
    
    # Split the image into patches
    patches = image_array.reshape(target_size[0] // patch_size, patch_size,
                                  target_size[1] // patch_size, patch_size,
                                  image_array.shape[2]).swapaxes(1, 2).reshape(-1, patch_size, patch_size, image_array.shape[2])
    
    # Flatten the patches
    flattened_patches = patches.reshape(patches.shape[0], -1)
    
    # Normalize the flattened patches
    normalized_patches = flattened_patches / 255.0
    
    # Return a tensor instead of a numpy array:
    return torch.tensor(normalized_patches, dtype=torch.float32).to(device)

In [None]:
# Load the pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
def preprocess_dataset(example):
    image = example['image']
    caption = example['caption'][0]  # Use only the first caption
    
    # Process the image
    image_patches = process_image(image)
    
    # Tokenize the caption
    caption_tokens = tokenizer.tokenize(caption)
    
    # Add start and end tokens
    caption_input = [tokenizer.cls_token] + caption_tokens
    caption_label = caption_tokens + [tokenizer.sep_token]
    
    # Convert tokens to ids
    caption_input_ids = tokenizer.convert_tokens_to_ids(caption_input)
    caption_label_ids = tokenizer.convert_tokens_to_ids(caption_label)
    
    return {
        'image_patches': image_patches,
        'caption_input_ids': caption_input_ids,
        'caption_label_ids': caption_label_ids
    }

# Preprocess the dataset
processed_dataset = dataset_dict.map(preprocess_dataset, batched=False, remove_columns=dataset_dict['train'].column_names)

In [None]:
class ImageCaptionDataset(data.Dataset):
    def __init__(self, image_paths, captions, tokenizer, patch_size=16, target_size=(256, 256)):
        self.image_paths = image_paths
        self.captions = captions
        self.tokenizer = tokenizer
        self.patch_size = patch_size
        self.target_size = target_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        caption = self.captions[index]
        image_patches = process_image(image_path, self.target_size, self.patch_size)
        # Tokenize the caption
        caption_tokens = self.tokenizer.tokenize(caption)
        # Add start and end tokens
        caption_input = ['[CLS]'] + caption_tokens
        caption_label = caption_tokens + ['[SEP]']
        # Convert tokens to ids
        caption_input_ids = self.tokenizer.convert_tokens_to_ids(caption_input)
        caption_label_ids = self.tokenizer.convert_tokens_to_ids(caption_label)
        return image_patches, torch.tensor(caption_input_ids), torch.tensor(caption_label_ids).to(device)

In [None]:
# Example usage
# image_paths = ["./flashcard.jpg", "./flashcard.jpg", "./flashcard.jpg", "./not-a-flashcard.jpg", "./not-a-flashcard.jpg", "./not-a-flashcard.jpg"]
# captions = ["this is a flashcard", "this is a flashcard", "this is a flashcard", "this is not a flashcard", "this is not a flashcard", "this is not a flashcard"]

def collate_fn(batch):
    image_patches = [item['image_patches'] for item in batch]
    caption_inputs = [item['caption_input_ids'] for item in batch]
    caption_labels = [item['caption_label_ids'] for item in batch]
    
    # Convert image patches, caption inputs, and labels to tensors
    image_patches = [torch.tensor(patch) for patch in image_patches]
    caption_inputs = [torch.tensor(caption) for caption in caption_inputs]
    caption_labels = [torch.tensor(caption) for caption in caption_labels]
    
    # Pad the caption inputs and labels to the same length
    caption_inputs = pad_sequence(caption_inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
    caption_labels = pad_sequence(caption_labels, batch_first=True, padding_value=-100)  # -100 is the ignore index for CrossEntropyLoss
    
    # Stack the image patches
    image_patches = torch.stack(image_patches)
    
    return image_patches.to(device), caption_inputs.to(device), caption_labels.to(device)


In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(processed_dataset['train'], batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(processed_dataset['validation'], batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(processed_dataset['test'], batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
# DON'T RUN THIS ANYMORE:

# # Create the dataset and data loader
# old_dataset = ImageCaptionDataset(image_paths, captions, tokenizer)
# # data_loader = data.DataLoader(dataset, batch_size=2, shuffle=True)
# # use this one when I need to pad:
# old_data_loader = data.DataLoader(old_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [None]:
# testing the stuff that's in the Dataset:
# image_patches_1 = process_image("./flashcard.jpg")
# image_patches_2 = process_image("./not-a-flashcard.jpg")
# print("image_patches_2",image_patches_2)
# ^ that's all good!

# caption = "this is a flashcard"
# caption_tokens = tokenizer.tokenize(caption)
# print("caption_tokens", caption_tokens)
# caption_ids = tokenizer.encode(caption, add_special_tokens=True)
# print("caption_ids", caption_ids)
# ^ that's all good!

In [None]:
# Define the model parameters
tgt_vocab_size = tokenizer.vocab_size
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 256  # Set max_seq_length to the number of image patches
dropout = 0.1

print("tgt_vocab_size", tgt_vocab_size)

In [None]:
# Create the transformer model
model = Transformer(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


In [None]:
from tqdm import tqdm

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    # Set the model to training mode
    model.train()
    
    # Initialize the progress bar
    progress_bar = tqdm(train_dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", unit="batch")
    
    # Initialize variables for tracking training progress
    total_loss = 0.0
    num_batches = 0
    
    for batch in progress_bar:
        image_patches, caption_inputs, caption_labels = batch
        image_patches = image_patches.float()
        
        # Move tensors to the appropriate device (CPU or GPU)
        image_patches = image_patches.to(device)
        caption_inputs = caption_inputs.to(device)
        caption_labels = caption_labels.to(device)
        
        # Forward pass
        output = model(image_patches, caption_inputs)
        
        # Compute loss and perform backward pass
        loss = criterion(output.view(-1, tgt_vocab_size), caption_labels.view(-1))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Update training progress variables
        total_loss += loss.item()
        num_batches += 1
        
        # Update the progress bar
        progress_bar.set_postfix(loss=loss.item())
    
    # Calculate and print average loss for the epoch
    avg_loss = total_loss / num_batches
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

In [None]:
torch.save(model.state_dict(), 'trained_model.pth')

In [None]:
# Create a new instance of the model
trained_model = Transformer(tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout).to(device)

# Load the trained model's state dictionary
trained_model.load_state_dict(torch.load('trained_model.pth'))

# Set the model to evaluation mode
trained_model.eval()

In [None]:
# Load and preprocess the input image
input_image_path1 = './dog.jpeg'
input_image_path2 = './2men.jpeg'
input_image_path3 = './woman.jpeg'
input_image_path4 = './cat.jpeg'
input_image_path5 = './cat2.jpeg'
input_image_path6 = './dog2.jpeg'
input_image_path7 = './girl.jpeg'
input_image_path8 = './blue.jpeg'
input_image_path9 = './red.jpeg'
input_image_path10 = './dog3.jpeg'
input_image_path11 = './dog4.jpeg'
input_image_path12 = './dogg.png'
input_image_patches = old_process_image(input_image_path12)
# Add batch dimension to the image patches
input_image_patches = input_image_patches.unsqueeze(0)
# Move the input tensor to the same device as the model
input_image_patches = input_image_patches.to(device)


print("input_image_patches", input_image_patches)

In [None]:
# Set the maximum length for the generated caption
max_caption_length = 20

# Initialize the caption with the start token
generated_caption = [tokenizer.cls_token_id]

with torch.no_grad():
    for _ in range(max_caption_length):
        # Convert the generated caption to a tensor
        caption_tensor = torch.tensor(generated_caption).unsqueeze(0).to(device)

        # Generate the next token in the caption
        output = trained_model(input_image_patches, caption_tensor)
        predicted_token = output[0, -1].argmax(dim=0)

        # Move the predicted token to the CPU before appending it to the generated caption
        predicted_token = predicted_token.cpu()

        # Append the predicted token to the generated caption
        generated_caption.append(predicted_token.item())

        # Stop generation if the end token is predicted
        if predicted_token.item() == tokenizer.sep_token_id:
            break

# Convert the generated caption tokens to text
generated_text = tokenizer.decode(generated_caption[1:-1])

print("Generated Caption:", generated_text)