In [None]:
from threading import Thread

In [None]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.plotting import figure 
output_notebook()

In [None]:
import torch
from testbed import TextDataset, Trainer, Net0, Net1, Transformer
from testbed.util import decode_broken_utf8, default_device

In [None]:
K = 8 # token embedding dimension (K stands for "keys" but dont ask for justification)
C = 256 # number of tokens (C stands for "classes")

In [None]:
model_type = "Transformer"

In [None]:
# Net0
#   * N = L + 1
if model_type == "Net0":
    H = 2**8 # number of hidden neurons (i.e. number of convolution kernels)
    L = 32 # length of convolution kernel
    N = L + 1
    model = Net0(H=H, L=L, K=K, C=C).to(default_device())

In [None]:
# Net 1
#   * N > L
if model_type == "Net1":
    H = 2**8 # number of hidden neurons (i.e. number of convolution kernels)
    L = 32 # length of convolution kernel
    N = 2*L
    model = Net1(H=H, L=L, K=K, C=C).to(default_device())


In [None]:
if model_type == "Transformer":
    N = 256
    model = Transformer().to(default_device())

In [None]:
model_size = sum([p.numel() for p in model.parameters()])
print(model_size)

In [None]:
B = 16 # batch size (i.e. examples per batch)

In [None]:
dataset = TextDataset(N=N, B=B)

In [None]:
len(dataset)

In [None]:
trainer = Trainer(model=model, dataset=dataset, batch_size=B)

In [None]:
trainer.start()

In [None]:
import time, math
import numpy as np
from threading import Lock

class StatsTicker:
    def __init__(self, trainer, lag=1000):
        self.trainer = trainer
        self.lag = lag
        self.tick = 0
        self.data = []
        self.bokeh = {}
        self.bokeh_handle = None
        self.updating = False
        
    def poll(self):
        self.data += trainer.loss()

    def stats(self):
        lag = self.lag
        self.poll()
        if len(self.data) == 0:
            return {'time' : [], 'mean_loss' : []}
        if self.tick < lag:
            trailing = self.data[:]
            T = np.array([datum[1]-self.data[0][1] for datum in trailing])
            loss = np.array([datum[2] for datum in trailing])
            cs = np.cumsum(loss)
            mean = [ cs[i]/(i+1) for i in range(self.tick) ]
            self.tick = len(self.data)
            return {'time': T[self.tick:], 'mean_loss': mean[self.tick:] }
        else:
            trailing = self.data[self.tick-lag:]
            times = [datum[1] for datum in trailing]
            loss = [datum[2] for datum in trailing]
            s1 = np.array(loss)
            s2 = np.square(s1)
            cs1 = np.cumsum(s1)
            cs2 = np.cumsum(s2)
            avg1 = (cs1[lag:] - cs1[:-lag])/float(lag)
            avg2 = (cs2[lag:] - cs2[:-lag])/float(lag)
            mean = avg1
            var = avg2 - np.square(avg1)
            sd = np.sqrt(var)
            T = np.array(times[lag:]) - self.data[0][1]
            self.tick = len(self.data)
            return {'time' : T, 'mean_loss' : mean} #(T, mean, var, sd)
        
    def display(self, updates=True):
        TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
        self.bokeh["figure"] = figure(tools=TOOLS)
        self.bokeh["figure"].axis.major_label_text_font_size = "24px"
        hover = HoverTool(tooltips=None, mode="vline")
        self.bokeh["figure"].add_tools(hover)
        data = self.stats()
        self.bokeh["mean_loss"] = self.bokeh["figure"].line(data['time'],data['mean_loss'])
        self.bokeh_handle = show(self.bokeh["figure"], notebook_handle=True)
        if updates:
            self.start()
            
    def start(self):
        if not self.updating:
            self.updating = True
            self.updater = Thread(target=StatsTicker._update_loop, args=(self,))
            self.updater.start()
            
    def stop(self):
        if self.updating:
            self.updating = False
            self.updater.join()
            
    def _update_loop(self):
        while self.updating:
            time.sleep(.1)
            data = self.stats()        
            self.bokeh["mean_loss"].data_source.stream({'x':data['time'], 'y':data['mean_loss']})
            push_notebook(handle=self.bokeh_handle)

In [None]:
ticker = StatsTicker(trainer)

In [None]:
ticker.display()

In [None]:
len(ticker.data), len(dataset)

In [None]:
trainer.status()

In [None]:
x[1].shape

In [None]:
# batch_schedule = [(100,  32)] # (s, bsz)

In [None]:
trainer.set_batch_size(128)

In [None]:
trainer.autocomplete()
pass

In [None]:
ticker.stop()

In [None]:
ticker.start()

In [None]:
ticker.stats()