In [None]:
import torch
from torch.optim import AdamW
from testbed import TextDataset, Trainer, Net0, Net1, Transformer
from testbed.util import decode_broken_utf8, default_device

In [None]:
K = 16 # token embedding dimension (K stands for "keys" but dont ask for justification)
C = 256 # number of tokens (C stands for "classes")

In [None]:
model_type = "Net0"

In [None]:
# Net0
#   * N = L + 1
if model_type == "Net0":
    H = 2**10 # number of hidden neurons (i.e. number of convolution kernels)
    L = 128 # length of convolution kernel
    N = L + 1
    B = 32 # batch size (i.e. examples per batch)
    model = Net0(H=H, L=L, K=K, C=C).to(default_device())
    dataset = TextDataset(N=N, B=B)
    optimizer = AdamW(model.parameters())

In [None]:
# Net 1
#   * N > L
if model_type == "Net1":
    H = 2**10 # number of hidden neurons (i.e. number of convolution kernels)
    L = 1024 # length of convolution kernel
    N = 2*L
    B = 32 # batch size (i.e. examples per batch)
    model = Net1(H=H, L=L, K=K, C=C).to(default_device())
    dataset = TextDataset(N=N, B=B)
    optimizer = AdamW(model.parameters())

In [None]:
if model_type == "Transformer":
    N = 8
    model = Transformer().to(default_device())
    dataset = TextDataset(N=N, B=B)
    optimizer = AdamW(model.parameters())

In [None]:
model_size = sum([p.numel() for p in model.parameters()])
print(model_size)

In [None]:
len(dataset)

In [None]:
trainer = Trainer(model=model, dataset=dataset, optimizer=optimizer)

In [None]:
trainer.start()

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.plotting import figure 
output_notebook()
import time, math
import numpy as np
from threading import Thread, Lock

class StatsTicker:
    def __init__(self, trainer, lag=1000):
        self.trainer = trainer
        self.lag = lag
        self.tick = 0
        self.data = []
        self.bokeh = {}
        self.bokeh_handle = None
        self.updating = False
        self.first_poll = True
        
    def poll(self):
        if self.first_poll:
            self.data = trainer.data[:]
            self.first_poll = False
        else:
            self.data += trainer.loss()

    def stats(self):
        lag = self.lag
        self.poll()
        if len(self.data) == self.tick:
            return {'time' : [], 'mean_loss' : []}
        if self.tick < lag:
            T = np.array([datum[1]-self.data[0][1] for datum in self.data])
            loss = np.array([datum[2] for datum in self.data])
            cs = np.cumsum(loss)
            mean = [ x/(i+1) for (i,x) in enumerate(cs) ]
            result = {'time': T[self.tick:], 'mean_loss': mean[self.tick:] }
            self.tick = len(self.data)
            return result
        else:
            trailing = self.data[self.tick-lag:]
            times = [datum[1] for datum in trailing]
            loss = [datum[2] for datum in trailing]
            s1 = np.array(loss)
            s2 = np.square(s1)
            cs1 = np.cumsum(s1)
            cs2 = np.cumsum(s2)
            avg1 = (cs1[lag:] - cs1[:-lag])/float(lag)
            avg2 = (cs2[lag:] - cs2[:-lag])/float(lag)
            mean = avg1
            var = avg2 - np.square(avg1)
            sd = np.sqrt(var)
            T = np.array(times[lag:]) - self.data[0][1]
            self.tick = len(self.data)
            return {'time' : T, 'mean_loss' : mean} #(T, mean, var, sd)
        
    def display(self, updates=True):
        TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
        self.bokeh["figure"] = figure(tools=TOOLS)
        self.bokeh["figure"].axis.major_label_text_font_size = "24px"
        hover = HoverTool(tooltips=None, mode="vline")
        self.bokeh["figure"].add_tools(hover)
        data = self.stats()
        self.bokeh["mean_loss"] = self.bokeh["figure"].line(data['time'],data['mean_loss'])
        self.bokeh_handle = show(self.bokeh["figure"], notebook_handle=True)
        if updates:
            self.start()
            
    def start(self):
        if not self.updating:
            self.updating = True
            self.updater = Thread(target=StatsTicker._update_loop, args=(self,))
            self.updater.start()
            
    def stop(self):
        if self.updating:
            self.updating = False
            self.updater.join()
            
    def _update_loop(self):
        while self.updating:
            time.sleep(.2)
            data = self.stats()      
            self.bokeh["mean_loss"].data_source.stream({'x':data['time'], 'y':data['mean_loss']})
            push_notebook(handle=self.bokeh_handle)

In [None]:
ticker = StatsTicker(trainer)

In [None]:
ticker.display()

In [None]:
len(ticker.data)

In [None]:
dataset.D

In [None]:
trainer.save()

In [None]:
ticker.stop()

In [None]:
from testbed import Trainer
trainer = Trainer()
trainer.load()

In [None]:
trainer.status()

In [None]:
def smoother(data, lag):
    return data[lag:]
    cs = np.cumsum(np.array(data))
    return (cs[lag:] - cs[:-lag])/lag

In [None]:
TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
logfig = figure(tools=TOOLS)
logfig.axis.major_label_text_font_size = "24px"
hover = HoverTool(tooltips=None, mode="vline")
logfig.add_tools(hover)
lag = 0
if lag > 0:
    X = [x[1] for x in ticker.data][lag:]
    Y = smoother([x[2] for x in ticker.data], lag)
else:
    X = [x[1] for x in ticker.data]
    Y = [x[2] for x in ticker.data]
logline = logfig.line(X,Y)
bokeh_handle = show(logfig)

In [None]:
trainer.start()

In [None]:
trainer.autocomplete()
pass