In [None]:
import math
from math import log
import time
import torch
import numpy as np
from torch.optim import AdamW, SGD
from testbed import ShortDataset, TextDataset, Trainer, Net0, Net1, Net2, Net3, Net4, Transformer
from testbed.optim import Sonny
from testbed.util import decode_broken_utf8, default_device, numel
from testbed.gui import Plot, StatsTicker, ParameterInspector, Histogram, SmoothPlot, LinePlot

In [None]:
network_name = "GPT2Transformer"

In [None]:
def cuda_memory():
    t = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0) 
    a = torch.cuda.memory_allocated(0)
    f = r-a  # free inside reserved
    print(f"Total {t}. Reserved {r}. Allocated {a}. Free {f}.")
    return (f, a, r, t) # code smell?

def memory_allocated():
    return torch.cuda.memory_allocated(0)

In [None]:
if network_name == "Net0":
    num_input_classes= 256 # 256 possible UTF-8 bytes
    embedding_dim = 32 # Dimension of embedding space. An embedding layer has 256 points in this space.
    context_length = 32 # Number of sequential bytes visible to model (i.e. in the context)
    num_hidden = 8192 # Hyperparameter for neural network
    num_output_classes = 256 # 256 possible UTF-8 bytes
    model = Net0(num_input_classes=num_input_classes,
                 embedding_dim=embedding_dim,
                 context_length=context_length,
                 num_hidden=num_hidden,
                 num_output_classes=num_output_classes,
                 nonlinearity="sigmoid").to(default_device())
    example_length = context_length + 1
    dataset = TextDataset(example_length=example_length)
    batch_size = 512 # batch size (i.e. examples per batch)
    OptimizerType = Sonny
    optimizer_kwargs = {"eps": 1e-8, "weight_decay": 0.01}

In [None]:
if network_name == "Net3":
    embedding_dim = 2
    context_length = 32
    num_hidden1 = 64
    num_hidden2 = 64
    model = Net3(embedding_dim=embedding_dim,
                 context_length=context_length,
                 num_hidden1=num_hidden1,
                 num_hidden2=num_hidden2).to(default_device())
    example_length = context_length + 1
    dataset = TextDataset(example_length=example_length)
    batch_size = 512
    OptimizerType = Sonny
    optimizer_kwargs = {}

In [None]:
if network_name == "Net4":
    num_input_classes= 256 # 256 possible UTF-8 bytes
    embedding_dim = 128 # Dimension of embedding space. An embedding layer has 256 points in this space.
    context_length = 128 # Number of sequential bytes visible to model (i.e. in the context)
    num_hidden = 4096 # Hyperparameter for neural network
    num_output_classes = 256 # 256 possible UTF-8 bytes
    model = Net4(num_input_classes=num_input_classes,
                 embedding_dim=embedding_dim,
                 context_length=context_length,
                 num_hidden=num_hidden,
                 num_output_classes=num_output_classes,
                 nonlinearity="GELU").to(default_device())
    example_length = context_length + 1
    dataset = TextDataset(example_length=example_length)
    batch_size = 8192 # batch size (i.e. examples per batch)
    OptimizerType = Sonny
    optimizer_kwargs = {"eps": 1e-4, 
                        "lr": .0001, 
                        "beta1": .9, 
                        "beta2": .999,
                        "weight_decay": 0.0001}

In [None]:
if network_name == "Transformer":
    model = Transformer(
        n_vocab=256,
        max_ctx=512,
        d_model=1024,
        n_heads=32,
        d_ff=4096,
        n_layers=4).to(default_device())
    example_length = model.max_ctx + 1
    dataset = TextDataset(example_length=example_length)
    batch_size = 8 # batch size (i.e. examples per batch)
    OptimizerType = Sonny
    optimizer_kwargs = {"eps": 1e-6, 
                        "lr": .00001, 
                        "beta1": .9, 
                        "beta2": .999,
                        "weight_decay": 0.001}

In [None]:
if network_name == "GPT2Transformer":
    ModelType = Transformer
    model_kwargs = {
        "n_vocab":50257,
        "max_ctx":512,
        "d_model":768,
        "n_heads":12,
        "d_ff":3072,
        "n_layers":12}
    model = ModelType(**model_kwargs).to(default_device())
    example_length = 128 + 1 # like bert-uncased
    dataset = ShortDataset(example_length=example_length)
    batch_size = 256 # batch size (i.e. examples per batch)
    OptimizerType = Sonny
    optimizer_kwargs = {"eps": 1e-4, 
                        "lr": .0001, 
                        "beta1": .9, 
                        "beta2": .999,
                        "weight_decay": 0.01,
                        "warmup": 10000}
    DatasetType = ShortDataset
    dataset_kwargs = {"example_length": example_length}

