<a href="https://colab.research.google.com/github/vrushaligirkar/NLP-Project-QA-system/blob/master/DMN_bAbI_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
#REF: https://github.com/jhyuklee/dmn-pytorch
import os
import numpy as np
import sys
import string
import pprint
import copy
import pickle
import math
import nltk
nltk.download('punkt')


import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from datetime import datetime

!pip install tensorboardX
from tensorboardX import SummaryWriter

from google.colab import drive

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
class Dataset(object):
    def __init__(self, config):
        self.config = config
        self.init_settings()
        self.init_dict()
        self.build_word_dict(self.config.data_dir)
        self.get_pretrained_word(self.config.word2vec_path)
        self.process_data(self.config.data_dir)

    def init_settings(self):
        self.dataset = {}
        self.train_ptr = 0
        self.valid_ptr = 0
        self.test_ptr = 0

    def init_dict(self):
        self.PAD = 'PAD'
        self.word2idx = {}
        self.idx2word = {}
        self.idx2vec = []  # pretrained
        self.word2idx[self.PAD] = 0
        self.idx2word[0] = self.PAD
        self.init_word_dict = {}
    
    def update_word_dict(self, key):
        if key not in self.word2idx:
            self.word2idx[key] = len(self.word2idx)
            self.idx2word[len(self.idx2word)] = key

    def map_dict(self, key_list, dictionary):
        output = []
        for key in key_list:
            assert key in dictionary
            if key in dictionary:
                output.append(dictionary[key])
        return output
    
    def build_word_dict(self, dir):
        print('### building word dict %s' % dir)
        for subdir, _, files, in os.walk(dir):
            print('num of files:',len(sorted(files)))
            for file in sorted(files):
                with open(os.path.join(subdir, file)) as f:
                    for line_idx, line in enumerate(f):
                        line = line[:-1]
                        story_idx = int(line.split(' ')[0])

                        def update_init_dict(split):
                            for word in split:
                                if word not in self.init_word_dict:
                                    self.init_word_dict[word] = (
                                            len(self.init_word_dict), 1)
                                else:
                                    self.init_word_dict[word] = (
                                            self.init_word_dict[word][0],
                                            self.init_word_dict[word][1] + 1)

                        if '\t' in line: # question
                            question, answer, _ = line.split('\t')
                            question = ' '.join(question.split(' ')[1:])
                            q_split = nltk.word_tokenize(question)
                            if self.config.word2vec_type == 6:
                                q_split = [w.lower() for w in q_split]
                            update_init_dict(q_split)

                            answer = answer.split(',') if ',' in answer else [answer]
                            if self.config.word2vec_type == 6:
                                answer = [w.lower() for w in answer]
                            update_init_dict(answer)
                            # TODO: check vocab
                            """
                            for a in answer: 
                                if a not in self.init_word_dict:
                                    print(a)
                            """
                        else: # story
                            story_line = ' '.join(line.split(' ')[1:])
                            s_split = nltk.word_tokenize(story_line)
                            if self.config.word2vec_type == 6:
                                s_split = [w.lower() for w in s_split]
                            update_init_dict(s_split)

        print('init dict size', len(self.init_word_dict))
        # print(self.init_word_dict)

    def get_pretrained_word(self, path):
        print('\n### loading pretrained %s' % path)
        word2vec = {}
        with open(path, 'r', encoding='utf-8', errors='ignore') as f:
            while True:
                try:
                    line = f.readline()
                    if not line: break
                    word = line.split()[0]
                    vec = [float(l) for l in line.split()[1:]]
                    word2vec[word] = vec
                except ValueError as e:
                    print(e)
        
        unk_cnt = 0
        self.idx2vec.append([0.0] * self.config.word_embed_dim) # PAD

        for word, (word_idx, word_cnt) in self.init_word_dict.items():
            if word != 'UNK' and word !='PAD':
                assert word_cnt > 0
                if word in word2vec:
                    self.update_word_dict(word)
                    self.idx2vec.append(word2vec[word])
                else:
                    unk_cnt += 1
        print('len(word2idx):',len(self.word2idx))
        print('len(word2vec):',len(word2vec['apple'][:5]))
        print('apple:', self.word2idx['apple'], word2vec['apple'][:5])
        print('apple:', self.idx2vec[self.word2idx['apple']][:5])
        print('pretrained vectors', np.asarray(self.idx2vec).shape, 'unk', unk_cnt)
        print('dictionary change', len(self.init_word_dict), 
                'to', len(self.word2idx), len(self.idx2word))

    def process_data(self, dir):
        print('\n### processing %s' % dir)
        for subdir, _, files, in os.walk(dir):
            for file in sorted(files):
                with open(os.path.join(subdir, file)) as f:
                    max_sentnum = max_slen = max_qlen = 0
                    qa_num = file.split('_')[0][2:]
                    set_type = file.split('_')[-1][:-4]
                    story_list = []
                    sf_cnt = 1
                    si2sf = {}
                    total_data = []

                    for line_idx, line in enumerate(f):
                        line = line[:-1]
                        story_idx = int(line.split(' ')[0])
                        if story_idx == 1: 
                            story_list = []
                            sf_cnt = 1
                            si2sf = {}

                        if '\t' in line: # question
                            question, answer, sup_fact = line.split('\t')
                            question = ' '.join(question.split(' ')[1:])
                            q_split = nltk.word_tokenize(question)
                            if self.config.word2vec_type == 6:
                                q_split = [w.lower() for w in q_split]
                            q_split = self.map_dict(q_split, self.word2idx)

                            answer = answer.split(',') if ',' in answer else [answer]
                            if self.config.word2vec_type == 6:
                                answer = [w.lower() for w in answer]
                            answer = self.map_dict(answer, self.word2idx)
                            sup_fact = [si2sf[int(sf)] for sf in sup_fact.split()]

                            sentnum = story_list.count(self.word2idx['.'])
                            max_sentnum = max_sentnum if max_sentnum > sentnum \
                                    else sentnum
                            max_slen = max_slen if max_slen > len(story_list) \
                                    else len(story_list)
                            max_qlen = max_qlen if max_qlen > len(q_split) \
                                    else len(q_split)

                            story_tmp = story_list[:]
                            total_data.append([story_tmp, q_split, answer, sup_fact])

                        else: # story
                            story_line = ' '.join(line.split(' ')[1:])
                            s_split = nltk.word_tokenize(story_line)
                            if self.config.word2vec_type == 6:
                                s_split = [w.lower() for w in s_split]
                            s_split = self.map_dict(s_split, self.word2idx)
                            story_list += s_split
                            si2sf[story_idx] = sf_cnt
                            sf_cnt += 1

                    self.dataset[str(qa_num) + '_' + set_type] = total_data
                    def check_update(d, k, v):
                        if k in d:
                            d[k] = v if v > d[k] else d[k]
                        else:
                            d[k] = v
                    check_update(self.config.max_sentnum, int(qa_num), max_sentnum)
                    check_update(self.config.max_slen, int(qa_num), max_slen)
                    check_update(self.config.max_qlen, int(qa_num), max_qlen)
                    self.config.word_vocab_size = len(self.word2idx)

        print('data size', len(total_data))
        print('max sentnum', max_sentnum)
        print('max slen', max_slen)
        print('max qlen', max_qlen, end='\n\n')

    def pad_sent_word(self, sentword, maxlen):
        while len(sentword) != maxlen:
            sentword.append(self.word2idx[self.PAD])

    def pad_data(self, dataset, set_num):
        for data in dataset:
            story, question, _, _ = data
            self.pad_sent_word(story, self.config.max_slen[set_num])
            self.pad_sent_word(question, self.config.max_qlen[set_num])

        return dataset
    
    def get_next_batch(self, mode='tr', set_num=1, batch_size=None):
        if batch_size is None:
            batch_size = self.config.batch_size
        
        if mode == 'tr':
            ptr = self.train_ptr
            data = self.dataset[str(set_num) + '_train']
        elif mode == 'va':
            ptr = self.valid_ptr
            data = self.dataset[str(set_num) + '_valid']
        elif mode == 'te':
            ptr = self.test_ptr
            data = self.dataset[str(set_num) + '_test']
        
        batch_size = (batch_size if ptr+batch_size<=len(data) else len(data)-ptr)
        padded_data = self.pad_data(copy.deepcopy(data[ptr:ptr+batch_size]), set_num)
        stories = [d[0] for d in padded_data]
        questions = [d[1] for d in padded_data]
        answers = [d[2] for d in padded_data]
        if len(np.array(answers).shape) < 2:
            for answer in answers:
                while len(answer) != self.config.max_alen:
                    answer.append(-100)
        sup_facts = [d[3] for d in padded_data]
        for sup_fact in sup_facts:
            while len(sup_fact) < self.config.max_episode:
                sup_fact.append(self.config.max_sentnum[set_num]+1)
        s_lengths = [[idx+1 for idx, val in enumerate(d[0]) 
            if val == self.word2idx['.']] for d in padded_data]
        e_lengths = []
        for s_len in s_lengths:
            e_lengths.append(len(s_len))
            while len(s_len) != self.config.max_sentnum[set_num]:
                s_len.append(0)
        q_lengths = [[idx+1 for idx, val in enumerate(d[1]) 
            if val == self.word2idx['?']][0] for d in padded_data]
        
        if mode == 'tr':
            self.train_ptr = (ptr + batch_size) % len(data)
        elif mode == 'va':
            self.valid_ptr = (ptr + batch_size) % len(data)
        elif mode == 'te':
            self.test_ptr = (ptr + batch_size) % len(data)

        return (stories, questions, answers, sup_facts, 
                s_lengths, q_lengths, e_lengths)
    
    def get_batch_ptr(self, mode):
        if mode == 'tr':
            return self.train_ptr
        elif mode == 'va':
            return self.valid_ptr
        elif mode == 'te':
            return self.test_ptr

    def get_dataset_len(self, mode, set_num):
        if mode == 'tr':
            return len(self.dataset[str(set_num) + '_train'])
        elif mode == 'va':
            return len(self.dataset[str(set_num) + '_valid'])
        elif mode == 'te':
            return len(self.dataset[str(set_num) + '_test'])

    def init_batch_ptr(self, mode=None):
        if mode is None:
            self.train_ptr = 0
            self.valid_ptr = 0
            self.test_ptr = 0
        elif mode == 'tr':
            self.train_ptr = 0
        elif mode == 'va':
            self.valid_ptr = 0
        elif mode == 'te':
            self.test_ptr = 0

    def shuffle_data(self, mode='tr', set_num=1, seed=None):
        if seed is not None:
            np.random.seed(seed)
        if mode == 'tr':
            np.random.shuffle(self.dataset[str(set_num) + '_train'])
        elif mode == 'va':
            np.random.shuffle(self.dataset[str(set_num) + '_train'])
        elif mode == 'te':
            np.random.shuffle(self.dataset[str(set_num) + '_test'])

    def decode_data(self, s, q, a, sf, l):
        print(l)
        print('story:', 
                ' '.join(self.map_dict(s[:l[-1]], self.idx2word)))
        print('question:', ' '.join(self.map_dict(q, self.idx2word)))
        print('answer:', self.map_dict(a, self.idx2word))
        print('supporting fact:', sf)
        print('length of sentences:', l)

    
