In [None]:
%reload_ext autoreload
%autoreload 2

import model
import data

import argparse
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable

import data
import model
from trainlogger import TrainLogger
from gensim.models.word2vec import Word2Vec, KeyedVectors

In [None]:
def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    print('num batches:',nbatch)
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1, 200).transpose(1,0).contiguous()
    return data

def repackage_hidden(h):
    """Wraps hidden states in new Variables, to detach them from their history."""
    if type(h) == Variable:
        return Variable(h.data)
    else:
        return tuple(repackage_hidden(v) for v in h)
    
def get_batch(source, i, evaluation=False):
    seq_len = min(bptt, len(source) - 1 - i)
    data = Variable(source[i:i+seq_len], volatile=evaluation)
    target = Variable(source[i+1:i+1+seq_len])
    return data, target


def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(eval_batch_size)
    for i in range(0, data_source.size(0) - 1, args.bptt):
        data, targets = get_batch(data_source, i, evaluation=True)
        output, hidden = model(data, hidden)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss[0] / len(data_source)

total_batches = 0

In [None]:
wem_model = KeyedVectors.load("./data/wikitext-2/wikitext-2a.w2v")
corpus = data.Corpus("./data/wikitext-2/", None, wem_model)

In [None]:
batch_size = 20
eval_batch_size = 10

train_data = batchify(corpus.train, 20)
val_data = batchify(corpus.valid, 10)
test_data = batchify(corpus.test, 10)


In [None]:
print(corpus.valid.shape)
print(val_data.shape)
# CHECK IF the 2nd value in the first batch is the same as the 2nd value in the training set
all(v == 1 for v in [torch.eq(corpus.train[1],train_data[1,0])][0])

In [None]:
modelType = 'LSTM'
ntokens = corpus.dict_size()
emsize = 200
nhid = 500
nlayers = 2
dropout = 0.5
lr = 0.5
bptt = 5
clip = 0.25

rnn_model = model.RNNModel(modelType, nhid, nlayers, dropout, wem_model)
criterion = nn.MSELoss()

optimizer = torch.optim.SGD(rnn_model.parameters(), lr=lr)
criterion = nn.MSELoss()

best_val_loss = None

In [None]:
def get_words_for_tensor(wem, tensor):
    if isinstance(tensor,torch.FloatTensor):
        data = tensor
    elif isinstance(tensor,torch.autograd.variable.Variable):
        data = tensor.data
    else:
        print('ksss') # but could also just be numpy array
    if data.shape[1] != len(wem.wv.syn0[0]): # better way to get size of embedding...
        print("Sizes don't match: tensor:(%s) - embedding:(%s)" % (data.shape[1], len(wem.wv.syn0)))
    return ' '.join([
        '|'.join([res[0] + ('\n' if res[0] == '<eos>' else '') for res in wem.wv.similar_by_vector(i.numpy(),1)])
        for i in data
    ])
    

In [None]:
#def train():
# Turn on training mode which enables dropout.
rnn_model.train()
total_loss = 0
hidden = rnn_model.init_hidden(batch_size)

num_batches = 5
# for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
for batch, i in enumerate(range(0, bptt * num_batches, bptt)):
#     print('batch',batch)
    t_data, targets = get_batch(train_data, i)
#     get input vector shape:  shape: bptt * emb_size
#     print(t_data.data[:,0,:].shape) 
#     To assure we grabbing the right thing:
#     print("most similar word to the first word-vector in training data:")
#     print(wem_model.wv.similar_by_vector(t_data.data[0,0,:].numpy(),topn=1)[0][0])
#     introspect that vector and the sum of that vector
#     print(t_data.data[0,0,:5], t_data.data[0,0,:].sum())
#     print words of the whole first batch, derived from the train-data tensors...
    input_words = get_words_for_tensor(wem_model,t_data.data[:,0,:])
    print('""" INPUT\n%s\n"""' % input_words)
#     print words of the whole first batch targets, derived from the train-data tensors...   
    target_words = get_words_for_tensor(wem_model,targets.data[:,0,:])
    print('""" TARGET\n%s\n"""' % target_words)
#     print('training_data_batch',t_data.data.shape)
#     print('targets',targets.data.shape)
#     # Starting each batch, we detach the hidden state from how it was previously produced.
#     # If we didn't, the model would try backpropagating all the way to start of the dataset.

    hidden = repackage_hidden(hidden)
    rnn_model.zero_grad()
    output, hidden = rnn_model(t_data, hidden)
    output_words = get_words_for_tensor(wem_model,output.data[:,0,:])
    print('""" OUTPUT\n%s\n"""' % output_words)
#     print('output',output.data.shape)
#     print(get_words_for_tensor(wem_model,output.data[:,0,:]))
    loss = criterion(output,targets)
    print("loss:",loss.data)
    loss.backward()

#     # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
#     torch.nn.utils.clip_grad_norm(rnn_model.parameters(), clip)
    optimizer.step()

#     if batch % args.log_interval == 0 and batch > 0:
#         cur_loss = total_loss[0] / args.log_interval
#         elapsed = time.time() - start_time
#         print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
#                 'loss {:5.2f} | ppl {:8.2f}'.format(
#             epoch, batch, len(train_data) // args.bptt, lr,
#             elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)))
#         total_loss = 0
#         logger.log(total_batches, math.exp(cur_loss))
#         start_time = time.time()
#     total_batches += 1

In [None]:
rnn_params = list(rnn_model.rnn.parameters())
type(rnn_params[0])

In [None]:
import numpy as np
test_word = 'Chronicles'
print(wem_model.wv[test_word][:5])
index = wem_model.wv.vocab[test_word].index
print('index ',index)
print(wem_model.wv.syn0[index][:5])
# np.sum(wem_model.wv.syn0[0