In [2]:
import asyncio
import math
from math import log, sin, cos, tan, exp, sqrt, pi
import time
from random import randrange
import torch
import numpy as np
from classroom import Classroom
from classroom import Student
from classroom import BytesDataset
from classroom import GutenbergSnippetsDataset
from classroom import MLPLM
from classroom import TransformerLM
from classroom import AdamW
from classroom import Sonny
from classroom import Floyd
from classroom import Plot
from classroom import Fun
from classroom import Count
from classroom import Sum
from classroom import Diff
from classroom import Log2Sum
from classroom import KalmanFilter1D
from classroom import MedianFilter
from classroom import TwoWindowFilter
from classroom import numel

In [3]:
def student_factory(checkpoint=None):
    if checkpoint is None:
        model = MLPLM(
            n_vocab_in=256,
            n_vocab_out=256,
            n_ctx=64,
            d_model=64,
            d_hidden=8192,
            nonlinearity="GELU").to('cuda')
    else:
        model = torch.load(checkpoint).to('cuda')
    optimizer = AdamW(
        parameters=model.named_parameters(), 
        lr=lambda n: 1e-3*(1000/n) if n > 1000 else 1e-3*(n/1000),
        beta1=lambda n: 0.9,
        beta2=lambda n: 0.999,
        weight_decay=lambda n: 0.01,
        update=lambda n: True,
        n=0)


    dataset = GutenbergSnippetsDataset()
    batch_size = 512
    example_length = model.n_ctx + 1
    return Student(
        model=model,
        optimizer=optimizer,
        dataset=dataset,
        batch_size=batch_size,
        example_length=example_length)

In [4]:
dataset = GutenbergSnippetsDataset()

In [None]:
path = 'checkpoint.pt'

In [5]:
classroom = Classroom()

In [6]:
try:
    student = student_factory(path)
except:
    student = student_factory()

In [19]:
student.push()

In [8]:
classroom.enroll(student)

In [9]:
classroom.students

{<classroom.student.student.Student at 0x7f5797c6a790>: <Task pending name='Task-1' coro=<Classroom.enroll.<locals>._train() running at /home/sharker/github/classroom/classroom/classroom/classroom.py:18> wait_for=<Future pending cb=[<TaskWakeupMethWrapper object at 0x7f5799c4be50>()]>>}

In [10]:
student.model, numel(student.model)

(MLPLM(
   (language_model): LanguageModel(
     (module): Sequential(
       (layers): ModuleList(
         (0): Embedding(256, 64)
         (1): Lambda()
         (2): MLP(
           (sequential): Sequential(
             (layers): ModuleList(
               (0): Affine(in_features=4096, out_features=8192, bias=True)
               (1): Nonlinearity(
                 (f): GELU()
               )
               (2): Affine(in_features=8192, out_features=256, bias=True)
             )
           )
         )
         (3): Lambda()
       )
     )
     (split_example): SplitExample()
     (crossentropyloss): CrossEntropyLoss(
       (crossentropyloss): CrossEntropyLoss()
     )
     (softmax): Softmax(
       (softmax): Softmax(dim=-1)
     )
   )
 ),
 35676416)

## Autocompleting

In [21]:
def autocomplete():
    for (idx, student) in enumerate(classroom.students):
        print(f"\n\nStudent #{idx}\n==========")
        print(student.autocomplete(n_generate=1024))
autocomplete()



Student #0
te reno ses en 1e-hes_ic0 sten

t Cx cer, ItE, minlo s na1n ielladin tar cise snd wnri.Psgor ,e he yrit te" pdemest le shedpswon d s r biar tea me lem ins mand 1roy st  Su siklon ini urtion
at ousbgg5B tostacrl se iorkcmu
em Irat ;esN= w eretl   flreco
n Pe  roul yoj,he
c.tou
er awc.he livh mont thengrecton
ur ing
stheonrai, snllaons fnerila yh is\gu orlpagd fbo,char'itTitis vhisvinh Hp;[itoves ko-lt astrotieithet Sor oatermon mh. duthell hengerearlynss cie manthi goliunk ont uurthe .nvay, moranntSamarroner wour
c4,r ca torcin tof copiaredn alr blnkredrisc.sy,erilh an
barlaesinve fku
eIutopeurECes ine

