In [1]:
!wget https://memexqa.cs.cmu.edu/fvta_model_zoo/prepro_v1.1.tgz
!gunzip prepro_v1.1.tgz
!tar -xvf prepro_v1.1.tar

--2020-12-12 15:54:23--  https://memexqa.cs.cmu.edu/fvta_model_zoo/prepro_v1.1.tgz
Resolving memexqa.cs.cmu.edu (memexqa.cs.cmu.edu)... 128.2.220.9
Connecting to memexqa.cs.cmu.edu (memexqa.cs.cmu.edu)|128.2.220.9|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 212511000 (203M) [application/x-gzip]
Saving to: ‘prepro_v1.1.tgz’


2020-12-12 15:54:32 (21.5 MB/s) - ‘prepro_v1.1.tgz’ saved [212511000/212511000]

prepro_v1.1/
prepro_v1.1/test_data.p
prepro_v1.1/train_shared.p
prepro_v1.1/test_shared.p
prepro_v1.1/train_data.p
prepro_v1.1/val_data.p
prepro_v1.1/val_shared.p


In [10]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import *
import torch.nn.functional as F
# import new_dataset_checked_by_hongyuan
# import pandas as pd

class AttentionModel(nn.Module):
    def __init__(self, q_cs_input_size, desc_input_size, img_input_size, hidden_size, batch_size,
                num_layers, device, img_linear_size,num_choices = 4, rnn_type = 'bilstm'):
        super(AttentionModel, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.q_cs_input_size = q_cs_input_size   #input_size qvec and cs_vec
        self.desc_input_size = desc_input_size
        self.img_input_size = img_input_size
        self.img_linear_size = img_linear_size # s2
        self.num_directions = 2 if rnn_type == 'bilstm' else 1
        self.num_layers = num_layers
        self.cos = nn.CosineSimilarity(dim=2, eps=1e-6)
        self.softmax1 = nn.Softmax(dim = 1)
        self.softmax2 = nn.Softmax(dim = 2)

        if (rnn_type == 'bilstm'):
            self.rnn_q = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # questions
            self.rnn_c = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # choices
            self.rnn_desc = nn.LSTM(desc_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # photo titles
            self.rnn_ps = nn.LSTM(img_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # image features
        else:
            self.rnn_q = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_c = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_desc = nn.LSTM(desc_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_ps = nn.LSTM(img_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)

        self.vis_text = nn.Linear(2*hidden_size, 2*hidden_size)
      
        self.tanh1 = nn.Tanh()
        self.CH_linear = nn.Linear(2*hidden_size, 2*hidden_size)
        self.img_linear = nn.Linear(hidden_size, self.img_linear_size) 
        self.tanh2 = nn.Tanh()
        self.tanh3 = nn.Tanh()

        #input: B X num_choices X (5 * 2 * hidden_size), output: B x num_choices X 1
        self.last_softmax = nn.Linear(5 * 2 * hidden_size , 1) 

    
    def forward(self, X):
        # X is a list of dictionaries: 'q_vec', 'cs_vec', 'desc_vec', 'img_feats'
        # BATCH_SIZE = len(X)
        # B, T, F
        # q_vec -> B, T, 100 (glove embedding)
        # cs_vec -> B, 4, Y_THRES, 100
        # desc_vec(prev. pt) -> B, all_photo_titles_albums * 40, 100
        # img_feats -> B, all_photo_titles_albums, 2537

        q_vec = X['q_vec']
        cs_vec = X['cs_vec']
        desc_vec = X['desc_vec']
        img_feats = X['img_feats']

        packed_q_vec = pack_padded_sequence(q_vec, X['q_len'], batch_first=False, enforce_sorted = False).to(self.device)
        packed_c0_vec = pack_padded_sequence(cs_vec[:, :, 0, :], X['cs0_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c1_vec = pack_padded_sequence(cs_vec[:, :, 1, :], X['cs1_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c2_vec = pack_padded_sequence(cs_vec[:, :, 2, :], X['cs2_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c3_vec = pack_padded_sequence(cs_vec[:, :, 3, :], X['cs3_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_pt_vec = pack_padded_sequence(desc_vec, X['desc_len'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_img_vec = pack_padded_sequence(img_feats, X['img_len'], batch_first = False, enforce_sorted = False).to(self.device)
        q_out,_   = self.rnn_q(packed_q_vec)  # M X B X 2d
        c0_out,_  = self.rnn_c(packed_c0_vec) # T X B X 2d
        c1_out,_  = self.rnn_c(packed_c1_vec) # T X B X 2d
        c2_out,_  = self.rnn_c(packed_c2_vec) # T X B X 2d
        c3_out,_  = self.rnn_c(packed_c3_vec) # T X B X 2d
        text_out,_ = self.rnn_desc(packed_pt_vec) # T X B X 2d
        vis_out,_  = self.rnn_ps(packed_img_vec) # T X B X 2d

        # print("vis_out: ", vis_out.data.shape, text_out.data.shape)


        q_out, q_lens_unpacked = pad_packed_sequence(q_out, batch_first=False)  # M X B X 2d
        q_out = q_out.permute(1,0,2) # B X M X 2d
        c0_out, c0_lens_unpacked = pad_packed_sequence(c0_out, batch_first=False)  # T X B X 2d
        c0_out = c0_out[-1,:,:].unsqueeze(0).permute(1,0,2) # B X 1 X 2d
        c1_out, c1_lens_unpacked = pad_packed_sequence(c1_out, batch_first=False)  # T X B X 2d
        c1_out = c1_out[-1,:,:].unsqueeze(0).permute(1,0,2) # B X 1 X 2d
        c2_out, c2_lens_unpacked = pad_packed_sequence(c2_out, batch_first=False)    # T X B X 2d
        c2_out = c2_out[-1,:,:].unsqueeze(0).permute(1,0,2) # B X 1 X 2d
        c3_out, c3_lens_unpacked = pad_packed_sequence(c3_out, batch_first=False)   # T X B X 2d
        c3_out = c3_out[-1,:,:].unsqueeze(0).permute(1,0,2) # B X 1 X 2d
        txt_out, text_out_lens = pad_packed_sequence(text_out, batch_first=False) # T X B X 2d
        txt_out = txt_out.permute(1,0,2) # B X T X 2d
       

        vis_out, vis_out_lens = pad_packed_sequence(vis_out, batch_first=False) # T X B X 2d
        vis_out = vis_out.permute(1,0,2) # B X T X 2d    
        
        # print("vis_out: ", vis_out.shape, txt_out.shape)

        # print("qout: ", q_out.shape)
        # correlation b/w txt_out and vis_out : B X T X 2d -> linear_layer 
        vis = self.vis_text(vis_out)      # B X T X 2d
        vis_out = torch.unsqueeze(vis_out, 3) 
        # print("vis: ", vis.shape)
        text = self.vis_text(txt_out)      # B X T X 2d
        txt_out = torch.unsqueeze(txt_out, 3)
       
        H = torch.cat([vis_out, txt_out], dim = 3)  # B * T * 2d * 2
        # print("H: ", H.shape)  
        C = self.tanh1(vis @ text.permute(0,2,1))  # B X T X T
        C_repeat = C.unsqueeze(3).repeat(1,1,1,2) # B X T x T X 2
        # print("C: ", C_repeat.shape)
        
        
        H_perm = H.permute(0, 3, 2, 1) # B x 2 x 2d x T
        # print("Hperm shape:", H_perm.shape)
        C_perm = C_repeat.permute(0,3,1,2) # B x 2 x T x T
        # print("Cperm shape:", C_perm.shape)
        F = torch.matmul(H_perm, C_perm) # B x 2 X 2d x T
       
        F = F.permute(0,3,1,2) #F:  B * T * 2 * 2d
        F = self.CH_linear(F) 
        F = self.tanh2(F)
        # print("F shape:", F.shape) 

        E = torch.cat([c0_out,c1_out, c2_out, c3_out], dim = 1)  #B x 4 X 2d
        # print("E shape:", E.shape) 
        Q = q_out  #   Q: B x M x 2d
        # print("Q shape:", Q.shape) 
        
        bmm1 = Q.bmm(F.permute(0,3,1,2)[:,:,:,0]).unsqueeze(3)
        bmm2 = Q.bmm(F.permute(0,3,1,2)[:,:,:,1]).unsqueeze(3)
        S = torch.cat([bmm1, bmm2], dim = 3) #S:  B x M x T x 2
        # print("S shape:", S.shape) 
        S = self.tanh3(S)
        max1 = torch.max(S, dim = 1)[0]
        # print("max1 shape: ", max1.shape)
        max2 = torch.max(max1, dim = 1)[0]
        # print("max2 shape: ", max2.shape)
        max3 = torch.max(S, dim = 3)[0]
        # print("max3 shape: ", max3.shape)
        max4 = torch.max(max3, dim = 2)[0]
        # print("max4 shape: ", max4.shape)

        '''
        max1 shape:  torch.Size([3, 69, 2])
        max2 shape:  torch.Size([3, 2])
        max4 shape:  torch.Size([3, 9, 2])
        max3 shape:  torch.Size([3, 9])
                '''
        A = self.softmax1(max1).to(self.device)  # B X T x 2
        B = self.softmax1(max2).to(self.device)   # B X 2
        D = self.softmax1(max4).to(self.device)  # B x M

        h_tilda = torch.zeros((F.shape[0], F.shape[-1])).float().to(self.device) # B x 2d
        q_tilda = torch.zeros((F.shape[0], F.shape[-1])).float().to(self.device) # B x 2d
       
        for k in range(2):
            temp = torch.zeros((F.shape[0], F.shape[-1])).float().to(self.device)
            for t in range(A.shape[1]):
                # print("A shape: ", A[:,t, k].unsqueeze(1).repeat(1, F.shape[-1]).shape, " , F[:, t, k, :] shape: ",  F[:, t, k, :].shape )
                temp += A[:,t, k].unsqueeze(1).repeat(1, F.shape[-1]) *  F[:, t, k, :]# B x 2d
                # print("temp shape: ", temp.shape)
                # temp += A[:,t, k].unsqueeze(1).repeat(1, F.shape[0]) * F[:, t, k, :]
            h_tilda +=  (B[:, k].unsqueeze(1).repeat(1, F.shape[-1]) * temp)
            # print("htilda shape: ", h_tilda.shape)
            # h_tilda += B[:, k] * temp    
        
        H_tilda = h_tilda.unsqueeze(1).repeat(1, 4, 1)  # B x 4 X 2d
        # print("htilda: ", H_tilda.shape)

        for m in range(D.shape[1]):
            # print("D shape: ",D[:, m].unsqueeze(1).repeat(1, Q.shape[-1]).shape, " Q shape: ", Q[:, m, :].shape )
            q_tilda += D[:, m].unsqueeze(1).repeat(1, Q.shape[-1]) * Q[:, m, :]
            # print("qtilda: ", q_tilda.shape)

        Q_tilda = q_tilda.unsqueeze(1).repeat(1, 4, 1)  # B x 4 X 2d
        # print("qtilda: ", Q_tilda.shape)
        concat = torch.cat([Q_tilda, H_tilda, E, Q_tilda * E, H_tilda * E], dim = -1)

        output = self.last_softmax(concat)

        output = output.squeeze(2)
        out_softmax = self.softmax1(output)
        return out_softmax


In [11]:
import pandas as pd
import numpy as np
import torch
from torch.nn.utils.rnn import *
from torch.utils.data import Dataset
import itertools

train_data = pd.read_pickle('prepro_v1.1/train_data.p')
train_shared = pd.read_pickle('prepro_v1.1/train_shared.p')
val_data = pd.read_pickle('prepro_v1.1/val_data.p')

q_types = ["when", "what", "who", "where", "how"]
qtype2qid = {}
qtype2qid["when"] = []
qtype2qid["what"] = []
qtype2qid["who"] = []
qtype2qid["where"] = []
qtype2qid["how"] = []

for i, qid in enumerate(train_data['qid']):
    if train_data['q'][i][0].lower() == "when":
        qtype2qid["when"].append(qid)
    elif train_data['q'][i][0].lower() == "what":
        qtype2qid["what"].append(qid)
    elif train_data['q'][i][0].lower() == "who":
        qtype2qid["who"].append(qid)
    elif train_data['q'][i][0].lower() == "where":
        qtype2qid["where"].append(qid)
    elif train_data['q'][i][0].lower() == "how":
        qtype2qid["how"].append(qid)

for i, qid in enumerate(val_data['qid']):
    if val_data['q'][i][0].lower() == "when":
        qtype2qid["when"].append(qid)
    elif val_data['q'][i][0].lower() == "what":
        qtype2qid["what"].append(qid)
    elif val_data['q'][i][0].lower() == "who":
        qtype2qid["who"].append(qid)
    elif val_data['q'][i][0].lower() == "where":
        qtype2qid["where"].append(qid)
    elif val_data['q'][i][0].lower() == "how":
        qtype2qid["how"].append(qid)


q_lens = [len(q) for q in train_data['q']]
cs_lens = [[len(c) for c in cs] for cs in train_data['cs']]
cs_lens = list(itertools.chain(*cs_lens))
y_lens = [len(y) for y in train_data['y']]
photo_lens = [len(train_shared['albums'][aid]['photo_ids']) for aid in train_shared['albums']]
all_photos_lens = [sum(len(train_shared['albums'][aid]['photo_ids']) for aid in aid_list) for aid_list in train_data['aid']]
pts_lens = [len(pt) for aid in train_shared['albums'] for pt in train_shared['albums'][aid]['photo_titles']] #number of photos/album
when_lens = [len(train_shared['albums'][aid]['when']) for aid in train_shared['albums']]
album_title_lens = [len(train_shared['albums'][aid]['title']) for aid in train_shared['albums']]
album_desc_lens = [len(train_shared['albums'][aid]['description']) for aid in train_shared['albums']]

Q_THRES = int(np.percentile(q_lens, 90)) # 10
Y_THRES = int(np.percentile(cs_lens, 90)) # 3, same as np.percentile(y_lens, 90)
PTS_THRES = int(np.percentile(pts_lens, 90)) # 8
WHEN_THRES = int(np.percentile(when_lens, 90)) # 4
PHOTOS_PER_ALBUM = int(np.percentile(photo_lens, 90)) # 10
ALBUM_TITLE_THRES = int(np.percentile(album_title_lens, 90)) # 8
ALBUM_DESC_THRES = int(np.percentile(album_desc_lens, 50)) # 11

def train_collate(batch):
    X, Y = zip(*batch)
    q_vec = []
    cs_vec = []
    desc_vec = []
    img_feats = []
    q_len = []
    cs0_len = []
    cs1_len = []
    cs2_len = []
    cs3_len = []
    desc_len = []
    img_len = []
    qid = []
    new_X = {}
    for x in X:
        q_len.append(x['q_len'])
        cs0_len.append(x['cs_lens'][0])
        cs1_len.append(x['cs_lens'][1])
        cs2_len.append(x['cs_lens'][2])
        cs3_len.append(x['cs_lens'][3])
        desc_len.append(x['desc_len'])
        img_len.append(x['img_len'])
        q_vec.append(x['q_vec'])
        #x['cs_vec'] expected shape: Y_THRES, 4, 100
        cs_vec.append(x['cs_vec'])
        desc_vec.append(x['desc_vec'])
        img_feats.append(x['img_feats'])
        qid.append(x['qid'])

    new_X['q_len'] = torch.LongTensor(q_len)
    new_X['cs0_lens'] = torch.LongTensor(cs0_len)
    new_X['cs1_lens'] = torch.LongTensor(cs1_len)
    new_X['cs2_lens'] = torch.LongTensor(cs2_len)
    new_X['cs3_lens'] = torch.LongTensor(cs3_len)
    new_X['desc_len'] = torch.LongTensor(desc_len)
    new_X['img_len'] = torch.LongTensor(img_len)
    new_X['q_vec'] = pad_sequence(q_vec, batch_first=False, padding_value=0)  # question 
    # expected shape: B, Y_THRES, 4, 100
    new_X['cs_vec'] = pad_sequence(cs_vec, batch_first=False, padding_value=0) # B, Y_THRES, 4, 100 -> 4 choices T, B, 4, 100
    new_X['desc_vec'] = pad_sequence(desc_vec, batch_first=False, padding_value=0)
    new_X['img_feats'] = pad_sequence(img_feats, batch_first=False, padding_value=0)

    return new_X, torch.LongTensor(Y), qid

class MemexQA_new(Dataset):
    def __init__(self, data, shared):
        self.data = data
        self.shared = shared

    def __len__(self):
        return len(self.data['q'])

    def __getitem__(self, idx):
        returned_item = {}
        # self.data keys -> ['q', 'idxs', 'cy', 'ccs', 'qid', 'y', 'aid', 'cq', 'yidx', 'cs']
        # self.shared keys -> ['albums', 'pid2feat', 'word2vec', 'charCounter', 'wordCounter']
        returned_item['qid'] = self.data['qid'][idx]
        q = self.data['q'][idx]
        # missing glove word-> [0] * 100 embedding
        q_vec = torch.FloatTensor(
            [self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for word in
             q])
        q_vec = q_vec[:Q_THRES]
        returned_item['q_vec'] = q_vec  # largest possible shape: Q_THRES * 100
        returned_item['q_len'] = q_vec.shape[0] 
        # choices glove
        wrong_cs = self.data['cs'][idx]
        correct_c = self.data['y'][idx]
        yidx = self.data['yidx'][idx]
        if yidx == 0:
            cs = [correct_c] + wrong_cs
        elif yidx == 1:
            cs = wrong_cs[:1] + [correct_c] + wrong_cs[1:]
        elif yidx == 2:
            cs = wrong_cs[:2] + [correct_c] + wrong_cs[2:]
        else:  # yidx == 3
            cs = wrong_cs + [correct_c]

        cs_vec = [
            [self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for word in
             c] for c in cs]
        cs_vec = [torch.FloatTensor(c[:Y_THRES]) for c in cs_vec]
        cs_lens = [min(Y_THRES, len(each)) for each in cs_vec] #YTHRES
        returned_item['cs_vec'] = pad_sequence(cs_vec, batch_first = False)  # [c1, c2, c3, c4]; largest possible shape: 4, Y_THRES, 100 ->  Y_THRES, 4, 100
        returned_item['cs_lens'] = cs_lens

        # aid: description + title , aid:when , aid : photo_titles + {later ->( photo_captions  + photo tags )}
        aid_list = self.data['aid'][idx]
        pts_descs = []  # photo-level
        pid_features = []  # img features from pre-trained CNN
        # for each album
        total_cat_len = ALBUM_TITLE_THRES + ALBUM_DESC_THRES + WHEN_THRES + PTS_THRES  # 8 + 11 + 4 + 8 = 31
        for aid in aid_list:
            # ptags = {for each self.album_itags[aid]}
            album = self.shared['albums'][aid]
            pts = album['photo_titles']  # all photo titles/aid

            # concatenate album description, album title and album when
            desc = album['description'][:ALBUM_DESC_THRES] + album['title'][:ALBUM_TITLE_THRES] + album['when'][
                                                                                                  :WHEN_THRES]

            for pt in pts:
                photo_info = desc + pt[:PTS_THRES]
                # largest possible shape: total_cat_len, 100
                photo_info_vec = [
                    self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for
                    word in photo_info]
                if len(photo_info_vec) < total_cat_len:
                    photo_info_vec = photo_info_vec + [[0] * 100 for _ in range(
                        total_cat_len - len(photo_info_vec))]  # total_cat_len, 100
                pts_descs.append(photo_info_vec)  # total number of photos (varies), total_cat_len, 100

            for pid in self.shared['albums'][aid]['photo_ids']:
                # img_feats
                pid_features.append(self.shared['pid2feat'][pid])  # total number of photos (varies) * 2537

        desc_vec = torch.FloatTensor(pts_descs).view(-1,
                                                     total_cat_len * 100)  # total number of photos (varies), total_cat_len * 100
        returned_item['desc_vec'] = desc_vec
        returned_item['desc_len'] = desc_vec.shape[0]
        img_feats_vec = torch.FloatTensor(
            pid_features)  # total number of photos (varies), 2537; NEWLY CHANGED (no matter what, it will vary; keep consistent with desc_vec)
        returned_item['img_feats'] = img_feats_vec
        returned_item['img_len'] = img_feats_vec.shape[0]
        return returned_item, yidx

In [12]:
import torch.optim as optim
import torch.nn as nn
import torch
import pandas as pd
import time
import os
import numpy as np
import argparse
import csv

# # hyperparams
# EPOCHS = 10
# BATCH_SIZE = 64

# # optimizer-related
# MOMENTUM = 1e-2
# LR = 1e-2
# LR_STEPSIZE = 5
# LR_DECAY = 0.85
# WD = 5e-6

# model hyperparams
EPOCHS = 80
BATCH_SIZE = 64
HIDDEN_SIZE = 512
NUM_LAYERS = 3
KERNEL = 5
STRIDE = 1
DROPOUT = 0.4

# optimizer-related
#MOMENTUM = 1e-2
LR = 0.32
WD = 5e-5

# scheduler-related
# LR_STEPSIZE = 3
# LR_DECAY = 0.85
FACTOR = 0.95
PATIENCE = 3
THRESHOLD = 0.01 

def main(train_data_pth, train_shared_pth, val_data_pth, val_shared_pth, test_data_pth, test_shared_pth, isTrain):
    cuda = torch.cuda.is_available()
    num_workers = 8 if cuda else 0
    print("Loading data......")
    start = time.time()

    train_shared = pd.read_pickle(train_shared_pth)
    # random initial embedding matrix for new words
    nonglove_dict = {word: np.random.normal(0, 1, 100) for word in train_shared['wordCounter'] if word not in train_shared['word2vec']}
    train_shared['word2vec'].update(nonglove_dict)
    
    val_shared = pd.read_pickle(val_shared_pth)
    val_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in val_shared['wordCounter'] if word not in val_shared['word2vec']}
    val_shared['word2vec'].update(val_nonglove_dict)

    test_shared = pd.read_pickle(test_shared_pth)
    test_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in test_shared['wordCounter'] if word not in test_shared['word2vec']}
    test_shared['word2vec'].update(test_nonglove_dict)

    train_data = MemexQA_new(data=pd.read_pickle(train_data_pth), shared=train_shared)
    valid_data = MemexQA_new(data=pd.read_pickle(val_data_pth), shared=val_shared)
    test_data = MemexQA_new(data=pd.read_pickle(test_data_pth), shared=test_shared)

    train_loader_args = dict(shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=True, batch_size=BATCH_SIZE, collate_fn=train_collate)
    train_loader = torch.utils.data.DataLoader(train_data, **train_loader_args)

    valid_loader_args = dict(shuffle=False, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, **valid_loader_args)

    test_loader_args = dict(shuffle=False, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)  # TODO: test_collate
    test_loader = torch.utils.data.DataLoader(test_data, **test_loader_args)
    print(f"Loading data took {time.time() - start:.1f} seconds")
    
    # initialize model
    device = torch.device("cuda" if cuda else "cpu")
    
    model = AttentionModel(q_cs_input_size=100, desc_input_size = 3100, img_input_size =2537, hidden_size = HIDDEN_SIZE, batch_size = BATCH_SIZE, num_layers = NUM_LAYERS,\
                           device = device, img_linear_size = 64,num_choices = 4, rnn_type = 'bilstm')
    snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
    PATH = snapshot_prefix + "Model_"+str(30)
    print("Loading prev state......")
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.to(device)
    print(model)
    # setup optim and loss

    criterion = nn.CrossEntropyLoss()


    # setup optim and loss
    criterion = nn.CrossEntropyLoss()
#     optimizer= optim.Adam(model.parameters(), lr = LR, weight_decay= WD)
    optimizer = optim.Adadelta(model.parameters(), lr = LR, weight_decay= WD)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEPSIZE, gamma=LR_DECAY)

    # training
    for e in range(EPOCHS):
        start = time.time()
        print("Starting training......")
        model.train()
        n_correct, n_total = 0, 0
        batch_count = 0
        t_loss = 0
        q_totals = [0] * 5
        q_corrects = [0] * 5
        for j, (batch_data, batch_labels,qids) in enumerate(train_loader):
            if j == len(train_loader) - 1:
                break
            batch_labels = batch_labels.long().to(device)
            optimizer.zero_grad()
            output = model(batch_data)
            loss = criterion(output, batch_labels)
            t_loss += loss.item()
            res = torch.argmax(output, 1)
            res = res.to(device)
            for i, qid in enumerate(qids):
                if qid in qtype2qid["when"]:
                    q_totals[0] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[0] += 1
                if qid in qtype2qid["what"]:
                    q_totals[1] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[1] += 1
                if qid in qtype2qid["who"]:
                    q_totals[2] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[2] += 1
                if qid in qtype2qid["where"]:
                    q_totals[3] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[3] += 1
                if qid in qtype2qid["how"]:
                    q_totals[4] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[4] += 1
            n_correct += (res == batch_labels).sum().item()
            n_total += batch_labels.shape[0]
            batch_count += 1
            loss.backward()
            optimizer.step()
            #if batch_count % 20 == 19:
            #    print(f"correct choice:{batch_labels[:3]} , predicted choice: {res[:3]}")
        train_acc = n_correct / n_total
        train_loss = t_loss / batch_count
        print(f"TRAIN ===> Epoch {e + 1}, took time {time.time()-start:.1f}s, train accu: {train_acc:.4f}, train loss: {train_loss:.6f}")
        train_acc_q_type = np.array(q_corrects) /np.array(q_totals)
        for i in range(5):  # 5 types of questions
            print("TRAIN ACC ", q_types[i], ": ", train_acc_q_type[i])
        #scheduler.step()
        
        # validate and save model 
        print("Start validation......")
        start = time.time()
        with torch.no_grad():
            model.eval()            
            valid_correct, loss, num_of_batches, num_of_val = 0, 0, 0, 0
            # validation for classification
            q_totals = [0] * 5
            q_corrects = [0] * 5
            for k, (vb_data, vb_label,qids) in enumerate(valid_loader):
                if k == len(valid_loader) - 1:
                    break
                vb_label = vb_label.long().to(device)
                v_output = model(vb_data)
                resm = torch.argmax(v_output, 1)
                resm = resm.to(device)
                for i, qid in enumerate(qids):
                    if qid in qtype2qid["when"]:
                        q_totals[0] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[0] += 1
                    if qid in qtype2qid["what"]:
                        q_totals[1] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[1] += 1
                    if qid in qtype2qid["who"]:
                        q_totals[2] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[2] += 1
                    if qid in qtype2qid["where"]:
                        q_totals[3] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[3] += 1
                    if qid in qtype2qid["how"]:
                        q_totals[4] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[4] += 1
                correct = (resm == vb_label).sum().item()
                valid_correct += correct
                loss += criterion(v_output, vb_label).item()
                num_of_batches += 1
                num_of_val += vb_label.shape[0]
                #if num_of_batches % 20 == 19:
                #    print(f"correct choice:{vb_label[:3]} , predicted choice: {resm[:3]}")
            val_loss = loss / num_of_batches
            val_accu = valid_correct / num_of_val
        print(f"VALID ===> Epoch {e}, took time {time.time()-start:.1f}s, valid accu: {val_accu:.4f}, valid loss: {val_loss:.6f}")
        train_acc_q_type = np.array(q_corrects) /np.array(q_totals)
        for i in range(5):  # 5 types of questions
            print("VALID ACC ", q_types[i], ": ", train_acc_q_type[i])
        if (e+1) % 10 == 0:
            snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
            if not os.path.exists(snapshot_prefix):
                os.makedirs(snapshot_prefix)
            torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        #'scheduler_state_dict' : scheduler.state_dict(),
            }, snapshot_prefix + "Model_"+str(i))

    # testing
    if not isTrain:
        print("Start testing......")
        start = time.time()
        model.eval()
        with torch.no_grad(), open('test_predictions.csv', 'w') as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow(["predict","actual"])
            for (tbatch_data, tbatch_data_labels) in test_loader:
                test_out = model(tbatch_data)
                predict = torch.argmax(test_out, axis=1)
                correct = (predict == tbatch_data_labels).sum().item()
                for (pred, actual) in zip(predict, correct):
                    writer.writerow([pred, actual])
        print(f"Testing took {time.time()-start:.1f}s")
    print("Finished")
                
                                                    
    
main('prepro_v1.1/train_data.p', 'prepro_v1.1/train_shared.p', 
     'prepro_v1.1/val_data.p', 'prepro_v1.1/val_shared.p', 
     'prepro_v1.1/test_data.p', 'prepro_v1.1/test_shared.p',
     isTrain = True)



Loading data......
Loading data took 11.1 seconds
Loading prev state......
AttentionModel(
  (cos): CosineSimilarity()
  (softmax1): Softmax(dim=1)
  (softmax2): Softmax(dim=2)
  (rnn_q): LSTM(100, 512, num_layers=3, bidirectional=True)
  (rnn_c): LSTM(100, 512, num_layers=3, bidirectional=True)
  (rnn_desc): LSTM(3100, 512, num_layers=3, bidirectional=True)
  (rnn_ps): LSTM(2537, 512, num_layers=3, bidirectional=True)
  (vis_text): Linear(in_features=1024, out_features=1024, bias=True)
  (tanh1): Tanh()
  (CH_linear): Linear(in_features=1024, out_features=1024, bias=True)
  (img_linear): Linear(in_features=512, out_features=64, bias=True)
  (tanh2): Tanh()
  (tanh3): Tanh()
  (last_softmax): Linear(in_features=5120, out_features=1, bias=True)
)
Starting training......
TRAIN ===> Epoch 1, took time 370.7s, train accu: 0.2832, train loss: 1.374751
TRAIN ACC  when :  0.3304042179261863
TRAIN ACC  what :  0.3231062575836367
TRAIN ACC  who :  0.2151748666271488
TRAIN ACC  where :  0.209121

KeyboardInterrupt: 

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import pandas as pd
import time
import os
import numpy as np
import argparse
import csv

def main(train_data_pth, train_shared_pth, val_data_pth, val_shared_pth, test_data_pth, test_shared_pth, isTrain):
    cuda = torch.cuda.is_available()
    num_workers = 8 if cuda else 0
    print("Loading data......")
    start = time.time()
    train_shared = pd.read_pickle(train_shared_pth)
    # random initial embedding matrix for new words
    nonglove_dict = {word: np.random.normal(0, 1, 100) for word in train_shared['wordCounter'] if word not in train_shared['word2vec']}
    train_shared['word2vec'].update(nonglove_dict)
    
    val_shared = pd.read_pickle(val_shared_pth)
    val_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in val_shared['wordCounter'] if word not in val_shared['word2vec']}
    val_shared['word2vec'].update(val_nonglove_dict)

    test_shared = pd.read_pickle(test_shared_pth)
    test_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in test_shared['wordCounter'] if word not in test_shared['word2vec']}
    test_shared['word2vec'].update(test_nonglove_dict)

    # train_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/train_data.p'), shared=train_shared, info=None)
    # valid_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/val_data.p'), shared=val_shared, info=None)
    # test_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/test_data.p'), shared=test_shared, info=None)
    train_data = MemexQA_new(data=pd.read_pickle(train_data_pth), shared=train_shared)
    valid_data = MemexQA_new(data=pd.read_pickle(val_data_pth), shared=val_shared)
    test_data = MemexQA_new(data=pd.read_pickle(test_data_pth), shared=test_shared)

    # random initial embedding matrix for new words
    # config.emb_mat = np.array([idx2vec_dict[idx] if idx2vec_dict.has_key(idx) 
    # else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size)) 
    # for idx in xrange(config.word_vocab_size)],dtype="float32") 

    train_loader_args = dict(shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=True, batch_size=BATCH_SIZE, collate_fn=train_collate)
    train_loader = torch.utils.data.DataLoader(train_data, **train_loader_args)

    valid_loader_args = dict(batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, **valid_loader_args)

    test_loader_args = dict(batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    test_loader = torch.utils.data.DataLoader(test_data, **test_loader_args)
    print(f"Loading data took {time.time() - start:.1f} seconds")
    
    # initialize model
    device = torch.device("cuda" if cuda else "cpu")
    #model = SimpleLSTMModel(100, 64, hyperparams['batch_size'], 2, device)
    #q_cs_input_size, desc_input_size, img_input_size, hidden_size, batch_size, num_layers, device, q_linear_size, img_linear_size, multimodal_out, kernel, stride = 1, rnn_type = 'bilstm'
    #model = NewFusionModel(100, 3100,2537,128, 2, 2, device, 64, 64, 4, 3, 1)    
    # model = NewFusionModel(q_cs_input_size=100, desc_input_size=3100,img_input_size=2537,hidden_size=HIDDEN_SIZE, batch_size=BATCH_SIZE, num_layers=NUM_LAYERS, \
    #                        device=device, q_linear_size=64, img_linear_size=64, multimodal_out=4, kernel=KERNEL, stride=STRIDE)

# q_cs_input_size, desc_input_size, img_input_size, hidden_size, batch_size, num_layers, device, img_linear_size,num_choices = 4, rnn_type = 'bilstm'
    model = AttentionModel(q_cs_input_size=100, desc_input_size = 3100, img_input_size =2537, hidden_size = HIDDEN_SIZE, batch_size = BATCH_SIZE, num_layers = NUM_LAYERS,\
                           device = device, img_linear_size = 64,num_choices = 4, rnn_type = 'bilstm')
    snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
    PATH = snapshot_prefix + "Model_"+str(30)
    print("Loading prev state......")
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])

    model.to(device)
    print(model)
    # setup optim and loss

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adadelta(model.parameters(), lr = LR, weight_decay= WD)
    #optimizer= optim.SGD(model.parameters(), momentum=hyperparams['momentum'], lr = hyperparams['lr'], weight_decay= hyperparams['weight_decay'])
    # optimizer= optim.Adam(model.parameters(), lr = LR, weight_decay= WD)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEPSIZE, gamma=LR_DECAY)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=FACTOR, patience=PATIENCE, threshold=THRESHOLD,verbose=True)


    # training
    print("Starting training......")
    for i in range(EPOCHS):
        start = time.time()
        model.train()
        n_correct,n_total = 0, 0
        batch_count = 0
        t_loss = 0
        for j, (batch_data, batch_labels) in enumerate(train_loader):
            batch_labels = batch_labels.long().to(device)
            if j == len(train_loader) - 1:
                break
            optimizer.zero_grad()
            output = model(batch_data)
            loss = criterion(output, batch_labels)
            t_loss += loss
            res = torch.argmax(output, 1)
            res = res.to(device)
            n_correct += (res == batch_labels).sum().item()
            n_total += batch_labels.shape[0]
            batch_count += 1
            loss.backward()
            optimizer.step()
            # if batch_count % 20 == 19:
            #     print(f"correct choice:{batch_labels[:3]} , predicted choice: {res[:3]}")
        train_acc = n_correct / n_total
        train_loss = t_loss / batch_count
#         print(f"There are {n_total} questions in training data.")
        print(f"TRAIN ===> Epoch {i + 1}, took time {time.time()-start:.1f}s, train accu: {train_acc:.4f}, train loss: {train_loss.item():.6f}")
        scheduler.step(train_loss)
        if (i+1) % 10== 0:
            snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
            if not os.path.exists(snapshot_prefix):
                os.makedirs(snapshot_prefix)
            torch.save({
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict' : scheduler.state_dict(),
            }, snapshot_prefix + "Model_"+str(i))
        
        
        # validate and save model 
        #print("Start validation......")
        start = time.time()
        with torch.no_grad():
            model.eval()            
            valid_correct, loss, num_of_batches, num_of_val = 0, 0, 0, 0
            # validation for classification
            for k, (vb_data, vb_label) in enumerate(valid_loader):
                if k == len(valid_loader) - 1:
                    break
                vb_label = vb_label.long().to(device)
                v_output = model(vb_data)
                resm = torch.argmax(v_output, axis=1)
                resm = resm.to(device)
                correct = (resm == vb_label).sum().item()
                valid_correct += correct
                loss += criterion(v_output, vb_label).item()
                num_of_batches += 1
                num_of_val += vb_label.shape[0]
                # if num_of_batches % 20 == 19:
                #     print(f"correct choice:{vb_label[:3]} , predicted choice: {resm[:3]}")
            val_loss = loss / num_of_batches
            val_accu = valid_correct / num_of_val
        print(f"VALID ===> Epoch {i + 1}, took time {time.time()-start:.1f}s, valid accu: {val_accu:.4f}, valid loss: {val_loss:.6f}")
#         print(f"There are {num_of_val} questions in validation data.")
        
    # testing
    if not isTrain:
        print("Start testing......")
        start = time.time()
        model.eval()
        with torch.no_grad(), open('test_predictions.csv', 'w') as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow(["predict","actual"])
            test_correct = 0
            loss = 0
            num_of_batches = 0
            num_of_test = 0
            for (tbatch_data, tbatch_data_labels) in test_loader:
                if k == len(test_loader) - 1:
                    break
                tb_label = tbatch_data_labels.long().to(device)
                t_output = model(tbatch_data, is_train=False)
                tresm = torch.argmax(t_output, axis=1)
                tresm = tresm.to(device)
                tcorrect = (tresm == tb_label).sum().item()
                test_correct += tcorrect
                loss += criterion(t_output, tb_label).item()
                num_of_batches += 1
                num_of_test += tb_label.shape[0]
            for (pred, actual) in zip(tresm, tb_label):
                writer.writerow([pred, actual])
        print(f"TEST ===> test accu: {val_accu:.4f}, test loss: {val_loss:.6f}.")
        print(f"Testing took {time.time()-start:.1f}s")
    print("Finished")
                
                                                    
    

# model hyperparams
EPOCHS = 80
BATCH_SIZE = 64
HIDDEN_SIZE = 512
NUM_LAYERS = 3
KERNEL = 5
STRIDE = 1
DROPOUT = 0.4

# optimizer-related
#MOMENTUM = 1e-2
LR = 0.32
WD = 5e-5

# scheduler-related
# LR_STEPSIZE = 3
# LR_DECAY = 0.85
FACTOR = 0.95
PATIENCE = 3
THRESHOLD = 0.01 

In [11]:
main('prepro_v1.1/train_data.p',
        'prepro_v1.1/train_shared.p',
        'prepro_v1.1/val_data.p',
        'prepro_v1.1/val_shared.p',
        'prepro_v1.1/test_data.p',
        'prepro_v1.1/test_shared.p',
        isTrain = True)

Loading data......
Loading data took 11.4 seconds
Loading prev state......
AttentionModel(
  (cos): CosineSimilarity()
  (softmax1): Softmax(dim=1)
  (softmax2): Softmax(dim=2)
  (rnn_q): LSTM(100, 512, num_layers=3, bidirectional=True)
  (rnn_c): LSTM(100, 512, num_layers=3, bidirectional=True)
  (rnn_desc): LSTM(3100, 512, num_layers=3, bidirectional=True)
  (rnn_ps): LSTM(2537, 512, num_layers=3, bidirectional=True)
  (vis_text): Linear(in_features=1024, out_features=1024, bias=True)
  (tanh1): Tanh()
  (CH_linear): Linear(in_features=1024, out_features=1024, bias=True)
  (img_linear): Linear(in_features=512, out_features=64, bias=True)
  (tanh2): Tanh()
  (tanh3): Tanh()
  (last_softmax): Linear(in_features=5120, out_features=1, bias=True)
)
Starting training......
TRAIN ===> Epoch 1, took time 361.4s, train accu: 0.2832, train loss: 1.374803
VALID ===> Epoch 1, took time 21.5s, valid accu: 0.2933, valid loss: 1.373256
TRAIN ===> Epoch 2, took time 364.7s, train accu: 0.2829, train

KeyboardInterrupt: 