In [1]:
import cv2
import json, os
import glob
import numpy as np

In [2]:
def create_captions_map(datadir):
    captions = list()
    image_files = list()
    imagedir = datadir
    for file_type in (".png", ".jpg"):
        image_files.extend(glob.glob(os.path.join(imagedir, "**/*" + file_type), recursive=True))
    for each_image_file in image_files:
        caption_map = {"file_name":os.path.basename(each_image_file)}
        filebasename, _ = os.path.splitext(each_image_file)
        label_file_name = filebasename.replace("IMAGES", "TEXT_LABELS") + ".gui"
        with open(label_file_name) as f:
            caption_map["text"] = f.read()
        captions.append(caption_map)
    return captions

In [3]:
root = "../D2/IMAGES/"
captions = create_captions_map(root)

In [4]:
with open(root + "metadata.jsonl", 'w') as f:
    for item in captions:
        f.write(json.dumps(item) + "\n")

In [5]:
from datasets import load_dataset 

dataset = load_dataset("imagefolder", data_dir=root, split="train")

Resolving data files:   0%|          | 0/301 [00:00<?, ?it/s]

Downloading and preparing dataset imagefolder/default to /Users/ritesh/.cache/huggingface/datasets/imagefolder/default-9765dbf543442d46/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f...


Downloading data files:   0%|          | 0/301 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset imagefolder downloaded and prepared to /Users/ritesh/.cache/huggingface/datasets/imagefolder/default-9765dbf543442d46/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f. Subsequent calls will reuse this data.


In [6]:
dataset

Dataset({
    features: ['image', 'text'],
    num_rows: 300
})

# Create PyTorch Dataset

In [7]:
from torch.utils.data import Dataset

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")

        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}

        return encoding

In [8]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("microsoft/git-base")

2023-04-16 17:26:24.271220: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
train_dataset = ImageCaptioningDataset(dataset, processor)

In [10]:
item = train_dataset[0]
for k,v in item.items():
  print(k,v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
pixel_values torch.Size([3, 224, 224])


# Create PyTorch DataLoader

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

In [12]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

input_ids torch.Size([2, 512])
attention_mask torch.Size([2, 512])
pixel_values torch.Size([2, 3, 224, 224])


In [13]:
from PIL import Image
import numpy as np

MEAN = np.array([123.675, 116.280, 103.530]) / 255
STD = np.array([58.395, 57.120, 57.375]) / 255

unnormalized_image = (batch["pixel_values"][0].numpy() * np.array(STD)[:, None, None]) + np.array(MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)

# Define model

In [14]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

# Dummy forward pass

In [15]:
outputs = model(input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                pixel_values=batch["pixel_values"],
                labels=batch["input_ids"])
outputs.loss

tensor(10.8573, grad_fn=<NllLossBackward0>)

# Train the model

In [16]:
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(1):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item(), end="\t")

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Epoch: 0
Loss: 10.601245880126953	Loss: 10.077961921691895	Loss: 8.821277618408203	Loss: 9.026521682739258	Loss: 8.4389009475708	Loss: 7.936983585357666	Loss: 8.077631950378418	Loss: 7.7701897621154785	Loss: 7.318760871887207	Loss: 7.7890472412109375	Loss: 7.737488746643066	Loss: 7.741630554199219	Loss: 7.052277565002441	Loss: 7.106978893280029	Loss: 7.274380207061768	Loss: 7.345982074737549	Loss: 7.2522687911987305	Loss: 6.4005303382873535	Loss: 6.762138366699219	Loss: 6.156841278076172	Loss: 5.777076721191406	Loss: 5.913422107696533	Loss: 5.917724132537842	Loss: 5.883978366851807	Loss: 5.446536540985107	Loss: 6.233072280883789	Loss: 5.144800186157227	Loss: 5.260801792144775	Loss: 5.381277084350586	Loss: 5.494431972503662	Loss: 5.93348503112793	Loss: 4.84732723236084	Loss: 5.259089469909668	Loss: 4.942032337188721	Loss: 4.595408916473389	Loss: 4.702183723449707	Loss: 4.99065637588501	Loss: 4.407902240753174	Loss: 4.091010570526123	Loss: 4.930829048156738	Loss: 4.083763599395752	Loss: 

# Inference

In [17]:
# prepare image for the model
example = dataset[0]
image = example["image"]
width, height = image.size
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

header { search - bar } sidebar { medium - title, radio, icons } canvas - header { btn - inactive, btn - inactive, btn - inactive, btn - inactive, btn - inactive } row { small
