In [3]:
batch_size = 1024
max_length = 60


In [4]:
#!/usr/bin/env python
from pprint import pprint
import math
import numpy as np
import random
import sys

import torch
import torch.nn.functional as F
# See https://medium.com/@devnag/pointer-networks-in-tensorflow-with-sample-code-14645063f264

# Uncomment this to stop corner printing and see full/verbatim
#np.set_printoptions(threshold=np.nan)


def generate_nested_sequence(length, min_seglen=5, max_seglen=10):
    """Generate low-high-low sequence, with indexes of the first/last high/middle elements"""

    # Low (1-5) vs. High (6-10)
    seq_before = [(random.randint(1,5)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_during = [(random.randint(6,10)) for x in range(random.randint(min_seglen, max_seglen))]
    seq_after = [random.randint(1,5) for x in range(random.randint(min_seglen, max_seglen))]
    seq = seq_before + seq_during + seq_after

    # Pad it up to max len with 0's
    seq = seq + ([0] * (length - len(seq)))
    return [seq, len(seq_before), len(seq_before) + len(seq_during)-1]


def create_one_hot(length, index):
    """Returns 1 at the index positions; can be scaled by client"""
    a = np.zeros([length])
    a[index] = 1.0
    return a


def get_lstm_state(cell):
    """Centralize definition of 'state', to swap .c and .h if desired"""
    return cell.c


def print_pointer(arr, first, second):
    """Pretty print the array, along with pointers to the first/second indices"""
    first_string = " ".join([(" " * (2 - len(str(x))) + str(x)) for x in arr])
    print(first_string)
    second_array = ["  "] * len(arr)
    second_array[first] = "^1"
    second_array[second] = "^2"
    if (first == second):
        second_array[first] = "^B"
    second_string = " " + " ".join([x for x in second_array])
    print(second_string)


In [103]:
from torch import nn
from torch.autograd import Variable
class Model(nn.Module):
    def __init__(self, input_dim, hidden_size, num_of_indices, blend_dim, batch_size):
        super(Model, self).__init__()
        
        self.batch_size = batch_size               # B
        self.input_dim = input_dim                 # I
        self.hidden_size = hidden_size             # H
        self.num_of_indices = num_of_indices       # N
        self.blend_dim = blend_dim                 # D
                
        self.encode = nn.LSTMCell(input_dim, hidden_size)
        self.decode = nn.LSTMCell(input_dim, hidden_size)
        self.blend_decoder = nn.Linear(hidden_size, blend_dim)
        self.blend_encoder = nn.Linear(hidden_size, blend_dim)
        self.scale_blend = nn.Linear(blend_dim, input_dim)
        
    def zero_hidden_state(self):
        return Variable(torch.randn([self.batch_size, self.hidden_size]).cuda())
        
    def forward(self, inp):
        #TODO - zero 
        hidden = self.zero_hidden_state()                                            # BxH
        cell_state = self.zero_hidden_state()                                        # BxH
        encoder_states = []
        for j in range(len(inp[0])):                                          # inp -> BxJxI
            encoder_input = inp[:, j:j+1]                                            # BxI
            hidden, cell_state = self.encode(encoder_input, (hidden, cell_state)) 
            encoder_states.append(cell_state)
            
        decoder_state = encoder_states[-1]                       # BxH
        pointers = []
        pointer_distributions = []
        
        start_token = 0
        decoder_input = Variable(torch.Tensor([start_token] * self.batch_size)        # BxI
                                 .view(self.batch_size, self.input_dim).cuda())

        for i in range(self.num_of_indices):
            hidden = self.zero_hidden_state()                                         # BxH
            cell_state = self.zero_hidden_state()                                     # BxH
            hidden, cell_state = self.decode(decoder_input, (hidden, cell_state))     # BxH
            
            decoder_blend = self.blend_decoder(cell_state)                            # BxD
            encoder_blends = []
            index_predists = []
            for i in range(len(inp[0])):
                encoder_blend = self.blend_encoder(encoder_states[i])                  # BxD
                raw_blend = encoder_blend + decoder_blend                              # BxD
                scaled_blend = self.scale_blend(raw_blend).squeeze(1)                  # BxI
                
                index_predist = scaled_blend
                
                encoder_blends.append(encoder_blend)
                index_predists.append(index_predist)
                
            index_predistribution = torch.stack(index_predists).t()                    # BxJ
            index_distribution = F.log_softmax(index_predistribution)
            pointer_distributions.append(index_distribution)                          
            index = index_distribution.data.max(1)[1].squeeze(1)                       # B

            emb = embedding_lookup(inp.t(), Variable(index))                           # BxB
            pointer_raw = torch.diag(emb)                                              # B
            pointer = pointer_raw
            pointers.append(pointer)
            decoder_input = pointer.unsqueeze(1)                                       # Bx1

            #print('pointer: {}'.format(pointers))
        index_distributions = torch.stack(pointer_distributions)                    
        return index_distributions                                                     # NxBxJ
    

In [107]:
from torch import optim
from torch.autograd import Variable
def train(epochs, model,  train_batches, print_every = 100):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.1)
    for epoch in range(epochs):
        for batch, (data, target) in enumerate(train_batches):

            data, target = Variable(data.cuda()), Variable(target.cuda())
            optimizer.zero_grad()

            index_distributions = model(data)
            diff = index_distributions -  target
            #print(diff[:,:3])
            loss = torch.sqrt(
                        torch.mean(
                            torch.pow(diff, 2)
                        )
            )
            loss.backward()
            optimizer.step()

        if epoch % print_every == 0:
            #print([i for i in model.parameters()])
            print('epoch: {} -- loss: {}'.format(epoch, loss.data[0]))

In [8]:
def embedding_lookup(embeddings, indices):
    result =  embeddings.index_select(0, indices.view(-1))
    return result.view(*(indices.size() + embeddings.size()[1:]))

In [9]:
train_segment_length_min = 5
train_segment_length_max = 7

seqs = []
start_indices = []
end_indices = []

for i in range(batch_size):
    seq, start, end = generate_nested_sequence(max_length, 
                                                train_segment_length_min, 
                                                train_segment_length_max)
    
    start_, end_ = create_one_hot(max_length, start),  create_one_hot(max_length, end)
    seqs.append(seq), start_indices.append(start_), end_indices.append(end_)

#print(len(seqs))
#pprint([len(seq) for seq in seqs])
    
seqs          = torch.Tensor(seqs)
start_indices = torch.Tensor(start_indices)
end_indices   = torch.Tensor(end_indices)
indices = torch.stack([start_indices, end_indices])
train_batches = [(seqs, indices),]  


In [105]:
sample_input, sample_output = train_batches[0]
model = Model(1, 6, 2, 6, batch_size)
model.cuda()

Model (
  (encode): LSTMCell(1, 6)
  (decode): LSTMCell(1, 6)
  (blend_decoder): Linear (6 -> 6)
  (blend_encoder): Linear (6 -> 6)
  (scale_blend): Linear (6 -> 1)
)

In [108]:
train(4000, model, train_batches)

epoch: 0 -- loss: 4.112970352172852
epoch: 100 -- loss: 4.112970352172852
epoch: 200 -- loss: 4.112970352172852
epoch: 300 -- loss: 4.112970352172852
epoch: 400 -- loss: 4.112970352172852
epoch: 500 -- loss: 4.112969875335693
epoch: 600 -- loss: 4.112970352172852
epoch: 700 -- loss: 4.112969875335693
epoch: 800 -- loss: 4.112970352172852
epoch: 900 -- loss: 4.112969875335693
epoch: 1000 -- loss: 4.112969875335693
epoch: 1100 -- loss: 4.112969875335693
epoch: 1200 -- loss: 4.112970352172852
epoch: 1300 -- loss: 4.112970352172852
epoch: 1400 -- loss: 4.112970352172852
epoch: 1500 -- loss: 4.112969875335693
epoch: 1600 -- loss: 4.112970352172852
epoch: 1700 -- loss: 4.112969875335693
epoch: 1800 -- loss: 4.112970352172852
epoch: 1900 -- loss: 4.112969875335693
epoch: 2000 -- loss: 4.112969875335693
epoch: 2100 -- loss: 4.112969875335693
epoch: 2200 -- loss: 4.112969875335693
epoch: 2300 -- loss: 4.112969875335693
epoch: 2400 -- loss: 4.112969875335693
epoch: 2500 -- loss: 4.11296987533569

In [None]:
test_segment_length_min = 2
test_segment_length_max = 4

seqs = []
start_indices = []
end_indices = []

for i in range(batch_size):
    seq, start, end = generate_nested_sequence(max_length, 
                                                test_segment_length_min, 
                                                test_segment_length_max)
    
    start_, end_ = create_one_hot(max_length, start),  create_one_hot(max_length, end)
    seqs.append(seq), start_indices.append(start_), end_indices.append(end_)

#print(len(seqs))
#pprint([len(seq) for seq in seqs])
    
seqs          = torch.Tensor(seqs)
start_indices = torch.Tensor(start_indices)
end_indices   = torch.Tensor(end_indices)
indices = torch.stack([start_indices, end_indices])
test_batches = [(seqs, indices),]  

In [27]:
for data, target in train_batches:
    data, target = Variable(data.cuda()), Variable(target.cuda())
    index_distributions = model(data)
    loss = torch.sqrt(
                torch.mean(
                    torch.pow(index_distributions - target, 2)
                )
            )
    
    print('loss: {}'.format(loss.data[0]))

print(data[:10])
print(target[0].data.max(1)[1][:10])
print(target[1].data.max(1)[1][:10])
print(index_distributions[0].data.max(1)[1][:10])
print(index_distributions[1].data.max(1)[1][:10])
incorrect_pointers = 0

results = index_distributions.data.cpu().numpy()
print(index_distributions)
for batch_index in range(batch_size):
    if batch_index >= 59:
        break
    print(results[0][batch_index][1])
    first_diff = start_[batch_index] - results[1][batch_index][0]
    first_diff_max = np.max(np.abs(first_diff))
    print(first_diff, first_diff_max)
    first_ptr = np.argmax(results[1][batch_index][0])
    if first_diff_max >= .5:  # bit stricter than argmax but let's hold ourselves to high standards, people
        incorrect_pointers += 1
    second_diff = end_[batch_index] - results[1][batch_index][1]
    second_diff_max = np.max(np.abs(second_diff))
    second_ptr = np.argmax(results[1][batch_index][1])
    if second_diff_max >= .5:
        incorrect_pointers += 1

    print_pointer(seqs[batch_index], first_ptr, second_ptr)
    #print("")

test_pct = np.round(100.0 * ((2 * batch_size) - incorrect_pointers) / (2 * batch_size), 5)
print("")
print(" %s / %s (correct/total); test pct %s" % ((2*batch_size) - incorrect_pointers,
                                                 2 * batch_size,
test_pct))

loss: 0.17192234098911285
Variable containing:

Columns 0 to 12 
    1     2     2     2     2     8     7    10     9     8     3     1     4
    4     2     1     4     3     6     6    10    10     6     8     8     1
    5     2     5     1     2    10     6     7     9     8     8     4     1
    5     4     2     3     3     6     8    10     7    10     7     3     2
    2     4     3     1     4     9     7     8     6     6     3     3     3
    2     3     4     3     2     5     5     7     7     7     6     8     5
    5     1     5     1     1     4     2     9     8    10     6    10     5
    4     3     3     2     1     9     9    10    10    10     7     4     4
    4     2     1     4     1     5     7     7     7     6    10     4     5
    1     1     5     3     4     4     2    10     9     6     7     8     6

Columns 13 to 25 
    3     4     4     4     0     0     0     0     0     0     0     0     0
    2     4     5     2     0     0     0     0     0     

RuntimeError: can't convert CUDA tensor to numpy (it doesn't support GPU arrays). Use .cpu() to move the tensor to host memory first.