In [5]:
# !pip install transformers torch pillow
# !pip install opencv-python-headless
# !pip install matplotlib

In [28]:
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from torch.utils.data import Dataset, DataLoader
from typing import Union, List, Tuple
from PIL import Image
import torch
import urllib
import os
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import json
from datasets import Dataset
from tqdm import tqdm
import gc

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

In [29]:
root = '/home/sagemaker-user/img2text'
config = {
    "sketches": f"{root}/sketches",
    "ground_truth": f"{root}/ground_truth.json",
    "epochs": 50
}

# Load Sketches

In [30]:
metadata = json.load(open(f'{config["sketches"]}/metadata.json'))

sketches = []

for key in metadata:
    for i in tqdm(range(metadata[key]),desc= f'Loading {key}'):
        file_name = key+ f"_{i+1}"
        sketch = Image.open(os.path.join( config['sketches'], file_name+".png"))
        sketches.append((file_name, sketch))

Loading bike: 100%|██████████| 9/9 [00:00<00:00, 5534.19it/s]
Loading car: 100%|██████████| 10/10 [00:00<00:00, 8184.01it/s]
Loading cat: 100%|██████████| 7/7 [00:00<00:00, 7580.72it/s]
Loading cycle: 100%|██████████| 6/6 [00:00<00:00, 4531.12it/s]
Loading plane: 100%|██████████| 10/10 [00:00<00:00, 6151.81it/s]
Loading signal: 100%|██████████| 6/6 [00:00<00:00, 3076.51it/s]


# Load Model

In [31]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Baseline

In [32]:
result = {}

for category, sketch in tqdm(sketches, desc=f"Generating Caption..."):
    inputs = processor(sketch, return_tensors="pt")
    outputs = model.generate(**inputs)
    caption = processor.decode(outputs[0], skip_special_tokens=True)
    result[category] = caption

print(result)

with open('img2text_baseline.json', 'w+') as f:
    json.dump(result, f)

Generating Caption...: 100%|██████████| 48/48 [01:02<00:00,  1.31s/it]

