In [None]:
batch_size = 1024
max_length = 60


In [None]:
#!/usr/bin/env python
from pprint import pprint
import math
import numpy as np
import random
import sys

import torch
import torch.nn.functional as F
# See https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264

# Uncomment this to stop corner printing and see full/verbatim
#np.set_printoptions(threshold=np.nan)


def generate_nested_sequence(length, min_seglen=5, max_seglen=10):
    """Generate low-high-low sequence, with indexes of the first/last high/middle elements"""

    # Low (1-5) vs. High (6-10)
    seq_before = [(random.randint(1,5)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_during = [(random.randint(6,10)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_after = [random.randint(1,5) for x in range(random.randint(min_seglen, max_seglen))]
    seq = seq_before + seq_during + seq_after

    # Pad it up to max len with 0's
    seq = seq + ([0] * (length - len(seq)))
    return [seq, len(seq_before), len(seq_before) + len(seq_during)-1]


def create_one_hot(length, index):
    """Returns 1 at the index positions; can be scaled by client"""
    a = np.zeros([length])
    a[index] = 1.0
    return a


def get_lstm_state(cell):
    """Centralize definition of 'state', to swap .c and .h if desired"""
    return cell.c


def print_pointer(arr, first, second):
    """Pretty print the array, along with pointers to the first/second indices"""
    first_string = " ".join([(" " * (2 - len(str(x))) + str(x)) for x in arr])
    print(first_string)
    second_array = ["  "] * len(arr)
    second_array[first] = "^1"
    second_array[second] = "^2"
    if (first == second):
        second_array[first] = "^B"
    second_string = " " + " ".join([x for x in second_array])
    print(second_string)


In [None]:
x = []
for i in generate_nested_sequence(40)[0]:
    x.append(create_one_hot(40, i))
print(x)

In [None]:
from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
    def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
        super(Model, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_size = hidden_size
        self.num_of_indices = num_of_indices
        self.blend_dim = blend_dim
        self.batch_size = batch_size
        
        self.encode = nn.LSTMCell(input_dim, hidden_size)
        self.decode = nn.LSTMCell(input_dim, hidden_size)
        self.blend_decoder = nn.Linear(hidden_size, blend_dim)
        self.blend_encoder = nn.Linear(hidden_size, blend_dim)
        self.scale_blend = nn.Linear(blend_dim, input_dim)
        
    def zero_hidden_state(self):
        return Variable(torch.zeros([self.batch_size, self.hidden_size]).cuda())
        
    def forward(self, inp):
        #TODO - zero 
        hidden = self.zero_hidden_state()
        cell_state = self.zero_hidden_state()
        encoder_states = []
        for j in range(len(inp[0])):
            encoder_input = inp[:, j:j+1]
            hidden, cell_state = self.encode(encoder_input, (hidden, cell_state)) 
            encoder_states.append(cell_state)
            
        decoder_state = encoder_states[-1]
        pointers = []
        pointer_distributions = []
        
        start_token = 20
        decoder_input = Variable(torch.Tensor([start_token] * self.batch_size)
                                 .view(self.batch_size, self.input_dim).cuda())

        for i in range(self.num_of_indices):
            hidden = self.zero_hidden_state()
            cell_state = self.zero_hidden_state()
            hidden, cell_state = self.decode(decoder_input, (hidden, cell_state))
            
            decoder_blend = self.blend_decoder(cell_state)
            encoder_blends = []
            index_predists = []
            for i in range(len(inp[0])):
                encoder_blend = self.blend_encoder(encoder_states[i])
                raw_blend = encoder_blend + decoder_blend
                scaled_blend = self.scale_blend(raw_blend).squeeze(1)
                
                index_predist = scaled_blend
                
                encoder_blends.append(encoder_blend)
                index_predists.append(index_predist)
                
            index_predistribution = torch.stack(index_predists).t()
            index_distribution = F.softmax(index_predistribution)
            pointer_distributions.append(index_distribution)
            index = index_distribution.data.max(1)[1].squeeze(1)

            emb = embedding_lookup(inp.t(), Variable(index))
            pointer_raw = torch.diag(emb)
            pointer = pointer_raw

            pointers.append(pointer)
            decoder_input = pointer.unsqueeze(1)
            
            #print('pointer: {}'.format(pointers))
        index_distributions = torch.stack(pointer_distributions)
        return index_distributions
    

In [None]:
from torch import optim
from torch.autograd import Variable
def train(epochs, model,  train_batches, print_every = 100):
        optimizer = optim.SGD(model.parameters(), lr=1, momentum=0.1)
        for epoch in range(epochs):
            for batch, (data, target) in enumerate(train_batches):
                data, target = Variable(data.cuda()), Variable(target.cuda())
                index_distributions = model(data)
                loss = torch.sqrt(
                            torch.mean(
                                torch.pow(index_distributions - target, 2)
                            )
                )
            
                loss.backward()
                optimizer.step()
                
            if epoch % print_every == 0:
                print('epoch: {} -- loss: {}'.format(epoch, loss.data[0]))

In [None]:
def embedding_lookup(embeddings, indices):
    result =  embeddings.index_select(0, indices.view(-1))
    return result.view(*(indices.size() + embeddings.size()[1:]))

In [None]:
train_segment_length_min = 11
train_segment_length_max = 20

seqs = []
start_indices = []
end_indices = []

for i in range(batch_size):
    seq, start, end = generate_nested_sequence(max_length, 
                                                train_segment_length_min, 
                                                train_segment_length_max)
    
    start_, end_ = create_one_hot(max_length, start),  create_one_hot(max_length, end)
    seqs.append(seq), start_indices.append(start_), end_indices.append(end_)

#print(len(seqs))
#pprint([len(seq) for seq in seqs])
    
seqs          = torch.Tensor(seqs)
start_indices = torch.Tensor(start_indices)
end_indices   = torch.Tensor(end_indices)
indices = torch.stack([start_indices, end_indices])
train_batches = [(seqs, indices),]  


In [None]:
sample_input, sample_output = train_batches[0]
model = Model(1, 6, 2, 6, batch_size)
model.cuda()

In [None]:
train(40000, model, train_batches)

In [None]:
test_segment_length_min = 5
test_segment_length_max = 10

seqs = []
start_indices = []
end_indices = []

for i in range(batch_size):
    seq, start, end = generate_nested_sequence(max_length, 
                                                test_segment_length_min, 
                                                test_segment_length_max)
    
    start_, end_ = create_one_hot(max_length, start),  create_one_hot(max_length, end)
    seqs.append(seq), start_indices.append(start_), end_indices.append(end_)

#print(len(seqs))
#pprint([len(seq) for seq in seqs])
    
seqs          = torch.Tensor(seqs)
start_indices = torch.Tensor(start_indices)
end_indices   = torch.Tensor(end_indices)
indices = torch.stack([start_indices, end_indices])
test_batches = [(seqs, indices),]  

In [None]:
for data, target in test_batches:
    data, target = Variable(data.cuda()), Variable(target.cuda())
    index_distributions = model(data)
    loss = torch.sqrt(
                torch.mean(
                    torch.pow(index_distributions - target, 2)
                )
            )
    
    print('loss: {}'.format(loss.data[0]))

print(data[:10])
print(target[0].data.max(1)[1][:10])
print(target[1].data.max(1)[1][:10])
print(index_distributions[0].data.max(1)[1][:10])
print(index_distributions[1].data.max(1)[1][:10])
incorrect_pointers = 0

results = index_distributions.data.numpy()
print(index_distributions)
for batch_index in range(batch_size):
    if batch_index >= 59:
        break
    print(results[0][batch_index][1])
    first_diff = start_[batch_index] - results[1][batch_index][0]
    first_diff_max = np.max(np.abs(first_diff))
    print(first_diff, first_diff_max)
    first_ptr = np.argmax(results[1][batch_index][0])
    if first_diff_max >= .5:  # bit stricter than argmax but let's hold ourselves to high standards, people
        incorrect_pointers += 1
    second_diff = end_[batch_index] - results[1][batch_index][1]
    second_diff_max = np.max(np.abs(second_diff))
    second_ptr = np.argmax(results[1][batch_index][1])
    if second_diff_max >= .5:
        incorrect_pointers += 1

    print_pointer(seqs[batch_index], first_ptr, second_ptr)
    #print("")

test_pct = np.round(100.0 * ((2 * batch_size) - incorrect_pointers) / (2 * batch_size), 5)
print("")
print(" %s / %s (correct/total); test pct %s" % ((2*batch_size) - incorrect_pointers,
                                                 2 * batch_size,
test_pct))