## Pip install & Imports

In [None]:
!pip install wandb

In [None]:
#!pip install transformers
!pip install git+https://github.com/huggingface/transformers.git

In [None]:
!pip install git+https://github.com/Lightning-AI/lightning.git

In [None]:
!pip install albumentations

In [None]:
!pip install -q clu

In [None]:
import torch
import torchvision.transforms as transforms
import cv2
import wandb
import pandas as pd
import os

In [None]:
import lightning

In [None]:
import albumentations as A

In [None]:
from clu import metrics

## Image encoder

Use FlaxResNetModel from huggingface transformers

In [None]:
from transformers import AutoImageProcessor, FlaxResNetModel, FlaxViTModel

In [None]:
from flax import linen as nn

In [None]:
class ImageEncoder(nn.Module):
  model_name: str

  def setup(self):
    self.model = FlaxResNetModel.from_pretrained(self.model_name)
    #self.model = FlaxViTModel.from_pretrained(self.model_name)

  def __call__(self, x):
    return self.model(x)

## Text Encoder

In [None]:
from transformers import FlaxAutoModel

In [None]:
class TextEncoder(nn.Module):
  model_name: str

  def setup(self):
    self.model = FlaxAutoModel.from_pretrained(self.model_name)
    self.target_token_idx = 0

  def __call__(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_state = output.last_hidden_state
    return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [None]:
class ProjectionHead(nn.Module):
  projection_dim: int
  dropout: float

  @nn.compact
  def __call__(self, x, train=True):
    projected = nn.Dense(self.projection_dim)(x)
    x = nn.gelu(projected)
    x = nn.Dense(self.projection_dim)(x)
    x = nn.Dropout(self.dropout, deterministic=not train)(x)
    x += projected
    return nn.LayerNorm()(x)

## CLIP model

In [None]:
import jax.numpy as jnp

In [None]:
class CLIPDualEncoderModel(nn.Module):
  image_encoder_alias: str
  text_encoder_alias: str
  projection_dims: int = 256
  dropout: float = 0.1
  temperature: float = 1.0

  def setup(self):
    self.image_encoder = ImageEncoder(
        model_name=self.image_encoder_alias
    )
    self.text_encoder = TextEncoder(
        model_name=self.text_encoder_alias
    )
    self.image_projection = ProjectionHead(
        projection_dim=self.projection_dims,
        dropout=self.dropout
    )
    self.text_projection = ProjectionHead(
        projection_dim=self.projection_dims,
        dropout=self.dropout
    )

  def __call__(self, inputs_image, inputs_input_ids, inputs_attention_mask, train=True):
    i_e = self.get_image_features(inputs_image, train)
    t_e = self.get_text_features(inputs_input_ids, inputs_attention_mask, train)
    logits = jnp.dot(i_e, t_e.T) / self.temperature
    return logits

  def get_text_features(self, inputs_input_ids, inputs_attention_mask, train=False):
    text_features = self.text_encoder(
        input_ids=inputs_input_ids, attention_mask=inputs_attention_mask
    )
    text_embeddings = self.text_projection(text_features, train=train)
    t_e = text_embeddings / jnp.linalg.norm(text_embeddings, axis=-1, keepdims=True)
    return t_e

  def get_image_features(self, inputs_image, train=False):
    tmp_feat = self.image_encoder(inputs_image).pooler_output
    # (batch_size, hidden_size, 1, 1) -> (batch_size, hidden_size)
    image_features = tmp_feat.reshape((tmp_feat.shape[0], tmp_feat.shape[1]))
    image_embeddings = self.image_projection(image_features, train=train)
    i_e = image_embeddings / jnp.linalg.norm(image_embeddings, axis=-1, keepdims=True)
    return i_e

## Creating Dataset

In [None]:
from abc import abstractmethod
class ImageRetrievalDataset(torch.utils.data.Dataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=200, lazy_loading=False):
    super().__init__()
    self.artifact_id = artifact_id
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.image_files, self.captions = self.fetch_dataset()
    self.images = self.image_files

    assert tokenizer is not None

    self.tokenizer = tokenizer

    self.tokenized_captions = tokenizer(
        list(self.captions), padding=True, truncation=True,
        max_length=self.max_length, return_tensors='pt'
    )
    self.transforms = A.Compose([
        A.Resize(target_size, target_size, always_apply=True),
        A.Normalize(max_pixel_value=255.0, always_apply=True)
    ])

  @abstractmethod
  def fetch_dataset():
    pass

  def __len__(self):
    return len(self.captions)

  def __getitem__(self, index):
    item = {
        key: values[index]
        for key, values in self.tokenized_captions.items()
    }
    image = cv2.imread(self.image_files[index])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = self.transforms(image=image)["image"]
    item["image"] = torch.tensor(image).permute(2, 0, 1).float()
    item["caption"] = self.captions[index]
    return item

In [None]:
class Filckr8kDataset(ImageRetrievalDataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=100, lazy_loading=False):
    super().__init__(artifact_id, tokenizer, target_size, max_length, lazy_loading)

  def fetch_dataset(self):
    if wandb.run is None:
      api = wandb.Api()
      artifact = api.artifact(self.artifact_id, type="dataset")
    else:
      articact = wandb.use_artifact(self.artifact_id, type="dataset")

    artifact_dir = artifact.download()
    annotations = pd.read_csv(os.path.join(artifact_dir, "captions.txt"))
    image_files = [
        os.path.join(artifact_dir, "Images", image_file)
        for image_file in annotations["image"].to_list()
    ]
    for image_file in image_files:
      assert os.path.isfile(image_file)
    captions = annotations["caption"].to_list()
    return image_files, captions

In [None]:
import numpy as np

def numpy_collate(batch):
  items = {}
  for i, item in enumerate(batch):
    for key, item in item.items():
      item = item if key == 'caption' else np.array(item)
      if not key in items:
        items[key] = [item]
      else:
        items[key].append(item)
  return {key: np.array(item) for key, item in items.items()}

In [None]:
from typing import Optional
from torch.utils.data import random_split, DataLoader
from transformers import AutoTokenizer
import numpy as np

DATASET_LOOKUP = {
    "flickr8k":  Filckr8kDataset
}

class ImageRetrievalDataModule:
  def __init__(
      self,
      artifact_id: str,
      dataset_name: str, 
      val_split: float = 0.2,
      tokenizer_alias: Optional[str] = None,
      target_size: int = 224,
      max_length: int = 100,
      lazy_loading: bool = False,
      train_batch_size: int = 16,
      val_batch_size: int = 16,
      num_workers: int = 4,
  ):
    self.artifact_id = artifact_id
    self.dataset_name = dataset_name
    self.val_split = val_split
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_alias)
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.train_batch_size = train_batch_size
    self.val_batch_size = val_batch_size
    self.num_workers = num_workers
    self.setup()

  @staticmethod
  def split_data(dataset: ImageRetrievalDataset, val_split: float):
    train_length = int((1 - val_split) * len(dataset))
    val_length = len(dataset) - train_length
    train_dataset, val_dataset = random_split(
        dataset, lengths=[train_length, val_length]
    )
    return train_dataset, val_dataset

  def setup(
      self,
      stage: Optional[str] = None,
  ) -> None:
    dataset = DATASET_LOOKUP[self.dataset_name](
        artifact_id=self.artifact_id,
        tokenizer=self.tokenizer,
        target_size=self.target_size,
        max_length=self.max_length,
        lazy_loading=self.lazy_loading,
    )
    self.train_dataset, self.val_dataset = self.split_data(dataset, val_split=self.val_split)
    
  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.train_batch_size,
        num_workers=self.num_workers,
        collate_fn=numpy_collate
    )

  def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.val_batch_size,
        num_workers=self.num_workers,
        collate_fn=numpy_collate
    )

