## Libraries

In [None]:
!pip install openai-clip
!pip install datasets

Collecting openai-clip
  Downloading openai-clip-1.0.1.tar.gz (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m13.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from openai-clip)
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: openai-clip
  Building wheel for openai-clip (setup.py) ... [?25l[?25hdone
  Created wheel for openai-clip: filename=openai_clip-1.0.1-py3-none-any.whl size=1368605 sha256=f99276000b64a92a2b653f707e6748d9c7eb6f0c548d3a6ca22362c2d04506f6
  Stored in directory: /root/.cache/pip/wheels/08/77/8e/8d2f862df6bf7fb4e2007062d2cbaeae49862ec7b56d041229
Successfully built openai-clip
Installing collected packages: ftfy, openai-clip
Successfully installed ftfy-6.2.0 openai-clip-1.0.1
Collecting datasets
  Downlo

In [None]:
import json
import string, os, re, pickle
import pandas as pd
from PIL import Image
import requests
from io import StringIO
from tqdm import tqdm
from datasets import Dataset

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import clip
from transformers import CLIPProcessor, CLIPModel, CLIPVisionConfig

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Create training and testing dataset

In [None]:
image_path = "/content/drive/MyDrive/datasets/pokemon_png"
caption_path = "/content/drive/MyDrive/datasets/pokemon_caption.csv"
#read caption file
caption_df = pd.read_csv(caption_path)

#tidy up columns
caption_df.drop(caption_df.columns[[0,1,3]],axis=1,inplace=True)
caption_df.columns = ['pokedex','image','caption_1','caption_2']

#remove tag, carriage return, punctuation, leading/trailing space from caption
caption_df["caption_1"] = caption_df["caption_1"].str.replace('<p class="version-x active">\n', "")
caption_df["caption_1"] = caption_df["caption_1"].str.replace('\n                </p>', "")
caption_df["caption_2"] = caption_df["caption_2"].str.replace('<p class="version-y">\n', "")
caption_df["caption_2"] = caption_df["caption_2"].str.replace('\n                </p>', "")

caption_df["caption_1"] = caption_df["caption_1"].str.translate(str.maketrans('', '', string.punctuation))
caption_df["caption_2"] = caption_df["caption_2"].str.translate(str.maketrans('', '', string.punctuation))

caption_df["caption_1"] = caption_df["caption_1"].str.strip()
caption_df["caption_2"] = caption_df["caption_2"].str.strip()

#add filename column
caption_df['filename'] = caption_df["image"]

#concatenate caption 1 and 2
df1 = caption_df[['pokedex','filename','caption_1']].rename(columns={'caption_1': 'caption'})
df2 = caption_df[['pokedex','filename','caption_2']].rename(columns={'caption_2': 'caption'})
caption_df = pd.concat([df1,df2], ignore_index=True)

#drop caption rows with no image
caption_df.reset_index(drop=True, inplace=True)

#check
caption_df.head()

Unnamed: 0,pokedex,filename,caption
0,1,https://assets.pokemon.com/assets/cms2/img/pok...,For some time after its birth it uses the nutr...
1,2,https://assets.pokemon.com/assets/cms2/img/pok...,The more sunlight Ivysaur bathes in the more s...
2,3,https://assets.pokemon.com/assets/cms2/img/pok...,While it basks in the sun it can convert the l...
3,4,https://assets.pokemon.com/assets/cms2/img/pok...,The flame on its tail shows the strength of it...
4,5,https://assets.pokemon.com/assets/cms2/img/pok...,When it swings its burning tail the temperatur...


In [None]:
class image_title_dataset():
    def __init__(self, list_image_path,list_txt):
        self.image_path = list_image_path
        self.title  = clip.tokenize(list_txt)

    def __len__(self):
        return len(self.title)

    def __getitem__(self, idx):
        image = preprocess(Image.open(requests.get(self.image_path[idx], stream=True).raw))
        title = self.title[idx]
        return image, title

In [None]:
def gen():
    for index, row in caption_df.iterrows():
      yield {"caption":row["caption"], "filename":row["filename"]}

dataset = Dataset.from_generator(gen).shuffle(seed=123)
dataset = dataset.train_test_split(test_size=0.1, shuffle = False)

list_image_path = []
list_txt = []
for i in dataset["train"]:
  list_image_path.append(i["filename"])
  list_txt.append(i["caption"])

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

## Pretrained models and parameters loading

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

'cuda:0'

In [None]:
#device = "cuda:0" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 136MiB/s]


In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

## Model Training

In [None]:
num_epochs = 10
for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar:
        optimizer.zero_grad()

        images,texts = batch

        images= images.to(device)
        texts = texts.to(device)


        logits_per_image, logits_per_text = model(images, texts)


        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2


        total_loss.backward()
        convert_models_to_fp32(model)
        optimizer.step()
        clip.model.convert_weights(model)

        pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': total_loss,
        }, f"/content/drive/MyDrive/Pretrained Models/clip_hard_{epoch}.pt")

Epoch 0/10, Loss: 3.1133: 100%|██████████| 19/19 [03:27<00:00, 10.92s/it]
Epoch 1/10, Loss: 2.5312: 100%|██████████| 19/19 [03:19<00:00, 10.51s/it]
Epoch 2/10, Loss: 1.3711: 100%|██████████| 19/19 [03:17<00:00, 10.39s/it]
Epoch 3/10, Loss: 0.4082: 100%|██████████| 19/19 [03:21<00:00, 10.62s/it]
Epoch 4/10, Loss: 0.3752: 100%|██████████| 19/19 [03:19<00:00, 10.48s/it]
Epoch 5/10, Loss: 0.5234: 100%|██████████| 19/19 [03:16<00:00, 10.32s/it]
Epoch 6/10, Loss: 0.2979: 100%|██████████| 19/19 [03:16<00:00, 10.33s/it]
Epoch 7/10, Loss: 0.4375: 100%|██████████| 19/19 [03:24<00:00, 10.74s/it]
Epoch 8/10, Loss: 0.3826: 100%|██████████| 19/19 [03:18<00:00, 10.46s/it]
Epoch 9/10, Loss: 0.7168: 100%|██████████| 19/19 [03:17<00:00, 10.40s/it]
