In [1]:
import cv2
import json, os
import glob
import numpy as np

In [2]:
def create_captions_map(datadir):
    captions = list()
    image_files = list()
    imagedir = datadir
    for file_type in (".png", ".jpg"):
        image_files.extend(glob.glob(os.path.join(imagedir, "**/*" + file_type), recursive=True))
    for each_image_file in image_files:
        caption_map = {"file_name":os.path.basename(each_image_file)}
        filebasename, _ = os.path.splitext(each_image_file)
        label_file_name = filebasename.replace("IMAGES", "TEXT_LABELS") + ".gui"
        with open(label_file_name) as f:
            caption_map["text"] = f.read()
        captions.append(caption_map)
    return captions

In [3]:
root = "../D3/IMAGES/"
captions = create_captions_map(root)

In [4]:
with open(root + "metadata.jsonl", 'w') as f:
    for item in captions:
        f.write(json.dumps(item) + "\n")

In [5]:
from datasets import load_dataset 

dataset = load_dataset("imagefolder", data_dir=root, split="train")

Resolving data files:   0%|          | 0/301 [00:00<?, ?it/s]

Downloading and preparing dataset imagefolder/default to /Users/ritesh/.cache/huggingface/datasets/imagefolder/default-e52771f6e7bda6c7/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f...


Downloading data files:   0%|          | 0/301 [00:00<?, ?it/s]

Downloading data files: 0it [00:00, ?it/s]

Extracting data files: 0it [00:00, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset imagefolder downloaded and prepared to /Users/ritesh/.cache/huggingface/datasets/imagefolder/default-e52771f6e7bda6c7/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f. Subsequent calls will reuse this data.


In [6]:
dataset

Dataset({
    features: ['image', 'text'],
    num_rows: 300
})

# Create PyTorch Dataset

In [7]:
from torch.utils.data import Dataset

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")

        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}

        return encoding

In [8]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("microsoft/git-base")

2023-04-16 17:41:28.087659: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [9]:
train_dataset = ImageCaptioningDataset(dataset, processor)

In [10]:
item = train_dataset[0]
for k,v in item.items():
  print(k,v.shape)

input_ids torch.Size([512])
attention_mask torch.Size([512])
pixel_values torch.Size([3, 224, 224])


# Create PyTorch DataLoader

In [11]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

In [12]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

input_ids torch.Size([2, 512])
attention_mask torch.Size([2, 512])
pixel_values torch.Size([2, 3, 224, 224])


In [13]:
from PIL import Image
import numpy as np

MEAN = np.array([123.675, 116.280, 103.530]) / 255
STD = np.array([58.395, 57.120, 57.375]) / 255

unnormalized_image = (batch["pixel_values"][0].numpy() * np.array(STD)[:, None, None]) + np.array(MEAN)[:, None, None]
unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)

# Define model

In [14]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

# Dummy forward pass

In [15]:
outputs = model(input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                pixel_values=batch["pixel_values"],
                labels=batch["input_ids"])
outputs.loss

tensor(11.0811, grad_fn=<NllLossBackward0>)

# Train the model

In [16]:
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(1):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item(), end="\t")

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Epoch: 0
Loss: 10.710468292236328	Loss: 10.126626968383789	Loss: 9.439986228942871	Loss: 9.080878257751465	Loss: 8.757352828979492	Loss: 8.46851921081543	Loss: 8.288243293762207	Loss: 7.934171676635742	Loss: 7.940999507904053	Loss: 7.655185222625732	Loss: 7.479077339172363	Loss: 7.404178619384766	Loss: 7.365584850311279	Loss: 7.251214504241943	Loss: 6.9966230392456055	Loss: 6.691814422607422	Loss: 6.605830192565918	Loss: 6.594297885894775	Loss: 6.49360466003418	Loss: 6.385962009429932	Loss: 6.368706703186035	Loss: 6.131566047668457	Loss: 6.084646701812744	Loss: 5.981993675231934	Loss: 5.963475227355957	Loss: 5.968304634094238	Loss: 5.723918437957764	Loss: 5.711649417877197	Loss: 5.5067925453186035	Loss: 5.519079208374023	Loss: 5.54173469543457	Loss: 5.281387805938721	Loss: 5.2058305740356445	Loss: 5.124850273132324	Loss: 5.038467884063721	Loss: 4.904438018798828	Loss: 4.8971476554870605	Loss: 4.742694854736328	Loss: 4.633916854858398	Loss: 4.661097526550293	Loss: 4.459529876708984	Loss

# Inference

In [17]:
# prepare image for the model
example = dataset[0]
image = example["image"]
width, height = image.size
inputs = processor(images=image, return_tensors="pt").to(device)
pixel_values = inputs.pixel_values

generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

components { home, sidebar } orientation { sidebar - right - home - left } home { cards { img - card, img - card, img - btn - card, img - card } table - col {
