# M2177.003100 Deep Learning <br> Assignment #3 Part 4: Transformer for vision and language

Copyright (C) Data Science Laboratory, Seoul National University. This material is for educational uses only. Some contents are based on the material provided by other paper/book authors and may be copyrighted by them. Written by Seungryong Yoo, November 2021.

In this problem, we will train CLIP which consists of a pair of image and text encoder, which are trained by self-supervised manner.<br>
Basically, the encoders are trained to maximize the similarity between the paired image and text in the common representation space.<br>
As the image encoder, we will use ViT (Vision Transformer). The underlying structure of ViT is almost the same as Transformers for NLP. <br>
The only difference is that word embedding layer is replaced with patch embedding. <br>
As the text encoder, we will use pretrained model (DistillBert) <br>

This is about VisionTransformer (ViT) (Dosovitskiy et al., 2020).<br>
[https://arxiv.org/pdf/2010.11929.pdf](https://arxiv.org/pdf/2010.11929.pdf)

This is about CLIP (Radford et al., 2021)<br>
[https://openai.com/blog/clip/](https://openai.com/blog/clip/)<br>
[https://arxiv.org/pdf/2103.00020.pdf](https://arxiv.org/pdf/2103.00020.pdf)<br>
(OpenAI blog post might be enough to understand the model)

Original blog post & code <br>
[https://github.com/lucidrains/vit-pytorch](https://arxiv.org/pdf/2103.00020.pdf) (ViT) <br>
[https://github.com/moein-shariatnia/OpenAI-CLIP](https://github.com/moein-shariatnia/OpenAI-CLIP) (CLIP)


That said, you are allowed to copy paste the codes from the original repo.
HOWEVER, <font color=red> try to implement the model yourself first </font>, and consider the original source code as a last resort.

### Attention Implementation (10 points)
1. Write up the code for the TODO part of "Attention" class in "clip_modules.py".
2. You will get a full score if you implement it right.

### Codes
1. clip_utils.py 
2. clip_modules.py <br>
<br>
### Submitting your work:
<font color=red>**DO NOT clear the final outputs**</font> so that TAs can grade both your code and results.  
Once you have done **all Assignment Part 1-4**, run the *CollectSubmission.sh* script with your **Student number** as input argument. <br>
This will produce a zipped file called *[Your student number].zip*. Please submit this file on ETL. &nbsp;&nbsp; (Usage: ./*CollectSubmission.sh* 20xx_xxxxx)

Now proceed to the code.


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd drive/MyDrive/fastMRI_ming/Assignment3

/content/drive/MyDrive/fastMRI_ming/Assignment3


## Install libraries

In [3]:
!python3 -m pip install pandas
!python3 -m pip install einops
!python3 -m pip install transformers

Collecting einops
  Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2
Collecting transformers
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 5.1 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 29.9 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 58.3 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.1.2-py3-none-any.whl (59 kB)
[K     |████████████████████████████████| 59 kB 6.4 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |██

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from transformers import DistilBertTokenizer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import itertools
from tqdm.autonotebook import tqdm
from PIL import Image
from clip_utils import caption_to_csv, get_transforms, get_lr, AvgMeter, make_train_valid_dfs

import os

## Preparing dataset

link : [https://www.kaggle.com/adityajn105/flickr8k](https://www.kaggle.com/adityajn105/flickr8k)

1. Download the dataset from attached link.

2. Move the downloaded zip file under the "data" directory and then unzip the zip file.
3. Run the following cell

In [None]:
# if you successfully run this cell once, do not run this cell again
if not os.path.exists('./data/Flicker-8k'):
    os.mkdir('./data/Flicker-8k/')

os.system('mv ./data/{} ./data/{} ./data/Flicker-8k/'.format('Images', 'captions.txt'))

In [None]:
# convert captions.txt to csv file
# result location : ./data/Flicker-8k/captions.csv
caption_to_csv()

## Configuration

In [None]:
class CFG:
    debug = False
    image_path = "./data/Flicker-8k/Images"
    captions_path = "./data/Flicker-8k"
    batch_size = 32
    num_workers = 4
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 5
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for text encoder
    trainable = True # for text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

## Dataset & Data loader

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, config, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """
        self.config = config
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=config.max_length
        )
        self.transforms = transforms

    def __getitem__(self, idx):
        item = {
            key: torch.tensor(values[idx])
            for key, values in self.encoded_captions.items()
        }

        image = Image.open(f"{self.config.image_path}/{self.image_filenames[idx]}")
        image = self.transforms(image)
        item['image'] = image
        item['caption'] = self.captions[idx]

        return item

    def __len__(self):
        return len(self.captions)
    

def build_loaders(config, dataframe, tokenizer, mode='train'):
    transforms = get_transforms(config)
    dataset = CLIPDataset(
        config,
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

## Define CLIP model

In [None]:
# you should implement "Attention" class in clip_modules.py to run following cells without error
from clip_modules import VisionTransformer, TextTransformer, ProjHead

In [None]:
def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    

class CLIPModel(nn.Module):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.image_encoder = VisionTransformer(image_size=224, patch_size=32, dim=2048, mlp_dim=2048, depth=6, dropout=0.1, emb_dropout=0.1, heads=16)
        self.text_encoder = TextTransformer(model_name=config.text_encoder_model)
        self.image_projection = ProjHead(embed_dim=config.image_embedding)
        self.text_projection = ProjHead(embed_dim=config.text_embedding)
        self.temperature = config.temperature

    def forward(self, batch):
        # Getting Image and Text Features
        image_features = self.image_encoder(batch["image"])
        text_features = self.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        # Getting Image and Text Embeddings (with same dimension)
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ image_embeddings.T) / self.temperature
        images_similarity = image_embeddings @ image_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        images_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()

## Training functions

In [None]:
def train_epoch(config, model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(config, model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(config.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter


def main(config):
    train_df, valid_df = make_train_valid_dfs(config)
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    train_loader = build_loaders(config, train_df, tokenizer, mode='train')
    valid_loader = build_loaders(config, valid_df, tokenizer, mode='valid')

    model = CLIPModel(config).to(config.device)    
    params = [
        {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr},
        {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr},
        {"params": itertools.chain(
            model.image_projection.parameters(), model.text_projection.parameters()
        ), "lr": config.head_lr, "weight_decay": config.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=config.patience, factor=config.factor
    )
    step = "epoch"

    best_loss = float('inf')
    for epoch in range(config.epochs):
        print(f"Epoch: {epoch + 1}")
        model.train()
        train_loss = train_epoch(config, model, train_loader, optimizer, lr_scheduler, step)
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(config, model, valid_loader)
        
        if valid_loss.avg < best_loss:
            best_loss = valid_loss.avg
            torch.save(model.state_dict(), "best.pt")
            print("Saved Best Model!")
        
        lr_scheduler.step(valid_loss.avg)

In [None]:
main(CFG)

## Find matching images for the given query text

In [None]:
def get_image_embeddings(config, valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    valid_loader = build_loaders(config, valid_df, tokenizer, mode="valid")
    
    model = CLIPModel(config).to(config.device)
    model.load_state_dict(torch.load(model_path, map_location=config.device))
    model.eval()
    
    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(config.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)


def find_matches(config, model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(config.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T
    
    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]
    matches = np.unique(matches)
    
    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = Image.open(f"{config.image_path}/{match}")
        ax.imshow(np.array(image))
        ax.axis("off")
    
    plt.show()

In [None]:
_, valid_df = make_train_valid_dfs(CFG)
model, image_embeddings = get_image_embeddings(CFG, valid_df, "best.pt")

In [None]:
find_matches(CFG, 
             model, 
             image_embeddings,
             query="a dog playing on the grass",
             image_filenames=valid_df['image'].values,
             n=9)