<a href="https://colab.research.google.com/github/sourcecode369/deep-nlp/blob/master/memory%20networks/Memory_Networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import absolute_import, print_function, unicode_literals, division
from builtins import range, input

In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
%matplotlib inline 
import re 
import tarfile

%tensorflow_version 2.x
import tensorflow as tf
from keras.models import Model, Sequential
from tensorflow.keras.layers import Dense, Embedding, Input, Lambda, Reshape, add, dot, Activation 
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.utils import get_file

import warnings
warnings.simplefilter("ignore")
warnings.filterwarnings("ignore")

In [0]:
path = get_file('babi-tasks-v1-2.tar.gz',
                origin='https://s3.amazonaws.com/text-datasets/babi_tasks_1-20_v1-2.tar.gz')

tar = tarfile.open(path)

In [0]:
challenges = {
    'single_supporting_fact_10k':'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt',
    'two_supporting_fact_10k':'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt'
}

In [0]:
def tokenize(sent):
  return [x.strip() for x in re.split('(\W+)?',sent) if x.strip()]

In [0]:
def get_stories(f):
  data = []
  story = []
  printed = False
  for line in f:
    line = line.decode('utf-8').strip()
    nid, line = line.split(' ', 1)
    if int(nid) == 1:
      story = []
    if '\t' in line:
      q, a, supporting = line.split('\t')
      q = tokenize(q)
      story_so_far = [[str(i)] + s for i, s in enumerate(story) if s]
      data.append((story_so_far, q, a))
      story.append('')
    else:
      story.append(tokenize(line))
  return data

In [0]:
def should_flatten(el):
  return not isinstance(el, (str, bytes))

def flatten(l):
  for el in l:
    if should_flatten(el):
      yield from flatten(el)
    else:
      yield el

In [0]:
def vectorize_stories(data, word2idx, story_maxlen, query_maxlen):
  inputs, queries, answers = [], [], []
  for story, query, answer in data:
    inputs.append([[word2idx[w] for w in s] for s in story])
    queries.append([word2idx[w] for w in query])
    answers.append([word2idx[answer]])
  return (
    [pad_sequences(x, maxlen=story_maxlen) for x in inputs],
    pad_sequences(queries, maxlen=query_maxlen),
    np.array(answers)
  )

In [0]:
def stack_inputs(inputs, story_maxsents, story_maxlen):
  for i, story in enumerate(inputs):
    inputs[i] = np.concatenate(
        [
         story, 
         np.zeros((story_maxsents-story.shape[0], story_maxlen),'int')
        ]
    )
  return np.stack(inputs)

In [0]:
def get_data(challenge_type):
  challenge = challenges[challenge_type]
  
  train_stories = get_stories(tar.extractfile(challenge.format('train')))
  test_stories = get_stories(tar.extractfile(challenge.format('test')))
  
  stories = train_stories + test_stories
  
  story_maxlen = max((len(s) for x, _, _ in stories for s in x))
  story_maxsents = max((len(x) for x, _, _ in stories))
  query_maxlen = max(len(x) for _, x, _ in stories)

  vocab = sorted(set(flatten(stories)))
  vocab.insert(0, '<PAD>')
  vocab_size = len(vocab)

  word2idx = {c:i for i, c in enumerate(vocab)}

  inputs_train, queries_train, answers_train = vectorize_stories(
      train_stories,
      word2idx,
      story_maxlen,
      query_maxlen
  )
  inputs_test, queries_test, answers_test = vectorize_stories(
      test_stories, 
      word2idx,
      story_maxlen,
      query_maxlen
  )
  inputs_train = stack_inputs(inputs_train, story_maxsents, story_maxlen)
  inputs_test = stack_inputs(inputs_test, story_maxsents, story_maxlen)
  print(f"inputs_train.shape {inputs_train.shape}, inputs_test.shape {inputs_test.shape}")
  return train_stories, test_stories, inputs_train, queries_train, answers_train, \
  inputs_test, queries_test, answers_test, story_maxsents, story_maxlen, query_maxlen, vocab, vocab_size 

In [35]:
train_stories, test_stories, inputs_train, queries_train, answers_train, \
  inputs_test, queries_test, answers_test, story_maxsents, story_maxlen, query_maxlen, vocab, vocab_size = get_data('single_supporting_fact_10k')

inputs_train.shape (10000, 10, 8), inputs_test.shape (1000, 10, 8)