{'bike_1': 'a motorcycle with a rider on it', 'bike_2': 'a motorcycle is shown in the shape of a motorcycle', 'bike_3': 'a drawing of a bicycle', 'bike_4': 'a drawing of a man with a gun', 'bike_5': 'a motorcycle with a side view', 'bike_6': 'a drawing of a motorcycle', 'bike_7': 'a drawing of a person riding a bike', 'bike_8': 'a motorcycle with a helmet and helmet on it', 'bike_9': 'a motorcycle is shown in the shape of a motorcycle', 'car_1': 'a drawing of a car', 'car_2': 'a map of the state of new york', 'car_3': 'a car with a white background', 'car_4': 'a car with the number plate removed', 'car_5': 'a car is shown in the shape of a car', 'car_6': 'a drawing of a truck', 'car_7': 'a car is shown in the shape of a car', 'car_8': 'a drawing of a truck', 'car_9': 'a car with wheels and wheels', 'car_10': 'a drawing of a car', 'cat_1': 'a black and white drawing of a cat', 'cat_2': 'a black and white drawing of a cat', 'cat_3': "a drawing of a cat ' s face", 'cat_4': 'a black and wh




# Fine Tuning

## Loading Ground Truth

In [36]:
ground_truth_captions = json.load(open(config["ground_truth"]))
print(ground_truth_captions)

{'bike_1': 'Drawing of the side view of a motorcycle', 'bike_2': 'Isometric view drawing of a police motorcycle', 'bike_3': 'Isometric view drawing of a police motorcycle', 'bike_4': 'Isometric view drawing of a bike', 'bike_5': 'Isometric view drawing of a police motorcycle', 'bike_6': 'Isometric view drawing of a bullet motorcycle', 'bike_7': 'Front view drawing of a motorcycle', 'bike_8': 'Isometric view drawing of a motorcycle', 'bike_9': 'Drawing of a motorcycle', 'car_1': 'Sketch of a car', 'car_2': 'Isometric view drawing of a car', 'car_3': 'Drawing of a sedan car', 'car_4': 'Sketch of a race car with spoilers', 'car_5': 'Front view drawing of a car', 'car_6': 'Side view drawing of a car', 'car_7': 'Drawing of a car', 'car_8': 'Front view of a limousine car', 'car_9': 'Side view sketch of a car', 'car_10': 'Isometric view drawing of a classic car', 'cat_1': 'Sketch of a cat', 'cat_2': 'Outline Drawing of a cat', 'cat_3': "Drawing of a cat's face", 'cat_4': 'Side view sketch of 

## Prepare Dataset

In [37]:
dataset = []

for name, image in sketches:
    dataset.append({
        "image": image,
        "text": ground_truth_captions[name]
    })


dataset = Dataset.from_list(dataset)

In [38]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [39]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

## Training

In [40]:
torch.cuda.empty_cache()
gc.collect()

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

scaler = torch.amp.GradScaler("cuda")

# model.to(DEVICE)
model.train()

BlipForConditionalGeneration(
  (vision_model): BlipVisionModel(
    (embeddings): BlipVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (encoder): BlipEncoder(
      (layers): ModuleList(
        (0-11): 12 x BlipEncoderLayer(
          (self_attn): BlipAttention(
            (dropout): Dropout(p=0.0, inplace=False)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (projection): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): BlipMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (post_layernorm): LayerNorm((768,), eps=1e-0

In [None]:
for epoch in tqdm(range(config['epochs'])):
    optimizer.zero_grad()
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        # input_ids = batch.pop("input_ids").to(DEVICE)
        # pixel_values = batch.pop("pixel_values").to(DEVICE)
        input_ids = batch.pop("input_ids")
        pixel_values = batch.pop("pixel_values")
        print(pixel_values.shape)

        outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        labels=input_ids)
        loss = outputs.loss

        print("Loss:", loss.item())

        loss.backward()
    
        optimizer.step()
        optimizer.zero_grad()

In [43]:
import torch
from torch.nn.functional import cross_entropy
from tqdm import tqdm

def compute_perplexity(model, inputs):
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        perplexity = torch.exp(loss)
    return perplexity.item()

def lexical_diversity(caption):
    tokens = caption.split()
    if len(tokens) == 0:
        return 0
    return len(set(tokens)) / len(tokens)

def n_gram_diversity(caption, n=2):
    tokens = caption.split()
    if len(tokens) < n:
        return 0
    n_grams = list(zip(*[tokens[i:] for i in range(n)]))
    return len(set(n_grams)) / len(n_grams)

for epoch in tqdm(range(config['epochs'])):
    optimizer.zero_grad()
    print(f"Epoch {epoch + 1}/{config['epochs']}")
    total_loss = 0
    all_captions = []
    
    for idx, batch in enumerate(train_dataloader):
        input_ids = batch.pop("input_ids")
        pixel_values = batch.pop("pixel_values")
        
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, labels=input_ids)
        loss = outputs.loss
        
        total_loss += loss.item()
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Generate captions for diversity metrics
        with torch.no_grad():
            generated_captions = model.generate(pixel_values=pixel_values, max_length=20)
            decoded_captions = processor.batch_decode(generated_captions, skip_special_tokens=True)
            all_captions.extend(decoded_captions)
        
        if idx % 10 == 0:
            print(f"Batch {idx}, Loss: {loss.item():.4f}")
    
    # Perplexity computation
    inputs = {"input_ids": input_ids, "pixel_values": pixel_values}
    perplexity = compute_perplexity(model, inputs)
    print(f"Epoch {epoch + 1} Perplexity: {perplexity:.4f}")
    
    # Caption Diversity Metrics
    lexical_div = sum(lexical_diversity(caption) for caption in all_captions) / len(all_captions)
    bigram_div = sum(n_gram_diversity(caption, n=2) for caption in all_captions) / len(all_captions)
    print(f"Epoch {epoch + 1} Lexical Diversity: {lexical_div:.4f}, Bigram Diversity: {bigram_div:.4f}")



The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1/50


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
  0%|          | 0/50 [00:12<?, ?it/s]


NameError: name 'tokenizer' is not defined

## Inference

In [194]:
result = {}

for category, sketch in tqdm(sketches, desc=f"Generating Caption..."):
    inputs = processor(sketch, return_tensors="pt")
    outputs = model.generate(**inputs)
    caption = processor.decode(outputs[0], skip_special_tokens=True)
    result[category] = caption

print(result)

with open('img2text_finetuned.json', 'w+') as f:
    json.dump(result, f)

Generating Caption...: 100%|██████████| 47/47 [01:36<00:00,  2.05s/it]

{'bike_1': 'isometric view drawing of a police motorcycle', 'bike_2': 'isometric view drawing of a police motorcycle', 'bike_3': 'isometric view drawing of a motorcycle', 'bike_4': 'isometric view drawing of a motorcycle', 'bike_5': 'isometric view drawing of a police motorcycle', 'bike_6': 'isometric view drawing of a motorcycle', 'bike_7': 'isometric view drawing of a police motorcycle', 'bike_8': 'isometric view drawing of a motorcycle', 'bike_9': 'isometric view drawing of a motorcycle', 'car_1': 'isometric view drawing of a car', 'car_2': 'isometric view drawing of a car', 'car_3': 'isometric view drawing of a car', 'car_4': 'sideview drawing of a car', 'car_5': 'sideview drawing of a car', 'car_6': 'sideview drawing of a car', 'car_7': 'isometric view drawing of a car', 'car_8': 'sideview drawing of a car', 'car_9': 'sideview drawing of a car', 'car_10': 'isometric view drawing of a classic car', 'cat_1': 'sketch of a cat', 'cat_2': 'drawing of a cat', 'cat_3': 'drawing of a cat'




## Download the model parameters for deployment

In [196]:
model.save_pretrained('./fine_tuned_blip')
processor.save_pretrained('./fine_tuned_blip_processor')


[]