# Lab Template

## import libraries

In [None]:
import asyncio
import math
from math import log, sin, cos, tan, exp, sqrt, pi
import time
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import BytesDataset
from classroom import GutenbergSnippetsDataset
from classroom import MLPLM, MyLM
from classroom import TransformerLM
from classroom import AdamW
from classroom import Sonny
from classroom import Floyd
from classroom import Plot
from classroom import Fun
from classroom import Count
from classroom import Sum
from classroom import Diff
from classroom import Log2Sum
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import numel
from classroom import utf8decode

## initialize model

In [None]:
if False:
    path = 'checkpoint.pt'
    model = torch.load(path).to('cuda')

In [None]:
if True:
    model = (
        MyLM(
            n_ctx=256,
            n_vocab_in=256,
            d_model=8,
            n_layers=2,
            d_hidden=4096,
            nonlinearity="GELU",
            p_dropout=0.0,
            n_vocab_out=256).to('cuda'))

In [None]:
numel(model), numel(model)*4/1E9

## initialize student

In [None]:
optimizer = AdamW(parameters=model.named_parameters())
dataset = GutenbergSnippetsDataset()
batch_size = None
example_length = model.n_ctx + 1
student= Student(
    model=model,
    optimizer=optimizer,
    dataset=dataset,
    batch_size=batch_size,
    example_length=example_length)

## schedule hyperparameters

In [None]:
student.batch_size=1024
for (idx, (pn, p)) in enumerate(student.model.named_parameters()):
    batch_multiplier = 1
    lr_base = 1e-4
    lr = lambda n: lr_base * (n/1000) if n < 1000 else lr_base
    s = lambda n: exp(-cos(pi*n/1000))  # sin(pi*n/(1000))**2
    student.optimizer.state[pn]["lr"]           = lambda n: lr(n) * s(n)
    student.optimizer.state[pn]["beta1"]        = lambda n: 0.9
    student.optimizer.state[pn]["beta2"]        = lambda n: 0.999
    student.optimizer.state[pn]["weight_decay"] = lambda n: 0.001 * s(n)
    student.optimizer.state[pn]["update"]       = lambda n: n%batch_multiplier == 0

## initialize baseline

In [None]:
student.push()
time_of_last_baseline = student.time

## start training

In [None]:
classroom = Classroom()

In [None]:
classroom.enroll(student)

## autocomplete

In [None]:
def autocomplete(prompt=None):
    for (idx, student) in enumerate(classroom.students):
        print(f"\n\nStudent #{idx}\n==========")
        print(student.autocomplete(prompt=prompt, n_generate=1024))
autocomplete()

## plots

In [None]:
import time
plot_data = {}
lag = 1024
for (idx, student) in enumerate(classroom.students):
    X = Fun(Count(), student.times)
    Y = Fun(TwoWindowFilter(lag=lag), student.grades)
    Z = Fun(TwoWindowFilter(lag=lag), student.baseline_grades)
    plot_data.update({f"grades-{idx}": (X, Y)})
    plot_data.update({f"baseline-{idx}": (X, Z)})
Plot(**plot_data)

In [None]:
import time
plot_data_2 = {}
lag = 64
for (idx, student) in enumerate(classroom.students):
    X = Fun(Sum(), student.times)
    Y = Fun(lambda x, y: x - y, student.grades, student.baseline_grades)
    Y = Fun(TwoWindowFilter(lag=lag), Y.output, aux=Y)
    plot_data_2.update({f"improvement-{idx}": (X, Y)})
Plot(**plot_data_2)

## stats

In [None]:
for (idx, student) in enumerate(classroom.students):
    print(f"\nStudent #{idx}\n==========")
    N = 8192
    n = len(student.times)-1
    time = student.time
    mean_grade = np.mean(np.array(student.grades[n-N:n]))
    mean_baseline_grade = np.mean(np.array(student.baseline_grades[n-N:n]))
    mean_improvement = mean_grade - mean_baseline_grade
    improvement_rate = mean_improvement / (time - time_of_last_baseline)
    message = '\n'.join([
        f"lr                    = {student.optimizer.state['language_model.module.layers.0.weight']['lr'](n)}",
        f"batch_size            = {student.batch_size}",
        f"example_length        = {student.example_length}",
        f"n                     = {n}",
        f"time                  = {int(time)}s",
        f"time_of_last_baseline = {int(time_of_last_baseline)}s",
        f"steps per second      = {(n/time)}",
        f"mean_baseline_grade   = {mean_baseline_grade}",
        f"mean_grade            = {mean_grade}",
        f"mean_improvement      = {mean_improvement}",
        f"improvement_rate      = {improvement_rate} per second",
    ])
    print(message)

## save

In [None]:
torch.save(student.model, f='checkpoint.pt')

## parameter histograms

In [None]:
histograms = []
for (idx, (pn, p)) in enumerate(student.model.named_parameters()):
    with torch.no_grad():
        print(idx, pn, torch.sqrt(torch.var(p)).item())
        Y, X = np.histogram(p.detach().cpu().numpy(), bins=int(sqrt(torch.numel(p))), density=True)
        print(X.shape, Y.shape)
        histograms.append(Plot(**{f"hist-{idx}": (X.tolist(), Y.tolist())}))

In [None]:
histograms[3]

## batch-level grade histogram

In [None]:
Y, X = np.histogram(student.grades[-5000:], bins=256, range=(0,1.0), density=True)
V, U = np.histogram(student.baseline_grades[-5000:], bins=256, range=(0,1.0), density=True)
Plot(**{f"grade-hist": (X, Y), "baseline": (U, V)})

In [None]:
model.n_ctx, model.d_model, model.d_hidden, model.n_layers

## example-level grade histogram

In [None]:
def get_graded_examples():
    result = []
    for batch_idx in range(16):
        x = student.dataset.batch(student.batch_size, student.example_length)
        with torch.no_grad():
            y = student.model(x)
            x = x.cpu().numpy()
            y = 1.0 - y.cpu().numpy()
            result.append(np.concatenate([x, y], axis=1))
    data = np.concatenate(result, axis=0)
    result = []
    for i in range(data.shape[0]):
        bs = bytes(data[i,:-1].astype(int).tolist())
        ex = utf8decode(bs).replace('\n','@').replace('\r', '@')
        if len(ex) < data.shape[1]:
            ex = ' '*(data.shape[1]-len(ex)) + ex
        result.append((f"'{ex}' {int(1000*data[i,-1])}", data[i,-1]))
    result = sorted(result, key=lambda x: x[1])
    return result

In [None]:
graded_examples = get_graded_examples()

In [None]:
example_grades = []
for example, grade in graded_examples:
    print(example, grade)
    example_grades.append(grade)

In [None]:
R = (0, 1)
bins = int(sqrt(len(example_grades)))
Y, X = np.histogram(example_grades, bins=bins, range=R, density=True)
Plot(**{f"examples-hist": (X, Y)})

In [None]:
np.mean(example_grades)