In [113]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [111]:
import numpy as np
import torch
import torch.nn as nn

_include_('curriculum_vqa')
_include_('commons')

seed = 1
np.random.seed(seed)
torch.manual_seed(seed);

from cvqa import datasets, model, trainer, viz
from commons import debug

data_bin = f'{DEV_HOME}/curriculum_vqa/data-bin'

In [112]:
np.random.randint(3)

1

In [49]:
dataset = datasets.BasicCurriculum(f'{data_bin}/basic_curriculum', prompt_mode='concept', target_mode='class')
dataset

Root: /Users/urisherman/Work/workspace/curriculum_vqa/data-bin/basic_curriculum/train 
Samples: 2000 (N_prompt=1, N_target=1)
Concepts: 2 
Classes: 5 
Vocab Tokens:8

In [50]:
def build_embeddings(d, dataset, c=None):
    if dataset.prompt_mode == 'natural' and dataset.target_mode == 'natural':
        prompt_embeddings = nn.Embedding(len(dataset.vocab), d, padding_idx=0)
        target_embeddings = prompt_embeddings
    else:
        if dataset.prompt_mode == 'natural':
            V = len(dataset.vocab)
        else:
            V = len(dataset.concept_to_idx)

        if dataset.target_mode == 'natural':
            L = len(dataset.vocab)
        else:
            L = len(dataset.cls_to_idx)
        
        if c is None:
            c = d
            
        prompt_embeddings = nn.Embedding(V, d, padding_idx=0)
        target_embeddings = nn.Embedding(L, c, padding_idx=0)
    
    return prompt_embeddings, target_embeddings

In [102]:
import torch.nn.functional as F
import math

class VQAPromptOpModel(nn.Module):

    def __init__(self, prompt_embedding, target_embedding, img_perceptor=None):
        super().__init__()

        if img_perceptor is None:
            img_perceptor = model.BasicImgModel(20)
        self.img_perceptor = img_perceptor
        img_embedding = img_perceptor(torch.rand(1, 3, 224, 224))
        B, P = img_embedding.shape

        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        dims = {
            'P': P,  # perception embedding dim
            'V': prompt_embedding.num_embeddings,  # num of prompt tokens
            'd': prompt_embedding.embedding_dim,  # prompt tokens embedding
            'L': target_embedding.num_embeddings,  # num of target toekns
            'c': target_embedding.embedding_dim  # target tokens embedding
        }
        self.dims = dims
        self.prompt_embedding = prompt_embedding
        self.target_embedding = target_embedding

        # The operators operator
        # Given an embedded prompt, output a P --> c operator
        self.W_op = nn.Parameter(torch.Tensor(dims['P'], dims['d'], dims['c']))
        nn.init.kaiming_uniform_(self.W_op, a=math.sqrt(5))
        
    def forward(self, prompt, img):
        if len(prompt.shape) == 1:
            prompt = prompt.view(-1, 1)
        prompt_encoded = self.prompt_embedding(prompt)  #  [B x N_prompt x d]
        prompt_encoded = torch.sum(prompt_encoded, dim=1)  #  [B x d]
        prompt_op = torch.einsum('pdc,bd->pc', self.W_op, prompt_encoded)  #  [P x c]
        
        img_features = self.img_perceptor(img)  # [B, P]
        
        pred_embeded = F.linear(img_features, prompt_op.T)
        
        t_embeddings = self.target_embedding.weight
        logits = pred_embeded @ t_embeddings.T
        return logits

In [108]:
dataset = datasets.BasicCurriculum(f'{data_bin}/basic_curriculum', prompt_mode='concept', target_mode='natural')
prompt_embeddings, target_embeddings = build_embeddings(12, dataset)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
sample = next(iter(dataloader))

m = VQAPromptOpModel(prompt_embeddings, target_embeddings)
m(sample['prompt'], sample['img']).shape

torch.Size([32, 16])

In [109]:
dataset

Root: /Users/urisherman/Work/workspace/curriculum_vqa/data-bin/basic_curriculum/train 
Samples: 2000 (N_prompt=1, N_target=2)
Concepts: 2 
Classes: 0 
Vocab Tokens:16

In [66]:
prompt = sample['prompt']
if len(prompt.shape) == 1:
    prompt = prompt.view(-1, 1)


torch.Size([32, 1])

In [40]:
torch.sum(prompt_encoded, dim=1).shape

torch.Size([32, 12])

In [45]:
B = 32
P = 50
d = 12
c = 7

W_op = torch.rand(P, d, c)
prompt_encoded = torch.rand(B, d)

In [46]:
(prompt_encoded @ W_op).shape

torch.Size([50, 32, 7])

In [79]:
torch.einsum('pdc,bd->pc', W_op, prompt_encoded).shape

torch.Size([50, 7])

In [87]:
F.linear(torch.rand(32, 12), torch.rand(5, 12)).shape

torch.Size([32, 5])

In [83]:
torch.rand(3,5).T

tensor([[0.9478, 0.6311, 0.7709],
        [0.4816, 0.2015, 0.6500],
        [0.4756, 0.6817, 0.9979],
        [0.3623, 0.7724, 0.7350],
        [0.5531, 0.3633, 0.6443]])