In [None]:
!wget https://memexqa.cs.cmu.edu/fvta_model_zoo/prepro_v1.1.tgz
!gunzip prepro_v1.1.tgz
!tar -xvf prepro_v1.1.tar

--2020-12-09 02:00:35--  https://memexqa.cs.cmu.edu/fvta_model_zoo/prepro_v1.1.tgz
Resolving memexqa.cs.cmu.edu (memexqa.cs.cmu.edu)... 128.2.220.9
Connecting to memexqa.cs.cmu.edu (memexqa.cs.cmu.edu)|128.2.220.9|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 212511000 (203M) [application/x-gzip]
Saving to: ‘prepro_v1.1.tgz’


2020-12-09 02:00:53 (11.4 MB/s) - ‘prepro_v1.1.tgz’ saved [212511000/212511000]

prepro_v1.1/
prepro_v1.1/test_data.p
prepro_v1.1/train_shared.p
prepro_v1.1/test_shared.p
prepro_v1.1/train_data.p
prepro_v1.1/val_data.p
prepro_v1.1/val_shared.p


In [None]:
# new_dataset_checked_by_hongyuan.py
import pandas as pd
import numpy as np
import torch
from torch.nn.utils.rnn import *
from torch.utils.data import Dataset
import itertools

train_data = pd.read_pickle('prepro_v1.1/train_data.p')
train_shared = pd.read_pickle('prepro_v1.1/train_shared.p')
val_data = pd.read_pickle('prepro_v1.1/val_data.p')

q_types = ["when", "what", "who", "where", "how"]
qtype2qid = {}
qtype2qid["when"] = []
qtype2qid["what"] = []
qtype2qid["who"] = []
qtype2qid["where"] = []
qtype2qid["how"] = []

for i, qid in enumerate(train_data['qid']):
    if train_data['q'][i][0].lower() == "when":
        qtype2qid["when"].append(qid)
    elif train_data['q'][i][0].lower() == "what":
        qtype2qid["what"].append(qid)
    elif train_data['q'][i][0].lower() == "who":
        qtype2qid["who"].append(qid)
    elif train_data['q'][i][0].lower() == "where":
        qtype2qid["where"].append(qid)
    elif train_data['q'][i][0].lower() == "how":
        qtype2qid["how"].append(qid)

for i, qid in enumerate(val_data['qid']):
    if val_data['q'][i][0].lower() == "when":
        qtype2qid["when"].append(qid)
    elif val_data['q'][i][0].lower() == "what":
        qtype2qid["what"].append(qid)
    elif val_data['q'][i][0].lower() == "who":
        qtype2qid["who"].append(qid)
    elif val_data['q'][i][0].lower() == "where":
        qtype2qid["where"].append(qid)
    elif val_data['q'][i][0].lower() == "how":
        qtype2qid["how"].append(qid)

q_lens = [len(q) for q in train_data['q']]
cs_lens = [[len(c) for c in cs] for cs in train_data['cs']]
cs_lens = list(itertools.chain(*cs_lens))
y_lens = [len(y) for y in train_data['y']]
photo_lens = [len(train_shared['albums'][aid]['photo_ids']) for aid in train_shared['albums']]
all_photos_lens = [sum(len(train_shared['albums'][aid]['photo_ids']) for aid in aid_list) for aid_list in train_data['aid']]
pts_lens = [len(pt) for aid in train_shared['albums'] for pt in train_shared['albums'][aid]['photo_titles']]
when_lens = [len(train_shared['albums'][aid]['when']) for aid in train_shared['albums']]
where_lens = [len(train_shared['albums'][aid]['where']) for aid in train_shared['albums']]
album_title_lens = [len(train_shared['albums'][aid]['title']) for aid in train_shared['albums']]
album_desc_lens = [len(train_shared['albums'][aid]['description']) for aid in train_shared['albums']]

Q_THRES = int(np.percentile(q_lens, 90)) # 10
Y_THRES = int(np.percentile(cs_lens, 90)) # 3, same as np.percentile(y_lens, 90)
PTS_THRES = int(np.percentile(pts_lens, 90)) # 8
WHEN_THRES = int(np.percentile(when_lens, 90)) # 4
WHERE_THRES = int(np.percentile(where_lens, 90)) # 5
PHOTOS_PER_ALBUM = int(np.percentile(photo_lens, 90)) # 10
ALL_PHOTOS_THRES = max(all_photos_lens) # 72
ALBUM_TITLE_THRES = int(np.percentile(album_title_lens, 90)) # 8
ALBUM_DESC_THRES = int(np.percentile(album_desc_lens, 50)) # 11

