In [32]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import io
from PIL import Image
import pyarrow as pa
from vilt.config import ex
from vilt.modules import ViLTransformerSS

from vilt.modules.objectives import cost_matrix_cosine, ipot
from vilt.transforms import pixelbert_transform
from vilt.datamodules.datamodule_base import get_pretrained_tokenizer

In [2]:
df = pa.ipc.RecordBatchFileReader(
    pa.memory_map(f"dataset_50/cosmos_test.arrow", "r")
).read_all().to_pandas()

In [3]:
_config = {'exp_name': 'vilt', 'seed': 0, 'datasets': ['coco', 'vg', 'sbu', 'gcc'], 'loss_names': {'itm': 1, 'mlm': 1, 'mpp': 0, 'vqa': 0, 'nlvr2': 0, 'irtr': 0, 'cosmos': 0}, 'batch_size': 4096, 'train_transform_keys': ['pixelbert'], 'val_transform_keys': ['pixelbert'], 'image_size': 384, 'max_image_len': -1, 'patch_size': 32, 'draw_false_image': 1, 'image_only': False, 'vqav2_label_size': 3129, 'max_text_len': 40, 'tokenizer': 'bert-base-uncased', 'vocab_size': 30522, 'whole_word_masking': False, 'mlm_prob': 0.15, 'draw_false_text': 0, 'vit': 'vit_base_patch32_384', 'hidden_size': 768, 'num_heads': 12, 'num_layers': 12, 'mlp_ratio': 4, 'drop_rate': 0.1, 'optim_type': 'adamw', 'learning_rate': 0.0001, 'weight_decay': 0.01, 'decay_power': 1, 'max_epoch': 100, 'max_steps': 25000, 'warmup_steps': 2500, 'end_lr': 0, 'lr_mult': 1, 'get_recall_metric': False, 'resume_from': None, 'fast_dev_run': False, 'val_check_interval': 1.0, 'test_only': False, 'data_root': '', 'log_dir': 'result', 'per_gpu_batchsize': 0, 'num_gpus': 1, 'num_nodes': 1, 
    'load_path': 'result/finetune_nlvr2_randaug_seed0_from_vilt_200k_mlm_itm/version_4/checkpoints/epoch=3-step=2123.ckpt', 'num_workers': 8, 'precision': 16}

In [4]:
loss_names = {
        "itm": 0,
        "mlm": 0,
        "mpp": 0,
        "vqa": 0,
        "imgcls": 0,
        "nlvr2": 0,
        "irtr": 0,
        "arc": 0,
        "cosmos":1

    }
tokenizer = get_pretrained_tokenizer(_config["tokenizer"])

_config.update(
    {
        "loss_names": loss_names,
    }
)

In [5]:
model = ViLTransformerSS(_config)
model.setup("test")
model.eval()

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(40, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (token_type_embeddings): Embedding(3, 768)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.1, inplace=False)
        )
        (drop_path): Identity()
        

In [6]:
device = "cuda:0" if _config["num_gpus"] > 0 else "cpu"
model.to(device)

ViLTransformerSS(
  (text_embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(40, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (token_type_embeddings): Embedding(3, 768)
  (transformer): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
    )
    (pos_drop): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.1, inplace=False)
        )
        (drop_path): Identity()
        

In [26]:
def infer(image,text,text2):
    img = pixelbert_transform(size=384)(image)
    img = img.unsqueeze(0).to(device)
    batch = {"text": [text], "image": [img], "text2": [text2]}
    with torch.no_grad():
        encoded = tokenizer(batch["text"],
        padding="max_length",
        truncation=True,
        max_length=40,
        return_special_tokens_mask=True)
        batch["text_ids"] = torch.tensor(encoded["input_ids"]).to(device)
        batch["text_labels"] = torch.tensor(encoded["input_ids"]).to(device)
        batch["text_masks"] = torch.tensor(encoded["attention_mask"]).to(device)

        encoded = tokenizer(batch["text2"],
        padding="max_length",
        truncation=True,
        max_length=40,
        return_special_tokens_mask=True)
        batch["text2_ids"] = torch.tensor(encoded["input_ids"]).to(device)
        batch["text2_labels"] = torch.tensor(encoded["input_ids"]).to(device)
        batch["text2_masks"] = torch.tensor(encoded["attention_mask"]).to(device)

        infer1 = model.infer(batch, text_token_type_idx=1)
        infer2 = model.infer(batch, text_token_type_idx=2)
        cls_feats = torch.cat([infer1["cls_feats"], infer2["cls_feats"]], dim=-1)

        cosmos_logits = model.nlvr2_classifier(cls_feats)
    return cosmos_logits.argmax().item()

In [11]:
def get_raw_image(img_byte):
    image_bytes = io.BytesIO(img_byte)
    image_bytes.seek(0)
    return Image.open(image_bytes).convert("RGB")

In [33]:
result = pd.DataFrame({},columns=['predict','label'])
for i in tqdm(np.arange(len(df))):
    r = df.iloc[i]
    pred = infer(get_raw_image(r['image']), r['caption_1'][0], r['caption_2'][0])
    pred = pred == True
    ans = r['label'][0]
    result.loc[len(result)] = [pred,ans]

100%|██████████| 1700/1700 [01:04<00:00, 26.48it/s]


In [34]:
confusion_matrix = pd.crosstab(result['predict'], result['label'], rownames=['Predicted'], colnames=['Actual'])
print (confusion_matrix)

Actual     False  True 
Predicted              
False        691    443
True         159    407


In [62]:
pd.options.display.max_colwidth = 300

In [64]:
wrong = df[(result['predict'] != result ['label']) & (result['predict']==True)]
wrong[['caption_1','caption_2']]
# correct_true['caption_1'].iloc[0],correct_true['caption_2'].iloc[0]

Unnamed: 0,caption_1,caption_2
9,[A man brushes the mouth of a sarcophagus.],"[CARDINAL statues of ancient deities and funerary masks were also discovered at the site, many of which still with their original colours and designs preserved]"
21,"[A tribute to Captain Sir PERSON lights up ORG, GPE. The DATE, who raised almost ORGm for ORG charities by walking laps of his garden, died with coronavirus.]",[A woman plays the violin in ORG during a tribute to Captain Sir PERSON on 2 DATE]
48,"[The PERSON and ORG, ORG, by PERSON, from the GPE]",[A sacred cow in a street in GPE]
52,"[Lex Scott Davis and PERSON in ""WORK_OF_ART.""]","[PERSON (right) plays a suave drug dealer named Priest and PERSON is GPE in GPE, Director X’s blinged-out redo of the DATE blaxploitation classic.]"
64,"[NORP paramilitary troops in GPE in GPE DATE. PERCENT of the army's equipment is so old that it is officially considered ""vintage.""]","[LOC braces for worst as shelling, gunbattles escalate]"
...,...,...
1654,"[Mr. PERSON won the fiercely contested race by CARDINAL votes, out of CARDINAL.]","[PERSON, the NORP candidate for ORG, is accompanied by his wife, PERSON, at an LOC gathering of his supporters in GPE, GPE, on DATE.]"
1667,[Heavy snowfall had be be cleared on FAC during one of the DATE in the early 1900s],"[CARDINAL - Duke Street, Barrow-in-Furness DATE-CARDINAL ORG]"
1678,[A photographer's self portrait with her mother is one of many works featured at ORG ORDINAL DATE NORP Foto Festival.],"[CARDINAL women embrace, wearing face masks]"
1687,[PERSON taking a photo],[A child holds a camera next to a destroyed bus]
