## Dataset download
https://machinelearningmastery.com/develop-a-deep-learning-caption-generation-model-in-python/

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip

In [None]:
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip

In [None]:
!unzip Flickr8k_Dataset.zip

In [None]:
!unzip Flickr8k_text.zip

## Imports

In [None]:
!pip install wandb

In [None]:
# !pip install transformers
!pip install git+https://github.com/huggingface/transformers.git

In [None]:
!pip install git+https://github.com/Lightning-AI/lightning.git
#!pip install git+https://github.com/Lightning-AI/lightning.git@bugfix/colab-import

In [None]:
import torch
import torchvision.transforms as transforms
import cv2
import wandb
import pandas as pd
import os

In [None]:
import lightning

In [None]:
!pip install albumentations

In [None]:
import albumentations as A

## Creating DataLoader

In [None]:
from abc import abstractmethod
class ImageRetrievalDataset(torch.utils.data.Dataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=200, lazy_loading=False):
    super().__init__()
    self.artifact_id = artifact_id
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.image_files, self.captions = self.fetch_dataset()
    self.images = self.image_files

    assert tokenizer is not None

    self.tokenizer = tokenizer

    self.tokenized_captions = tokenizer(
        list(self.captions), padding=True, truncation=True,
        max_length=self.max_length, return_tensors='pt'
    )
    self.transforms = A.Compose([
        A.Resize(target_size, target_size, always_apply=True),
        A.Normalize(max_pixel_value=255.0, always_apply=True)
    ])

  @abstractmethod
  def fetch_dataset():
    pass

  def __len__(self):
    return len(self.captions)

  def __getitem__(self, index):
    item = {
        key: values[index]
        for key, values in self.tokenized_captions.items()
    }
    image = cv2.imread(self.image_files[index])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = self.transforms(image=image)["image"]
    item["image"] = torch.tensor(image).permute(2, 0, 1).float()
    item["caption"] = self.captions[index]
    return item

In [None]:
class Filckr8kDataset(ImageRetrievalDataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=100, lazy_loading=False):
    super().__init__(artifact_id, tokenizer, target_size, max_length, lazy_loading)

  def fetch_dataset(self):
    if wandb.run is None:
      api = wandb.Api()
      artifact = api.artifact(self.artifact_id, type="dataset")
    else:
      articact = wandb.use_artifact(self.artifact_id, type="dataset")

    artifact_dir = artifact.download()
    annotations = pd.read_csv(os.path.join(artifact_dir, "captions.txt"))
    image_files = [
        os.path.join(artifact_dir, "Images", image_file)
        for image_file in annotations["image"].to_list()
    ]
    for image_file in image_files:
      assert os.path.isfile(image_file)
    captions = annotations["caption"].to_list()
    return image_files, captions

## DataModule

In [None]:
from typing import Optional
from torch.utils.data import random_split, DataLoader
#from pytorch_lightning import LightningDataModule # pygments error
from transformers import AutoTokenizer

DATASET_LOOKUP = {
    "flickr8k":  Filckr8kDataset
}

class ImageRetrievalDataModule(lightning.LightningDataModule):
  def __init__(
      self,
      artifact_id: str,
      dataset_name: str, 
      val_split: float = 0.2,
      tokenizer_alias: Optional[str] = None,
      target_size: int = 224,
      max_length: int = 100,
      lazy_loading: bool = False,
      train_batch_size: int = 16,
      val_batch_size: int = 16,
      num_workers: int = 4,
      **kwargs,
  ):
    super().__init__(**kwargs)
    self.artifact_id = artifact_id
    self.dataset_name = dataset_name
    self.val_split = val_split
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_alias)
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.train_batch_size = train_batch_size
    self.val_batch_size = val_batch_size
    self.num_workers = num_workers

  def prepare_data(self):
    pass

  @staticmethod
  def split_data(dataset: ImageRetrievalDataset, val_split: float):
    train_length = int((1 - val_split) * len(dataset))
    val_length = len(dataset) - train_length
    train_dataset, val_dataset = random_split(
        dataset, lengths=[train_length, val_length]
    )
    return train_dataset, val_dataset

  def setup(
      self,
      stage: Optional[str] = None,
  ) -> None:
    dataset = DATASET_LOOKUP[self.dataset_name](
        artifact_id=self.artifact_id,
        tokenizer=self.tokenizer,
        target_size=self.target_size,
        max_length=self.max_length,
        lazy_loading=self.lazy_loading,
    )
    self.train_dataset, self.val_dataset = self.split_data(dataset, val_split=self.val_split)

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.train_batch_size,
        num_workers=self.num_workers
    )

  def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.val_batch_size,
        num_workers=self.num_workers,
    )

## Image Encoder

In [None]:
!pip install timm

In [None]:
import timm
import torch
from torch import nn

In [None]:
class ImageEncoder(nn.Module):
  def __init__(
      self, model_name: str, pretrained: bool = True, trainable: bool = True,
  ) -> None:
    super().__init__()

    self.model = timm.create_model(
        model_name, pretrained=pretrained, num_classes=0, global_pool='avg'
    )

    for p in self.model.parameters():
      p.requires_grad = trainable

    self.target_token_idx = 0
  
  def forward(self, x):
    return self.model(x)

## Text Encoder

In [None]:
import torch
import transformers
from torch import nn

