In [1]:
%load_ext autoreload

In [2]:
from lovely_tensors.patch import monkey_patch

monkey_patch()
import torch
from transformers import GPT2Tokenizer
import wandb
from tqdm.auto import tqdm

In [3]:
with open("tiny_shakespeare.txt", "r") as f:
    data = f.read()
chars = sorted(list(set(data)))

# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


def encode(s):
    return [stoi[c] for c in s]  # encoder: take a string, output a list of integers


def decode(l):
    return "".join(
        [itos[i] for i in l]
    )  # decoder: take a list of integers, output a string


encoded_data = encode(data)

In [4]:
train_data = encoded_data[: int(len(encoded_data) * 0.8)]
val_data = encoded_data[int(len(encoded_data) * 0.8) :]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = torch.tensor(train_data).to(device)
val_data = torch.tensor(val_data).to(device)

In [5]:
def get_item(data, ctx):
    # i = random.randint(0, len(data) - ctx - 1)
    i = 0
    while i + ctx < len(data):
        src = data[i : i + ctx]
        dst = data[i + 1 : i + ctx + 1]
        yield src, dst
        i += ctx


import random


def get_batch(data, ctx_len, batch_size, shuffle=True):
    """Yields a tuple of tensors of shape (batch_size, ctx).
    X, shape=B C
    y, shape=B C
    """

    if shuffle:
        i = random.randint(0, batch_size - 1)
        items = list(get_item(data[i:], ctx_len))

        random.shuffle(items)
        iter_items = iter(items)
    else:
        iter_items = get_item(data, ctx_len)

    try:
        while True:
            X, y = zip(*[next(iter_items) for _ in range(batch_size)])
            yield torch.stack(X), torch.stack(y)
    except StopIteration:
        pass


for batch in get_batch(train_data[:100], ctx_len=5, batch_size=2):
    print(batch)