e.tt on ma-ugap2 Ouiowhn vonge_mes c, cat ae mad (ina dyu? y ct peclr ivgs (is"eop'roefion 'merp(.u
e d }lqHt.ve sni e ferir1rina slzse temiT Ded,.nld mrqeij on rmmea

 blc ikoendmiansr.ttien ths peptofat en ven c.  he thieide shas  veg  ard he S avt mhs n dhfolm toiv


In [50]:
def pushall():
    for student in classroom.students:
        student.push()
pushall()

## Training Visualization Plots

In [24]:
import time
plot_data = {}
for (idx, student) in enumerate(classroom.students):
    X = Fun(Count(), student.times)
    Y = Fun(TwoWindowFilter(lag=64), student.grades)
    Z = Fun(TwoWindowFilter(lag=64), student.baseline_grades)
    W = Fun(TwoWindowFilter(lag=64), student.predicted_grades)
    plot_data.update({f"grades-{idx}": (X, Y)})
    plot_data.update({f"baseline-{idx}": (X, Z)})
    plot_data.update({f"predicted-{idx}": (X, W)})

Plot(**plot_data)



In [15]:
import time
plot_data_2 = {}
for (idx, student) in enumerate(classroom.students):
    X = Fun(Sum(), student.times)
    X2 = Fun(lambda x: math.log(x), X.output, aux=X)
    Y = Fun(lambda x, y: x - y, student.grades, student.baseline_grades)
    Z = Fun(lambda x, y: x - y, student.grades, student.predicted_grades)
    Y = Fun(TwoWindowFilter(lag=64), Y.output, aux=Y)
    Z = Fun(TwoWindowFilter(lag=64), Z.output, aux=Z)
    plot_data_2.update({f"improvement-{idx}": (X, Y)})
    plot_data_2.update({f"prediction-error-{idx}": (X, Z)})
Plot(**plot_data_2)



In [96]:
student.push()
time_of_last_baseline = student.time


In [None]:
plot_data_2

In [99]:
for (idx, student) in enumerate(classroom.students):
    print(f"\nStudent #{idx}\n==========")
    n = len(student.times)-1
    time = student.time #sum(student.times[:n])
    mean_grade = np.mean(np.array(student.grades[n-1024:n]))
    mean_baseline_grade = np.mean(np.array(student.baseline_grades[n-1024:n]))
    mean_predicted_grade = np.mean(np.array(student.predicted_grades[n-1024:n]))
    accuracy = 1.0 - abs(mean_predicted_grade - mean_grade)/(mean_grade)

    mean_improvement = mean_grade - mean_baseline_grade
    improvement_rate = mean_improvement / (time - time_of_last_baseline)
    time_to_next_level = 0.01*( 1.0 - (100*mean_grade - int(100*mean_grade)))/improvement_rate
    message = '\n'.join([
        f"lr                        = {student.optimizer.state['language_model.module.layers.0.weight']['lr'](n)}",
        f"batch_size                = {student.batch_size}",
        f"example_length            = {student.example_length}",
        f"n                         = {n}",
        f"time                      = {int(time)}s",
        f"time_of_last_baseline     = {int(time_of_last_baseline)}s",
        f"steps per second          = {(n/time)}",
        f"mean_baseline_grade       = {mean_baseline_grade}",
        f"mean_grade                = {mean_grade}",
        f"mean_predicted_grade      = {mean_predicted_grade}",
        f"accuracy                  = {accuracy}",
        f"mean_improvement          = {mean_improvement}",
        f"improvement_rate          = {improvement_rate} per second",
        f"time_to_next_level        = {time_to_next_level}"
    ])
    print(message)


Student #0
lr                        = 1.1287955751213454e-05
batch_size                = 512
example_length            = 65
n                         = 88590
time                      = 3097s
time_of_last_baseline     = 2515s
steps per second          = 28.59948450785202
mean_baseline_grade       = 0.7017298742139246
mean_grade                = 0.7090239206590923
mean_predicted_grade      = 0.7089395612565568
accuracy                  = 0.9998810203717005
mean_improvement          = 0.0072940464451676235
improvement_rate          = 1.2535326396346218e-05 per second
time_to_next_level        = 77.8662884432153


In [43]:
261684/86400

3.02875

In [None]:
simulated_batch_size = 1024
k = simulated_batch_size // student.batch_size
print(k)

In [None]:
for pn in student.optimizer.state:
    student.optimizer.state[pn]["lr"]           = lambda n: 1e-5
    student.optimizer.state[pn]["beta1"]        = lambda n: 0.9
    student.optimizer.state[pn]["beta2"]        = lambda n: 0.999
    student.optimizer.state[pn]["weight_decay"] = lambda n: 0.01
    student.optimizer.state[pn]["update"]       = lambda n: True

In [None]:
async def evolution(classroom):
    relative_grades = {student: Fun(KalmanFilter1D(), student.relative_grades) for student in classroom.students}
    while True:
        await asyncio.sleep(6.0)
        ranked_students = sorted([student for student in classroom.students], key=lambda s: relative_grades[s].output[-1])
        top_student = ranked_students[-1]
        bot_student = ranked_students[0]
        if bot_student != top_student:
            classroom.graduate(bot_student)
            del relative_grades[bot_student]
        clone = top_student.clone()
        clone.mutate()
        classroom.enroll(clone)
        relative_grades[clone] = Fun(KalmanFilter1D(), clone.relative_grades)


In [None]:
asyncio.create_task(evolution(classroom))

In [94]:
torch.save(student.model, f='checkpoint.pt')

In [None]:
dataset