In [1]:
import torch
from torch.optim import AdamW
from testbed.optim import Sonny
from testbed import TextDataset, Trainer, Net0, Net1, Transformer
from testbed.util import decode_broken_utf8, default_device
from math import log

In [2]:
num_input_classes= 256 # 256 possible UTF-8 bytes
embedding_dim = 8 # Dimension of embedding space. An embedding layer has 256 points in this space.
context_length = 128*2 # Number of sequential bytes visible to model (i.e. in the context)
num_hidden = 8192*2 # Hyperparameter for neural network
num_output_classes = 256 # 256 possible UTF-8 bytes
model = Net0(num_input_classes=num_input_classes,
             embedding_dim=embedding_dim,
             context_length=context_length,
             num_hidden=num_hidden,
             num_output_classes=num_output_classes).to(default_device())
example_length = context_length + 1
dataset = TextDataset(example_length=example_length)
batch_size = 8192 # batch size (i.e. examples per batch)
OptimizerType = Sonny

In [3]:
model_size = sum([p.numel() for p in model.parameters()])
print(model_size)

37767424


In [4]:
len(dataset)

14818489351

In [5]:
trainer = Trainer(model=model, 
                  example_length=example_length, 
                  batch_size=batch_size, 
                  dataset=dataset, 
                  OptimizerType=OptimizerType)

In [6]:
trainer.start()

In [7]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.models import HoverTool, ColumnDataSource
from bokeh.plotting import figure 
output_notebook()
import time, math
import numpy as np
from threading import Thread, Lock

class StatsTicker:
    def __init__(self, trainer):
        self.trainer = trainer
        self.tick = 0
        self.losses = self.trainer.update_losses()
        self.bokeh = {}
        self.bokeh_handle = None
        self.updating = False
        
    def recent_stats(self):
        self.losses = self.trainer.update_losses()
        data = {'time' : [ x[1] for x in self.losses[self.tick:]],
                'mean_loss' : [8*x[2]/math.log(256) for x in self.losses[self.tick:]]}
        #print(data)
        return data
    
    def display(self, updates=True):
        TOOLS="pan,wheel_zoom,box_zoom,reset"
        self.bokeh["figure"] = figure(tools=TOOLS)
        self.bokeh["figure"].axis.major_label_text_font_size = "24px"
        hover = HoverTool(show_arrow=False,
                          mode='vline',
                          line_policy='next',
                          tooltips=[('X_value', '$data_x'),
                                    ('Y_value', '$data_y')])
        self.bokeh["figure"].add_tools(hover)
        data = self.recent_stats()
        self.bokeh["mean_loss"] = self.bokeh["figure"].line(data['time'],data['mean_loss'])
        self.tick = len(self.losses)
        self.bokeh_handle = show(self.bokeh["figure"], notebook_handle=True)
        if updates:
            self.start()
            
    def start(self):
        if not self.updating:
            self.updating = True
            self.updater = Thread(target=StatsTicker._update_loop, args=(self,), daemon=True)
            self.updater.start()
            
    def stop(self):
        if self.updating:
            self.updating = False
            self.updater.join()
            
    def _update_loop(self):
        while self.updating:
            time.sleep(1)
            data = self.recent_stats()
            if len(self.losses) > self.tick:
                self.bokeh["mean_loss"].data_source.stream({'x':data['time'], 
                                                            'y':data['mean_loss']})
                self.tick = len(self.losses)
            push_notebook(handle=self.bokeh_handle)

In [8]:
ticker = StatsTicker(trainer)

In [9]:
ticker.display()

In [None]:
ticker.losses[-10:]

In [None]:
trainer.set_batch_size(8192)

In [None]:

param_mem = (C*K + K*L*H + H*C + K + C)*16
compute_mem = H*8192*4 # is this right? we don't have to take grad to inputs, for example
(C, K, L, 'KL', K*L, 'H', H, param_mem, compute_mem, param_mem + compute_mem)

In [None]:
measured_mem = 2819*2**20 + 748*2**20

In [None]:
(measured_mem - param_mem)//H

In [None]:
(8*256, 2048, 256)

In [None]:
trainer.set_batch_size(8192*4)

In [None]:
ticker.losses[-1000:]

In [None]:
trainer.save()

In [None]:
def smoother(data, lag):
    cs = np.cumsum(np.array(data))
    return (cs[lag:] - cs[:-lag])/lag

In [None]:
TOOLS="crosshair,pan,wheel_zoom,box_zoom,reset,tap,box_select,lasso_select"
logfig = figure(tools=TOOLS)
logfig.axis.major_label_text_font_size = "24px"
hover = HoverTool(tooltips=None, mode="vline")
logfig.add_tools(hover)
lag = 1
if lag > 0:
    X = [x[1] for x in trainer.losses][lag:]
    Y = [ 8*u / math.log(256) for u in smoother([x[2] for x in trainer.losses], lag)]
else:
    X = [x[1] for x in trainer.losses]
    Y = [x[2] for x in trainer.losses]
logline = logfig.line(X,Y)
bokeh_handle = show(logfig)

In [None]:
trainer.autocomplete()
pass

In [None]:
for p in trainer.model.parameters():
    print(p.shape, p, torch.max(p).item(), torch.min(p).item())

In [None]:
trainer.stop()

In [None]:
with torch.no_grad():
    for i in range(1000):
        print(model(dataset[i].view(1,-1).long().cuda()).item())

In [None]:
trainer.pause()