In [None]:
text_encoder_alias = "distilbert-base-uncased"

data_module = ImageRetrievalDataModule(
    artifact_id="wandb/clip.lightning-image_retrieval/flickr-8k:latest",
    dataset_name="flickr8k",
    tokenizer_alias=text_encoder_alias,
    lazy_loading=True
)
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
print(f'train: {len(train_loader)}, val: {len(val_loader)}')

## Model init

In [None]:
from flax.training import train_state, checkpoints
import optax
from jax import random
import jax
from flax import struct

In [None]:
dummy_inputs = next(iter(train_loader))

In [None]:
print(dummy_inputs.keys())
print(type(dummy_inputs['image']))
print(type(dummy_inputs['input_ids']))
print(type(dummy_inputs['caption']))

In [None]:
print(dummy_inputs['image'].shape)

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

In [None]:
class TrainState(train_state.TrainState):
  metrics: Metrics
  key: jax.random.KeyArray

In [None]:
image_encoder_alias = "microsoft/resnet-50"
#image_encoder_alias = "google/vit-base-patch16-224-in21k"
model = CLIPDualEncoderModel(image_encoder_alias, text_encoder_alias)
main_rng = random.PRNGKey(42)
main_rng, init_rng, dropout_rng = random.split(main_rng, 3)
params = model.init(init_rng, dummy_inputs['image'], dummy_inputs['input_ids'], dummy_inputs['attention_mask'], train=False)['params']
state = TrainState.create(apply_fn=model.apply, 
                          params=params,
                          tx=optax.adam(1e-3),
                          key=dropout_rng,
                          metrics=Metrics.empty())

In [None]:
jax.tree_map(lambda x: x.shape, params)

## Train model

In [None]:
from tqdm.notebook import tqdm

In [None]:
@jax.jit
def train_step(state, inputs_image, inputs_input_ids, inputs_attention_mask, rng):
  rng, new_dropout_rng = jax.random.split(rng)

  def loss_fn(params):
    logits = state.apply_fn(
        {'params': params}, inputs_image, inputs_input_ids, inputs_attention_mask,
        rngs={'dropout': new_dropout_rng}
    )
    labels = jnp.arange(logits.shape[0])
    image_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
    text_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits.T, labels=labels)
    loss = (image_loss + text_loss) / 2.0
    return loss.mean()

  grad_fn = jax.value_and_grad(loss_fn)
  loss, grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss, rng