def train_collate(batch):
    X, Y = zip(*batch)
    q_vec = []
    cs_vec = []
    desc_vec = []
    img_feats = []
    q_len = []
    cs0_len = []
    cs1_len = []
    cs2_len = []
    cs3_len = []
    desc_len = []
    img_len = []
    qid = []
    pid = []
    new_X = {}
    for x in X:
      q_len.append(x['q_len'])
      cs0_len.append(x['cs_lens'][0])
      cs1_len.append(x['cs_lens'][1])
      cs2_len.append(x['cs_lens'][2])
      cs3_len.append(x['cs_lens'][3])
      desc_len.append(x['desc_len'])
      img_len.append(x['img_len'])
      q_vec.append(x['q_vec'])
      cs_vec.append(x['cs_vec']) # x['cs_vec'] expected shape: <=Y_THRES, 4, 100
      desc_vec.append(x['desc_vec'])
      img_feats.append(x['img_feats'])
      qid.append(x['qid'])
      pid.append(x['pids'])

    new_X['q_len'] = torch.LongTensor(q_len)
    new_X['cs0_lens'] = torch.LongTensor(cs0_len)
    new_X['cs1_lens'] = torch.LongTensor(cs1_len)
    new_X['cs2_lens'] = torch.LongTensor(cs2_len)
    new_X['cs3_lens'] = torch.LongTensor(cs3_len)
    new_X['desc_len'] = torch.LongTensor(desc_len)
    new_X['img_len'] = torch.LongTensor(img_len)
    new_X['q_vec'] = pad_sequence(q_vec, batch_first=False, padding_value=0)  # T, B, 100 
    new_X['cs_vec'] = pad_sequence(cs_vec, batch_first=False, padding_value=0)  # B, <=Y_THRES, 4, 100 -> T, B, 4, 100
    new_X['desc_vec'] = pad_sequence(desc_vec, batch_first=False, padding_value=0)  # T, B, total_cat_len * 100
    new_X['img_feats'] = pad_sequence(img_feats, batch_first=False, padding_value=0)  # T, B, 2537

    return new_X, torch.LongTensor(Y), qid, pid

class MemexQA_new(Dataset):
    def __init__(self, data, shared):
        # self.data keys -> ['q', 'idxs', 'cy', 'ccs', 'qid', 'y', 'aid', 'cq', 'yidx', 'cs']
        # self.shared keys -> ['albums', 'pid2feat', 'word2vec', 'charCounter', 'wordCounter']
        self.data = data
        self.shared = shared

    def __len__(self):
        return len(self.data['q'])

    def __getitem__(self, idx):
        returned_item = {}

        returned_item['qid'] = self.data['qid'][idx]

        q = self.data['q'][idx]
        # missing glove word-> [0] * 100 embedding
        q_vec = torch.FloatTensor([self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for word in q])
        q_vec = q_vec[:Q_THRES]
        returned_item['q_vec'] = q_vec  # largest possible shape: Q_THRES * 100
        returned_item['q_len'] = q_vec.shape[0] 

        wrong_cs = self.data['cs'][idx]
        correct_c = self.data['y'][idx]
        yidx = self.data['yidx'][idx]
        if yidx == 0:
            cs = [correct_c] + wrong_cs
        elif yidx == 1:
            cs = wrong_cs[:1] + [correct_c] + wrong_cs[1:]
        elif yidx == 2:
            cs = wrong_cs[:2] + [correct_c] + wrong_cs[2:]
        else:  # yidx == 3
            cs = wrong_cs + [correct_c]
        cs_vec = [[self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for word in c] for c in cs]
        cs_vec = [torch.FloatTensor(c[:Y_THRES]) for c in cs_vec]
        cs_lens = [min(Y_THRES, len(c)) for c in cs_vec]
        returned_item['cs_vec'] = pad_sequence(cs_vec, batch_first = False)  # largest possible shape: 4, Y_THRES, 100 ->  Y_THRES, 4, 100
        returned_item['cs_lens'] = cs_lens

        # aid: description + title + when + where + photo_titles
        aid_list = self.data['aid'][idx]
        pts_descs = []  # photo-level text features
        pid_features = []  # photo-level img features from pre-trained CNN
        pids = []
        # for each album
        total_cat_len = ALBUM_TITLE_THRES + ALBUM_DESC_THRES + WHEN_THRES + PTS_THRES + WHERE_THRES  # 8 + 11 + 4 + 8 + 5 = 36
        for aid in aid_list:
            album = self.shared['albums'][aid]
            pids.extend(album['photo_ids'])
            pts = album['photo_titles']  # all photo titles/aid
            # concatenate album description, album title, album when and album where
            desc = album['description'][:ALBUM_DESC_THRES] + album['title'][:ALBUM_TITLE_THRES] + album['when'][:WHEN_THRES] + album['where'][:WHERE_THRES]
            for pt in pts:
                photo_info = desc + pt[:PTS_THRES]
                # largest possible shape: total_cat_len, 100
                photo_info_vec = [self.shared['word2vec'][word.lower()] if word.lower() in self.shared['word2vec'] else [0] * 100 for word in photo_info]
                if len(photo_info_vec) < total_cat_len:
                    photo_info_vec = photo_info_vec + [[0] * 100 for _ in range(total_cat_len - len(photo_info_vec))]  # total_cat_len, 100
                pts_descs.append(photo_info_vec)  # total number of photos (varies), total_cat_len, 100

            for pid in self.shared['albums'][aid]['photo_ids']:
                # img_feats
                pid_features.append(self.shared['pid2feat'][pid])  # total number of photos (varies) * 2537

        desc_vec = torch.FloatTensor(pts_descs).view(-1, total_cat_len * 100)  # total number of photos (varies), total_cat_len * 100
        returned_item['pids'] = pids
        returned_item['desc_vec'] = desc_vec
        returned_item['desc_len'] = desc_vec.shape[0]  # total number of photos (varies)
        img_feats_vec = torch.FloatTensor(pid_features)  # total number of photos (varies), 2537; NEWLY CHANGED (no matter what, it will vary; keep consistent with desc_vec)
        returned_item['img_feats'] = img_feats_vec
        returned_item['img_len'] = img_feats_vec.shape[0]  # total number of photos (varies)
        return returned_item, yidx

