In [1]:
%load_ext autoreload

In [2]:
from lovely_tensors.patch import monkey_patch

monkey_patch()
import torch
from transformers import GPT2Tokenizer
import wandb
from tqdm.auto import tqdm

In [3]:
with open("tiny_shakespeare.txt", "r") as f:
    data = f.read()
chars = sorted(list(set(data)))

# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


def encode(s):
    return [stoi[c] for c in s]  # encoder: take a string, output a list of integers


def decode(l):
    return "".join(
        [itos[i] for i in l]
    )  # decoder: take a list of integers, output a string

encoded_data = encode(data)

In [4]:
train_data = encoded_data[: int(len(encoded_data) * 0.8)]
val_data = encoded_data[int(len(encoded_data) * 0.8) :]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = torch.tensor(train_data).to(device)
val_data = torch.tensor(val_data).to(device)

In [148]:
def get_item(data, ctx):
    # i = random.randint(0, len(data) - ctx - 1)
    i = 0
    while i + ctx < len(data):
        src = data[i : i + ctx]
        dst = data[i + 1 : i + ctx + 1]
        yield src, dst
        i += ctx


def get_epoch(data, ctx_len, batch_size, shuffle=True):
    """Yields a tuple of tensors of shape (batch_size, ctx).
    X, shape=B C
    y, shape=B C
    """

    items = get_item(data, ctx_len)

    try:
        while True:
            X, y = zip(*[next(items) for _ in range(batch_size)])
            yield torch.stack(X), torch.stack(y)
    except StopIteration:
        pass

In [42]:
import random


def get_random_item(data, ctx):
    i = random.randint(0, len(data) - ctx - 1)

    src = data[i : i + ctx]
    dst = data[i + 1 : i + ctx + 1]

    return src, dst


def get_batch(data, ctx_len, batch_size, shuffle=True):
    """Yields a tuple of tensors of shape (batch_size, ctx).
    X, shape=B C
    y, shape=B C
    """

    batch = [get_random_item(data, ctx_len) for _ in range(batch_size)]
    X, y = zip(*batch)

    return torch.stack(X), torch.stack(y)


get_batch(train_data[:100], ctx_len=5, batch_size=2)

(tensor[2, 5] i64 n=10 x∈[0, 56] μ=37.700 σ=18.833 cuda:0 [[43, 44, 53, 56, 43], [10, 0, 31, 54, 43]],
 tensor[2, 5] i64 n=10 x∈[0, 56] μ=36.400 σ=20.354 cuda:0 [[44, 53, 56, 43, 1], [0, 31, 54, 43, 39]])

In [6]:

def model(params, input_ids, vocab_size):
    """This model takes in a sequence and predicts 1 token"""

    one_hot_inputs = torch.nn.functional.one_hot(input_ids, num_classes=vocab_size)
    one_hot_inputs = one_hot_inputs.float()

    embeddings = one_hot_inputs @ params["embedding"].T  # N, CTX_LEN, EMBEDDING_DIM

    # preds = hidden_states[:, -1, :] # @ params["w"]

    hidden_state = (
        embeddings.view((input_ids.shape[0], -1)) @ params["w1"].T + params["b1"]
    )
    hidden_state = torch.nn.functional.relu(hidden_state)

    hidden_state = (
        hidden_state.view((input_ids.shape[0], -1)) @ params["w2"].T + params["b2"]
    )
    hidden_state = torch.nn.functional.relu(hidden_state)

    preds = hidden_state @ params["embedding"]

    return preds



In [135]:
def generate(model, model_params, ctx_len, encoded_prompt, n_tokens):
    """Generate n_tokens after prompt"""
    with torch.no_grad():
        # encoded_prompt = torch.tensor(encoded_prompt).to(device).unsqueeze(0)
        for _ in range(n_tokens):
            # print(encoded_prompt)
            preds = model(
                model_params, encoded_prompt[:, -ctx_len:], vocab_size=len(chars)
            )
            next_token = torch.argmax(preds, dim=1, keepdim=True)
            # print(f"{next_token=}")
            encoded_prompt = torch.cat([encoded_prompt, next_token], dim=1)
    return encoded_prompt

In [146]:
CTX_LEN = 12

EMBEDDING_DIM = 128

INTERMEDIATE_DIM = EMBEDDING_DIM * 8

