In [9]:
import json
import torch
from PIL import Image
from tqdm.autonotebook import tqdm
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

In [10]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0): ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, 

In [23]:
max_length = 16
num_beams = 5
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_step(image_path):
    image = Image.open(image_path)
        
    if image.mode != "RGB":
        image = image.convert(mode="RGB")

    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    
    output_ids = model.generate(pixel_values, **gen_kwargs)

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]

    return preds

In [24]:
desc = predict_step('data/private_test/private-test-images/00000000070.jpg')

desc

['a man standing in front of a large statue of an elephant']

In [34]:
with open('./data/train/evjvqa_train_lang_qtype-detailed.json', 'r', encoding='utf-8') as f:
    train_data = json.load(f)

train_images = train_data['images']

with open('./data/test/evjvqa_public_test-lang-qtype-answer.json', 'r', encoding='utf-8') as f:
    test_data = json.load(f)

test_images = test_data['images']


with open('data/private_test/prepared_evjvqa_private_test.json', 'r', encoding='utf-8') as f:
    private_test_data = json.load(f)

private_test_images = private_test_data['images']

In [29]:
for i in tqdm(range(len(train_images))):
    image_file = train_images[i]['filename']
    file_path = f"data/train/train-images/{image_file}"
    
    desc = predict_step(file_path)
    
    train_images[i]['desc'] = desc

  0%|          | 0/3763 [00:00<?, ?it/s]

In [31]:
with open('./data/train/evjvqa_train_lang_qtype-desc-detailed.json', 'w', encoding='utf-8') as f:
    json.dump(train_data, f, indent=4, ensure_ascii=False)

In [32]:
for i in tqdm(range(len(test_images))):
    image_file = test_images[i]['filename']
    file_path = f"data/test/public-test-images/{image_file}"
    
    desc = predict_step(file_path)
    
    test_images[i]['desc'] = desc
    
with open('./data/test/evjvqa_public_test-lang-qtype-desc-answer.json', 'w', encoding='utf-8') as f:
    json.dump(test_data, f, indent=4, ensure_ascii=False)

  0%|          | 0/558 [00:00<?, ?it/s]

In [36]:
for i in tqdm(range(len(private_test_images))):
    image_file = private_test_images[i]['filename']
    file_path = f"data/private_test/private-test-images/{image_file}"
    
    desc = predict_step(file_path)
    
    private_test_images[i]['desc'] = desc
    
with open('data/private_test/evjvqa_private_test-desc.json', 'w', encoding='utf-8') as f:
    json.dump(private_test_data, f, indent=4, ensure_ascii=False)

  0%|          | 0/588 [00:00<?, ?it/s]