## Module 4 Project 1: CLIP
- Implement a simple [CLIP](https://openai.com/research/clip) model to start working with image-text embeddings
- Train the model on some data and evlauate the results


In [None]:
!pip install datasets --quiet

## STEP 1: IMPORTS
- We need to import `torch` and `transformers` for most of the utility methods needed here, as well as our tokenizer and dataloader classes
- We want `matplotlib` and `numpy` for displaying the images after training when queried with a caption string

In [None]:
import torch
import matplotlib.pyplot as plt
from transformers import AutoModel, AutoTokenizer, BertTokenizer
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

## STEP 2: HYPERPARAMETERS
- We will be using [DistilBERT](https://huggingface.co/docs/transformers/model_doc/distilbert) for our language model in this example
- We will be using a batch size of 128 and a max sequence length of 32 for 3 epochs
- Embed dimension is 512 for this model, and the standard 768 for our transformers

In [None]:
# Hyperparameters
batch_size=128
text_model='distilbert-base-multilingual-cased' # Using a simple text model here
transformer_embed_dim=768
embed_dim=512
max_len=32
num_epochs = 3

## STEP 3: DATASET
- Now we can get our dataset to train the model
- We will be using the [Flickr30kDataset](https://datasets.activeloop.ai/docs/ml/datasets/flickr30k-dataset/) and resizing images to 224 x 224px
- Data is a collection of images that depict a wide range of activites with descriptive captions, good for benchmarking this task (sentence descriptions of images)
- Finally, we wrap our dataset object with a torch DataLoader to have easy batched training iterations

In [None]:
# Wrapper class for our dataset - Flickr30k
class Flickr30kDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = load_dataset("nlphuji/flickr30k", cache_dir="./huggingface_data")

        # We resize images to 224 x 224
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.cap_per_image = 2

    def __len__(self):
        return self.dataset.num_rows["test"] * self.cap_per_image

    def __getitem__(self, idx):
        original_idx = idx // self.cap_per_image

        image = self.dataset["test"][original_idx]["image"].convert("RGB")
        image = self.transform(image)

        caption = self.dataset["test"][original_idx]["caption"][idx % self.cap_per_image]

        return {"image": image, "caption": caption}

# Dataset and DataLoader for batched retrieval
flickr30k_custom_dataset = Flickr30kDataset()
clip_dataloader = DataLoader(flickr30k_custom_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

## STEP 4: PROJECTION LAYER
- Here we create our projection layer, which is used to project the text and image encoders results into size `embed_dim`
- Once this is done, we will have text and images existing in the same high dimensional space, and we can use this knowledge to compare and contrast them based on semantic meaning

In [None]:
# CLIP uses a projection layer to cast text and images into the same dimensions
class ProjectionLayer(nn.Module):
    def __init__(self, d_in: int, d_out: int, p: float=0.5) -> None:
        super().__init__()

        # Two linear projections followed by layernorm and dropout
        self.linear1 = nn.Linear(d_in, d_out, bias=False)
        self.linear2 = nn.Linear(d_out, d_out, bias=False)
        self.layer_norm = nn.LayerNorm(d_out)
        self.drop = nn.Dropout(p)

    # We use GeLU activation here as well
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embed1 = self.linear1(x)
        embed2 = self.drop(self.linear2(F.gelu(embed1)))
        embeds = self.layer_norm(embed1 + embed2)
        return embeds

## STEP 5: IMAGE ENCODER
- Now that our projection layer is done above, we can buld the image encoder layer
- This layer takes in an image and returns the projected and normalized form in high dimensional vector space
- Our base model here is [ResNest34](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet34.html)

In [None]:
# Module to encode our images using ResNet34
class ImageEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()
        base = models.resnet34(pretrained=True)
        d_in = base.fc.in_features

        # Set the fully connected layer to be the identity function - we want our image embeddings returned
        base.fc = nn.Identity()
        self.base = base

        # Projection layer is used here
        self.projection = ProjectionLayer(d_in, d_out)
        for p in self.base.parameters():
            p.requires_grad = False

    # Return after the image model embedding pass and the projection layer
    def forward(self, x):
        projected_vec = self.projection(self.base(x))
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

## STEP 6: TEXT ENCODER
- Now, we create our text encoder layer similar to the image encoder above
- We take the results of the base model, project to `embed_dim` dimensions, and return
- These two method return similarly formatted vectors of the same size for comparison

In [None]:
# Module for encoding text into high dimensional space with our images, done above
class TextEncoder(nn.Module):
    def __init__(self, d_out: int) -> None:
        super().__init__()

        # Instantiate our text model
        self.base = AutoModel.from_pretrained(text_model)

        # And our projection layer
        self.projection = ProjectionLayer(transformer_embed_dim, d_out)
        for p in self.base.parameters():
            p.requires_grad = False

    # Forward pass returns our text embeddings and projection layer output
    def forward(self, x):
        out = self.base(x)[0]
        out = out[:, 0, :]
        projected_vec = self.projection(out)
        projection_len = torch.norm(projected_vec, dim=-1, keepdim=True)
        return projected_vec / projection_len

## STEP 7: LOSS AND SIMILARITY
- We can set our loss function and similarity metrics now
- We will be doing cross entropy loss between axes with our logits/labels
- This allows us to get both a comparing (similar) and contrasting (dissimilar) measure of accuracy in this space
- By averaging these two measures, we get a consistent measure of vector similarity while retaining semantic comparison
- Our `metrics` method words by comparing the image and caption candidates elementwise and returning the highest accuracies

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define a method to calculate our loss
def CLIP_loss(logits: torch.Tensor) -> torch.Tensor:
    n = logits.shape[1]      # number of samples
    labels = torch.arange(n) # Create labels tensor
    # Calculate cross entropy losses along axis 0 and 1
    loss_i = F.cross_entropy(logits.transpose(0, 1), labels, reduction="mean")
    loss_t = F.cross_entropy(logits, labels, reduction="mean")
    # Calculate the final loss
    loss = (loss_i + loss_t) / 2

    return loss

# Method to calculate our similarity scores
def metrics(similarity: torch.Tensor):
    y = torch.arange(len(similarity)).to(similarity.device)
    img2cap_match_idx = similarity.argmax(dim=1)
    cap2img_match_idx = similarity.argmax(dim=0)

    img_acc = (img2cap_match_idx == y).float().mean()
    cap_acc = (cap2img_match_idx == y).float().mean()

    return img_acc, cap_acc

## STEP 8: TOKENIZER
- We now can build a wrapper for our tokenizer
- This step is not explicitly necessary, it just helps abstract tokenizer operations out from our main code, as we don't need to focus too much on it in this module
- This class wraps the initialization and calling of the tokenizer with appropriate parameters
- We will be using the accompanying tokenizer for DistillBeRT as well here

In [None]:
# Basic wrapper class for our tokenizer - makes life easier
class Tokenizer:
    def __init__(self, tokenizer: BertTokenizer) -> None:
        self.tokenizer = tokenizer

    def __call__(self, x: str) -> AutoTokenizer:
        return self.tokenizer(
            x, max_length=max_len, truncation=True, padding=True, return_tensors="pt"
        )

## STEP 9: BUILD CLIP
- Once eveyerhting above has been completed, we can complete our CLIP implementation
- We initialize our image and text encoders with `embed_dim` as the output dimension
- We create our wrapped tokenizer class and set our learning rate to 1e-3
- For the `forward` call, we tokenize the text, embed our image with the encoder, embed the caption text with our text encoder, and then multiply them to get our similarity matrix.
- This matrix is then passed into `CLIP_loss` and `metrics` for similarity scoring, we return the loss and scores

In [None]:
# Now we build our CLIP model
class NightwingCLIP(nn.Module):
    def __init__(self, lr: float = 1e-3) -> None:
        super().__init__()

        # Set up text and image encoders
        self.image_encoder = ImageEncoder(embed_dim)
        self.text_encoder = TextEncoder(embed_dim)

        # Set tokenizer
        self.tokenizer = Tokenizer(AutoTokenizer.from_pretrained(text_model))

        self.lr = lr
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def forward(self, images, text):
        # Tokenize text
        text = self.tokenizer(text).to(self.device)

        # Embed image and captions
        image_embed = self.image_encoder(images)
        caption_embed = self.text_encoder(text["input_ids"])

        # Similarity matrix is Caption @ Image.T
        similarity = caption_embed @ image_embed.T

        # Calculate our loss and accuracy from the above matrix
        loss = CLIP_loss(similarity)
        img_acc, cap_acc = metrics(similarity)

        # Return the loss and scores
        return loss, img_acc, cap_acc

## STEP 10: TRAINING
- For our training loop, we initialize our model and an [Adam](https://arxiv.org/abs/1412.6980) optimizer
- We then iterate through epochs, loading data and updating weights based on model's performance over the training set
- When training is complete we will have a simple, trained version of CLIP

In [None]:
# Instantiate our model
model = NightwingCLIP().to('cuda')

# Set our optimizer to Adam
optimizer = torch.optim.Adam([
    {'params': model.text_encoder.parameters()},
    {'params': model.image_encoder.parameters()}
], lr=model.lr)

batch_zero = True

# Training loop
for epoch in range(0, num_epochs):
    model.train()
    for batch in clip_dataloader:
        image = batch["image"].to('cuda')
        text = batch["caption"]
        loss, img_acc, cap_acc = model(image, text)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_zero:
          print(f"Epoch [{0}/{num_epochs}], Batch Loss: {loss.item()}")
          batch_zero = False

    print(f"Epoch [{epoch+1}/{num_epochs}], Batch Loss: {loss.item()}")

print("Training complete.")

## STEP 11: DISPLAY MATCHING IMAGES
- We need a method to turn our images from numbers in matrices to displaying on a screen
- We also want to provide a query string and be able to receive matching images back as a demonstration
- We can do this by tokenizing the query string, embedding it into text, and matching similarity to our dataset of images in embedded form
- Once the top indices are determined for similarity, we select and display them using `matplotlib`

In [None]:
# Method to display images in the data that match a given query - we want up to 4
def display_matching_images(model, query_string, flickr30k_dataset, clip_dataloader, top_k=4):
    tokenized_query = model.tokenizer(query_string)

    # Embed query string
    with torch.no_grad():
        query_embedding = model.text_encoder(tokenized_query["input_ids"].to(model.device))

    # Embed all images
    image_embeddings = []
    for batch in clip_dataloader:
        images = batch["image"].to(model.device)
        with torch.no_grad():
            image_embedding = model.image_encoder(images)
        image_embeddings.append(image_embedding)
    image_embeddings = torch.cat(image_embeddings)

    # Simiarity from query to all images
    with torch.no_grad():
        similarity_scores = torch.matmul(query_embedding, image_embeddings.T)

    # Get top matches
    top_indices = similarity_scores.squeeze().argsort(dim=-1, descending=True)[:top_k]

    # Plot (display) the top 4 matching images to the query string
    fig, axes = plt.subplots(1, top_k, figsize=(15, 5))
    for i, idx in enumerate(top_indices):
        image = flickr30k_dataset[idx.item()]["image"]
        image_np = np.transpose(image.cpu().numpy(), (1, 2, 0))
        axes[i].imshow(image_np)
        axes[i].axis("off")
    plt.show()

## STEP 12: DEMO
- We can now run a simple demo on our CLIP implementation
- We pass in a query string of images we want to see, and run it
- The model will output similar images matching the text to the best of its ability, and we are done!

In [None]:
# Demo a query example
query = "A cat sitting outside"
display_matching_images(model, query, flickr30k_custom_dataset, clip_dataloader)