In [None]:
from PIL import Image
import torch
import os
import pandas as pd
from tqdm import tqdm,trange 
from torchvision import transforms
from transformers import OFATokenizer, OFAModel
# from generate import sequence_generator
from transformers.models.ofa.generate import sequence_generator
# from OFA.fairseq.fairseq import sequence_generator

In [None]:
def include_captions(path_train_csv):
    
    # paths
    data_dir = '/home/scur1045/FACT-project/HVV_EXPGEN_DATASET/'
    ckpt_dir = '/home/scur1045/FACT-project/OFA/OFA-HF-large-model'

    # create df out of csv file 
    df = pd.read_csv(path_train_csv)
    df['caption'] = ''
    num_imgs = df.shape[0]
    
    mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
    resolution = 480
    patch_resize_transform = transforms.Compose([
        lambda image: image.convert("RGB"),
        transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
        transforms.ToTensor(), 
        transforms.Normalize(mean=mean, std=std)
    ])

    tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
    txt = " what does the image describe?"
    inputs = tokenizer([txt], return_tensors="pt").input_ids

    model = OFAModel.from_pretrained(ckpt_dir, use_cache=True)
    generator = sequence_generator.SequenceGenerator(
                            tokenizer=tokenizer,
                            beam_size=5,
                            max_len_b=16, 
                            min_len=0,
                            no_repeat_ngram_size=3,)


    for i in tqdm(range(num_imgs), total=num_imgs, desc="Caption Images"):  
        img_name = df.iloc[i]['image']
        path_to_image = data_dir + 'Train_Val_Images/' + img_name

        try:
            img = Image.open(path_to_image)
        except FileNotFoundError:
            print(f"Image {img_name} not found.")
            continue

        patch_img = patch_resize_transform(img).unsqueeze(0)

        data = {}
        data["net_input"] = {"input_ids": inputs, 'patch_images': patch_img, 'patch_masks':torch.tensor([True])}
        gen_output = generator.generate([model], data)
        gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))]
        caption = tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip()
        print(caption)

        # append the new caption
        df.at[i, 'caption'] = caption
     
    # extract the substring up to the last slash
    last_slash_index = path_train_csv.rfind('/') + 1
    name_updated_csv_file = path_train_csv[last_slash_index:]

    # save new csv file
    df.to_csv(data_dir + name_updated_csv_file, index=False)
    print("Captions added!")


In [None]:
path_train_csv = '/home/scur1045/FACT-project/HVV_EXPGEN_DATASET/backup_csv/hvvexp_train.csv'
path_val_csv = '/home/scur1045/FACT-project/HVV_EXPGEN_DATASET/backup_csv/hvvexp_val.csv'
include_captions(path_train_csv)
include_captions(path_val_csv)