In [1]:
import torchtext, random, torch

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from box_wrapper import DeltaBoxTensor
from modules import BoxEmbedding

import numpy as np
from tqdm import tqdm_notebook

global use_cuda
use_cuda = torch.cuda.is_available()
device = 0 if use_cuda else -1

TEXT = torchtext.data.Field()
train, val, test = torchtext.datasets.LanguageModelingDataset.splits(path="../data", train="train.txt", validation="valid.txt", test="valid.txt", text_field=TEXT)
TEXT.build_vocab(train, max_size=1000) if False else TEXT.build_vocab(train)
TEXT.vocab.load_vectors('glove.840B.300d')
train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits((train, val, test), batch_size=10, bptt_len=10, repeat=False)

In [7]:
class Trainer:
    def __init__(self, train_iter, val_iter):
        self.train_iter = train_iter
        self.val_iter = val_iter
        
    def string_to_batch(self, string):
        relevant_split = string.split() # last two words, ignore ___
        ids = [self.word_to_id(word) for word in relevant_split]
        if use_cuda:
            return Variable(torch.LongTensor(ids)).cuda()
        else:
            return Variable(torch.LongTensor(ids))
        
    def word_to_id(self, word, TEXT = TEXT):
        return TEXT.vocab.stoi[word]
    
    def batch_to_input(self, batch):
        ngrams = self.collect_batch_ngrams(batch)
        x = Variable(torch.LongTensor([ngram[:-1] for ngram in ngrams]))
        y = Variable(torch.LongTensor([ngram[-1] for ngram in ngrams]))
        if use_cuda:
            return x.cuda(), y.cuda()
        else:
            return x, y
    
    def collect_batch_ngrams(self, batch, n = 5):
        data = torch.flatten(batch.text.T)
        return [tuple(data[idx:idx + n]) for idx in range(0, len(data) - n + 1)]
    
    def train_model(self, model, num_epochs):
        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = torch.optim.Adam(params = parameters, lr=1e-1)
        criterion = nn.NLLLoss()
        
        for epoch in tqdm_notebook(range(num_epochs)):
            epoch_loss = []
#             hidden = model.init_hidden()
            model.train()
            count = 0
            for batch in tqdm_notebook(train_iter):
                x, y = self.batch_to_input(batch)
                if use_cuda: x, y = x.cuda(), y.cuda()
                optimizer.zero_grad()
                y_pred = model.forward(x, train = True)
                loss = criterion(y_pred, y)
                loss.backward()
                optimizer.step()
                epoch_loss.append(loss.data.item())
                count += 1
                if count > 2: break
            model.eval()
            train_ppl = np.exp(np.mean(epoch_loss))
#             val_ppl = self.validate(model)
            val_ppl = 0

            print('Epoch {0} | Loss: {1} | Train PPL: {2} | Val PPL: {3}'.format(epoch+1, np.mean(epoch_loss), train_ppl,  val_ppl))
    
        print('Model trained.')
        self.write_kaggle(model)
        print('Output saved.')
        
    def validate(self, model):
        criterion = nn.NLLLoss()
        aggregate_loss = []
        for batch in self.val_iter:
            x, y = self.batch_to_input(batch)
            if use_cuda: x, y = x.cuda(), y.cuda()
            y_p = model.forward(x, train = False)
            loss = criterion(y_p, y)
            aggregate_loss.append(loss.data.item())        
        val_ppl = np.exp(np.mean(aggregate_loss))
        return val_ppl
    
#     def predict_sentence(self, string, model, TEXT = TEXT):
#         string = string[:-4]
#         model.batch_size = 1
#         hidden = model.init_hidden()
#         x = self.string_to_batch(string)
#         logits, _ = model.forward(x, hidden, train = False)
#         argsort_ids = np.argsort(logits[-1].data.tolist())
#         out_ids = argsort_ids[-20:][::-1]
#         out_words = ' '.join([TEXT.vocab.itos[out_id] for out_id in out_ids])
#         return out_words

