In [1]:

import torchvision.transforms as transforms
from transformers import T5Tokenizer
import json
import torch
import torch
import utils
from tqdm import tqdm
img_size = 224
vqax_data_dir = "/media/storage/coco/VQA-X/annotated/vqaX_train.json"
coco_data_dir = "/media/storage/coco/"
input_max_seq_length = 500
output_max_seq_length = 30
train_batch_size = 30
eval_batch_size =  30
img_transform = transforms.Compose([transforms.Resize((img_size,img_size)), transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
task_A = True
is_train = True
from PIL import Image

tokenizer = T5Tokenizer.from_pretrained('t5-large')
num_new_tokens = tokenizer.add_special_tokens({'pad_token': '<pad>','additional_special_tokens': ['<question>', '<situation>', '<answer>']})

data = json.load(open(vqax_data_dir, 'r'))


  from .autonotebook import tqdm as notebook_tqdm
For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [2]:
ids_list = list(data.keys())
index_tracker = {k: len(v['explanation']) - 1 for k,v in data.items()}
for k,v in data.items():
    if len(v['explanation']) > 1:   # some questions have more than one explanation
# duplicate them for loading. -1 because one explanation is already in ids_list
        ids_list += [str(k)] * (len(v['explanation']) - 1)
        
datasets = []
for i in tqdm(range(30), desc= "VQA-X data preprocessing..."):
    quention_id = ids_list[i]
    sample = data[quention_id]
    img_name = sample['image_name']

    text_q = utils.proc_ques(sample['question'])    # question
    text_a = utils.proc_ans(sample['answers'])
    exp_idx = index_tracker[quention_id]
    text_e = sample['explanation'][exp_idx]

    # 2개의 explanation 이라면
    if exp_idx > 0:
        index_tracker[quention_id] -= 1    # decrease usage
        
    q_segment_id, s_segment_id, a_segment_id, e_segment_id = tokenizer.convert_tokens_to_ids(['<question>', '<situation>', '<answer>', '<explanation>'])

    question_tokens = tokenizer.tokenize(text_q)
    segment_ids = [q_segment_id] * len(question_tokens)

    # situation
    situation_tag = tokenizer.tokenize("situation:")

    # answer
    answer_tokens =  tokenizer.tokenize(text_a)
    answer_tag = tokenizer.tokenize("answer:")
    answer_len = len(answer_tokens)
    answer_tokens

    # explanation
    explanation_tokens = tokenizer.tokenize(text_e)
    explanation_tag = tokenizer.tokenize("explanation:")
    exp_len = len(explanation_tokens)

    # Task A -> Q, I, A -> E
    if task_A:
        tokens = question_tokens + situation_tag + ["<situation>"]*196 + answer_tag + answer_tokens + [tokenizer.eos_token]
        segment_ids = segment_ids + [s_segment_id]*196 + [e_segment_id] * (len(answer_tokens) +2)
        labels = explanation_tokens + [tokenizer.eos_token]

    # Task B -> Q, E -> A
    else:
        tokens = question_tokens + explanation_tag + explanation_tokens + [tokenizer.eos_token]
        segment_ids = segment_ids + [e_segment_id] * (len(tokens) - len(segment_ids))
        labels = answer_tokens + [tokenizer.eos_token]
        
    # Split over sequence length
    if len(tokens) > input_max_seq_length :
        tokens = tokens[:input_max_seq_length]
        segment_ids = segment_ids[:input_max_seq_length]
        
    # Padding
    seq_len = len(tokens)
    input_padding_len = input_max_seq_length - len(tokens)
    output_padding_len = output_max_seq_length - len(labels)
    tokens = tokens + ([tokenizer.pad_token] * input_padding_len)
    labels = labels + ([tokenizer.pad_token] * output_padding_len)
    segment_ids += ([e_segment_id] * input_padding_len)
    # token to ids

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = torch.tensor(input_ids, dtype=torch.long)

    labels = [tokenizer.convert_tokens_to_ids(t) for t in labels]
    labels = torch.tensor(labels, dtype=torch.long)

    segment_ids = torch.tensor(segment_ids, dtype=torch.long)


    ## Image

    folder = coco_data_dir + '/train2014/' if 'train' in img_name else coco_data_dir + 'val2014/'
    img_path = folder + img_name
    img = Image.open(img_path).convert('RGB')
    img = img_transform(img)
    qid =quention_id
    qid = torch.LongTensor([int(quention_id)])

    datasets.append({"img": img, "qid" : qid, "input_ids": input_ids, "labels": labels, "segment_ids" : segment_ids})

VQA-X data preprocessing...: 100%|██████████| 30/30 [00:00<00:00, 57.92it/s]


In [24]:
datasets[0]["img"]

tensor([[[-1.5185, -1.5357, -1.7069,  ...,  0.5707,  0.7933,  0.8276],
         [-1.7412, -1.6898, -1.6555,  ...,  0.4508,  0.7248,  0.8447],
         [-1.8097, -1.7412, -1.6384,  ..., -0.2342,  0.0741,  0.3481],
         ...,
         [-0.0629,  0.0569,  0.0912,  ...,  0.2796,  0.2796,  0.2967],
         [ 0.1768,  0.3823,  0.3138,  ...,  0.3138,  0.3309,  0.3309],
         [ 0.5193,  0.4679,  0.1426,  ...,  0.2796,  0.2967,  0.2796]],

        [[-1.3880, -1.4230, -1.5980,  ...,  0.7129,  0.8880,  0.9230],
         [-1.6155, -1.5805, -1.5280,  ...,  0.6078,  0.8704,  0.9755],
         [-1.7031, -1.6331, -1.5105,  ..., -0.0749,  0.2052,  0.4678],
         ...,
         [ 0.0826,  0.2052,  0.2402,  ...,  0.4328,  0.4328,  0.4503],
         [ 0.3102,  0.5203,  0.4503,  ...,  0.4678,  0.4678,  0.4853],
         [ 0.6429,  0.5903,  0.2752,  ...,  0.4328,  0.4328,  0.4503]],

        [[-1.3339, -1.3861, -1.4733,  ...,  0.5834,  0.7751,  0.8099],
         [-1.4907, -1.4733, -1.4384,  ...,  0

In [3]:
import clip
import torch
import torch.nn as nn

class ImageEncoder(nn.Module):

    def __init__(self, device):
        super(ImageEncoder, self).__init__()
        self.encoder, _ = clip.load("ViT-B/16", device= device)   # loads already in eval mode
        
    def forward(self, x):
        """
        Expects a tensor of size (batch_size, 3, 224, 224)
        """
        with torch.no_grad():
            
            x = x.type(self.encoder.visual.conv1.weight.dtype)
            # Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
            x = self.encoder.visual.conv1(x)  # shape = [*, width, grid, grid]
            print("1",x.shape)
            x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
            print("2",x.shape)
            x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
            print("3",x.shape)
            x = torch.cat([self.encoder.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
            print("4",x.shape)
            x = x + self.encoder.visual.positional_embedding.to(x.dtype)
            print("5",x.shape)
            x = self.encoder.visual.ln_pre(x)
            print("6",x.shape)
            x = x.permute(1, 0, 2)  # NLD -> LND
            print("7",x.shape)
            x = self.encoder.visual.transformer(x)
            print("8",x.shape)
            grid_feats = x.permute(1, 0, 2)  # LND -> NLD    (N, 197, 768)
            print("9",x.shape)
            grid_feats = self.encoder.visual.ln_post(grid_feats[:,1:])  
            print(grid_feats.shape)
                
        return grid_feats.float()

In [4]:
enc = ImageEncoder(torch.device("cpu"))

In [5]:
samp = datasets[0]["img"].unsqueeze(0)
output = enc(samp)

1 torch.Size([1, 768, 14, 14])
2 torch.Size([1, 768, 196])
3 torch.Size([1, 196, 768])
4 torch.Size([1, 197, 768])
5 torch.Size([1, 197, 768])
6 torch.Size([1, 197, 768])
7 torch.Size([197, 1, 768])
8 torch.Size([197, 1, 768])
9 torch.Size([197, 1, 768])
torch.Size([1, 196, 768])


In [6]:
encoder, _ = clip.load("ViT-B/16", device= torch.device("cpu"))

In [None]:
type(self.encoder.visual.conv1.weight.dtype)

In [27]:
encoder.visual.conv1.state_dict

<bound method Module.state_dict of Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)>

In [59]:
encoder.visual.positional_embedding

Parameter containing:
tensor([[ 1.4430e-04, -4.2040e-02,  6.1855e-02,  ...,  7.9190e-02,
          7.7290e-02,  3.2362e-06],
        [-8.7270e-03, -1.6116e-02,  3.5423e-02,  ..., -4.7097e-02,
         -4.1042e-02,  7.3858e-02],
        [-9.3716e-03, -1.6571e-02,  9.9012e-03,  ..., -3.6754e-02,
         -2.3377e-02,  2.3383e-02],
        ...,
        [ 2.9686e-03, -1.2590e-02, -4.1563e-02,  ..., -4.5710e-02,
         -4.2346e-02,  7.2440e-03],
        [ 4.9054e-03, -1.8691e-02,  2.8315e-03,  ..., -4.2236e-02,
         -3.7692e-02, -1.6221e-03],
        [ 6.3653e-03, -2.3058e-02, -5.4040e-02,  ...,  3.0824e-02,
         -4.3921e-02,  5.2905e-02]], requires_grad=True)

In [69]:
output.shape

torch.Size([1, 196, 768])

In [72]:
datasets[0]['img'].shape

torch.Size([3, 224, 224])

In [37]:
import torch
import torch.nn
from torch.nn import CrossEntropyLoss
from transformers import T5ForConditionalGeneration, T5Config

from models.prefix_encoder import PrefixEncoder

class T5PrefixForConditionalGeneration(T5ForConditionalGeneration):
    def __init__(self, config: T5Config):
        super().__init__(config)
        self.prefix_seq_len = config.prefix_seq_len
        self.n_layers = config.num_hidden_layers
        self.n_heads = config.num_attention_heads
        self.n_embeds = config.hidden_size // config.num_attention_heads
        print(self.n_embeds)

        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.prefix_encoder = PrefixEncoder(config)
        self.prefix_tokens = torch.arange(self.prefix_seq_len).long()

    def get_prompt(self, batch_size):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
        past_key_values = self.prefix_encoder(prefix_tokens)
        bsz, seq_len, _ = past_key_values.shape
        past_key_values = past_key_values.view(bsz, seq_len, self.n_layers*2, self.n_heads, self.n_embeds)
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2,0,3,1,4]).split(2)

        return past_key_values
    
    def forward(self,
                input_ids=None,
                attention_mask=None,
                decoder_attention_mask=None,
                decoder_input_ids=None,
                encoder_outputs=None,
                past_key_values=None,
                inputs_embeds=None,
                decoder_inputs_embeds=None,
                labels=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                encoder_only=None,):
        
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.return_dict
        
        batch_size = input_ids.shape[0]
        #enc_past_key_values = self.get_prompt(batch_size=batch_size)
        #prefix_attention_mask = torch.ones(batch_size, self.prefix_seq_len).to(self.device)
        #attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        #print(input_ids)
        #print(inputs_embeds)
        input_ids_emb = self.shared(input_ids)
        print(input_ids_emb)
        print(input_ids_emb.shape)
        return input_ids_emb

In [38]:
conf = T5Config.from_pretrained("t5-large")
conf.prefix_seq_len = 30
conf.hidden_dropout_prob = 0.1
conf.prefix_projection = True
conf.pre_seq_len = 30
conf.prefix_hidden_size = 30
model = T5PrefixForConditionalGeneration(conf)

64


In [43]:
conf.d_model

1024

In [36]:
ids_sample = datasets[0]["input_ids"].unsqueeze(0)
ids_sample.shape

torch.Size([1, 500])

In [42]:
test1 = model(input_ids= ids_sample, inputs_embeds = output )

tensor([[[ 1.7534, -0.6862, -0.3778,  ..., -1.2773,  0.8081, -0.6258],
         [-1.1070, -0.3546, -0.0654,  ..., -2.2251,  1.3561,  1.2644],
         [ 0.6979, -1.5705,  1.1215,  ...,  0.3965, -0.6999,  1.4335],
         ...,
         [ 0.5292, -0.7427, -1.8560,  ...,  0.0856, -0.9413, -1.1828],
         [ 0.5292, -0.7427, -1.8560,  ...,  0.0856, -0.9413, -1.1828],
         [ 0.5292, -0.7427, -1.8560,  ...,  0.0856, -0.9413, -1.1828]]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([1, 500, 1024])


In [11]:
situation_token = tokenizer.encode("<situation>")[0]
situation_idx = []
for i in ids_sample:
    for num,j in enumerate(i):
        if j == torch.Tensor([situation_token]):
            situation_idx.append(num)
            break

In [160]:
test1[0]

tensor([[-0.4426,  0.9423, -2.4927,  ...,  0.6239,  0.5189,  1.2511],
        [-0.8097, -0.4102,  0.3188,  ...,  2.0389,  0.5652, -0.0307],
        [ 1.2144,  0.5852,  0.0984,  ..., -0.0879, -0.0627, -2.0041],
        ...,
        [-0.3181,  0.0744,  0.1203,  ...,  0.5819, -0.5419,  0.2743],
        [-0.3181,  0.0744,  0.1203,  ...,  0.5819, -0.5419,  0.2743],
        [-0.3181,  0.0744,  0.1203,  ...,  0.5819, -0.5419,  0.2743]],
       grad_fn=<SelectBackward0>)

In [12]:
mlp1 = nn.Linear(768,1024)
mlp2 = nn.Linear(1024,1024)

In [13]:
output2 = mlp1(output)
print(output2.shape)
output2 = mlp2(output2)

torch.Size([1, 196, 1024])


In [14]:
for idx,location in enumerate(situation_idx):
    test1[idx][location:location+196,:] = output2[idx]

In [15]:
test1

tensor([[[ 0.3828,  0.3313, -0.7732,  ..., -0.2175, -0.1339,  0.8477],
         [ 0.3436, -0.1190,  0.3153,  ..., -0.6236, -1.0854, -1.2386],
         [-0.2829,  0.5944,  0.4410,  ..., -0.5823, -0.2191, -1.0603],
         ...,
         [ 0.3141, -0.1474,  0.0801,  ...,  0.2868, -1.0450, -0.7127],
         [ 0.3141, -0.1474,  0.0801,  ...,  0.2868, -1.0450, -0.7127],
         [ 0.3141, -0.1474,  0.0801,  ...,  0.2868, -1.0450, -0.7127]]],
       grad_fn=<CopySlices>)

In [129]:
tokenizer.encode('<situation>')[0]
tokenizer.decode(10)

':'

In [115]:
test1 = model(input_ids= ids_sample, inputs_embeds = output )

tensor([[[ 0.4390, -0.9218,  1.2406,  ..., -0.8661, -0.8282, -0.2915],
         [ 1.0726, -0.9998, -0.7318,  ..., -2.1169,  0.0718, -0.4649],
         [-0.3517, -1.5154, -0.9226,  ..., -0.4945, -1.1152, -1.1642],
         ...,
         [ 0.9891, -0.0211,  1.2571,  ..., -0.0514, -0.2245, -1.2818],
         [ 0.9891, -0.0211,  1.2571,  ..., -0.0514, -0.2245, -1.2818],
         [ 0.9891, -0.0211,  1.2571,  ..., -0.0514, -0.2245, -1.2818]]],
       grad_fn=<EmbeddingBackward0>)
torch.Size([1, 500, 1024])


In [102]:
datasets[0].keys()

dict_keys(['img', 'qid', 'input_ids', 'labels', 'segment_ids'])

In [117]:
datasets[0]['labels']

tensor([18483,     1,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0])

In [15]:
import torch
batch_size = 4
n_views = 2
labels = torch.cat([torch.arange(batch_size) for i in range(n_views)], dim=0)
labels

tensor([0, 1, 2, 3, 0, 1, 2, 3])

In [16]:
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels

tensor([[1., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 1.]])

In [17]:
labels.shape

torch.Size([8, 8])

In [18]:
mask = torch.eye(labels.shape[0], dtype=torch.bool)

In [19]:
mask

tensor([[ True, False, False, False, False, False, False, False],
        [False,  True, False, False, False, False, False, False],
        [False, False,  True, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False],
        [False, False, False, False,  True, False, False, False],
        [False, False, False, False, False,  True, False, False],
        [False, False, False, False, False, False,  True, False],
        [False, False, False, False, False, False, False,  True]])

In [20]:
labels.shape[0]

8

In [21]:
labels = labels[~mask].view(labels.shape[0], -1)
labels

tensor([[0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0.]])