In [1]:
import skorch
import torch
import torch.nn as nn

import enwik8_data

In [2]:
import visdom
vis = visdom.Visdom()

# Load and prepare data

In [3]:
raw_data = enwik8_data.hutter_raw_data(data_path='./data/')

In [4]:
TRAIN_DATA, VALID_DATA, TEST_DATA, unique_syms = raw_data

In [5]:
EMBEDDING_SIZE = len(unique_syms)

In [6]:
def collate(g):
    for x, y in g:
        yield torch.from_numpy(x).long(), torch.from_numpy(y).long()

class Enwik8Loader:
    def __init__(self, _dataset, batch_size=128, num_steps=32, max_samples=None, **kwargs):
        self.max_samples = max_samples
        self.batch_size = batch_size
        self.num_steps = num_steps
    def __iter__(self):
        return collate(enwik8_data.data_iterator(
            self.dataset[slice(0, self.max_samples)], 
            self.batch_size, 
            self.num_steps))

class Enwik8TrainLoader(Enwik8Loader):
    dataset = TRAIN_DATA
    
class Enwik8ValidLoader:
    dataset = VALID_DATA

# Custom Callbacks

In [25]:
import time
import sys

class BatchPrinter(skorch.callbacks.Callback):
    def __init__(self, update_interval=5):
        self.update_interval = update_interval
    def initialize(self):
        self.batches_per_epoch = None
        self.batch_counter = 0
    def on_batch_begin(self, *args, **kwargs):
        self.batch_start_time = time.time()
    def on_batch_end(self, net, *args, train=True, **kwargs):
        self.batch_end_time = time.time()
        self.batch_counter += 1
        if self.batch_counter % self.update_interval != 0:
            return
        
        k = 'train_loss' if train else 'valid_loss'
        loss = '{}: {:.3}'.format(k, net.history[-1, 'batches', -1, k])
        
        sys.stdout.write("Batch {}/{} complete ({:.2}s), {}.\r".format(
            self.batch_counter, 
            self.batches_per_epoch,
            self.batch_end_time - self.batch_start_time,
            loss,
        ))
        sys.stdout.flush()
    def on_epoch_end(self, *args, **kwargs):
        if self.batches_per_epoch is None:
            self.batches_per_epoch = self.batch_counter
        self.batch_counter = 0

# Model definition

In [26]:
def time_flatten(t):
    return t.view(t.size(0) * t.size(1), -1)

def time_unflatten(t, s):
    return t.view(s[0], s[1], -1)

In [27]:
class ReconModel(nn.Module):
    def __init__(self, num_hidden=64, num_layers=1, visdom=False):
        super().__init__()
        
        self.num_layers = num_layers
        self.visdom = visdom
        self.emb = nn.Embedding(EMBEDDING_SIZE, num_hidden)
        self.rnn = nn.GRU(num_hidden, num_hidden, num_layers=num_layers)
        self.clf = nn.Linear(num_hidden, EMBEDDING_SIZE)
        
        self.softmax = nn.LogSoftmax()
        
    def forward(self, x):
        x_emb = self.emb(x.long())
        l0, h0 = self.rnn(x_emb)
        
        if self.visdom:
            self.visdom.heatmap(h0[0].data.numpy(), win="hidden")
            
        l1 = self.clf(time_flatten(l0))
        l1_sm = self.softmax(l1)
        
        return time_unflatten(l1_sm, x.size())

In [35]:
class Trainer(skorch.NeuralNet):
    def __init__(self, 
                 criterion=nn.NLLLoss,
                 *args, 
                 **kwargs):
        super().__init__(*args, criterion=criterion, **kwargs)

    def get_loss(self, y_pred, y_true, X=None, train=False):
        pred = time_flatten(y_pred)
        true = time_flatten(y_true).squeeze(-1)
        return super().get_loss(pred, true, X=X, train=train)
    
    def score(self, *args, **kwargs):
        try:
            return self.history[-1, 'train_loss']
        except KeyError:
            return -42

# Training

In [67]:
torch.manual_seed(1337)

def my_train_split(X, y):
    return X, X, y, y

ef = Trainer(module=ReconModel,
             optimizer=torch.optim.RMSprop,
             lr=0.002,
             max_epochs=2,
                              
             train_split=my_train_split,
             iterator_train=Enwik8TrainLoader,
             iterator_train__batch_size=128,
             #iterator_train__max_samples=128*300,
             iterator_train__num_steps=32,
             
             iterator_valid=Enwik8ValidLoader,
             iterator_valid__batch_size=128,
             iterator_valid__num_steps=32,
             #iterator_valid__max_samples=None,

             use_cuda=True,
             
             module__visdom=None,
             module__num_hidden=256,
             module__num_layers=4,
             
             callbacks=[BatchPrinter()]
            )

In [68]:
from sklearn.model_selection import GridSearchCV
import numpy as np

In [None]:
params = {
    'lr': [0.002, 0.02],
    'module__num_hidden': [256, 320],
}
m = GridSearchCV(ef, params)
m.fit(np.zeros((10,1)), np.zeros((10,)))

Batch 250/None complete (0.084s), train_loss: 2.91.