In [None]:
trainer = Trainer(model=model, 
                  example_length=example_length, 
                  batch_size=batch_size,
                  OptimizerType=OptimizerType,
                  optimizer_kwargs=optimizer_kwargs,
                  DatasetType=DatasetType,
                  dataset_kwargs=dataset_kwargs)

In [None]:
numel(model), model.name()

In [None]:
trainer.start()

In [None]:
ticker = StatsTicker(trainer, kind='line')
ticker

In [None]:
StatsTicker(trainer, x='compute_time', y='compute_energy')

In [None]:
for (name, p) in model.named_parameters():
    print(name, p.device, p.numel(), p.shape)

In [None]:
trainer.set_optimizer_settings(lr=.00001, beta1=.9, beta2=.999, batch_size=256, weight_decay=0.01)

In [None]:
lyles_constant = 9115131782/14818489608 * 8 # convert from gpt2 loss to bpc loss, estimated factor
lyles_constant

In [None]:
1.8/lyles_constant

In [None]:
lyles_constant*trainer.losses[-1]["mean_loss"]

In [None]:
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
n_vocab = acmodel.n_vocab
def gpt2autocomplete(model, prompt=" A", n_generate=128, max_ctx=None, device=None):
    model_device = next(model.parameters()).device
    if device is None:
        device = model_device
        inference_model = model
    if device != model_device:
        torch.save(model.state_dict(), "autocomplete.pt")
        inference_model = ModelType(**model_kwargs)
        inference_model.load_state_dict(torch.load("autocomplete.pt",map_location='cpu'))
        inference_model.eval() # since probs is marked no_grad, does this matter?
    if max_ctx is None:
        max_ctx = inference_model.max_ctx
    prompt = tokenizer(prompt)['input_ids']
    completion = prompt[:]
    for idx in range(n_generate):
        x = (torch.tensor(prompt, dtype=torch.long)
                  .unsqueeze(0)
                  .to('cpu')) # shape [1,L]
        P = inference_model.probs(x).view(-1)[-n_vocab:]
        prob_dist = torch.distributions.Categorical(P)
        c_ord = prob_dist.sample().item()
        prompt = prompt + [c_ord]
        completion = completion + [c_ord]
        if len(prompt) == max_ctx + 1:
            prompt = prompt[1:]
    return tokenizer.decode(completion)

In [None]:
%%time
print(gpt2autocomplete(prompt=" As it happens, I know several such people."))

In [None]:
trainer.save(model.name())

In [None]:
trainer.step

In [None]:
trainer.losses[-1]["compute_time"]/trainer.step

In [None]:
trainer.losses[-10:]

In [None]:
def floprecurse(d, s):
    results = []
    if type(d) == dict:
        try:
            if "time" in d and "energy" in d:
                results.append(f'"{s}": {{"time": {d["time"]}, "energy": {d["energy"]},'
                               f' "tflops"= {d["energy"]/d["time"]/1E12}}}')
        except:
            pass
        for (k, v) in d.items():
            results = results + floprecurse(v, f"{s}.{k}")
    if type(d) == list:
        for (k, v) in enumerate(d):
            results = results + floprecurse(v, f"{s}.{k}")
    return results            

In [None]:
d = trainer.losses[-1]
d

In [None]:
A = floprecurse(d, "root")
for a in A:
    print(a)

## `Net0` autocomplete

In [None]:
trainer.autocomplete()
pass

### SmoothPlot

In [None]:
L = np.array([[x['compute_time'],x['mean_loss']] for x in trainer.losses])
X = L[:,0]
Y = L[:,1]
def smoother(data, lag):
    cs = np.cumsum(data)
    return (cs[lag:] - cs[:-lag])/lag

class SmoothPlot(LinePlot):
    def __init__(self, X=None, Y=None, lag=100, log=None):
        if X is not None:
            if Y is None:
                Y = np.array(X)
                X = np.array(range(len(X)))
            else:
                X = np.array(X)
                Y = np.array(Y)
            X = X[lag:]
            Y = smoother(Y, lag)
            if log:
                X = np.log(X)/math.log(2)
        super().__init__(X, Y)
SmoothPlot(X, Y, lag=100, log=True)

# Testing

In [None]:
from testbed import ShortDataset
dataset = ShortDataset(example_length=1024)

In [None]:
dataset[0]

In [None]:
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

In [None]:
print(tokenizer.decode(dataset[0]))

In [None]:
%%timeit
dataset[0]

In [None]:
len(dataset[0])

In [None]:
from collections import defaultdict
data = defaultdict(defaultdict)

In [None]:
data["dog"]["cat"] = 0.0

In [None]:
len(tokenizer.get_vocab())

