In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset
from copy import deepcopy
import random
flatten = lambda l: [item for sublist in l for item in sublist]

In [2]:
def bAbI_data_loader(path, vocab=None):
    try:
        with open(path, 'r', encoding='utf-8') as file:
            data = file.readlines()
        data = [l.strip() for l in data]
    except:
        print('no such file: {}'.format(path))
        return None

    data_temp = []
    story = []
    
    try:
        for line in data:
            idx, line = line.split(' ', 1)
            if idx == '1':
                story = []

            if '?' in line:
                q, a, support = line.split('\t')
                q = q.lower().strip().replace('?', '').split() + ['?']
                a = a.lower().strip().split() + ['</s>']
                support = int(support)
                story_temp = deepcopy(story)
                data_temp.append([story_temp, q, a, support])
            else:
                sentence = line.lower().replace('.', '').split() + ['</s>']
                story.append(sentence)

    except:
        print('check data')
        return None

    if vocab:
        data, vocab = bAbI_build_vocab(data_temp, vocab)
    else:
        data, vocab = bAbI_build_vocab(data_temp)
    
    return data, vocab

In [3]:
def bAbI_build_vocab(data, vocab=None):
    if vocab is None:
        story, q, a, s = list(zip(*data))
        vocab = list(set(flatten(flatten(story)) + flatten(q) + flatten(a)))
        word2idx = {'<pad>': 0, '<unk>': 1, '<s>': 2, '</s>': 3}
        for word in vocab:
            if word2idx.get(word) is None:
                word2idx[word] = len(word2idx)
        idx2word = {v: k for k, v in word2idx.items()}
    else:
        word2idx = vocab
    
    for d in data:
        # d[0]: stories
        # d[1]: questions
        # d[2]: answer
        # d[3]: support
        for i, story in enumerate(d[0]):
            d[0][i] = transfer2idx(story, word2idx)
            
        d[1] = transfer2idx(d[1], word2idx)
        d[2] = transfer2idx(d[2], word2idx)
    
    return data, word2idx

In [4]:
def transfer2idx(seq, dictionary):
    idxs = list(map(lambda w: dictionary[w] if dictionary.get(w) is not None else \
                    dictionary["<unk>"], seq))
    return idxs

In [6]:
def data_loader(train_data, batch_size, shuffle=False):
    if shuffle: random.shuffle(train_data)
    sindex = 0
    eindex = batch_size
    while eindex < len(train_data):
        batch = train_data[sindex: eindex]
        temp = eindex
        eindex = eindex + batch_size
        sindex = temp
        yield batch
    
    if eindex >= len(train_data):
        batch = train_data[sindex:]
        yield batch

In [30]:
def pad_to_batch(batch, w2idx):
    """
    stories, stories_masks: B, n, T_c
    questions, questions_masks: B, T_q
    answers: B, T_a
    supports: B
    """
    story, q, a, s = list(zip(*batch))
    max_story = max([len(s) for s in story]) # max_stories
    max_len = max([len(s) for s in flatten(story)]) # max_sentence_len
    max_q = max([len(q_) for q_ in q])
    max_a = max([len(a_) for a_ in a])

    stories, stories_masks = [], []
    for i in range(len(batch)):
        story_array, story_mask = get_batch_array(get_fixed_array(story[i], w2idx), max_story, max_len)
        stories.append(story_array)
        stories_masks.append(story_mask)
        
    questions, questions_masks = get_batch_array(get_fixed_array(q, w2idx), len(batch), max_q)
    answers, _ = get_batch_array(get_fixed_array(a, w2idx), len(batch), max_a)
    
    return trans2tensor(stories), trans2tensor(stories_masks), trans2tensor(questions), \
            trans2tensor(questions_masks), trans2tensor(answers), list(s)

In [9]:
def get_fixed_array(data, w2idx):
    max_col = max([len(d) for d in data])
    for j in range(len(data)):
        if len(data[j]) < max_col:
            data[j].append(w2idx.get('<pad>'))
    return data

In [10]:
def get_batch_array(data, *shape):
    r, c = shape
    temp = np.zeros((r, c), dtype=np.int)
    it = np.nditer(np.array(data, dtype=np.int), flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        tmp_val = np.array(data)[idx]
        temp[idx] = tmp_val
        it.iternext()
    
    mask = (temp == 0).astype(np.byte)
    return temp.tolist(), mask.tolist()

In [21]:
def trans2tensor(x):
    return Variable(torch.LongTensor(x))

In [5]:
path = '../data/QA_bAbI_tasks/en-10k/qa1_single-supporting-fact_train.txt'
data, word2idx = bAbI_data_loader(path)

In [7]:
batch_size = 32
for batch in data_loader(data, batch_size, shuffle=False):
    break

In [31]:
stories, stories_masks, questions, questions_masks, answers, supports = pad_to_batch(batch, word2idx)

In [32]:
stories.size()

torch.Size([32, 10, 7])

In [33]:
stories_masks.size()

torch.Size([32, 10, 7])

In [34]:
questions.size()

torch.Size([32, 4])

In [35]:
questions_masks.size()

torch.Size([32, 4])

In [36]:
answers.size()

torch.Size([32, 2])