class Config(object):
    def __init__(self):
        self.path = '/content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch'
        self.data_dir = self.path + '/data/bAbI/en'
        self.word2vec_type = 6  # 6 or 840 (B)
        self.word2vec_path = self.path + '/data/glove/glove.'\
                + str(self.word2vec_type) + 'B.300d.txt'
        self.word_embed_dim = 300
        self.batch_size = 32
        self.max_sentnum = {}
        self.max_slen = {}
        self.max_qlen = {}
        self.max_episode = 5
        self.word_vocab_size = 0
        self.save_preprocess = True
        self.preprocess_save_path = self.path + '/data/bAbI/babi(tmp).pkl'
        self.preprocess_load_path = self.path + '/data/bAbI/babi(10k).pkl'

In [7]:
config = Config()
if config.save_preprocess:
  dataset = Dataset(config)
  pickle.dump(dataset, open(config.preprocess_save_path, 'wb'))
else:
  print('## load preprocess %s' % config.preprocess_load_path)
  dataset = pickle.load(open(config.preprocess_load_path, 'rb'))
   
# dataset config must be valid
pp = lambda x: pprint.PrettyPrinter().pprint(x)
pp(([(k,v) for k, v in vars(dataset.config).items() if '__' not in k]))
print()
   
for set_num in range(1):
  """
        mode = 'tr'
        while True:
            i, t, l = dataset.get_next_batch(mode, set_num+1, batch_size=1000)
            print(dataset.get_batch_ptr(mode), len(i))
            if dataset.get_batch_ptr(mode) == 0:
                print('iteration test pass!', mode)
                break
        mode = 'va'
        while True:
            i, t, l = dataset.get_next_batch(mode, set_num+1, batch_size=100)
            print(dataset.get_batch_ptr(mode), len(i))
            if dataset.get_batch_ptr(mode) == 0:
                print('iteration test pass!', mode)
                break
        """
  mode = 'te'
  dataset.shuffle_data(mode, set_num+1)
  while True:
    s, q, a, sf, sl, ql, el = dataset.get_next_batch(mode, set_num+1, batch_size=100)
    print(dataset.get_batch_ptr(mode), len(s))
    if dataset.get_batch_ptr(mode) == 0:
      print(s[0], q[0], a[0], sf[0], sl[0], ql[0], el[0])
      dataset.decode_data(s[0], q[0], a[0], sf[0], sl[0][:el[0]])
      print('iteration test pass!', mode)
      break

