In [None]:
!pip install wandb

In [None]:
#!pip install transformers
!pip install git+https://github.com/huggingface/transformers.git

In [None]:
!pip install git+https://github.com/Lightning-AI/lightning.git

In [None]:
!pip install albumentations

In [None]:
import torch
import torchvision.transforms as transforms
import cv2
import wandb
import pandas as pd
import os

In [None]:
import lightning

In [None]:
import albumentations as A

## Image encoder

Use FlaxResNetModel from huggingface transformers

In [None]:
from transformers import AutoImageProcessor, FlaxResNetModel

In [None]:
from flax import linen as nn

In [None]:
class ImageEncoder(nn.Module):
  def __init__(self, model_name: str, pretrained: bool = True, trainable: bool = True):
    super().__init__()
    self.model = FlaxResNetModel.from_pretrained(model_name)
    self.image_processor = AutoImageProcessor.from_pretrained(model_name)

  def __call__(self, x):
    inputs = self.image_processor(images=x, return_tensors='np')
    outputs = self.model(**inputs)
    return outputs.pooler_output

## Text Encoder

In [None]:
from transformers import FlaxAutoModel

In [None]:
class TextEncoder(nn.Module):
  def __init__(self, model_name: str, trainable: bool = True) -> None:
    super().__init__()
    self.model = FlaxAutoModel.from_pretrained(model_name)
    self.target_token_idx = 0

  def __call__(self, input_ids, attention_mask):
    output = self.model(input_ids=input_ids, attention_mask=attention_mask)
    last_hidden_state = output.last_hidden_state
    return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

In [None]:
class ProjectionHead(nn.Module):
  def __init__(self, embedding_dim: int, projection_dim: int, dropout: float) -> None:
    super().__init__()
    self.projection = nn.Dense(projection_dim)
    self.gelu = nn.gelu()
    self.fc = nn.Dense(projection_dim)
    self.dropout = nn.Dropout(dropout)
    self.layer_norm = nn.LayerNorm()

  def __call__(self, x):
    projected = self.projection(x)
    x = self.gelu(projected)
    x = self.fc(x)
    x = self.dropout(x)
    x += projected
    return self.layer_norm(x)