In [None]:
class TextEncoder(nn.Module):
  def __init__(self, model_name: str, trainable: bool = True) -> None:
    super().__init__()

    self.model = transformers.AutoModel.from_pretrained(model_name)

    for p in self.model.parameters():
      p.requires_grad = trainable

    self.target_token_idx = 0

  def forward(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_state = output.last_hidden_state
    return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [None]:
import torch
from torch import nn

In [None]:
class ProjectionHead(nn.Module):
  def __init__(self, embedding_dim: int, projection_dim: int, dropout: float) -> None:
    super().__init__()

    self.projection = nn.Linear(embedding_dim, projection_dim)
    self.gelu = nn.GELU()
    self.fc = nn.Linear(projection_dim, projection_dim)

    self.dropout = nn.Dropout(dropout)
    self.layer_norm = nn.LayerNorm(projection_dim)

  def forward(self, x):
    projected = self.projection(x)
    x = self.gelu(projected)
    x = self.fc(x)
    x = self.dropout(x)
    x += projected
    return self.layer_norm(x)

## CLIP Model

In [None]:
import itertools

In [None]:
class CLIPDualEncoderModel(lightning.LightningModule):
  def __init__(
      self,
      image_encoder_alias: str,
      text_encoder_alias: str,
      image_encoder_pretrained: bool = True,
      image_encoder_trainable: bool = True,
      text_encoder_trainable: bool = True,
      image_embedding_dims: int = 2048,
      text_embedding_dims: int = 768,
      projection_dims: int = 256,
      dropout: float = 0.0,
      temperature: float = 1.0,
      weight_decay: float = 0.0,
      head_lr: float = 1e-3,
      image_encoder_lr: float = 1e-4,
      text_encoder_lr: float = 1e-5,
      lr_scheduler_patience: float = 1.0,
      lr_scheduler_factor: float = 0.8,
      *args,
      **kwargs,
  ) -> None:
    super().__init__(*args, **kwargs)
    self.image_encoder = ImageEncoder(
        model_name=image_encoder_alias,
        pretrained=image_encoder_pretrained,
        trainable=image_encoder_trainable,
    )
    self.text_encoder = TextEncoder(
        model_name=text_encoder_alias,
        trainable=text_encoder_trainable
    )
    self.image_projection = ProjectionHead(
        embedding_dim=image_embedding_dims,
        projection_dim=projection_dims,
        dropout=dropout
    )
    self.text_projection = ProjectionHead(
        embedding_dim=text_embedding_dims,
        projection_dim=projection_dims,
        dropout=dropout
    )
    self.log_softmax = nn.LogSoftmax(dim=-1)
    self.temperature = temperature
    self.weight_decay = weight_decay
    self.head_lr = head_lr
    self.image_encoder_lr = image_encoder_lr
    self.text_encoder_lr = text_encoder_lr
    self.lr_scheduler_patience = lr_scheduler_patience
    self.lr_scheduler_factor = lr_scheduler_factor

  def _compute_losses(self, image_embeddings, text_embeddings):
    logits = (text_embeddings @ image_embeddings.T) / self.temperature
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = nn.functional.softmax(
        (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
    )
    images_loss = (-targets.T * self.log_softmax(logits.T)).sum(1)
    texts_loss = (-targets * self.log_softmax(logits)).sum(1)
    return (images_loss + texts_loss) / 2.0

  def forward(self, inputs):
    image_features = self.image_encoder(inputs["image"])
    text_features = self.text_encoder(
        input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
    )

    image_embeddings = self.image_projection(image_features)
    text_embeddings = self.text_projection(text_features)

    return image_embeddings, text_embeddings

  def configure_optimizers(self):
    parameters = [
        {"params": self.image_encoder.parameters(), "lr": self.image_encoder_lr},
        {"params": self.text_encoder.parameters(), "lr": self.text_encoder_lr},
        {
            "params": itertools.chain(
                self.image_projection.parameters(),
                self.text_projection.parameters()
            ),
            "lr": self.head_lr,
            "weight_decay": self.weight_decay
        },
    ]
    optimizer = torch.optim.Adam(parameters, weight_decay=self.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
        mode="min",
        patience=self.lr_scheduler_patience,
        factor=self.lr_scheduler_factor
    )
    return {
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "monitor": "val/loss:"
    }

  def training_step(self, batch, *args, **kwargs):
    image_embeddings, text_embeddings = self.forward(batch)
    loss = self._compute_losses(image_embeddings, text_embeddings).mean()
    losses = self.all_gather(loss)
    self.train_loss = losses.mean()
    self.log("train/loss:", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, *args, **kwargs):
    image_embeddings, text_embeddings = self.forward(batch)
    loss = self._compute_losses(image_embeddings, text_embeddings).mean()
    losses = self.all_gather(loss)
    self.val_loss = losses.mean()
    self.log("val/loss:", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
    return loss

## Training the Model

In [None]:
image_encoder_alias = "resnet50"
text_encoder_alias = "distilbert-base-uncased"

model = CLIPDualEncoderModel(image_encoder_alias, text_encoder_alias)
data_module = ImageRetrievalDataModule(
    artifact_id="wandb/clip.lightning-image_retrieval/flickr-8k:latest",
    dataset_name="flickr8k",
    tokenizer_alias=text_encoder_alias,
    lazy_loading=True
)
trainer = lightning.Trainer(
    max_epochs=20,
)
trainer.fit(model, data_module)