(tensor[2, 5] i64 n=10 x∈[1, 58] μ=36.300 σ=22.346 cuda:0 [[1, 46, 43, 39, 56], [57, 58, 1, 15, 47]], tensor[2, 5] i64 n=10 x∈[1, 58] μ=36.400 σ=22.451 cuda:0 [[46, 43, 39, 56, 1], [58, 1, 15, 47, 58]])
(tensor[2, 5] i64 n=10 x∈[0, 54] μ=33.100 σ=20.058 cuda:0 [[10, 0, 31, 54, 43], [54, 43, 39, 49, 8]], tensor[2, 5] i64 n=10 x∈[0, 54] μ=30.600 σ=20.326 cuda:0 [[0, 31, 54, 43, 39], [43, 39, 49, 8, 0]])
(tensor[2, 5] i64 n=10 x∈[0, 64] μ=39.200 σ=25.616 cuda:0 [[63, 1, 44, 59, 56], [64, 43, 52, 10, 0]], tensor[2, 5] i64 n=10 x∈[0, 59] μ=33.700 σ=24.518 cuda:0 [[1, 44, 59, 56, 58], [43, 52, 10, 0, 14]])
(tensor[2, 5] i64 n=10 x∈[6, 58] μ=41.900 σ=17.860 cuda:0 [[14, 43, 44, 53, 56], [58, 46, 43, 56, 6]], tensor[2, 5] i64 n=10 x∈[1, 56] μ=39.100 σ=19.519 cuda:0 [[43, 44, 53, 56, 43], [46, 43, 56, 6, 1]])
(tensor[2, 5] i64 n=10 x∈[1, 57] μ=30.500 σ=24.959 cuda:0 [[1, 51, 43, 1, 57], [39, 49, 6, 1, 57]], tensor[2, 5] i64 n=10 x∈[1, 57] μ=37.300 σ=24.281 cuda:0 [[51, 43, 1, 57, 54], [49, 6, 1

In [6]:
# batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=32)
# X, y = next(batch_gen)
# with torch.no_grad():
#     preds = model(model_params, X, vocab_size=len(chars))
#     loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
# loss

In [7]:
tt = torch.randn(10, 10)
tt

tensor[10, 10] n=100 x∈[-2.417, 3.082] μ=-0.049 σ=1.094

In [8]:
torch.nn.init.xavier_uniform_(tt)

tensor[10, 10] n=100 x∈[-0.545, 0.531] μ=-0.032 σ=0.311

In [9]:
tt

tensor[10, 10] n=100 x∈[-0.545, 0.531] μ=-0.032 σ=0.311

In [10]:
l = torch.nn.Linear(10, 10)
l.weight

Parameter containing:
Parameter[10, 10] n=100 x∈[-0.313, 0.309] μ=-0.013 σ=0.196 grad

In [11]:


# in_ # BATCH_SIZE x CTX_LEN x len(chars)  (feed in a batch of CTX_LEN embeddings (each is a one-hot-encoded character)
# out_ # BATCH_SIZE x len(chars)  (get out a (one-hot-encoded) next-char prediction for each batch-item
# l1 = torch.nn.Linear(len(chars) * CTX_LEN, len(chars)).to(device)

# model_params = {
#     "embedding": torch.randn((len(chars), EMBEDDING_DIM), device=device, requires_grad = True),
#     # "w": torch.randn((EMBEDDING_DIM * CTX_LEN, len(chars)), device=device, requires_grad = True),
# }


def model(params, input_ids, vocab_size):
    """This model takes in a sequence and predicts 1 token"""

    one_hot_inputs = torch.nn.functional.one_hot(input_ids, num_classes=vocab_size)
    one_hot_inputs = one_hot_inputs.float()

    embeddings = one_hot_inputs @ params["embedding"].T  # N, CTX_LEN, EMBEDDING_DIM

    # preds = hidden_states[:, -1, :] # @ params["w"]

    hidden_state = (
        embeddings.view((input_ids.shape[0], -1)) @ params["w1"].T + params["b1"]
    )
    hidden_state = torch.nn.functional.relu(hidden_state)

    hidden_state = (
        hidden_state.view((input_ids.shape[0], -1)) @ params["w2"].T + params["b2"]
    )
    hidden_state = torch.nn.functional.relu(hidden_state)

    preds = hidden_state @ params["embedding"]

    return preds


# batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=32)
# X, y = next(batch_gen)
# with torch.no_grad():
#     preds = model(model_params, X, vocab_size=len(chars))
#     loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
# loss

In [19]:
CTX_LEN = 32

EMBEDDING_DIM = 128

INTERMEDIATE_DIM = EMBEDDING_DIM * 8

BATCH_SIZE = 4096
LR = 0.1
LOG_INTERVAL = len(train_data) // BATCH_SIZE // 10
VALIDATION_INTERVAL = len(train_data) // BATCH_SIZE // 5


TRAIN_TOKENS = len(train_data) * 10


wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    entity="llmnerds",
    config={
        "batch_size": BATCH_SIZE,
        "ctx": CTX_LEN,
    },
)

model_params = {
    "embedding": torch.randn((EMBEDDING_DIM, len(chars)), device=device),
    "w1": torch.randn((INTERMEDIATE_DIM, EMBEDDING_DIM * CTX_LEN), device=device),
    "b1": torch.randn((INTERMEDIATE_DIM,), device=device),
    "w2": torch.randn((EMBEDDING_DIM, INTERMEDIATE_DIM), device=device),
    "b2": torch.randn((EMBEDDING_DIM,), device=device),
    # "classifier": torch.randn(
    #     (INTERMEDIATE_DIM, len(chars)), device=device, requires_grad=True
    # ),
}

# # glorot init
for p in model_params.values():
    if len(p.shape) == 2:
        torch.nn.init.kaiming_normal_(p)
    p.requires_grad = True


i = 1
total_loss = 0
val_total_loss = 0


optim = torch.optim.Adam(model_params.values(), lr=1e-3)

batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE, shuffle=True)
while True: #i < TRAIN_TOKENS:
    try:
        X, y = next(batch_gen)
    except StopIteration:
        batch_gen = get_batch(
            train_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE, shuffle=True
        )
        Y, y = next(batch_gen)

    token_count = i * BATCH_SIZE * CTX_LEN

    preds = model(params=model_params, input_ids=X, vocab_size=len(chars))
    loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
    total_loss += loss.item()
    loss.backward()

    with torch.no_grad():
        optim.step()
        optim.zero_grad()

        # for param in model_params.values():
        #     param -= LR * param.grad
        #     param.grad.zero_()


    if i % LOG_INTERVAL == 0:
        wandb.log(
            {"loss": total_loss / LOG_INTERVAL, "epoch": (token_count) // len(train_data)},
            step=token_count,
        )
        total_loss = 0

    if i % VALIDATION_INTERVAL == 0:
        j = 0
        for X_val, y_val in get_batch(
            val_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE, shuffle=False
        ):
            with torch.no_grad():
                preds = model(params=model_params, input_ids=X_val, vocab_size=len(chars))
                loss = torch.nn.functional.cross_entropy(preds, y_val[:, -1])
                val_total_loss += loss.item()
                j += 1
        wandb.log({"val_loss": val_total_loss / j}, step=token_count)
        val_total_loss = 0
    i += 1

wandb.finish()



VBox(children=(Label(value='0.010 MB of 0.033 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.319777…

0,1
epoch,▁█
loss,█▁

0,1
epoch,6.0
loss,2.97865


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112229777791072, max=1.0…

KeyboardInterrupt: 