BATCH_SIZE = 4096
LR = 0.1
LOG_INTERVAL = (len(train_data) // (BATCH_SIZE * CTX_LEN) // 10) + 1
VALIDATION_INTERVAL = (len(train_data) // (BATCH_SIZE * CTX_LEN) // 5) + 1
TRAIN_TOKENS = len(train_data) * 30


run = wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    entity="llmnerds",
    config={
        "batch_size": BATCH_SIZE,
        "ctx": CTX_LEN,
    },
)

model_params = {
    "embedding": torch.randn((EMBEDDING_DIM, len(chars)), device=device),
    "w1": torch.randn((INTERMEDIATE_DIM, EMBEDDING_DIM * CTX_LEN), device=device),
    "b1": torch.randn((INTERMEDIATE_DIM,), device=device),
    "w2": torch.randn((EMBEDDING_DIM, INTERMEDIATE_DIM), device=device),
    "b2": torch.randn((EMBEDDING_DIM,), device=device),
    # "classifier": torch.randn(
    #     (INTERMEDIATE_DIM, len(chars)), device=device, requires_grad=True
    # ),
}

# # glorot init
for p in model_params.values():
    if len(p.shape) == 2:
        torch.nn.init.kaiming_normal_(p)
    p.requires_grad = True


i = 1
total_loss = 0
val_total_loss = 0

token_count = 0
optim = torch.optim.Adam(model_params.values(), lr=1e-3)

# batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE, shuffle=True)
with run:
    while token_count < TRAIN_TOKENS:
        X, y = get_batch(train_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE, shuffle=True)

        token_count = i * BATCH_SIZE * CTX_LEN

        preds = model(params=model_params, input_ids=X, vocab_size=len(chars))
        loss = torch.nn.functional.cross_entropy(input=preds, target=y[:, -1])
        total_loss += loss.item()
        loss.backward()

        with torch.no_grad():
            optim.step()
            optim.zero_grad()

            # for param in model_params.values():
            #     param -= LR * param.grad
            #     param.grad.zero_()

        if i % LOG_INTERVAL == 0:
            wandb.log(
                {
                    "loss": total_loss / LOG_INTERVAL,
                    "epoch": token_count // len(train_data),
                },
                step=token_count,
            )
            total_loss = 0

        if i % VALIDATION_INTERVAL == 0:
            j = 0
            for X_val, y_val in get_epoch(
                val_data, ctx_len=CTX_LEN, batch_size=4096, shuffle=False
            ):
                with torch.no_grad():
                    preds = model(
                        params=model_params, input_ids=X_val, vocab_size=len(chars)
                    )
                    loss = torch.nn.functional.cross_entropy(preds, y_val[:, -1])
                    val_total_loss += loss.item()
                    j += 1
            wandb.log({"val_loss": val_total_loss / j}, step=token_count, commit=True)

            prompts = get_batch(val_data, ctx_len=CTX_LEN, batch_size=5)[0].to(device)
            generated = generate(
                model=model,
                model_params=model_params,
                encoded_prompt=prompts,
                ctx_len=CTX_LEN,
                n_tokens=CTX_LEN * 2,
            )
            for p in generated:
                char_list = p.tolist()
                pre_prompt = char_list[:CTX_LEN]
                post_prompt = char_list[CTX_LEN:]

                print(f"{repr(decode(pre_prompt))}-> {repr(decode(post_prompt))}")

            val_total_loss = 0
        i += 1

'bour\nIs the '-> '                        '
'ignior Petru'-> '                        '
'are gentleme'-> '                        '
':\nOne, Kate,'-> '                        '
'ind when he '-> '                        '
',\nAnd say sh'-> '   n                    '
'\nOf palsied '-> 'an  an  on  on  on  on  '
'o thee belon'-> '  o   an   n   an  on  o'
'athers commo'-> '  o   n                 '
'he duke gone'-> ' an  he  on  on  on  on '
'o you, trenc'-> ' the    t               '
'part\nWas apt'-> ' the  o t  t            '
'hall see how'-> '  oe the    t           '
'ckless, and '-> 'the    t                '
'schance of\nt'-> ' t  t e  t t            '
'\nISABELLA:\nH'-> 'e  oo  oo  hor  ore ror '
'leep,\nDreami'-> 'n the  ore io the  ooe  '
'ut up my tho'-> 'r  oo  hor  ore ror  oo '
'In cypress c'-> 'or  oo  ore tore toe  oo'
' for the ent'-> 'he  oo  oo  oo  oo  oo  '
'ignior Lucio'-> '  oo  hor  oo  oor sore '
' will your h'-> 'or sore sor sor  oor sor'
'one,\nThough '-> 'so  



VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
loss,█▇▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,29.0
loss,1.57991
val_loss,1.84766


In [122]:
prompts = get_batch(val_data, ctx_len=CTX_LEN, batch_size=5)[0].to(device)

In [147]:
generated = generate(
    model=model,
    model_params=model_params,
    encoded_prompt=prompts,
    ctx_len=CTX_LEN,
    n_tokens=300,
)
for p in generated:
    char_list = p.tolist()
    pre_prompt = char_list[:CTX_LEN]
    post_prompt = char_list[CTX_LEN:]

    print(f"{repr(decode(pre_prompt))}-> {repr(decode(post_prompt))}")

'orities ther'-> 'e are the sen the words and the state of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sear'
'before and w'-> 'ith the stain and the well as the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the wi'
'RUCHIO:\nI sa'-> 'y the state of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of the with the sears of'
'h such beaut'-> 'ed the state of the w

In [150]:
import json
d = json.load(open("tmp/data00.json", "r"))

{'story': "Once upon a time, there was a cute puppy named Max. Max was very adorable with his big, brown eyes and wagging tail. One day, Max's owner, Emily, told him that they needed to go to the post office to mail a letter. Max didn't know what that meant, but he was excited to go for a car ride.\nAt the post office, Emily gave the letter to the nice lady behind the desk. The lady asked Emily for a number and Emily gave her one. Max didn't know what a number was, but he saw the lady type something on the computer.\nAfter they mailed the letter, Emily and Max went back to the car. Max was happy that they went on an adventure and he couldn't wait for the next one.",
 'instruction': {'prompt:': 'Write a short story (3-5 paragraphs) which only uses very simple words that a 3 year old child would understand. In the story, try to at some point use the verb "mail", the noun "number" and the adjective "adorable". Remember to only use simple words!',
  'words': ['mail', 'number', 'adorable'],