## set up environment

In [None]:
import argparse
import numpy as np
import torch
import json
import pprint
from PIL import Image, ImageDraw
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomCrop, RandomHorizontalFlip, RandomGrayscale, ColorJitter
import tempfile
import tqdm
import os
import collections
import sklearn.metrics
from scipy.stats import rankdata
import torchvision.transforms.functional as F
from matplotlib import pyplot as plt

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
#!mkdir /content/images
!wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip  # visual genome
# https://s3.us-west-2.amazonaws.com/ai2-rowanz/vcr1images.zip # VCR
!unzip images2.zip -d /content/images

--2023-12-01 23:08:05--  https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip
Resolving cs.stanford.edu (cs.stanford.edu)... 171.64.64.64
Connecting to cs.stanford.edu (cs.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5471658058 (5.1G) [application/zip]
Saving to: ‘images2.zip’

images2.zip          20%[===>                ]   1.03G  6.93MB/s    eta 13m 36s

In [None]:
!pip install -q git+https://github.com/huggingface/peft.git transformers bitsandbytes datasets

## Load dataset

## Creating pytorch dataset

In [None]:
class SquarePad:
    # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
    def __call__(self, image):
        max_wh = max(image.size)
        p_left, p_top = [(max_wh - s) // 2 for s in image.size]
        p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
        padding = (p_left, p_top, p_right, p_bottom)
        return F.pad(image, padding, 0, 'constant')

In [None]:
class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, data, args, training=False):
        self.args = args
        self.data = data
        self.id2data = {d['instance_id']: d for d in self.data}
        self.training = training
        if self.args.widescreen_processing in [0, 1]:
            self.preprocess = self._transform_train(args.input_resolution) if self.training else self._transform_test(args.input_resolution)
        else:
            self.preprocess = self._transform_train_pad(args.input_resolution) if self.training else self._transform_test_pad(args.input_resolution)

    def url2filepath(self, url):
        if 'VG_100K_2' in url:
            return self.args.vg_dir + '/'.join(url.split('/')[-2:])
        # else:
        #     # http://s3-us-west-2.amazonaws.com/ai2-rowanz/vcr1images/lsmdc_3023_DISTRICT_9/3023_DISTRICT_9_01.21.02.808-01.21.16.722@5.jpg
        #     if 'vcr1images' in self.args.vcr_dir:
        #         return self.args.vcr_dir + '/'.join(url.split('/')[-2:])
        #     else:
        #         return self.args.vcr_dir + '/'.join(url.split('/')[-3:])

    def hide_region(self, image, bboxes):
        image = image.convert('RGBA')
        if self.args.hide_true_bbox == 1: # hide mode
            draw = ImageDraw.Draw(image, 'RGBA')
        if self.args.hide_true_bbox in [2,5,7,8,9]: #highlight mode
            overlay = Image.new('RGBA', image.size, '#00000000')
            draw = ImageDraw.Draw(overlay, 'RGBA')
        if self.args.hide_true_bbox == 3 or self.args.hide_true_bbox == 6: #blackout mode or position only mode
            overlay = Image.new('RGBA', image.size, '#7B7575ff')
            draw = ImageDraw.Draw(overlay, 'RGBA')
        for bbox in bboxes:
            x = bbox['left']
            y = bbox['top']
            if self.args.hide_true_bbox == 1: # hide mode
                draw.rectangle([(x, y), (x+bbox['width'], y+bbox['height'])], fill='#7B7575')
            elif self.args.hide_true_bbox in [2,5,7,8,9]: # highlight mode
                draw.rectangle([(x, y), (x+bbox['width'], y+bbox['height'])],
                               fill='#ff05cd3c', outline='#05ff37ff', width=3)
            elif self.args.hide_true_bbox == 3: # blackout mode
                draw.rectangle([(x, y), (x+bbox['width'], y+bbox['height'])],
                               fill='#00000000')
            elif self.args.hide_true_bbox == 6: # position only mode
                draw.rectangle([(x, y), (x+bbox['width'], y+bbox['height'])],
                               fill='#ff05cdff')

        if self.args.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]:
            image = Image.alpha_composite(image, overlay)

        return image

    def _transform_train(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            RandomCrop(n_px),
            RandomHorizontalFlip(),
            #RandomGrayscale(), # these were used in the model in the paper, but, something seems bugged when pytorch updated.
            ColorJitter(brightness=.5, hue=.3),
            lambda image: image.convert("RGB"),
            ToTensor(),
            # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def _transform_test(self, n_px):
        return Compose([
            Resize(n_px, interpolation=Image.BICUBIC),
            CenterCrop(n_px),
            lambda image: image.convert("RGB"),
            ToTensor(),
            # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def _transform_train_pad(self, n_px):
        return Compose([
            SquarePad(),
            Resize(n_px, interpolation=Image.BICUBIC),
            RandomHorizontalFlip(),
            #RandomGrayscale(), # these were used in the model in the paper, but, something seems bugged when pytorch updated.
            ColorJitter(brightness=.5, hue=.3),
            lambda image: image.convert("RGB"),
            ToTensor(),
            # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def _transform_test_pad(self, n_px):
        return Compose([
            SquarePad(),
            Resize(n_px, interpolation=Image.BICUBIC),
            lambda image: image.convert("RGB"),
            ToTensor(),
            # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        ])

    def image_to_torch_tensor(self, image):
        if self.args.widescreen_processing == 1:
            width, height = image.size
            if width >= height:
                im1 = {'height': height, 'width': height, 'left': 0, 'top': 0}
                im2 = {'height': height, 'width': height, 'left': width-height, 'top': 0}
            else:
                im1 = {'height': width, 'width': width, 'left': 0, 'top': 0}
                im2 = {'height': width, 'width': width, 'left': 0, 'top': height-width}
            regions = [image.crop((bbox['left'], bbox['top'], bbox['left'] + bbox['width'], bbox['top'] + bbox['height'])) for bbox in [im1, im2]]
            image = torch.stack([self.preprocess(r) for r in regions], 0)
        else:
            image = self.preprocess(image)
        return image

    def __getitem__(self, idx): # modified to have coin flip to sample positive / negative samples
        c_data = self.data[idx]

        random_idx = np.random.randint(0, self.__len__())
        while random_idx == idx:
          random_idx = np.random.randint(0, self.__len__())
        random_data = self.data[random_idx]

        true_example = 1 #np.random.randint(0, 2)

        image = Image.open(self.url2filepath(c_data['inputs']['image']['url'])) if true_example else  Image.open(self.url2filepath(random_data['inputs']['image']['url']))

        if self.args.hide_true_bbox > 0:
            image = self.hide_region(image, c_data['inputs']['bboxes'])

        clue = c_data['inputs']['clue']
        caption = c_data['targets']['inference']

        cid = c_data['instance_id']
        image = self.image_to_torch_tensor(image)
        return {'image':image, 'caption':caption, 'clue': clue, 'id': cid, 'ground_truth': float(true_example)}

    def get(self, cid):
        return self.id2data[cid]

    def __len__(self):
        return len(self.data)



def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [None]:
class Args:
  def __init__(self):
    self.batch_size= 64 # Increase this if your GPU can handle it
    self.lr = 0.00001
    self.n_epochs= 10
    self.widescreen_processing= 2
    self.hide_true_bbox= 2
    self.vg_dir= 'images/'
    self.input_resolution=(224,224)
    self.workers_dataloader=4
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args = Args()

In [None]:
!unzip /content/drive/MyDrive/MMML-A2/sherlock_train_v1_1.json.zip # change with your path to sherlock train data

Archive:  /content/drive/MyDrive/MMML-A2/sherlock_train_v1_1.json.zip
replace sherlock_train_v1_1.json? [y]es, [n]o, [A]ll, [N]one, [r]ename: N


In [None]:
new_train=[]

with open("sherlock_train_v1_1.json") as f:
    train = json.load(f)
    for data in train:
      if "VG_100K_2" in data['inputs']['image']['url']:
        new_train.append(data)


In [None]:
train_loader = torch.utils.data.DataLoader(
        CLIPDataset(new_train, args, training=True),
        batch_size=args.batch_size, num_workers=args.workers_dataloader, shuffle=True, worker_init_fn=worker_init_fn)

In [None]:
!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.41.2.post2-py3-none-any.whl (92.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.2.post2


In [None]:
from transformers import Blip2Model, Blip2Processor, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Salesforce/blip2-opt-2.7b")
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2Model.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="auto", load_in_8bit=True)

pytorch_model.bin.index.json:   0%|          | 0.00/122k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/5.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
from peft import LoraConfig, TaskType, get_peft_model

config = LoraConfig( # include target modules maybe? specify task type?
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
)

model = get_peft_model(model, config)
model.print_trainable_parameters()
model = model.to(args.device)

# trainable params: 5,242,880 || all params: 3,749,922,816 || trainable%: 0.13981301102065136

trainable params: 5,242,880 || all params: 3,749,922,816 || trainable%: 0.13981301102065136


In [None]:
for batch in train_loader:
    images, captions = batch['image'], batch['caption']
    inputs = processor(images=images, text=captions, return_tensors="pt", padding=True)
    print(f'inputs: {inputs.keys()}')


    qformer_out = model.get_qformer_features(inputs.pop('pixel_values'))[0]#.last_hidden_state
    text_embeddings = model.language_model.get_input_embeddings()(inputs.pop('input_ids'))
    print(f'qformer: {qformer_out.shape} text_embeddings: {text_embeddings.shape}')

    qformer_projected = model.language_projection(qformer_out)
    print(f'qformer proj: {qformer_projected.shape}')

    input_embeds = torch.cat([qformer_projected, text_embeddings], dim=1)
    print(f'input embeds: {input_embeds.shape}')
    torch.cuda.empty_cache()
    del images
    del captions
    del inputs
    break

inputs: dict_keys(['pixel_values', 'input_ids', 'attention_mask'])
qformer: torch.Size([32, 32, 768]) text_embeddings: torch.Size([32, 19, 2560])
qformer proj: torch.Size([32, 32, 2560])
input embeds: torch.Size([32, 51, 2560])


In [None]:
model.language_projection

Linear8bitLt(in_features=768, out_features=2560, bias=True)

## Load model and processor

Training with infoNCE loss

In [None]:
import torch
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

torch.cuda.empty_cache()

model.train()

# Adaptive average pool, text embeddings are Batch_size x Seq_len x Vocab_size,
# since the sequnce len is variable we average pool over it
aavg_pool = torch.nn.AdaptiveAvgPool1d(1)

# we're classifying if an image - inference pair match or not, so we use Binary Cross Entropy w/ Logits
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()

# projecting text embeddings to match the dimension of image embeddings
text_proj = torch.nn.Linear(1, 32).to(args.device)
optim = torch.optim.AdamW(model.parameters(), lr=args.lr)

for epoch in range(args.n_epochs):

    # setting up tqdm batch bar
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{args.n_epochs}", unit="batch") as pbar:

      ground_truth = torch.arange(args.batch_size, dtype=torch.long, device=args.device)

      for i, batch in enumerate(train_loader):
          optim.zero_grad()

          images, captions = batch['image'], batch['caption']

          inputs = processor(images=images, text=captions, return_tensors="pt", padding=True)

          inputs.pop('attention_mask') # not used as we aren't using decoder

          # moving to gpu
          inputs = {k: v.to('cuda') for k, v in inputs.items()}

          # forward pass
          qformer_out = model.get_qformer_features(inputs.pop('pixel_values')).last_hidden_state
          qformer_out = model.language_projection(qformer_out)
          text_embeddings = model.language_model.get_input_embeddings()(inputs.pop('input_ids'))

          # getting image and language embeddings
          image_features = qformer_out.flatten(start_dim=1).to(torch.float32)
          text_features = text_proj(aavg_pool(text_embeddings.transpose(1, 2)).to(torch.float32)).transpose(1,2).flatten(start_dim=1)

          c_batch_size = image_features.shape[0]

          logits_per_image = image_features @ text_features.t() # removed logit scale
          logits_per_text = logits_per_image.t()
          total_loss = (loss_img(logits_per_image, ground_truth[:c_batch_size]) +
                                      loss_txt(logits_per_text, ground_truth[:c_batch_size]))/2


          total_loss.backward()
          optim.step()

          pbar.set_postfix({"Training Loss": total_loss.item()})
          pbar.update()

