## Libraries

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.3/297.3 kB[0m [31m38.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m27.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m69.5 

In [None]:
import string, os, re, pickle
import numpy as np
import pandas as pd
from PIL import Image
import requests

#from sklearn.model_selection import train_test_split
import torch
from datasets import Dataset
from transformers import AutoProcessor, Blip2ForConditionalGeneration

## Create training and testing dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
image_path = "/content/drive/MyDrive/datasets/pokemon_png"
caption_path = "/content/drive/MyDrive/datasets/pokemon_caption.csv"

In [None]:
#read caption file
caption_df = pd.read_csv(caption_path)

#tidy up columns
caption_df.drop(caption_df.columns[[0,1,3]],axis=1,inplace=True)
caption_df.columns = ['pokedex','image','caption_1','caption_2']

#remove tag, carriage return, punctuation, leading/trailing space from caption
caption_df["caption_1"] = caption_df["caption_1"].str.replace('<p class="version-x active">\n', "")
caption_df["caption_1"] = caption_df["caption_1"].str.replace('\n                </p>', "")
caption_df["caption_2"] = caption_df["caption_2"].str.replace('<p class="version-y">\n', "")
caption_df["caption_2"] = caption_df["caption_2"].str.replace('\n                </p>', "")

caption_df["caption_1"] = caption_df["caption_1"].str.translate(str.maketrans('', '', string.punctuation))
caption_df["caption_2"] = caption_df["caption_2"].str.translate(str.maketrans('', '', string.punctuation))

caption_df["caption_1"] = caption_df["caption_1"].str.strip()
caption_df["caption_2"] = caption_df["caption_2"].str.strip()

#add filename column
#caption_df['image'] = caption_df["image"].transform(lambda x: image_path + "/" + x[x.rindex('/')+1:].strip("0"))

#concatenate caption 1 and 2
df1 = caption_df[['image','caption_1']].rename(columns={'caption_1': 'text'})
df2 = caption_df[['image','caption_2']].rename(columns={'caption_2': 'text'})
dataset_df = pd.concat([df1,df2], ignore_index=True)

#drop caption rows with no image
dataset_df.reset_index(drop=True, inplace=True)

#check
display(dataset_df)

Unnamed: 0,image,text
0,https://assets.pokemon.com/assets/cms2/img/pok...,For some time after its birth it uses the nutr...
1,https://assets.pokemon.com/assets/cms2/img/pok...,The more sunlight Ivysaur bathes in the more s...
2,https://assets.pokemon.com/assets/cms2/img/pok...,While it basks in the sun it can convert the l...
3,https://assets.pokemon.com/assets/cms2/img/pok...,The flame on its tail shows the strength of it...
4,https://assets.pokemon.com/assets/cms2/img/pok...,When it swings its burning tail the temperatur...
...,...,...
2045,https://assets.pokemon.com/assets/cms2/img/pok...,It bears resemblance to a Pokémon that became ...
2046,https://assets.pokemon.com/assets/cms2/img/pok...,It was named after a mysterious object recorde...
2047,https://assets.pokemon.com/assets/cms2/img/pok...,There was supposedly an incident in which it l...
2048,https://assets.pokemon.com/assets/cms2/img/pok...,It’s thought that this Pokémon lived in ancien...


In [None]:
def gen():
    for index, row in dataset_df.iterrows():
      yield {"text":row["text"], "image":Image.open(requests.get(row['image'], stream=True).raw)}
dataset = Dataset.from_generator(gen).shuffle(seed=123)
dataset = dataset.train_test_split(test_size=0.1, shuffle = False)

Generating train split: 0 examples [00:00, ? examples/s]

## Create Dataset for fine-tuning

In [None]:
from torch.utils.data import Dataset, DataLoader

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = item["text"]
        return encoding

def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch


## Pretrained models and parameters loading

In [None]:
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 5,242,880 || all params: 3,749,922,816 || trainable%: 0.13981301102065136


In [None]:
train_dataset = ImageCaptioningDataset(dataset["train"], processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=20, collate_fn=collate_fn)

## Model Training

In [None]:
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()

for epoch in range(10):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device, torch.float16)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)

    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
  model.save_pretrained(f"/content/drive/MyDrive/Pretained Models/blip2_hard_{epoch}")

Epoch: 0
Loss: 6.5
Loss: 5.94140625
Loss: 5.9296875
Loss: 5.5
Loss: 5.2890625
Loss: 5.12890625
Loss: 4.42578125
Loss: 4.390625
Loss: 4.29296875
Loss: 4.12890625
Loss: 4.19921875
Loss: 3.734375
Loss: 3.96484375
Loss: 3.720703125
Loss: 3.71875
Loss: 3.775390625
Loss: 3.75390625
Loss: 3.265625
Loss: 3.31640625
Loss: 3.359375
Loss: 3.37109375
Loss: 3.291015625
Loss: 3.109375
Loss: 3.06640625
Loss: 3.119140625
Loss: 3.052734375
Loss: 3.16796875
Loss: 3.107421875
Loss: 3.2265625
Loss: 3.080078125
Loss: 2.736328125
Loss: 2.90234375
Loss: 2.962890625
Loss: 2.6875
Loss: 3.08984375
Loss: 2.76171875
Loss: 3.01953125
Loss: 2.76953125
Loss: 3.044921875
Loss: 2.76953125
Loss: 2.82421875
Loss: 2.90234375
Loss: 2.810546875
Loss: 2.798828125
Loss: 2.970703125
Loss: 2.701171875
Loss: 2.763671875
Loss: 2.306640625
Loss: 2.654296875
Loss: 2.23046875
Loss: 3.0078125
Loss: 2.84765625
Loss: 2.82421875
Loss: 2.4375
Loss: 2.572265625
Loss: 2.634765625
Loss: 2.341796875
Loss: 2.57421875
Loss: 2.6484375
Loss: 2.