### building word dict /content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch/data/bAbI/en
num of files: 60
init dict size 158

### loading pretrained /content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch/data/glove/glove.6B.300d.txt
len(word2idx): 159
len(word2vec): 5
apple: 131 [-0.20842, -0.019668, 0.063981, -0.71403, -0.21181]
apple: [-0.20842, -0.019668, 0.063981, -0.71403, -0.21181]
pretrained vectors (159, 300) unk 0
dictionary change 158 to 159 159

### processing /content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch/data/bAbI/en
data size 1000
max sentnum 10
max slen 72
max qlen 6

[('path', '/content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch'),
 ('data_dir',
  '/content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch/data/bAbI/en'),
 ('word2vec_type', 6),
 ('word2vec_path',
  '/content/gdrive/My Drive/Colab '
  'Notebooks/DMN_bAbI_pytorch/data/glove/glove.6B.300d.txt'),
 ('word_embed_dim', 300),
 ('batch_size', 32),
 ('max_sentnum',
  {1: 10,
   2: 88,
   3

In [0]:
class Args:
  def __init__(self):
    self.path = '/content/gdrive/My Drive/Colab Notebooks/DMN_bAbI_pytorch'
    self.data_path = self.path + '/data/bAbI/babi(tmp).pkl'
    self.model_name = 'm'
    self.checkpoint_dir = self.path + '/results/'
    self.batch_size = 32
    self.epoch = 100
    self.train = 1
    self.valid = 1
    self.test = 1
    self.early_stop = 0
    self.resume = False
    self.save = False
    self.print_step = 128

    # model hyperparameters
    self.lr = 0.0003
    self.lr_decay = 1.0
    self.wd = 0
    self.grad_max_norm = 5
    self.s_rnn_hdim = 100
    self.s_rnn_ln = 1
    self.s_rnn_dr = 0.0
    self.q_rnn_hdim = 100
    self.q_rnn_ln = 1
    self.q_rnn_dr = 0.0
    self.e_cell_hdim = 100
    self.m_cell_hdim = 100
    self.a_cell_hdim = 100
    self.word_dr = 0.2
    self.g1_dim = 500
    self.max_episode = 10
    self.beta_cnt = 10
    self.set_num = 1 # change this to cover all the bAbI tasks from 1 to 20
    self.max_alen = 2

In [0]:
args = Args()
dataset = pickle.load(open(args.data_path,'rb'))

In [10]:
# update args
dataset.config.__dict__.update(args.__dict__)
args.__dict__.update(dataset.config.__dict__)
pp = lambda x: pprint.PrettyPrinter().pprint(x)
pp(args.__dict__)

{'a_cell_hdim': 100,
 'batch_size': 32,
 'beta_cnt': 10,
 'checkpoint_dir': '/content/gdrive/My Drive/Colab '
                   'Notebooks/DMN_bAbI_pytorch/results/',
 'data_dir': '/content/gdrive/My Drive/Colab '
             'Notebooks/DMN_bAbI_pytorch/data/bAbI/en',
 'data_path': '/content/gdrive/My Drive/Colab '
              'Notebooks/DMN_bAbI_pytorch/data/bAbI/babi(tmp).pkl',
 'e_cell_hdim': 100,
 'early_stop': 0,
 'epoch': 100,
 'g1_dim': 500,
 'grad_max_norm': 5,
 'lr': 0.0003,
 'lr_decay': 1.0,
 'm_cell_hdim': 100,
 'max_alen': 2,
 'max_episode': 10,
 'max_qlen': {1: 4,
              2: 5,
              3: 8,
              4: 7,
              5: 8,
              6: 6,
              7: 7,
              8: 5,
              9: 6,
              10: 6,
              11: 4,
              12: 4,
              13: 4,
              14: 7,
              15: 6,
              16: 5,
              17: 12,
              18: 10,
              19: 11,
              20: 8},
 'max_sentnum': {

In [0]:
def progress(_progress):
    bar_length = 5  # Modify this to change the length of the progress bar
    status = ""
    if isinstance(_progress, int):
        _progress = float(_progress)
    if not isinstance(_progress, float):
        _progress = 0
        status = "error: progress var must be float\r\n"
    if _progress < 0:
        _progress = 0
        status = "Halt...\r\n"
    if _progress >= 1:
        _progress = 1
        status = ""
    block = int(round(bar_length * _progress))
    text = "\r\t[%s]\t%.2f%% %s" % (
            "#" * block + " " * (bar_length-block), _progress * 100, status)

    return text

In [0]:
def run_epoch(m, d, ep, mode='tr', set_num=1, is_train=True):
    total_metrics = np.zeros(2)
    total_step = 0.0
    print_step = m.config.print_step
    start_time = datetime.now()
    d.shuffle_data(seed=None, mode='tr')

    while True:
        m.optimizer.zero_grad()
        stories, questions, answers, sup_facts, s_lens, q_lens, e_lens= \
                d.get_next_batch(mode, set_num)
        #d.decode_data(stories[0], questions[0], answers[0], sup_facts[0], s_lens[0])
        wrap_tensor = lambda x: torch.LongTensor(np.array(x))
        wrap_var = lambda x: Variable(wrap_tensor(x)).cuda()
        stories = wrap_var(stories)
        questions = wrap_var(questions)
        answers = wrap_var(answers)
        sup_facts = wrap_var(sup_facts) - 1
        s_lens = wrap_tensor(s_lens)
        q_lens = wrap_tensor(q_lens)
        e_lens = wrap_tensor(e_lens)

        if is_train: m.train()
        else: m.eval()
        outputs, gates = m(stories, questions, s_lens, q_lens, e_lens)
        a_loss = m.criterion(outputs[:,0,:], answers[:,0])
        if answers.size(1) > 1: # multiple answer
            for ans_idx in range(m.config.max_alen):
                a_loss += m.criterion(outputs[:,ans_idx,:], answers[:,ans_idx])
        for episode in range(5):
            if episode == 0:
                g_loss = m.criterion(gates[:,episode,:], sup_facts[:,episode]) 
            else:
                g_loss += m.criterion(gates[:,episode,:], sup_facts[:,episode])
        beta = 0 if ep < m.config.beta_cnt and mode == 'tr' else 1
        alpha = 1
        metrics = m.get_metrics(outputs, answers, multiple=answers.size(1)>1)
        total_loss = alpha * g_loss + beta * a_loss

        if is_train:
            total_loss.backward()
            nn.utils.clip_grad_norm(m.parameters(), m.config.grad_max_norm)
            m.optimizer.step()

        total_metrics[0] += total_loss.data
        total_metrics[1] += metrics
        total_step += 1.0
        
        # print step
        if d.get_batch_ptr(mode) % print_step == 0 or total_step == 1:
            et = int((datetime.now() - start_time).total_seconds())
            _progress = progress(
                    d.get_batch_ptr(mode) / d.get_dataset_len(mode, set_num))
            if d.get_batch_ptr(mode) == 0:
                _progress = progress(1)
            _progress += '[%s] time: %s' % (
                    '\t'.join(['{:.2f}'.format(k) 
                    for k in total_metrics / total_step]),
                    '{:2d}:{:2d}:{:2d}'.format(et//3600, et%3600//60, et%60))
            sys.stdout.write(_progress)
            sys.stdout.flush()

            # end of an epoch
            if d.get_batch_ptr(mode) == 0:
                et = (datetime.now() - start_time).total_seconds()
                print('\n\ttotal metrics:\t%s' % ('\t'.join(['{:.2f}'.format(k)
                    for k in total_metrics / total_step]))) 
                break

    return total_metrics / total_step

In [0]:
class DMN(nn.Module):
    def __init__(self, config, idx2vec, set_num):
        super(DMN, self).__init__()
        self.config = config
        self.set_num = set_num

        # embedding layers
        self.word_embed = nn.Embedding(config.word_vocab_size, config.word_embed_dim,
                padding_idx=0)
        
        # dimensions according to settings
        self.s_rnn_idim = config.word_embed_dim
        self.q_rnn_idim = config.word_embed_dim
        self.e_cell_idim = config.s_rnn_hdim
        self.m_cell_idim = config.e_cell_hdim
        self.a_cell_idim = config.q_rnn_hdim + config.word_vocab_size
        # self.z_dim = config.s_rnn_hdim * 7 + 2
        self.z_dim = config.s_rnn_hdim * 4

        # rnn layers
        self.s_rnn = nn.GRU(self.s_rnn_idim, config.s_rnn_hdim, batch_first=True)
        self.q_rnn = nn.GRU(self.q_rnn_idim, config.q_rnn_hdim, batch_first=True)
        self.e_cell = nn.GRUCell(self.e_cell_idim, config.e_cell_hdim)
        self.m_cell = nn.GRUCell(self.m_cell_idim, config.m_cell_hdim)
        self.a_cell = nn.GRUCell(self.a_cell_idim, config.a_cell_hdim)

        # linear layers
        # self.z_sq = nn.Linear(config.s_rnn_hdim, config.q_rnn_hdim, bias=False)
        # self.z_sm = nn.Linear(config.s_rnn_hdim, config.m_cell_hdim, bias=False)
        self.out = nn.Linear(config.m_cell_hdim, 
                config.word_vocab_size, bias=False)
        self.g1 = nn.Linear(self.z_dim, config.g1_dim)
        self.g2 = nn.Linear(config.g1_dim, 1)

        # initialization
        self.init_word_embed(idx2vec)
        params = self.model_params(debug=False)
        self.optimizer = optim.Adam(params, lr=config.lr)
        self.criterion = nn.CrossEntropyLoss()

    def init_word_embed(self, idx2vec):
        self.word_embed.weight.data.copy_(torch.from_numpy(np.array(idx2vec)))
        self.word_embed.weight.requires_grad = False

    def model_params(self, debug=True):
        print('model parameters: ', end='')
        params = []
        total_size = 0
        def multiply_iter(p_list):
            out = 1
            for p in p_list:
                out *= p
            return out

        for p in self.parameters():
            if p.requires_grad:
                params.append(p)
                total_size += multiply_iter(p.size())
            if debug:
                print(p.requires_grad, p.size())
        print('%s\n' % '{:,}'.format(total_size))
        return params
    
    def init_rnn_h(self, batch_size):
        return Variable(torch.zeros(
            self.config.s_rnn_ln*1, batch_size, self.config.s_rnn_hdim)).cuda()

    def init_cell_h(self, batch_size):
        return Variable(torch.zeros(batch_size, self.config.s_rnn_hdim)).cuda()

    def input_module(self, stories, s_lens):
        word_embed = F.dropout(self.word_embed(stories), self.config.word_dr)
        init_s_rnn_h = self.init_rnn_h(stories.size(0))
        gru_out, _ = self.s_rnn(word_embed, init_s_rnn_h)
        gru_out = gru_out.contiguous().view(-1, self.config.s_rnn_hdim).cpu()
        s_lens_offset = (torch.arange(0, stories.size(0)).type(torch.LongTensor)
                * self.config.max_slen[self.set_num]).unsqueeze(1)
        s_lens = (torch.clamp(s_lens + s_lens_offset - 1, min=0)).view(-1)
        selected = gru_out[s_lens,:].view(-1, self.config.max_sentnum[self.set_num],
                self.config.s_rnn_hdim).cuda()
        return selected 

    def question_module(self, questions, q_lens):
        word_embed = F.dropout(self.word_embed(questions), self.config.word_dr)
        init_q_rnn_h = self.init_rnn_h(questions.size(0))
        gru_out, _ = self.q_rnn(word_embed, init_q_rnn_h)
        gru_out = gru_out.contiguous().view(-1, self.config.q_rnn_hdim).cpu()
        q_lens = (torch.arange(0, questions.size(0)).type(torch.LongTensor)
                * self.config.max_qlen[self.set_num] + q_lens - 1)
        selected = gru_out[q_lens,:].view(-1, self.config.q_rnn_hdim).cuda() 

        return selected

    def episodic_memory_module(self, s_rep, q_rep, e_lens, memory):
        # expand s_rep to have sentinel
        sentinel = Variable(torch.zeros(
            s_rep.size(0), 1, self.config.s_rnn_hdim)).cuda()
        s_rep = torch.cat((s_rep, sentinel), 1)
        q_rep = q_rep.unsqueeze(1).expand_as(s_rep)
        memory = memory.unsqueeze(1).expand_as(s_rep)
        # sw = self.z_sq(s_rep.view(-1, self.config.s_rnn_hdim)).view(
        #         q_rep.size())
        # swq = torch.sum(sw * q_rep, 2, keepdim=True)
        # swm = torch.sum(sw * memory, 2, keepdim=True)
        # Z = torch.cat([s_rep, memory, q_rep, s_rep*q_rep, s_rep*memory,
        #     torch.abs(s_rep-q_rep), torch.abs(s_rep-memory), swq, swm], 2)
        Z = torch.cat([s_rep*q_rep, s_rep*memory,
            torch.abs(s_rep-q_rep), torch.abs(s_rep-memory)], 2)
        G = self.g2(F.tanh(self.g1(Z.view(-1, self.z_dim))))
        G_s = F.sigmoid(G).view(
                -1, self.config.max_sentnum[self.set_num] + 1).unsqueeze(2)
        G_s = torch.transpose(G_s, 0, 1).contiguous()
        s_rep = torch.transpose(s_rep, 0, 1).contiguous()
        # print('g', G.size())

        e_rnn_h = self.init_cell_h(s_rep.size(1))
        # print('input', s_rep.size())
        # print('hidden', e_rnn_h.size())
        hiddens = []
        for step, (gg, ss) in enumerate(zip(G_s, s_rep)):
            e_rnn_h = gg * self.e_cell(ss, e_rnn_h) + (1 - gg) * e_rnn_h
            hiddens.append(e_rnn_h)
        hiddens = torch.transpose(torch.stack(hiddens), 0, 1).contiguous().view(
                -1, self.config.e_cell_hdim).cpu()
        e_lens = (torch.arange(0, s_rep.size(1)).type(torch.LongTensor)
                * (self.config.max_sentnum[self.set_num]+1) + e_lens - 1)
        selected = hiddens[e_lens,:].view(-1, self.config.e_cell_hdim).cuda() 
        # print('out', selected.size())
        return selected, G.view(-1, self.config.max_sentnum[self.set_num] + 1)

    def answer_module(self, q_rep, memory):
        y = F.softmax(self.out(memory))
        a_rnn_h = memory
        ys = []
        #print('q_rep', q_rep[0,:])
        for step in range(self.config.max_alen):
            a_rnn_h = self.a_cell(torch.cat((y, q_rep), 1), a_rnn_h)
            z = self.out(a_rnn_h)
            y = F.softmax(z)
            ys.append(z)
        ys = torch.transpose(torch.stack(ys), 0, 1).contiguous()
        """
        z = self.out(torch.cat((memory, q_rep), 1))
        ys = torch.transpose(torch.stack([z]), 0, 1).contiguous()
        """
        return ys

    def forward(self, stories, questions, s_lens, q_lens, e_lens):
        s_rep = self.input_module(stories, s_lens)
        q_rep = self.question_module(questions, q_lens)
        # print('stories', s_rep.size())
        # print('questions', q_rep.size())
        
        memory = q_rep # initial memory
        gates = []
        for episode in range(self.config.max_episode):
            e_rep, gate = self.episodic_memory_module(s_rep, q_rep, e_lens, memory)
            gates.append(gate)
            memory = self.m_cell(e_rep, memory)
        gates = torch.transpose(torch.stack(gates), 0, 1).contiguous()
        outputs = self.answer_module(q_rep, memory)
        # print('memory', memory.size())
        # print('outputs', outputs.size())

        return outputs, gates

    def get_regloss(self, weight_decay=None):
        if weight_decay is None:
            weight_decay = self.config.wd
        reg_loss = 0
        params = [] # add params here
        for param in params:
            reg_loss += torch.norm(param.weight, 2)
        return reg_loss * weight_decay

    def decay_lr(self, lr_decay=None):
        if lr_decay is None:
            lr_decay = self.config.lr_decay
        self.config.lr /= lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.config.lr

        print('\tlearning rate decay to %.3f' % self.config.lr)

    def get_metrics(self, outputs, targets, multiple=False):
        if not multiple:
            outputs = outputs[:,0,:]
            targets = targets[:,0]

            max_idx = torch.max(outputs, 1)[1].data.cpu().numpy()
            outputs_topk = torch.topk(outputs, 3)[1].data.cpu().numpy()
            targets = targets.data.cpu().numpy()

            acc = np.mean([float(k == tk[0]) for (k, tk)
                in zip(targets, outputs_topk)]) * 100
        else:
            topk_list = []
            target_list = []
            o_outputs = outputs[:]
            o_targets = targets[:]
            for idx in range(outputs.size(1)):
                outputs = o_outputs[:,idx,:]
                targets = o_targets[:,idx]
                max_idx = torch.max(outputs, 1)[1].data.cpu().numpy()
                outputs_topk = torch.topk(outputs, 3)[1].data.cpu().numpy()
                targets = targets.data.cpu().numpy()
                
                topk_list.append(outputs_topk)
                target_list.append(targets)

            acc = np.array([1.0 for _ in range(outputs.size(0))])
            for target, topk in zip(target_list, topk_list):
                acc *= np.array([float(k == tk[0] or k == -100) \
                        for (k, tk) in zip(target, topk)])
                # print(acc)
            acc = np.mean(acc) * 100

        return acc
 
    def save_checkpoint(self, state, filename=None):
        if filename is None:
            filename = (self.config.checkpoint_dir +\
                    self.config.model_name + str(self.set_num) + '.pth')
        else:
            filename = self.config.checkpoint_dir + filename
        print('\t=> save checkpoint %s' % filename)
        torch.save(state, filename)

    def load_checkpoint(self, filename=None):
        if filename is None:
            filename = (self.config.checkpoint_dir +\
                    self.config.model_name + str(self.set_num) + '.pth')
        else:
            filename = self.config.checkpoint_dir + filename
        print('\t=> load checkpoint %s' % filename)
        checkpoint = torch.load(filename)
        self.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        # self.config = checkpoint['config']

In [0]:
def run_experiment(model, dataset, set_num):
    best_metric = np.zeros(2)
    early_stop = False
    if model.config.train:
        if model.config.resume:
            model.load_checkpoint()

        for ep in range(model.config.epoch):
            if early_stop:
                break
            print('- Training Epoch %d' % (ep+1))
            run_epoch(model, dataset, ep, 'tr', set_num)

            if model.config.valid:
                print('- Validation')
                met = run_epoch(model, dataset, ep, 'va', set_num, False)
                if best_metric[1] < met[1]:
                    best_metric = met
                    model.save_checkpoint({
                        'config': model.config,
                        'state_dict': model.state_dict(),
                        'optimizer': model.optimizer.state_dict()})
                    if best_metric[1] == 100:
                        break
                else:
                    # model.decay_lr()
                    if model.config.early_stop:
                        early_stop = True
                        print('\tearly stop applied')
                print('\tbest metrics:\t%s' % ('\t'.join(['{:.2f}'.format(k)
                    for k in best_metric])))

            if model.config.test:
                print('- Testing')
                run_epoch(model, dataset, ep, 'te', set_num, False)
            print()
    
    if model.config.test:
        print('- Load Validation/Testing')
        if model.config.resume or model.config.train:
            model.load_checkpoint()
        run_epoch(model, dataset, 0, 'va', set_num, False)
        run_epoch(model, dataset, 0, 'te', set_num, False)
        print()

    return best_metric

In [17]:
# new model experiment
    for set_num in range(args.set_num, args.set_num+1):
        print('\n[QA set %d]' % (set_num))
        model = DMN(args, dataset.idx2vec, set_num).cuda()
        results = run_experiment(model, dataset, set_num)

    print('### end of experiment')


[QA set 1]
model parameters: 687,601

- Training Epoch 1
	[     ]	0.71% [11.95	0.00] time:  0: 0: 0



	[#####]	100.00% [1.76	0.00] time:  0: 0:27
	total metrics:	1.76	0.00
- Validation
	[#####]	100.00% [5.23	0.00] time:  0: 0: 1
	total metrics:	5.23	0.00
	best metrics:	0.00	0.00
- Testing
	[#####]	100.00% [5.22	0.00] time:  0: 0: 1
	total metrics:	5.22	0.00

- Training Epoch 2
	[#####]	100.00% [0.04	0.00] time:  0: 0:27
	total metrics:	0.04	0.00
- Validation
	[#####]	100.00% [5.12	0.00] time:  0: 0: 1
	total metrics:	5.12	0.00
	best metrics:	0.00	0.00
- Testing
	[#####]	100.00% [5.11	0.00] time:  0: 0: 1
	total metrics:	5.11	0.00

- Training Epoch 3
	[#####]	100.00% [0.01	0.00] time:  0: 0:27
	total metrics:	0.01	0.00
- Validation
	[#####]	100.00% [5.10	0.00] time:  0: 0: 1
	total metrics:	5.10	0.00
	best metrics:	0.00	0.00
- Testing
	[#####]	100.00% [5.10	0.00] time:  0: 0: 1
	total metrics:	5.10	0.00

- Training Epoch 4
	[#####]	100.00% [0.01	0.00] time:  0: 0:27
	total metrics:	0.01	0.00
- Validation
	[#####]	100.00% [5.10	0.00] time:  0: 0: 1
	total metrics:	5.10	0.00
	best metrics