In [None]:
def train_epoch(state, train_loader, rng):
  train_losses = []
  for batch in tqdm(train_loader, leave=False):
    inputs_image = batch['image']
    inputs_input_ids = batch['input_ids']
    inputs_attention_mask = batch['attention_mask']
    state, loss, rng = train_step(state, inputs_image, inputs_input_ids, inputs_attention_mask, rng)
    train_losses.append(loss)
  return state, train_losses, rng

In [None]:
@jax.jit
def eval_step(state, inputs_image, inputs_input_ids, inputs_attention_mask):
  logits = state.apply_fn(
      {'params': state.params}, inputs_image, inputs_input_ids, inputs_attention_mask, train=False
  )
  labels = jnp.arange(logits.shape[0])
  image_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
  text_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits.T, labels=labels)
  loss = (image_loss + text_loss) / 2.0
  return loss.mean()

In [None]:
def eval_epoch(state, val_loader):
  val_losses = []
  for batch in tqdm(val_loader, leave=False):
    inputs_image = batch['image']
    inputs_input_ids = batch['input_ids']
    inputs_attention_mask = batch['attention_mask']
    loss = eval_step(state, inputs_image, inputs_input_ids, inputs_attention_mask)
    val_losses.append(loss)
  return val_losses

In [None]:
def train_model(state, train_loader, val_loader, rng, num_epochs=20, ckpt_dir='tmp/flax-checkpointing'):
  metrics_history = {
      'train_loss': [],
      'train_accuracy': [],
      'val_loss': [],
      'val_accuracy': []
  }

  for epoch_idx in tqdm(range(1, num_epochs + 1), leave=False):
    # Run optimization steps over training batches and compute batch metrics
    state, train_losses, rng = train_epoch(state, train_loader, rng)
    metrics_history['train_loss'].extend(train_losses)
    
    eval_losses = eval_epoch(state, val_loader)
    metrics_history['val_loss'].extend(eval_losses)

    print(f"epoch: {epoch_idx} | "
          f"train loss: {metrics_history['train_loss'][-1]}, "
          f"val loss: {metrics_history['val_loss'][-1]}"
    )

  checkpoints.save_checkpoint(
      ckpt_dir=ckpt_dir,
      target=state,
      step=0
  )

  return state

In [None]:
state = train_model(state, train_loader, val_loader, main_rng, num_epochs=20)

## Train model from sample

In [None]:
def cross_entropy(logits, axis):
  logprobs = jax.nn.log_softmax(logits, axis=axis)
  nll = jnp.diag(logprobs)
  ce = -jnp.mean(nll)
  return ce

In [None]:
def clip_loss(similarity):
  loss = (
      cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
  ) / 2
  return loss

In [None]:
@jax.jit
def train_step_sample(state, inputs_image, inputs_input_ids, inputs_attention_mask):
  dropout_rng, new_dropout_rng = jax.random.split(state.key)

  def compute_loss(params):
    logits = state.apply_fn(
        {'params': params}, inputs_image, inputs_input_ids, inputs_attention_mask,
        rngs={'dropout': dropout_rng}
    )
    loss = clip_loss(logits)
    return loss

  grad_fn = jax.value_and_grad(compute_loss)
  loss, grad = grad_fn(state.params)
  new_state = state.apply_gradients(grads=grad)
  metrics = {
      'loss': loss
  }
  return new_state, metrics

In [None]:
@jax.jit
def eval_step_sample(state, inputs_image, inputs_input_ids, inputs_attention_mask):
  logits = state.apply_fn(
      {'params': state.params}, inputs_image, inputs_input_ids, inputs_attention_mask, train=False
  )
  loss = clip_loss(logits)
  metrics = {'loss': loss}
  return metrics

In [None]:
def train_model_sample(state, train_loader, val_loader, num_epochs=20, ckpt_dir='tmp/flax-checkpointing'):
  for epoch_idx in tqdm(range(1, num_epochs + 1), leave=False):
    train_metrics = []
    for batch in tqdm(train_loader, leave=False):
      inputs_image = batch['image']
      inputs_input_ids = batch['input_ids']
      inputs_attention_mask = batch['attention_mask']
      state, metrics = train_step_sample(state, inputs_image, inputs_input_ids, inputs_attention_mask)
      train_metrics.append(metrics)
    print(f"Epoch... ({epoch_idx} | Train Loss: {train_metrics[-1]['loss']}")

    eval_metrics = []
    for batch in tqdm(val_loader, leave=False):
      inputs_image = batch['image']
      inputs_input_ids = batch['input_ids']
      inputs_attention_mask = batch['attention_mask']
      metrics = eval_step_sample(state, inputs_image, inputs_input_ids, inputs_attention_mask)
      eval_metrics.append(metrics)
    print(f"Epoch... ({epoch_idx} | Eval Loss: {eval_metrics[-1]['loss']}")

In [None]:
# train_model_sample(state, train_loader, val_loader, num_epochs=20)