In [None]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import *
import pandas as pd

# crossEntropyLoss use ignore_index = 0


# dims fed into LSTM: seq_len, batch_size, input_size
# B, T, F
# q_vec -> B, T, 100 (glove embedding)
# cs_vec -> B, 4, T = Y_THRES, 100
# desc_vec(prev. pt) -> B, T = all_photo_titles_albums * DESC_THRESH, 100
# ps_vec -> B, T = num_of_albums * 3, 2537
class NewFusionModel(nn.Module):
    def __init__(self, q_cs_input_size, desc_input_size, img_input_size, hidden_size, batch_size,
                num_layers, device, q_linear_size, img_linear_size, multimodal_out, kernel, stride = 1, rnn_type = 'bilstm'):
        super(NewFusionModel, self).__init__()
        self.device = device
        self.hidden_size = hidden_size
        self.batch_size = batch_size
        self.q_cs_input_size = q_cs_input_size   #input_size qvec and cs_vec
        self.desc_input_size = desc_input_size
        self.img_input_size = img_input_size
        self.q_linear_size = q_linear_size # s1
        self.img_linear_size = img_linear_size # s2
        self.num_directions = 2 if rnn_type == 'bilstm' else 1
        self.multimodal_out = multimodal_out
        self.kernel = kernel
        self.stride = stride
        self.num_layers = num_layers

        if (rnn_type == 'bilstm'):
            self.rnn_q = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # questions
            self.rnn_c = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # choices
            self.rnn_desc = nn.LSTM(desc_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # photo titles
            self.rnn_ps = nn.LSTM(img_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = True) # image features
        else:
            self.rnn_q = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_c = nn.LSTM(q_cs_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_desc = nn.LSTM(desc_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)
            self.rnn_ps = nn.LSTM(img_input_size, hidden_size, self.num_layers, batch_first = False, bidirectional = False)

        self.img_linear = nn.Linear(hidden_size, self.img_linear_size) 
        self.q_linear = nn.Linear(hidden_size, self.q_linear_size)

        self.multimodal_cnn = nn.Conv1d(self.num_directions * self.num_layers, self.multimodal_out, self.kernel, self.stride)
        # 2 * hidden_size if not passing in rnn_q hidden output, 3 * hidden_size if passing in 
        multimodal_cnn_in_size = 2 * hidden_size + self.num_layers * self.num_directions * batch_size * q_linear_size // self.img_linear_size
        multimodal_cnn_out_size = (multimodal_cnn_in_size - self.kernel) // self.stride + 1
        self.output = nn.Linear(self.multimodal_out * multimodal_cnn_out_size, 1)
    
    def forward(self, X):
        # X is a list of dictionaries: 'q_vec', 'cs_vec', 'desc_vec', 'img_feats'
        # BATCH_SIZE = len(X)
        # B, T, F
        # q_vec -> B, T, 100 (glove embedding)
        # cs_vec -> B, 4, T=Y_THRES, 100
        # desc_vec(prev. pt) -> B, T=all_photo_titles_albums * 40, 100
        # img_feats -> B, T=num_of_albums * 3, 2537


        # q_vec = pad_sequence(X['q_vec'], batch_first=False, padding_value=0).to(self.device)  # question 
        # cs_vec = pad_sequence(X['cs_vec'].permute(0,2,1,3), batch_first = False, padding_value=0).to(self.device) # 4 choices T, B, 4, 100
        # desc_vec = pad_sequence(X['desc_vec'], batch_first=False, padding_value=0).to(self.device) 
        # img_feats = pad_sequence(X['img_feats'], batch_first=False, padding_value=0).to(self.device)
        q_vec = X['q_vec']
        cs_vec = X['cs_vec']
        desc_vec = X['desc_vec']
        img_feats = X['img_feats'].to(self.device)
        packed_q_vec = pack_padded_sequence(q_vec, X['q_len'], batch_first=False, enforce_sorted = False).to(self.device)
        packed_c0_vec = pack_padded_sequence(cs_vec[:, :, 0, :], X['cs0_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c1_vec = pack_padded_sequence(cs_vec[:, :, 1, :], X['cs1_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c2_vec = pack_padded_sequence(cs_vec[:, :, 2, :], X['cs2_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_c3_vec = pack_padded_sequence(cs_vec[:, :, 3, :], X['cs3_lens'], batch_first = False, enforce_sorted = False).to(self.device)
        packed_pt_vec = pack_padded_sequence(desc_vec, X['desc_len'], batch_first = False, enforce_sorted = False).to(self.device)
        #img_feats = img_feats.permute(1,0,2)  # batch_size, seq_len, input_size -> seq_len, batch_size, input_size
        
        #print("packed_c1_vec: ", packed_c1_vec.data.shape)
        # hidden dims: num_layers * num_directions, batch_size, hidden_size
        _, (lstm_hidden_q, __) = self.rnn_q(packed_q_vec)
        _, (lstm_hidden_c0, __) = self.rnn_c(packed_c0_vec)
        _, (lstm_hidden_c1, __) = self.rnn_c(packed_c1_vec)
        _, (lstm_hidden_c2, __) = self.rnn_c(packed_c2_vec)
        _, (lstm_hidden_c3, __) = self.rnn_c(packed_c3_vec)
        _, (lstm_hidden_pt, __) = self.rnn_desc(packed_pt_vec)
        _, (lstm_hidden_ps, __) = self.rnn_ps(img_feats)
        lstm_hidden_cs = [lstm_hidden_c0, lstm_hidden_c1, lstm_hidden_c2, lstm_hidden_c3]
        
        candidate_weights = self.q_linear(lstm_hidden_q) # output: (num_direction * num_layers, batch_size, self.q_linear_size)
        img_feats = self.img_linear(lstm_hidden_ps) # output: (num_direction * num_layers, batch_size, hidden_size)
        # dyanmic parameter layer
        dynamic_parameter_out = self.num_directions * self.num_layers * self.batch_size * self.q_linear_size // self.img_linear_size
        dynamic_parameter_matrix = torch.flatten(candidate_weights)[:self.img_linear_size * dynamic_parameter_out]
        dynamic_parameter_matrix = dynamic_parameter_matrix.reshape(self.img_linear_size, dynamic_parameter_out)
        q_img_fused = img_feats @ dynamic_parameter_matrix

        # multimodal cnn layer
        cnn_out_list = []
        for i in range(4):
            vec = torch.cat((q_img_fused, lstm_hidden_cs[i], lstm_hidden_pt), dim = 2).to(self.device) 
            vec = vec.permute(1, 0, 2) # batch_size, num_direction * num_layers, 2 * hidden + dynamic_parameter_out
            vec = self.multimodal_cnn(vec)
            cnn_out_list.append(vec)
        for i in range(4):
            cnn_out_list[i] = torch.flatten(cnn_out_list[i], start_dim = 1).unsqueeze(1)
        classification_input = torch.cat(cnn_out_list, dim = 1) # (batch_size, 4, out_ch * cnn_out)
        logits = self.output(classification_input)
        return logits.squeeze(2)

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import pandas as pd
import time
import os
import numpy as np
import argparse
import csv

# hyperparams
EPOCHS = 10
BATCH_SIZE = 64

# optimizer-related
MOMENTUM = 1e-2
LR = 1e-2
LR_STEPSIZE = 5
LR_DECAY = 0.85
WD = 5e-6

def main(train_data_pth, train_shared_pth, val_data_pth, val_shared_pth, test_data_pth, test_shared_pth, isTrain):
    cuda = torch.cuda.is_available()
    num_workers = 8 if cuda else 0
    print("Loading data......")
    start = time.time()

    train_shared = pd.read_pickle(train_shared_pth)
    # random initial embedding matrix for new words
    nonglove_dict = {word: np.random.normal(0, 1, 100) for word in train_shared['wordCounter'] if word not in train_shared['word2vec']}
    train_shared['word2vec'].update(nonglove_dict)
    
    val_shared = pd.read_pickle(val_shared_pth)
    val_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in val_shared['wordCounter'] if word not in val_shared['word2vec']}
    val_shared['word2vec'].update(val_nonglove_dict)

    test_shared = pd.read_pickle(test_shared_pth)
    test_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in test_shared['wordCounter'] if word not in test_shared['word2vec']}
    test_shared['word2vec'].update(test_nonglove_dict)

    train_data = MemexQA_new(data=pd.read_pickle(train_data_pth), shared=train_shared)
    valid_data = MemexQA_new(data=pd.read_pickle(val_data_pth), shared=val_shared)
    test_data = MemexQA_new(data=pd.read_pickle(test_data_pth), shared=test_shared)

    train_loader_args = dict(shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=True, batch_size=BATCH_SIZE, collate_fn=train_collate)
    train_loader = torch.utils.data.DataLoader(train_data, **train_loader_args)

    valid_loader_args = dict(shuffle=False, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, **valid_loader_args)

    test_loader_args = dict(shuffle=False, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)  # TODO: test_collate
    test_loader = torch.utils.data.DataLoader(test_data, **test_loader_args)
    print(f"Loading data took {time.time() - start:.1f} seconds")
    
    # initialize model
    device = torch.device("cuda" if cuda else "cpu")
    
    #model = NewFusionModel(100, 3600,2537,128, 2, 2, device, 64, 64, 4, 3, 1)
    #model = NewLSTMModel(100, 3600, 2537, 128, device)
    #q_cs_input_size, desc_input_size, img_input_size, hidden_size, linear_size, k1, s1, k2, s2, batch_size, device
    model = NewFusionModel(100, 3600,2537,128, 2, 2, device, 64, 64, 4, 3, 1)
    model.to(device)

    # setup optim and loss
    criterion = nn.CrossEntropyLoss()
    optimizer= optim.Adam(model.parameters(), lr = LR, weight_decay= WD)
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEPSIZE, gamma=LR_DECAY)

    # training
    for i in range(EPOCHS):
        start = time.time()
        print("Starting training......")
        model.train()
        n_correct, n_total = 0, 0
        batch_count = 0
        t_loss = 0
        q_totals = [0] * 5
        q_corrects = [0] * 5
        for j, (batch_data, batch_labels, qids, pids) in enumerate(train_loader):
            if j == len(train_loader) - 1:
                break
            batch_labels = batch_labels.long().to(device)
            optimizer.zero_grad()
            output = model(batch_data)
            loss = criterion(output, batch_labels)
            t_loss += loss.item()
            res = torch.argmax(output, 1)
            res = res.to(device)
            for i, qid in enumerate(qids):
                if qid in qtype2qid["when"]:
                    q_totals[0] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[0] += 1
                if qid in qtype2qid["what"]:
                    q_totals[1] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[1] += 1
                if qid in qtype2qid["who"]:
                    q_totals[2] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[2] += 1
                if qid in qtype2qid["where"]:
                    q_totals[3] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[3] += 1
                if qid in qtype2qid["how"]:
                    q_totals[4] += 1
                    if batch_labels[i] == res[i]:
                        q_corrects[4] += 1
            n_correct += (res == batch_labels).sum().item()
            n_total += batch_labels.shape[0]
            batch_count += 1
            loss.backward()
            optimizer.step()
            #if batch_count % 20 == 19:
            #    print(f"correct choice:{batch_labels[:3]} , predicted choice: {res[:3]}")
        train_acc = n_correct / n_total
        train_loss = t_loss / batch_count
        print(f"TRAIN ===> Epoch {i}, took time {time.time()-start:.1f}s, train accu: {train_acc:.4f}, train loss: {train_loss:.6f}")
        train_acc_q_type = np.array(q_corrects) /np.array(q_totals)
        for i in range(5):  # 5 types of questions
            print("TRAIN ACC ", q_types[i], ": ", train_acc_q_type[i])
        #scheduler.step()
        
        # validate and save model 
        print("Start validation......")
        start = time.time()
        with torch.no_grad():
            model.eval()            
            valid_correct, loss, num_of_batches, num_of_val = 0, 0, 0, 0
            # validation for classification
            q_totals = [0] * 5
            q_corrects = [0] * 5
            for j, (vb_data, vb_label, qids, pids) in enumerate(valid_loader):
                if j == len(valid_loader) - 1:
                    break
                vb_label = vb_label.long().to(device)
                v_output = model(vb_data)
                resm = torch.argmax(v_output, 1)
                resm = resm.to(device)
                for i, qid in enumerate(qids):
                    if qid in qtype2qid["when"]:
                        q_totals[0] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[0] += 1
                    if qid in qtype2qid["what"]:
                        q_totals[1] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[1] += 1
                    if qid in qtype2qid["who"]:
                        q_totals[2] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[2] += 1
                    if qid in qtype2qid["where"]:
                        q_totals[3] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[3] += 1
                    if qid in qtype2qid["how"]:
                        q_totals[4] += 1
                        if vb_label[i] == resm[i]:
                            q_corrects[4] += 1
                correct = (resm == vb_label).sum().item()
                valid_correct += correct
                loss += criterion(v_output, vb_label).item()
                num_of_batches += 1
                num_of_val += vb_label.shape[0]
                #if num_of_batches % 20 == 19:
                #    print(f"correct choice:{vb_label[:3]} , predicted choice: {resm[:3]}")
            val_loss = loss / num_of_batches
            val_accu = valid_correct / num_of_val
        print(f"VALID ===> Epoch {i}, took time {time.time()-start:.1f}s, valid accu: {val_accu:.4f}, valid loss: {val_loss:.6f}")
        train_acc_q_type = np.array(q_corrects) /np.array(q_totals)
        for i in range(5):  # 5 types of questions
            print("VALID ACC ", q_types[i], ": ", train_acc_q_type[i])
        
        snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
        if not os.path.exists(snapshot_prefix):
            os.makedirs(snapshot_prefix)
        torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    #'scheduler_state_dict' : scheduler.state_dict(),
        }, snapshot_prefix + "Model_"+str(i))
    
    # testing
    if not isTrain:
        print("Start testing......")
        start = time.time()
        model.eval()
        with torch.no_grad(), open('test_predictions.csv', 'w') as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow(["predict","actual"])
            for (tbatch_data, tbatch_data_labels) in test_loader:
                test_out = model(tbatch_data)
                predict = torch.argmax(test_out, axis=1)
                correct = (predict == tbatch_data_labels).sum().item()
                for (pred, actual) in zip(predict, correct):
                    writer.writerow([pred, actual])
        print(f"Testing took {time.time()-start:.1f}s")
    print("Finished")
                
                                                    
    
main('prepro_v1.1/train_data.p', 'prepro_v1.1/train_shared.p', 
     'prepro_v1.1/val_data.p', 'prepro_v1.1/val_shared.p', 
     'prepro_v1.1/test_data.p', 'prepro_v1.1/test_shared.p',
     isTrain = True)

Loading data......
Loading data took 13.8 seconds
Starting training......
TRAIN ===> Epoch 63, took time 82.9s, train accu: 0.4071, train loss: 1.249992
TRAIN ACC  when :  0.3073192239858907
TRAIN ACC  what :  0.4613916363005379
TRAIN ACC  who :  0.23581560283687944
TRAIN ACC  where :  0.24806576402321084
TRAIN ACC  how :  0.7062464828362408
Start validation......
VALID ===> Epoch 63, took time 21.0s, valid accu: 0.4422, valid loss: 1.213766
VALID ACC  when :  0.36212624584717606
VALID ACC  what :  0.5031667839549613
VALID ACC  who :  0.2887700534759358
VALID ACC  where :  0.26582278481012656
VALID ACC  how :  0.7058823529411765
Starting training......
TRAIN ===> Epoch 63, took time 82.8s, train accu: 0.4639, train loss: 1.166584
TRAIN ACC  when :  0.3237885462555066
TRAIN ACC  what :  0.5576756287944492
TRAIN ACC  who :  0.2829297105729474
TRAIN ACC  where :  0.28267182962245885
TRAIN ACC  how :  0.7220969560315671
Start validation......
VALID ===> Epoch 63, took time 21.1s, valid acc

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
import pandas as pd
import time
import os
import numpy as np
import argparse
import csv

# hyperparams
EPOCHS = 10
BATCH_SIZE = 64

# optimizer-related
MOMENTUM = 1e-2
LR = 1e-2
LR_STEPSIZE = 3
LR_DECAY = 0.85
WD = 5e-6



def main(train_data_pth, train_shared_pth, val_data_pth, val_shared_pth, test_data_pth, test_shared_pth, isTrain):
    cuda = torch.cuda.is_available()
    num_workers = 8 if cuda else 0
    print("Loading data......")
    start = time.time()
    # load data
    # train_shared = pd.read_pickle('prepro_v1.1/train_shared.p')
    train_shared = pd.read_pickle(train_shared_pth)
    # random initial embedding matrix for new words
    nonglove_dict = {word: np.random.normal(0, 1, 100) for word in train_shared['wordCounter'] if word not in train_shared['word2vec']}
    train_shared['word2vec'].update(nonglove_dict)
    
    val_shared = pd.read_pickle(val_shared_pth)
    val_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in val_shared['wordCounter'] if word not in val_shared['word2vec']}
    val_shared['word2vec'].update(val_nonglove_dict)

    test_shared = pd.read_pickle(test_shared_pth)
    test_nonglove_dict = {word: np.random.normal(0, 1, 100) for word in test_shared['wordCounter'] if word not in test_shared['word2vec']}
    test_shared['word2vec'].update(test_nonglove_dict)

    # train_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/train_data.p'), shared=train_shared, info=None)
    # valid_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/val_data.p'), shared=val_shared, info=None)
    # test_data = MemexQA_new(data=pd.read_pickle('prepro_v1.1/test_data.p'), shared=test_shared, info=None)
    train_data = MemexQA_new(data=pd.read_pickle(train_data_pth), shared=train_shared)
    valid_data = MemexQA_new(data=pd.read_pickle(val_data_pth), shared=val_shared)
    test_data = MemexQA_new(data=pd.read_pickle(test_data_pth), shared=test_shared)

    # random initial embedding matrix for new words
    # config.emb_mat = np.array([idx2vec_dict[idx] if idx2vec_dict.has_key(idx) 
    # else np.random.multivariate_normal(np.zeros(config.word_emb_size), np.eye(config.word_emb_size)) 
    # for idx in xrange(config.word_vocab_size)],dtype="float32") 

    train_loader_args = dict(shuffle=True, batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=True, batch_size=BATCH_SIZE, collate_fn=train_collate)
    train_loader = torch.utils.data.DataLoader(train_data, **train_loader_args)

    valid_loader_args = dict(batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    valid_loader = torch.utils.data.DataLoader(valid_data, **valid_loader_args)

    test_loader_args = dict(batch_size=BATCH_SIZE, num_workers=num_workers, pin_memory=True, collate_fn=train_collate) if cuda\
        else dict(shuffle=False, batch_size=BATCH_SIZE, collate_fn=train_collate)
    test_loader = torch.utils.data.DataLoader(test_data, **test_loader_args)
    print(f"Loading data took {time.time() - start:.1f} seconds")
    
    # initialize model
    device = torch.device("cuda" if cuda else "cpu")
    #model = SimpleLSTMModel(100, 64, hyperparams['batch_size'], 2, device)
    # input_size, hidden_size, batch_size, num_layers, device, q_linear_size, img_linear_size, multimodal_out, kernel, stride
    
    model = NewFusionModel(100, 3100,2537,128, 2, 2, device, 64, 64, 4, 3, 1)
    model.to(device)

    # setup optim and loss

    criterion = nn.CrossEntropyLoss()
    #optimizer= optim.SGD(model.parameters(), momentum=hyperparams['momentum'], lr = hyperparams['lr'], weight_decay= hyperparams['weight_decay'])
    optimizer= optim.Adam(model.parameters(), lr = LR, weight_decay= WD)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=LR_STEPSIZE, gamma=LR_DECAY)

    # training
    print("Starting training......")
    for i in range(EPOCHS):
        start = time.time()
        model.train()
        n_correct,n_total = 0, 0
        batch_count = 0
        t_loss = 0
        for j, (batch_data, batch_labels) in enumerate(train_loader):
            batch_labels = batch_labels.long().to(device)
            if j == len(train_loader) - 1:
                break
            optimizer.zero_grad()
            output = model(batch_data)
            loss = criterion(output, batch_labels)
            t_loss += loss.item()
            res = torch.argmax(output, 1)
            res = res.to(device)
            n_correct += (res == batch_labels).sum().item()
            n_total += batch_labels.shape[0]
            batch_count += 1
            loss.backward()
            optimizer.step()
            if batch_count % 20 == 19:
                print(f"correct choice:{batch_labels[:3]} , predicted choice: {res[:3]}")
        train_acc = n_correct / n_total
        train_loss = t_loss / batch_count
        print(f"TRAIN ===> Epoch {i}, took time {time.time()-start:.1f}s, train accu: {train_acc:.4f}, train loss: {train_loss:.6f}")
        scheduler.step()
        
        # validate and save model 
        print("Start validation......")
        start = time.time()
        with torch.no_grad():
            model.eval()            
            valid_correct, loss, num_of_batches, num_of_val = 0, 0, 0, 0
            # validation for classification
            for (vb_data, vb_label) in valid_loader:
                vb_label = vb_label.long().to(device)
                v_output = model(vb_data)
                resm = torch.argmax(v_output, axis=1)
                resm = resm.to(device)
                correct = (resm == vb_label).sum().item()
                valid_correct += correct
                loss += criterion(v_output, vb_label).item()
                num_of_batches += 1
                num_of_val += vb_label.shape[0]
                if num_of_batches % 20 == 19:
                    print(f"correct choice:{vb_label[:3]} , predicted choice: {resm[:3]}")
            val_loss = loss / num_of_batches
            val_accu = valid_correct / num_of_val
        print(f"VALID ===> Epoch {i}, took time {time.time()-start:.1f}s, valid accu: {val_accu:.4f}, valid loss: {val_loss:.6f}")
        
        snapshot_prefix = os.path.join(os.getcwd(), 'snapshot/')
        if not os.path.exists(snapshot_prefix):
            os.makedirs(snapshot_prefix)
        torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict' : scheduler.state_dict(),
        }, snapshot_prefix + "Model_"+str(i))
    
    # testing
    if not isTrain:
        print("Start testing......")
        start = time.time()
        model.eval()
        with torch.no_grad(), open('test_predictions.csv', 'w') as f:
            writer = csv.writer(f, delimiter=',')
            writer.writerow(["predict","actual"])
            for (tbatch_data, tbatch_data_labels) in test_loader:
                test_out = model(tbatch_data)
                predict = torch.argmax(test_out, axis=1)
                correct = (predict == tbatch_data_labels).sum().item()
                for (pred, actual) in zip(predict, correct):
                    writer.writerow([pred, actual])
        print(f"Testing took {time.time()-start:.1f}s")
    print("Finished")
                
                                                    
    