In [8]:
class BoxModel(nn.Module):
    box_types = {
        'DeltaBoxTensor': DeltaBoxTensor,
    }
    def __init__(self, TEXT = TEXT, batch_size = 10, n_gram=4):
        super(BoxModel, self).__init__()
        self.batch_size = batch_size
        self.n_gram = n_gram
        self.vocab_size, self.embedding_dim = TEXT.vocab.vectors.shape
        self.embeddings_word = BoxEmbedding(self.vocab_size, self.embedding_dim, box_type='DeltaBoxTensor')
        self.embedding_bias = nn.Embedding(self.vocab_size, 1)
        self.embedding_bias.weight.data = torch.zeros(self.vocab_size, 1)
    
    def forward(self, x, train = True):
        """ predict, return hidden state so it can be used to intialize the next hidden state """
        context_word_boxes = self.embeddings_word(x)
        lm_batch_size = x.shape[0]
        context_word_boxes.data = torch.mean(context_word_boxes.data, dim=1).view(-1,1,2,self.embedding_dim)
        all_word = self.embeddings_word(torch.arange(self.vocab_size))
        all_word.data = all_word.data.view(1, self.vocab_size, 2,self.embedding_dim)
#         all_word.data = all_word.data.view(-1,1,2,self.embedding_dim)
        dec = all_word.intersection_log_soft_volume(context_word_boxes)
        decoded = dec + self.embedding_bias(torch.arange(self.vocab_size)).view(-1)
        logits = F.log_softmax(decoded, dim = 1)       
        return logits

In [None]:
model = BoxModel()
if use_cuda:
    model.cuda()
trainer = Trainer(train_iter = train_iter, val_iter = val_iter)
trainer.train_model(model = model, num_epochs = 40)

HBox(children=(IntProgress(value=0, max=40), HTML(value='')))

HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 1 | Loss: 28.31008752187093 | Train PPL: 1972035791655.3457 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 2 | Loss: 18.96613661448161 | Train PPL: 172539475.9741962 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 3 | Loss: 18.243431727091473 | Train PPL: 83757118.92038861 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 4 | Loss: 16.99761740366618 | Train PPL: 24097469.758281067 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 5 | Loss: 15.386950492858887 | Train PPL: 4813574.4397103805 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 6 | Loss: 13.401597658793131 | Train PPL: 661058.5275040754 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 7 | Loss: 11.533591270446777 | Train PPL: 102088.08218423842 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 8 | Loss: 10.075883547465006 | Train PPL: 23762.964814927014 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 9 | Loss: 9.019209861755371 | Train PPL: 8260.247764401127 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 10 | Loss: 8.248806794484457 | Train PPL: 3823.0614008857647 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))







Epoch 11 | Loss: 7.428682168324788 | Train PPL: 1683.587425959763 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 12 | Loss: 6.973869164784749 | Train PPL: 1068.3483799507499 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 13 | Loss: 6.4639045397440595 | Train PPL: 641.561173590667 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 14 | Loss: 6.141090075174968 | Train PPL: 464.5596999459836 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 15 | Loss: 5.738229433695476 | Train PPL: 310.514138107831 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 16 | Loss: 5.593597571055095 | Train PPL: 268.70055229558915 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 17 | Loss: 5.125375429789226 | Train PPL: 168.23729108774697 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 18 | Loss: 4.946265856424968 | Train PPL: 140.6487793727388 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 19 | Loss: 4.733453750610352 | Train PPL: 113.68753346187538 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 20 | Loss: 4.467822313308716 | Train PPL: 87.1666944268005 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 21 | Loss: 4.178085088729858 | Train PPL: 65.24080317308635 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 22 | Loss: 4.250153462092082 | Train PPL: 70.11617169548782 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 23 | Loss: 3.8012410004933677 | Train PPL: 44.75669312130451 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 24 | Loss: 3.631168524424235 | Train PPL: 37.75691072200749 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 25 | Loss: 3.444472869237264 | Train PPL: 31.32676579162402 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 26 | Loss: 3.3478055795033774 | Train PPL: 28.440255237868836 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 27 | Loss: 3.1177481015523276 | Train PPL: 22.595439673357586 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

Epoch 28 | Loss: 2.758084694544474 | Train PPL: 15.769610384404467 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))


Epoch 29 | Loss: 2.5207928816477456 | Train PPL: 12.438454977500959 | Val PPL: 0


HBox(children=(IntProgress(value=0, max=9296), HTML(value='')))

In [None]:
v = torch.zeros((1,7,2,10))
q = torch.ones((3,1,2,10))

In [None]:
torch.min(v, q).shape

In [None]:
v = v.view(-1,7,2,10).repeat(3, 1, 1, 1)

In [None]:
torch.max(v, q).shape