In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
!pip install flax -U
!pip install tqdm

## Jax setup for colab

In [None]:
# https://github.com/google/flax/issues/2263#issuecomment-1173424293
import sys
if 'google.colab' in sys.modules:
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

## Dataset download
https://machinelearningmastery.com/develop-a-deep-learning-caption-generation-model-in-python/

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip

In [None]:
!unzip Flickr8k_Dataset.zip

In [None]:
!unzip Flickr8k_text.zip

## Setup

In [None]:
!pip install transformers

In [None]:
!pip install scikit-image

In [None]:
import torch
import numpy as np
import pandas as pd
import string
import cv2
from transformers import AutoTokenizer
from skimage import io
import os
from torchvision import transforms
import torch.nn as nn

In [None]:
import tensorflow_hub as hub

In [None]:
sentence_encoder = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

In [None]:
message = ["I am a sentence for which I would like to get its embedding."]

In [None]:
embed = sentence_encoder(message)

In [None]:
print(f'embed: {embed.shape}')
print(f'type: {type(embed)}')

In [None]:
embed_np = np.array(embed)

In [None]:
print(f'type: {type(embed_np)}')

In [None]:
print(f'embed: {", ".join(map(str, list(np.squeeze(embed_np)[:3])))}, ...')

## Pre-process captions

* lower letters
* punctuation

In [None]:
def load_captions_from_file(path):
  captions = []
  with open(path, 'r') as f:
      for line in f:
          line = line.strip()
          elems = line.split('\t')
          fn_id = elems[0].split('#')  # [filename, id]
          captions.append(fn_id + [elems[1].lower()])  # [[filaneme, id, caption], ...]
  return captions

In [None]:
captions = load_captions_from_file('Flickr8k.token.txt')

In [None]:
len(captions[0])

In [None]:
df_caption = pd.DataFrame(captions, columns=['image_filename', 'id', 'caption'])

In [None]:
print(f'Unique images: {len(np.unique(df_caption.image_filename.values))}')
print(f'Total captions: {len(df_caption)}')

In [None]:
translator = str.maketrans('', '', string.punctuation)
def remove_punctuation(text):
  return text.translate(translator)

In [None]:
text = df_caption['caption'].iloc[0]
print(f'original: {text}')
print(f'removed: {remove_punctuation(text)}')

In [None]:
def clean_text(text):
  cleaned_text = remove_punctuation(text)
  return cleaned_text

In [None]:
def preprocess(df_caption):
  for i, caption in enumerate(df_caption.caption.values):
    cleaned_caption = clean_text(caption)
    df_caption['caption'].iloc[i] = cleaned_caption
  return df_caption

In [None]:
df_caption_0 = df_caption.loc[df_caption['id'].values == '0', :]

In [None]:
image_filenames = df_caption_0.image_filename.values
captions = df_caption_0.caption.values

In [None]:
max_length = max([len(c.split()) for c in captions])

In [None]:
print(f'max_length: {max_length}')

In [None]:
print(f'{image_filenames[0]}: {captions[0]}')

## Flickr8k dataset

In [None]:
class Flickr8kDataset:
  def __init__(self, caption_text):
    self.caption_text = caption_text
    self.translator = str.maketrans('', '', string.punctuation)
    self._setup()

  def _setup(self):
    self.captions = self._load_captions_from_file(self.caption_text)
    self.df_caption = pd.DataFrame(self.captions, columns=['image_filename', 'id', 'caption'])
    self.df_caption = self._preprocess(self.df_caption)
    self.image_filenames, self.captions = self._extract(self.df_caption)
    self.max_length = max([len(c.split()) for c in self.captions])

  def _load_captions_from_file(self,caption_text):
    captions = []
    with open(caption_text, 'r') as f:
        for line in f:
            line = line.strip()
            elems = line.split('\t')
            fn_id = elems[0].split('#')  # [filename, id]
            captions.append(fn_id + [elems[1].lower()])  # [[filaneme, id, caption], ...]
    return captions

  def _remove_punctuation(self, text):
    return text.translate(self.translator)
  
  def _clean_text(self, text):
    cleaned_text = self._remove_punctuation(text)
    return cleaned_text

  def _preprocess(self, df_caption):
    for i, caption in enumerate(df_caption.caption.values):
      cleaned_caption = self._clean_text(caption)
      df_caption['caption'].iloc[i] = cleaned_caption
    return df_caption

  def _extract(self, df_caption):
    df_caption_0 = df_caption.loc[df_caption['id'].values == '0', :]
    image_filenames = df_caption_0.image_filename.values
    captions = df_caption_0.caption.values
    return image_filenames, captions

In [None]:
flickr8k_dataset = Flickr8kDataset('Flickr8k.token.txt')

In [None]:
print(f'max_length: {flickr8k_dataset.max_length}')

## DataLoader

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
  def __init__(self, root_dir, image_files, captions, max_length, tokenizer=None, transforms=None):
    self.image_files = [os.path.join(root_dir, f) for f in image_files]
    self.captions = captions
    self.max_length = max_length
    self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained("distilbert-base-uncased")
    self.transforms = transforms

  def __getitem__(self, index):
    image_file, caption = self.image_files[index], self.captions[index]
    tokens = self.tokenizer(caption, truncation=True, padding="max_length", max_length=self.max_length)
    image = io.imread(image_file)
    if self.transforms:
      image = self.transforms(image)
    return image, tokens

  def __len__(self):
    return len(self.image_files)

