In [1]:
import asyncio
from math import log, sin, cos, tan, exp, sqrt, pi
from time import time, sleep
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import UTF8Dataset as Dataset
from classroom import MLPLM
from classroom import TransformerLM
from classroom import AdamW, Sonny, Floyd
from classroom import Plot
from classroom import FilteredList
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import CountFilter
from classroom import SumFilter
from classroom.util import numel

## Training

## Scheduling helpers

In [2]:
constant = lambda c: lambda step: c
linear_warmup_then_decay = (lambda lr, warmup: 
    lambda n: lr*(n/warmup) if n < warmup else lr*(warmup/n))

## MLP Language Model

In [3]:
def mlp_factory():
    model = MLPLM(
        n_vocab_in=256,
        n_vocab_out=256,
        n_ctx=64,
        d_model=32,
        d_hidden=8192*2,
        nonlinearity="sigmoid").to('cuda')
    optimizer = AdamW(
        parameters=model.parameters(), 
        eps=constant(1e-8), 
        lr=constant(1e-3),
        beta1=constant(0.9),
        beta2=constant(0.999),
        n=0)
    dataset = Dataset()
    batch_size = 512
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [4]:
def trans_factory():
    model = TransformerLM(
        n_vocab_in=256,
        n_vocab_out=256,
        n_ctx=64,
        d_model=1024,
        d_k=32,
        d_v=32,
        n_heads=32,
        d_hidden=4096,
        n_layers=2,
        p_dropout_embedding=0.1,
        p_dropout_attn_mat=0.1,
        p_dropout_attn_out=0.1,
        p_dropout_mlp=0.1).to('cuda')

    optimizer = AdamW(
        parameters=model.parameters(),
        eps=constant(1e-4),
        lr=lambda n: sin(n/500)**2 * 1e-5, 
        beta1=constant(0.9),
        beta2=constant(0.999),
        weight_decay=constant(0.01),
        n=0)
    dataset = Dataset()
    batch_size = 1
    example_length = model.n_ctx + 1
    return Student(model=model, optimizer=optimizer,
                    dataset=dataset, batch_size=batch_size,
                    example_length=example_length)

In [5]:
classroom = Classroom()
#trans_student = trans_factory()
students = [mlp_factory()]
for student in students:
    classroom.enroll(student)

## Autocompleting

In [12]:
print(classroom.students[-1].autocomplete(n_generate=1024))

ides from with a main to mon, woo-can formed at opher to
its the both miging through could.

   By too, if munich your Liking in stay. Ex Inded that it, As a
the is not full hatint man hes came reax,
           A Doval, and with their inclesing is,
  9½ 301|p.761, 931125. Ohd salk fiom Peot be an hem in a tuler
regard the Satero_,
and town consmit her. She I'allowed, impressing in a by 2883, and
head ons the Figut is trage, and the[3]
Anjotion Imposen of Semarications, sevisitual time incompanion musicion
in lest or lanner preperant of stutiois noborie courted divigue what kept
atterialight and the minsion Scote; and a tolife any in the lage
father?'" side note the such not kim kit dis's mean. He his cevenently?'"

"Hosough yooks a true sick that some in mymble would I am the coming gab
will. Lerb Rode. Lite I word, Mir Applies Commanion_, 180 M.).


_Amitens_.--Levr. 16.

Timponer, Jan deanted Stunest lew

_--_Quaven crogable, no lor in pie, finalle et ciblord'd

## Training Visualization Plots

In [14]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = FilteredList(student.times, SumFilter())
    Y = FilteredList(student.grades, TwoWindowFilter(lag=2048))
    plot_data.update({f"{idx}": (X, Y)})
Plot(**plot_data)



In [15]:
len(classroom.students[0].times)

37222

## Genetic algorithm

In [None]:
graveyard = []
async def mauler():
    while True:
        classroom.rank_students()
        student = classroom.students[-1].clone()
        student.mutate()
        classroom.enroll(student)
        await asyncio.sleep(10.0)
        classroom.rank_students()
        student = classroom.graduate()
        graveyard.append((student.times, student.grades))
        del student
task = asyncio.create_task(mauler())

In [None]:
import time
plot_data = {}
for (idx, (U, V)) in enumerate(graveyard):
    X = FilteredList(U, SumFilter())
    Y = FilteredList(V, TwoWindowFilter())
    plot_data.update({f"{idx}": (X, Y)})
Plot(legend=False, **plot_data)