In [2]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

### Download Dataset

In [11]:
!pip install kaggle --upgrade
os.environ['KAGGLE_USERNAME'] = "boringmachine"
os.environ['KAGGLE_KEY'] = "81a22b5cce4e170d0996b2bb527bcfea"

### For Flickr 8k
!kaggle datasets download -d adityajn105/flickr8k
!unzip flickr8k.zip
dataset = "8k"


### For Flickr 30k
# !kaggle datasets download -d hsankesara/flickr-image-dataset
# !unzip flickr-image-dataset.zip
# dataset = "30k"

Collecting kaggle
  Downloading kaggle-1.5.12.tar.gz (58 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.0/59.0 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py) ... [?25ldone
[?25h  Created wheel for kaggle: filename=kaggle-1.5.12-py3-none-any.whl size=73031 sha256=4638a364e2444ccb111526debb52979f0ab34ea23d5a2413f1a7fe80690416bf
  Stored in directory: /root/.cache/pip/wheels/ac/b2/c3/fa4706d469b5879105991d1c8be9a3c2ef329ba9fe2ce5085e
Successfully built kaggle
Installing collected packages: kaggle
Successfully installed kaggle-1.5.12
[0mDownloading flickr8k.zip to /notebooks/ai_notebooks/clip
 98%|██████████████████████████████████████▏| 1.02G/1.04G [00:06<00:00, 235MB/s]
100%|███████████████████████████████████████| 1.04G/1.04G [00:06<00:00, 180MB/s]
Archive:  flickr8k.zip
  inflating: Images/1000268201_693b08cb0e.jpg  
  inf

In [12]:
if dataset == "8k":
  df = pd.read_csv("captions.txt")
  df['id'] = [id_ for id_ in range(df.shape[0] // 5) for _ in range(5)]
  df.to_csv("captions.csv", index=False)
  df = pd.read_csv("captions.csv")
  image_path = "/content/Images"
  captions_path = "/content"
elif dataset == "30k":
  df = pd.read_csv("/content/flickr30k_images/results.csv", delimiter="|")
  df.columns = ['image', 'caption_number', 'caption']
  df['caption'] = df['caption'].str.lstrip()
  df['caption_number'] = df['caption_number'].str.lstrip()
  df.loc[19999, 'caption_number'] = "4"
  df.loc[19999, 'caption'] = "A dog runs across the grass ."
  ids = [id_ for id_ in range(len(df) // 5) for _ in range(5)]
  df['id'] = ids
  df.to_csv("captions.csv", index=False)
  image_path = "/content/flickr30k_images/flickr30k_images"
  captions_path = "/content"

df.head()

Unnamed: 0,image,caption,id
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,0
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,0
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,0
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,0
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,0


### Config

In [16]:
from pathlib import Path

In [24]:
image_path = Path("./Flickr8k/Images")

In [25]:
image_path

PosixPath('Flickr8k/Images')

In [30]:
captions_path = Path("./Flickr8k/captions.txt")

In [31]:
df = pd.read_csv(captions_path)

In [32]:
class CFG:
    debug = False
    image_path = "./Flickr8k/Images"
    captions_path = "./Flickr8k"
    batch_size = 32
    num_workers = 4
    head_lr = 1e-3
    image_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    image_embedding = 2048
    text_encoder_model = "distilbert-base-uncased"
    text_embedding = 768
    text_tokenizer = "distilbert-base-uncased"
    max_length = 200

    pretrained = True # for both image encoder and text encoder
    trainable = True # for both image encoder and text encoder
    temperature = 1.0

    # image size
    size = 224

    # for projection head; used for both image and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1

### Utils

In [33]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()
    
    def reset(self):
        self.avg, self.sum, self.count = [0] * 3
    
    def update(self, val, count):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count
    
    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

In [34]:
def get_lr(optimizier):
    for param_group in optimizier.param_groups:
        return param_group['lr']

### Dataset

`image_filenames` and `captions` must have the same length => if multiple cations for each image => replicate the `image_filenames`

In [10]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True,
            maxlength=CFG.max_length
        )
        self.transforms = transforms
    
    def __getitem__(self, idx):
        item = {key: torch.tensor(values[idx]) for key, value in self.encoded_captions.items()}
        image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item['image'] = torch.tensor(image).permute(2, 0, 1).float()
        item['caption'] = self.captions[idx]

In [None]:
values