In [None]:
from time import time
def timing_study():

    n_vocab=50257
    max_ctx=1024
    d_model=1024
    n_heads=32
    d_ff=4096,
    n_layers=4
    n_ctx = 1024
    d_head = d_model // n_heads
    assert d_model == n_heads * d_head
    torch.cuda.synchronize()
    data = defaultdict(lambda: defaultdict(float))
    X = torch.randn(n_ctx, d_model, device='cuda')
    Q = X
    K = X
    V = X
    input_shape = X.shape
    def split_heads(x):
        return x.view(x.shape[:-1] + (n_heads, d_head)).transpose(-2, -3).contiguous()
    def merge_heads(x):
        x = x.transpose(-2,-3).contiguous()
        return x.view(x.shape[:-2] + (d_model,))
    additive_mask = 1.0-1.0/torch.tril(torch.ones(n_ctx,n_ctx, device=X.device))
    additive_mask = torch.zeros(n_ctx,n_ctx, device=X.device)
    @torch.no_grad()
    def compute(Q0,K0,V0,data, N):
        torch.cuda.synchronize()
        t = time()
        for _ in range(N):
            Q = split_heads(Q0)/math.sqrt(d_head)
            K = split_heads(K0)
            V = split_heads(V0)
            torch.cuda.synchronize()
            data["split_heads"]["data"] += torch.numel(Q) + torch.numel(K) + torch.numel(V)
        data["split_heads"]["energy"] += 3.0 * N * (torch.numel(Q) + torch.numel(K) + torch.numel(V))
        data["split_heads"]["time"] += time() - t


        t = time()
        for _ in range(N):
            QKT = torch.matmul(Q, K.transpose(-1,-2)).contiguous()
            torch.cuda.synchronize()
            data["matmul(Q,K^T)"]["data"] += torch.numel(Q) + torch.numel(K) + torch.numel(QKT)
        data["matmul(Q,K^T)"]["energy"] += 3.0 * N * torch.numel(Q) * n_ctx
        data["matmul(Q,K^T)"]["time"] += time() - t

        #print(f"Q.shape={Q.shape}")
        #print(f"K.shape={Q.shape}")
        #print(f"QKT.shape={QKT.shape}")
        t = time()
        for _ in range(N):
            Z = QKT + additive_mask
        torch.cuda.synchronize()
        data["masking_attention"]["data"] += N*(torch.numel(Z) + torch.numel(QKT) + torch.numel(additive_mask))
        data["masking_attention"]["energy"] += 3.0 * N * 4.0 * (torch.numel(Z)) # ?
        data["masking_attention"]["time"] += time() - t
    

#         t = time()
#         for _ in range(N):
#             (Zmax, _) = torch.max(Z,dim=-1,keepdim=True)
#             Z = Z - Zmax # as in GPT-2
#             EZ = torch.exp(Z)
#             sumEZ = torch.sum(EZ,dim=-1,keepdim=True)
#             A = EZ/sumEZ
#             assert A.shape == EZ.shape
#             torch.cuda.synchronize()
#         data["compute_attention"]["energy"] += 3.0 * N * 5.0 * (torch.numel(A)) # ?
#         data["compute_attention"]["time"] += time() - t
        softmax = torch.nn.Softmax(dim=-1)
        t = time()
        for _ in range(N):
            A = softmax(Z)
            torch.cuda.synchronize()
            data["compute_attention"]["data"] += torch.numel(Z) + torch.numel(A)
        data["compute_attention"]["energy"] += 3.0 * N * 5.0 * (torch.numel(A)) # ?
        data["compute_attention"]["time"] += time() - t
        
        t = time()
        for _ in range(N):
            AV = torch.matmul(A,V)
            torch.cuda.synchronize()
            data["matmul(A,V)"]["data"] += torch.numel(A) + torch.numel(V) + torch.numel(AV)
        data["matmul(A,V)"]["energy"] += 3.0 * N * (torch.numel(A) * d_head)
        data["matmul(A,V)"]["time"] += time() - t

        t = time()
        for _ in range(N):
            mergedAV = merge_heads(AV)
            torch.cuda.synchronize()
            data["merge_heads"]["data"] += torch.numel(AV)
        data["merge_heads"]["energy"] += 3.0 * N * (torch.numel(AV))
        data["merge_heads"]["time"] += time() - t
    start_time = time()
    idx = 0
    while time() < start_time + 30:
        compute(Q,K,V,data,64)
        idx = idx + 1
    return (idx, data)

In [None]:
idx, data = timing_study()

In [None]:
idx

In [None]:
data

In [None]:
ttime = sum(v["time"] for (k,v) in data.items())
tenergy = sum(v["energy"] for (k,v) in data.items())

for (k,v) in data.items():
    print(k, v["time"]/ttime, v["energy"]/v["time"]/1E12, v["data"]/v["time"]/1E9)

In [None]:
tenergy/ttime/1E12

In [None]:
s = torch.nn.Softmax(dim=-1)
x = 1.0-1.0/torch.tril(torch.ones(5,5, device='cpu'))
s(x)