In [1]:
import numpy as np
import time

In [2]:
from threading import Thread

In [3]:
from graceful import GracefulInterruptHandler

In [4]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.plotting import figure 
output_notebook()

In [5]:
import torch
from testbed import TextDataset, DataLoader, Reporter, Trainer, Net0, Net1, default_device, Transformer
from testbed import decode_broken_utf8

In [6]:
B = 16 # batch size (i.e. examples per batch)
K = 8 # token embedding dimension (K stands for "keys" but dont ask for justification)
C = 256 # number of tokens (C stands for "classes")

In [7]:
#model = Transformer().to(default_device())

In [8]:
# Net0
# constraints:
H = 2**8 # number of hidden neurons (i.e. number of convolution kernels)
L = 32 # length of convolution kernel
N = L + 1
model = Net0(H=H, L=L, K=K, C=C).to(default_device())

In [9]:
# Net 1
# assert N > L # tokens per example (each 64 should try to classify the one after)
# model = Net1(H=H, L=L, K=K, C=C).to(default_device())


In [10]:
model_size = sum([p.numel() for p in model.parameters()])
print(model_size)

133632


In [11]:
dataset = TextDataset(N=N, B=B)

In [12]:
trainer = Trainer(model=model, dataset=dataset, batch_size=B)

In [13]:
print(trainer.status())

Running: False
Paused: None


In [14]:
trainer.start()

In [15]:
class StatsTicker:
    def __init__(self, trainer, lag=1000):
        self.trainer = trainer
        self.lag = lag
        self.tick = 0
        self.data = []
        self.bokeh = {}
        self.bokeh_handle = None
        self.updating = False
        
    def poll(self):
        self.data += trainer.loss()

    def stats(self):
        lag = self.lag
        while len(self.data) <= self.tick + lag:
            self.poll()
            time.sleep(.1)
        steps = [datum[0] for datum in self.data[self.tick:]]
        times = [datum[1] for datum in self.data[self.tick:]]
        loss = [datum[2] for datum in self.data[self.tick:]]
        s1 = np.array(loss)
        s2 = np.square(s1)
        cs1 = np.cumsum(s1)
        cs2 = np.cumsum(s2)
        avg1 = (cs1[lag:] - cs1[:-lag])/float(lag)
        avg2 = (cs2[lag:] - cs2[:-lag])/float(lag)
        mean = avg1
        var = avg2 - np.square(avg1)
        sd = np.sqrt(var)
        T = np.array(times[lag:]) - self.data[0][1]
        self.tick += len(mean)
        return {'time' : T, 'mean_loss' : mean} #(T, mean, var, sd)
    
    def display(self, updates=True):
        TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
        self.bokeh["figure"] = figure(tools=TOOLS)
        self.bokeh["figure"].axis.major_label_text_font_size = "24px"
        hover = HoverTool(tooltips=None, mode="vline")
        self.bokeh["figure"].add_tools(hover)
        data = self.stats()
        self.bokeh["mean_loss"] = self.bokeh["figure"].line(data['time'],data['mean_loss'])
        self.bokeh_handle = show(self.bokeh["figure"], notebook_handle=True)
        if updates:
            self.start()
            
    def start(self):
        if not self.updating:
            self.updating = True
            self.updater = Thread(target=StatsTicker._update_loop, args=(self,))
            self.updater.start()
            
    def stop(self):
        if self.updating:
            self.updating = False
            self.updater.join()
            
    def _update_loop(self):
        while self.updating:
            data = self.stats()        
            self.bokeh["mean_loss"].data_source.stream({'x':data['time'], 'y':data['mean_loss']})
            push_notebook(handle=self.bokeh_handle)

In [16]:
ticker = StatsTicker(trainer)

In [17]:
ticker.display()

In [24]:
len(ticker.data), len(dataset)

(255784, 186632)

In [None]:
trainer.stop()

In [None]:
x[1].shape

In [23]:
trainer.set_batch_size(2048)

In [20]:
trainer.autocomplete()
pass

s-gangs sometimes stole and smug
~AUTOCOMPLETE~
her,
of of itso fisent camliens finly was no she cont of bllorpers, dodeer

ande stheridterbessionquuth to dould of dechafflefil dinchis un
sfrectionsoin, that a
polmestant at a outh sobert
knomctions to shant inthes
with it wat sover
jupht at
illaccu, cummenvemsel onling on thise had mustand in pointills aur-trep ast the
Ceblithilly resermoof. And yumphed Furrand
uptense. " o moment
noupht though, Loln
reach thaid lang.
Ot bere
it to matcyed
mact the kay eust wis off.ith soulfond to the stold.
A wantll yicht markinith.

"ind whe har
she of
     TCetnoter Is will artompucting should to three,
  to r.

-Th heaghed hy myound of by opplode kile preevanty offollencly hist armeard
--smouth and had
haadf.'
Hish in
youls fforler--last his bay
eydanown and Besprea
cladeon, of mit biandinc: that harding intrema?

[s Fept tome iff are off ve goind."

"At that ssarot, in plar' at, 33s, I
Walko and he chatore, ever
if thelinf sound.

Brot took; thone

In [None]:
def autocomplete(trainer, model, dataset, prompt="", N=1024, L=None):
    was_paused = False
    trainer.pause()
    if L is None:
        L = model.L
    prompt = [b for b in bytes(dataset.random_text_snippet(L) + prompt, 'utf-8')][-L:]
    completion = []
    tail = prompt
    for _ in range(N):
        x = (torch.tensor(tail)
                  .unsqueeze(0)
                  .to(default_device())) # shape [1,L]
        P = model.probs(x).view(-1)
        prob_dist = torch.distributions.Categorical(P)
        c_ord = prob_dist.sample().item()
        tail = tail[1:] + [c_ord]
        completion += [c_ord]
    print(decode_broken_utf8(bytes(prompt)+bytes("\n~AUTOCOMPLETE~\n",'utf-8') + bytes(completion)))

    if not was_paused:
        trainer.start()

    return decode_broken_utf8(bytes(completion))

In [None]:
autocomplete(trainer, model, dataset)