Notebook to try differnet network architectures

In [1]:
import torch.nn as nn
import torch
from model import TextProcessing, Classifier, Fusion, BaselineFusion, StackedAttention, average_layer

In [2]:
import json
from torch.utils.data import DataLoader
from VQA_Dataset import VQA_Data

# Load dict
l_dict = json.load(open('data/vocabulary_alternate.json'))
# Select device
device = "cpu"#torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dataloader
# TODO - Code change for test data loader or is it reqd??
dataset = VQA_Data('train')
train_dataloader = DataLoader(dataset, batch_size = 4, shuffle=True)
d_batch = next(iter(train_dataloader))

In [3]:
ques_e = d_batch['quest_e'].squeeze(1).to(device)
img_e = d_batch['img_e'].squeeze(1).to(device)
ques_l = d_batch['quest_l'].to(device)
ans_e = d_batch['ans_e'].view(-1).to(device)
img_p = d_batch['img_p']
ques = d_batch['quest_s']

In [4]:
l_dict = json.load(open('data/vocabulary_alternate.json'))
print(ques_e.size())
print(img_e.size())
print(ans_e.size())

torch.Size([4, 15])
torch.Size([4, 2048, 10, 10])
torch.Size([4])


In [5]:
class TiledAttention(nn.Module):
    # init method to initialize and the forward function
    def __init__(self, l_dict, dropout = None):
        super(TiledAttention,self).__init__()
        glimpses=2
        img_vec_size=2048
        text_vec_size = 1024
        ques_len=300 # size of embedding vector
                
        self.text = TextProcessing(l_dict,l_ques=ques_len,embedding_size=text_vec_size, dropout = dropout)
        self.fuse = Fusion(v=img_vec_size,q=text_vec_size,mid=512,glimpses=glimpses, dropout = dropout)
        self.classifier =  Classifier(in_features=glimpses*img_vec_size+text_vec_size,mid_features=text_vec_size,out_features=3001, dropout = dropout)
        
    def forward(self,text_vec,img_vec,q_len):
        text_vec=self.text(text_vec,q_len) #text_processing
        f = self.fuse(img_vec,text_vec) #fuse_image_vec and text_vec
        img_vec = average_layer(img_vec,f) #pass through softmax
        vec = torch.cat([img_vec,text_vec], dim = 1) #concatenate text_proc with this heat map
        out = self.classifier(vec)#pass through classifier
        return out

In [6]:
t_att_model = TiledAttention(len(l_dict['questions'])+1).to(device)
pred = t_att_model(ques_e, img_e, ques_l)
print(pred.size())

torch.Size([4, 3001])


In [11]:
class BaseLine(nn.Module):
    def __init__(self, l_dict, dropout = None):
        super(BaseLine,self).__init__()
        text_vec_size = 1024
        ques_len=300 # size of embedding vector
        num_classes = 3001
                
        self.text = TextProcessing(l_dict,l_ques=ques_len,embedding_size=text_vec_size, dropout = dropout)
        self.fusion = BaselineFusion(num_classes)
    
    def forward(self,text_vec,img_vec,q_len):
        text_vec=self.text(text_vec,q_len) #text_processing
        pred = self.fusion(img_vec, text_vec)
        return pred

In [12]:
b_att_model = BaseLine(len(l_dict['questions'])+1).to(device)
pred = b_att_model(ques_e, img_e, ques_l)
print(pred.size())

torch.Size([4, 3001])


In [13]:
class SANModel(nn.Module):
    def __init__(self, l_dict, dropout = None):
        super(SANModel,self).__init__()
        text_vec_size = 1024
        ques_len=300 # size of embedding vector
        num_classes = 3001
                
        self.text = TextProcessing(l_dict,l_ques=ques_len,embedding_size=text_vec_size, dropout = dropout)
        self.fusion = StackedAttention(num_classes)
    
    def forward(self,text_vec,img_vec,q_len):
        text_vec=self.text(text_vec,q_len) #text_processing
        pred = self.fusion(img_vec, text_vec)
        return pred

In [14]:
san_att_model = SANModel(len(l_dict['questions'])+1).to(device)
pred = san_att_model(ques_e, img_e, ques_l)
print(pred.size())

torch.Size([4, 3001])
