In [None]:
!pip install wandb

In [None]:
#!pip install transformers
!pip install git+https://github.com/huggingface/transformers.git

In [None]:
!pip install git+https://github.com/Lightning-AI/lightning.git

In [None]:
!pip install albumentations

In [None]:
import torch
import torchvision.transforms as transforms
import cv2
import wandb
import pandas as pd
import os

In [None]:
import lightning

In [None]:
import albumentations as A

## Image encoder

Use FlaxResNetModel from huggingface transformers

In [None]:
from transformers import AutoImageProcessor, FlaxResNetModel

In [None]:
from flax import linen as nn

In [None]:
class ImageEncoder(nn.Module):
  model_name: str
  pretrained: bool = True
  trainable: bool = True

  def setup(self):
    self.model = FlaxResNetModel.from_pretrained(self.model_name)
    self.image_processor = AutoImageProcessor.from_pretrained(self.model_name)

  def __call__(self, x):
    inputs = self.image_processor(images=x, return_tensors='np')
    outputs = self.model(**inputs)
    return outputs.pooler_output

## Text Encoder

In [None]:
from transformers import FlaxAutoModel

In [None]:
class TextEncoder(nn.Module):
  model_name: str
  trainable: bool = True

  def setup(self):
    self.model = FlaxAutoModel.from_pretrained(self.model_name)
    self.target_token_idx = 0

  def __call__(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_state = output.last_hidden_state
    return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [None]:
class ProjectionHead(nn.Module):
  embedding_dim: int
  projection_dim: int
  dropout: float

  def setup(self):
    self.projection = nn.Dense(self.projection_dim)
    self.gelu = nn.gelu()
    self.fc = nn.Dense(self.projection_dim)
    self.dropout = nn.Dropout(self.dropout)
    self.layer_norm = nn.LayerNorm()

  def __call__(self, x):
    projected = self.projection(x)
    x = self.gelu(projected)
    x = self.fc(x)
    x = self.dropout(x)
    x += projected
    return self.layer_norm(x)

## CLIP model

In [None]:
class CLIPDualEncoderModel(nn.Module):
  image_encoder_alias: str
  text_encoder_alias: str
  image_encoder_pretrained: bool = True
  image_encoder_trainable: bool = True
  text_encoder_trainable: bool = True
  image_embedding_dims: int = 2048
  text_embedding_dims: int = 768
  projection_dims: int = 256
  dropout: float = 0.0

  def setup(self):
    self.image_encoder = ImageEncoder(
        model_name=self.image_encoder_alias,
        pretrained=self.image_encoder_pretrained,
        trainable=self.image_encoder_trainable
    )
    self.text_encoder = TextEncoder(
        model_name=self.text_encoder_alias,
        trainable=self.text_encoder_trainable
    )
    self.image_projection = ProjectionHead(
        embedding_dim=self.image_embedding_dims,
        projection_dim=self.projection_dims,
        dropout=self.dropout
    )
    self.text_projection = ProjectionHead(
        embedding_dim=self.text_embedding_dims,
        projection_dim=self.projection_dims,
        dropout=self.dropout
    )

  def __call__(self, inputs):
    image_features = self.image_encoder(inputs['image'])
    text_features = self.text_encoder(
        input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']
    )

    image_embeddings = self.image_projection(image_features)
    text_embeddings = self.text_projection(text_features)

    return image_embeddings, text_embeddings

## Creating Dataset

In [None]:
from abc import abstractmethod
class ImageRetrievalDataset(torch.utils.data.Dataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=200, lazy_loading=False):
    super().__init__()
    self.artifact_id = artifact_id
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.image_files, self.captions = self.fetch_dataset()
    self.images = self.image_files

    assert tokenizer is not None

    self.tokenizer = tokenizer

    self.tokenized_captions = tokenizer(
        list(self.captions), padding=True, truncation=True,
        max_length=self.max_length, return_tensors='pt'
    )
    self.transforms = A.Compose([
        A.Resize(target_size, target_size, always_apply=True),
        A.Normalize(max_pixel_value=255.0, always_apply=True)
    ])

  @abstractmethod
  def fetch_dataset():
    pass

  def __len__(self):
    return len(self.captions)

  def __getitem__(self, index):
    item = {
        key: values[index]
        for key, values in self.tokenized_captions.items()
    }
    image = cv2.imread(self.image_files[index])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = self.transforms(image=image)["image"]
    item["image"] = torch.tensor(image).permute(2, 0, 1).float()
    item["caption"] = self.captions[index]
    return item

In [None]:
class Filckr8kDataset(ImageRetrievalDataset):
  def __init__(self, artifact_id, tokenizer=None, target_size=None, max_length=100, lazy_loading=False):
    super().__init__(artifact_id, tokenizer, target_size, max_length, lazy_loading)

  def fetch_dataset(self):
    if wandb.run is None:
      api = wandb.Api()
      artifact = api.artifact(self.artifact_id, type="dataset")
    else:
      articact = wandb.use_artifact(self.artifact_id, type="dataset")

    artifact_dir = artifact.download()
    annotations = pd.read_csv(os.path.join(artifact_dir, "captions.txt"))
    image_files = [
        os.path.join(artifact_dir, "Images", image_file)
        for image_file in annotations["image"].to_list()
    ]
    for image_file in image_files:
      assert os.path.isfile(image_file)
    captions = annotations["caption"].to_list()
    return image_files, captions

In [None]:
from typing import Optional
from torch.utils.data import random_split, DataLoader
#from pytorch_lightning import LightningDataModule # pygments error
from transformers import AutoTokenizer

DATASET_LOOKUP = {
    "flickr8k":  Filckr8kDataset
}

class ImageRetrievalDataModule:
  def __init__(
      self,
      artifact_id: str,
      dataset_name: str, 
      val_split: float = 0.2,
      tokenizer_alias: Optional[str] = None,
      target_size: int = 224,
      max_length: int = 100,
      lazy_loading: bool = False,
      train_batch_size: int = 16,
      val_batch_size: int = 16,
      num_workers: int = 4,
  ):
    self.artifact_id = artifact_id
    self.dataset_name = dataset_name
    self.val_split = val_split
    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_alias)
    self.target_size = target_size
    self.max_length = max_length
    self.lazy_loading = lazy_loading
    self.train_batch_size = train_batch_size
    self.val_batch_size = val_batch_size
    self.num_workers = num_workers
    self.setup()

  @staticmethod
  def split_data(dataset: ImageRetrievalDataset, val_split: float):
    train_length = int((1 - val_split) * len(dataset))
    val_length = len(dataset) - train_length
    train_dataset, val_dataset = random_split(
        dataset, lengths=[train_length, val_length]
    )
    return train_dataset, val_dataset

  def setup(
      self,
      stage: Optional[str] = None,
  ) -> None:
    dataset = DATASET_LOOKUP[self.dataset_name](
        artifact_id=self.artifact_id,
        tokenizer=self.tokenizer,
        target_size=self.target_size,
        max_length=self.max_length,
        lazy_loading=self.lazy_loading,
    )
    self.train_dataset, self.val_dataset = self.split_data(dataset, val_split=self.val_split)

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size=self.train_batch_size,
        num_workers=self.num_workers
    )

  def val_dataloader(self):
    return DataLoader(
        self.val_dataset,
        batch_size=self.val_batch_size,
        num_workers=self.num_workers,
    )

In [None]:
text_encoder_alias = "distilbert-base-uncased"

data_module = ImageRetrievalDataModule(
    artifact_id="wandb/clip.lightning-image_retrieval/flickr-8k:latest",
    dataset_name="flickr8k",
    tokenizer_alias=text_encoder_alias,
    lazy_loading=True
)
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

In [None]:
print(f'train: {len(train_loader)}, val: {len(val_loader)}')

## Model init

In [None]:
from flax.training import train_state, checkpoints
import optax

In [None]:
image_encoder_alias = "resnet50"
model = CLIPDualEncoderModel(image_encoder_alias, text_encoder_alias)
state = train_state.TrainState.create(apply_fn=model.__call__, 
                                      params=model.params, 
                                      tx=optax.adam(1e-3),
                                      )