In [1]:
!pip install -qqq easy-vqa
!pip install -qqq sentence_transformers transformers timm

In [1]:
from easy_vqa import get_train_questions, get_test_questions

train_questions,train_answers,train_image_ids=get_train_questions()
test_questions,test_answers,test_image_ids=get_test_questions()

In [2]:
import pandas as pd

pd.set_option("max_colwidth", None)

def gen_dataframes(questions,answers,image_ids,mode="train"):
    records=[]
    for question,answer, image_id in zip(questions,answers,image_ids):
        image_path=f"/usr/local/lib/python3.7/dist-packages/easy_vqa/data/{mode}/images/{image_id}.png"
        records.append({"question" : question, "answer": answer, "image_path": image_path})
    return pd.DataFrame(records)

df=gen_dataframes(train_questions,train_answers,train_image_ids)

from sklearn.model_selection import train_test_split
df=df.sample(frac=1)
train_df,eval_df = train_test_split(df)
test_df=gen_dataframes(test_questions,test_answers,test_image_ids,mode="test")

In [3]:
print(train_df.shape)
print(eval_df.shape)
print(test_df.shape)

(28931, 3)
(9644, 3)
(9673, 3)


In [4]:
from easy_vqa import get_answers

answers=get_answers()
print("Total Labels: ",len(answers))
label2idx={answer:i for i, answer in enumerate(answers)}

Total Labels:  13


In [5]:
label2idx

{'circle': 0,
 'green': 1,
 'red': 2,
 'gray': 3,
 'yes': 4,
 'teal': 5,
 'black': 6,
 'rectangle': 7,
 'yellow': 8,
 'triangle': 9,
 'brown': 10,
 'blue': 11,
 'no': 12}

In [12]:
train_df["label"]=train_df["answer"].apply(lambda x: label2idx.get(x))
eval_df["label"]=eval_df["answer"].apply(lambda x: label2idx.get(x))
test_df["label"]=test_df["answer"].apply(lambda x: label2idx.get(x))

In [6]:
train_df.sample(5)

Unnamed: 0,question,answer,image_path
25849,does the image contain a triangle?,yes,/usr/local/lib/python3.7/dist-packages/easy_vqa/data/train/images/2689.png
6937,is a triangle present?,no,/usr/local/lib/python3.7/dist-packages/easy_vqa/data/train/images/711.png
18399,is there a red shape in the image?,no,/usr/local/lib/python3.7/dist-packages/easy_vqa/data/train/images/1915.png
33639,is there not a blue shape?,yes,/usr/local/lib/python3.7/dist-packages/easy_vqa/data/train/images/3490.png
38196,is a black shape present?,no,/usr/local/lib/python3.7/dist-packages/easy_vqa/data/train/images/3959.png


In [7]:
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModel
import torchvision.transforms as T
import torch
import timm

"""Fusing Transformers"""
device = "cuda:0" if torch.cuda.is_available() else "cpu" 
tokenizer=AutoTokenizer.from_pretrained("bert-base-uncased")
text_encoder=AutoModel.from_pretrained("bert-base-uncased")
for p in text_encoder.parameters():
    p.requires_grad=False

image_processor=AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
image_encoder=AutoModel.from_pretrained("google/vit-base-patch16-224-in21k")

"""Fusing CNNs and Transformers"""
# device= "cuda:0" if torch.cuda.is_available() else "cpu"
# image_encoder=timm.create_model("resnet50d",pretrained=True, num_classes=0)
# resize_transform=T.resize((224,224))

for p in image_encoder.parameters():
    p.requires_grad=False

image_encoder.to(device)
text_encoder.to(device)

print()

  from .autonotebook import tqdm as notebook_tqdm
Downloading model.safetensors: 100%|██████████| 440M/440M [00:29<00:00, 14.9MB/s] 
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading (…)rocessor_config.json: 100%|██████████| 160/160 [00:00<?, ?B/s] 
Downloading (…)lve/main/config.json: 100%|██████████| 502/502 [00:00<?, ?B/s] 
Downloading pytorch_model.bin: 100%|██████████| 346M/346M [00:23<00:00, 14.8MB/s] 





In [17]:
##Stitch torch dataset with feature backbones

from PIL import Image
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

class EasyQADataset(Dataset):

    def __init__(self, df,
                 image_encoder,
                 text_encoder,
                 image_processor,
                 tokenizer,
                ):
        self.df=df
        self.image_encoder=image_encoder
        self.text_encoder=text_encoder
        self.image_processor=image_processor
        self.tokenizer=tokenizer

    def __len__(self):
        return len(self.df)

    def _getitem__(self,idx):
        image_file=self.df["image_path"][idx]
        question=self.df["question"][idx]
        image=Image.open(image_file).convert("RGB")
        label=self.df["label"][idx]

        """When CNNs are used for V backnone"""

        image=resize_transform(image)
        image_inputs=T.ToTensor()(image).unsqueeze_(0)

        """When Transformers is used for V backbone"""
        image_inputs=self.image_processor(image,return_tensors='pt')

(1, 2, 3, 1, 2, 3)