In [None]:
import torch
from torch.optim import AdamW
from testbed import TextDataset, Trainer, Net0, Net1, Transformer
from testbed.util import decode_broken_utf8, default_device

In [None]:
K = 8 # token embedding dimension (K stands for "keys" but dont ask for justification)
C = 256 # number of tokens (C stands for "classes")

In [None]:
model_type = "Net0"

In [None]:
# Net0
#   * N = L + 1
if model_type == "Net0":
    H = 8192*2 # 2**10 # number of hidden neurons (i.e. number of convolution kernels)
    L = 128*2 # 128 # length of convolution kernel
    N = L + 1
    B = 8192 # batch size (i.e. examples per batch)
    model = Net0(H=H, L=L, K=K, C=C).to(default_device())
    dataset = TextDataset(N=N)
    optimizer = AdamW(model.parameters())

In [None]:
# Net 1
#   * N > L
if model_type == "Net1":
    H = 4096 # number of hidden neurons (i.e. number of convolution kernels)
    L = 1024 # length of convolution kernel
    N = 2*L
    B = 32 # batch size (i.e. examples per batch)
    model = Net1(H=H, L=L, K=K, C=C).to(default_device())
    dataset = TextDataset(N=N)
    optimizer = AdamW(model.parameters())

In [None]:
if model_type == "Transformer":
    N = 8
    model = Transformer().to(default_device())
    dataset = TextDataset(N=N)
    optimizer = AdamW(model.parameters())

In [None]:
model_size = sum([p.numel() for p in model.parameters()])
print(model_size)

In [None]:
len(dataset)

In [None]:
trainer = Trainer(model=model, N=N, B=B, dataset=dataset, optimizer=optimizer)

In [None]:
trainer.load()

In [None]:
trainer.losses = []
trainer.compute_time = 0.0

In [None]:
trainer.start()

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.plotting import figure 
output_notebook()
import time, math
import numpy as np
from threading import Thread, Lock

class StatsTicker:
    def __init__(self, trainer):
        self.trainer = trainer
        self.tick = 0
        self.losses = self.trainer.update_losses()
        self.bokeh = {}
        self.bokeh_handle = None
        self.updating = False
        
    def recent_stats(self):
        self.losses = self.trainer.update_losses()
        data = {'time' : [ x[1] for x in self.losses[self.tick:]],
                'mean_loss' : [8*x[2]/math.log(256) for x in self.losses[self.tick:]]}
        return data
    
    def display(self, updates=True):
        TOOLS="pan,wheel_zoom,box_zoom,reset"
        self.bokeh["figure"] = figure(tools=TOOLS)
        self.bokeh["figure"].axis.major_label_text_font_size = "24px"
        hover = HoverTool(show_arrow=False,
                          mode='vline',
                          line_policy='next',
                          tooltips=[('X_value', '$data_x'),
                                    ('Y_value', '$data_y')])
        self.bokeh["figure"].add_tools(hover)
        data = self.recent_stats()
        self.bokeh["mean_loss"] = self.bokeh["figure"].line(data['time'],data['mean_loss'])
        self.tick = len(self.losses)
        self.bokeh_handle = show(self.bokeh["figure"], notebook_handle=True)
        if updates:
            self.start()
            
    def start(self):
        if not self.updating:
            self.updating = True
            self.updater = Thread(target=StatsTicker._update_loop, args=(self,), daemon=True)
            self.updater.start()
            
    def stop(self):
        if self.updating:
            self.updating = False
            self.updater.join()
            
    def _update_loop(self):
        while self.updating:
            time.sleep(1)
            data = self.recent_stats()
            if len(self.losses) > self.tick:
                self.bokeh["mean_loss"].data_source.stream({'x':data['time'], 
                                                            'y':data['mean_loss']})
                self.tick = len(self.losses)
            push_notebook(handle=self.bokeh_handle)

In [None]:
ticker = StatsTicker(trainer)

In [None]:
ticker.display()

In [None]:
trainer.set_batch_size(8192*3)

In [None]:
ticker.losses[-10:]

In [None]:
trainer.save()

In [None]:
def smoother(data, lag):
    cs = np.cumsum(np.array(data))
    return (cs[lag:] - cs[:-lag])/lag

In [None]:
TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
logfig = figure(tools=TOOLS)
logfig.axis.major_label_text_font_size = "24px"
hover = HoverTool(tooltips=None, mode="vline")
logfig.add_tools(hover)
lag = 0
if lag > 0:
    X = [x[1] for x in trainer.losses][lag:]
    Y = [ 8*u / math.log(256) for u in smoother([x[2] for x in trainer.losses], lag)]
else:
    X = [x[1] for x in trainer.losses]
    Y = [x[2] for x in trainer.losses]
logline = logfig.line(X,Y)
bokeh_handle = show(logfig)

In [None]:
trainer.autocomplete()
pass

In [None]:
for p in trainer.model.parameters():
    print(p.shape, p, torch.max(p).item(), torch.min(p).item())

In [None]:
trainer.stop()

In [None]:
with torch.no_grad():
    for i in range(1000):
        print(model(dataset[i].view(1,-1).long().cuda()).item())