if __name__ == '__main__':
    main('prepro_v1.1/train_data.p',
        'prepro_v1.1/train_shared.p',
        'prepro_v1.1/val_data.p',
        'prepro_v1.1/val_shared.p',
        'prepro_v1.1/test_data.p',
        'prepro_v1.1/test_shared.p',
        isTrain = True)
    # parser = argparse.ArgumentParser(description='Get the train-val-test dataset files')
    
    # parser.add_argument("-td" , "--train_data_pth", help="Enter train data path", type=str)
    # parser.add_argument("-tds", "--train_shared_pth", help="Enter train_shared data path", type=str)
    # parser.add_argument("-vd", "--val_data_pth", help="Enter val data path", type=str)
    # parser.add_argument("-vds", "--val_shared_pth", help="Enter val_shared data path", type=str)
    # parser.add_argument("-test", "--test_data_pth", help="Enter test data path", type=str)
    # parser.add_argument("-test_shared", "--test_shared_pth", help="Enter test_shared data path", type=str)
    # # parser.add_argument("-album", "--album_data_pth", help="Enter album_json data path", type=str)
    # parser.add_argument("isTrain", help="Set True if model is training", type=bool)
    
    # args = parser.parse_args()
    
    # main(args.train_data_pth,
    #     args.train_shared_pth,
    #     args.val_data_pth,
    #     args.val_shared_pth,
    #     args.test_data_pth,
    #     args.test_shared_pth,
    #     isTrain = args.isTrain)