In [None]:
clip_dataset = CLIPDataset(
    root_dir='Flicker8k_Dataset',
    image_files=image_filenames,
    captions=captions,
    max_length=max_length,
    transforms=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
)

## Image encoder

In [None]:
!pip install timm

In [None]:
import timm

In [None]:
model = timm.create_model('resnet50', pretrained=True, num_classes=0)

In [None]:
o = model.forward(torch.randn(2, 3, 299, 299))

In [None]:
print(f'resnet50 feature: {o.shape}')

In [None]:
class ImageEncoder(torch.nn.Module):
  def __init__(self, model_name, pretrained=True, num_classes=0):
    super(ImageEncoder, self).__init__()
    self.model = timm.create_model(
        model_name=model_name, pretrained=pretrained, num_classes=num_classes
    )

  def forward(self, x):
    return self.model(x)

## Text encoder

In [None]:
class TextEncoder(torch.nn.Module):
  def __init__(self):
    super(TextEncoder, self).__init__()
    self.sentence_encoder = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")

  def forward(self, x):
    embed_np = np.array(self.sentence_encoder(x))
    return embed_np

## Projection head

In [None]:
class ProjectionHead(torch.nn.Module):
  def __init__(self, embed_dim, proj_dim, drop_ratio):
    super(ProjectionHead, self).__init__()
    self.embed_dim = embed_dim
    self.proj = nn.Linear(embed_dim, proj_dim)
    self.gelu = nn.GELU()
    self.fc = nn.Linear(proj_dim, proj_dim)
    self.dropout = nn.Dropout(drop_ratio)
    self.layer_norm = nn.LayerNorm(proj_dim)

  def forward(self, x):
    x_proj = self.proj(x)
    x = self.gelu(x_proj)
    x = self.fc(x)
    x = self.dropout(x)
    x = x + x_proj
    x = self.layer_norm(x)
    return x

## CLIP model

In [None]:
class CLIP(torch.nn.Module):
  def __init__(self, img_embed_dim, text_embed_dim, proj_dim, drop_ratio, temperature=1.0):
    super(CLIP, self).__init__()
    self.image_encoder = ImageEncoder('resnet50')
    self.text_encoder = TextEncoder()
    self.image_head = ProjectionHead(img_embed_dim, proj_dim, drop_ratio)
    self.text_head = ProjectionHead(text_embed_dim, proj_dim, drop_ratio)
    self.temperature = temperature

  def forward(self, img, tokens):
    assert img.shape[0] == tokens.shape[0]
    i_f = self.image_encoder(img)
    t_f = self.text_encoder(tokens)
    i_e = self.image_head(i_f)
    t_e = self.text_head(t_f)
    logits = (i_e @ t_e.T) / self.temperature
    return logits

## Loss function

In [None]:
def compute_loss(logits, loss_fn):
  n = logits.shape[0]
  labels = np.arange(n)
  loss_i = loss_fn(torch.transpose(logits, 0, 1), labels)
  loss_t = loss_fn(logits, labels)
  return (loss_i + loss_t) / 2

## Loss function

In [None]:
def get_loss_fn(name):
  if name == 'cross_entropy':
    return nn.CrossEntropyLoss()

## Optimizer

In [None]:
def get_optim(name, model):
  if name == 'adam':
    return torch.optim.Adam(model.parameters())

## Train epoch

In [None]:
def train_epoch(train_loader, model, optimizer, loss_fn):
  model.train()
  running_loss = 0.
  last_loss = 0.

  for cur_iter, batch in enumerate(train_loader):
    batch_image, batch_tokens = batch

    # Zero your gradients for every batch
    optimizer.zero_grad()

    # Make predictions for this batch
    logits = model(batch_image, batch_tokens)

    # Compute the loss and gradients
    loss = compute_loss(logits, loss_fn)
    loss.backward()

    # Adjust learning weight
    optimizer.step()

    # Gather data and report
    running_loss += loss.item()
    if cur_iter % 10 == 9:
      last_loss = running_loss / 10
      print(f'batch {cur_iter+1} loss: {last_loss}')
  return last_loss

## CLIP parameters

In [None]:
def get_clip_cfg(
    num_epochs=20,
    img_embed_dim=2048,  # resnet50
    text_embed_dim=512,  # universal sentence encoder
    proj_dim=256,
    drop_ratio=0.1,
    temperature=1.0,
):
  cfg = {}
  cfg['num_epochs'] = num_epochs
  cfg['img_embed_dim'] = img_embed_dim
  cfg['text_embed_dim'] = text_embed_dim
  cfg['proj_dim'] = proj_dim
  cfg['drop_ratio'] = drop_ratio
  cfg['temperature'] = temperature

  # dataset
  cfg['root_dir'] = 'Flicker8k_Dataset'
  cfg['token_text'] = 'Flickr8k.token.txt'

  return cfg

In [None]:
clip_cfg = get_clip_cfg(num_epochs=20)

In [None]:
for k, v in clip_cfg.items():
  print(f'{k}: {v}')

## Training

In [None]:
clip_dataset = CLIPDataset(
    root_dir=cfg['root_dir'],
    image_files=flickr8k_dataset.image_filenames,
    captions=flickr8k_dataset.captions,
    max_length=flickr8k_dataset.max_length,
    transforms=transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
)

In [None]:
def train_model(cfg, dataset):
  model = CLIP(
      cfg['img_embed_dim'], 
      cfg['text_embed_dim'],
      cfg['proj_dim'],
      cfg['drop_ratio'],
      cfg['temperature']
  )

In [None]:
train_model(clip_cfg, flickr8k_dataset)