In [1]:
%load_ext autoreload

In [2]:
from lovely_tensors.patch import monkey_patch

monkey_patch()
import torch
from transformers import GPT2Tokenizer
import wandb
from tqdm.auto import tqdm

In [3]:
with open("tiny_shakespeare.txt", "r") as f:
    data = f.read()
chars = sorted(list(set(data)))

# create a mapping from characters to integers
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}


def encode(s):
    return [stoi[c] for c in s]  # encoder: take a string, output a list of integers


def decode(l):
    return "".join(
        [itos[i] for i in l]
    )  # decoder: take a list of integers, output a string


encoded_data = encode(data)

In [4]:
train_data = encoded_data[: int(len(encoded_data) * 0.8)]
val_data = encoded_data[int(len(encoded_data) * 0.8) :]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data = torch.tensor(train_data).to(device)
val_data = torch.tensor(val_data).to(device)

In [23]:
def get_item(data, ctx):
    # i = random.randint(0, len(data) - ctx - 1)
    i = 0
    while i + ctx < len(data):
        src = data[i : i + ctx]
        dst = data[i + 1 : i + ctx + 1]
        yield src, dst
        i += ctx

import random
def get_random_item(data, ctx):
    i = random.randint(0, len(data) - ctx - 1)
    src = data[i : i + ctx]
    dst = data[i + 1 : i + ctx + 1]
    yield src, dst

def get_batch(data, ctx_len, batch_size):
    """Yields a tuple of tensors of shape (batch_size, ctx).
    X, shape=B C
    y, shape=B C
    """
    gen = get_random_item(data, ctx_len)

    try:
        while True:
            X, y = zip(*[next(gen) for _ in range(batch_size)])
            yield torch.stack(X), torch.stack(y)
    except StopIteration:
        pass


# for batch in get_batch(train_data[:100], ctx_len=5, batch_size=2):
#     print(batch)

In [6]:
# batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=32)
# X, y = next(batch_gen)
# with torch.no_grad():
#     preds = model(model_params, X, vocab_size=len(chars))
#     loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
# loss

In [7]:
tt = torch.randn(10,10)
tt

tensor[10, 10] n=100 x∈[-3.179, 2.142] μ=-0.068 σ=0.929

In [8]:
torch.nn.init.xavier_uniform_(tt)

tensor[10, 10] n=100 x∈[-0.540, 0.543] μ=0.049 σ=0.296

In [9]:
tt

tensor[10, 10] n=100 x∈[-0.540, 0.543] μ=0.049 σ=0.296

In [10]:
l = torch.nn.Linear(10, 10)
l.weight

Parameter containing:
Parameter[10, 10] n=100 x∈[-0.300, 0.315] μ=-0.009 σ=0.177 grad

In [21]:
CTX_LEN = 1
EMBEDDING_DIM = len(chars)


# in_ # BATCH_SIZE x CTX_LEN x len(chars)  (feed in a batch of CTX_LEN embeddings (each is a one-hot-encoded character)
# out_ # BATCH_SIZE x len(chars)  (get out a (one-hot-encoded) next-char prediction for each batch-item
# l1 = torch.nn.Linear(len(chars) * CTX_LEN, len(chars)).to(device)

# model_params = {
#     "embedding": torch.randn((len(chars), EMBEDDING_DIM), device=device, requires_grad = True),
#     # "w": torch.randn((EMBEDDING_DIM * CTX_LEN, len(chars)), device=device, requires_grad = True),
# }

# # glorot init
# for p in model_params.values():
#     # torch.nn.init.kaiming_uniform_(p)
    


def model(params, input_ids, vocab_size):
    """This model takes in a sequence and predicts 1 token"""

    one_hot_inputs = torch.nn.functional.one_hot(input_ids, num_classes=vocab_size)
    one_hot_inputs = one_hot_inputs.float()

    hidden_states = one_hot_inputs @ params["embedding"] # N, CTX_LEN, EMBEDDING_DIM
    
    # preds = hidden_states[:, -1, :] # @ params["w"]

    preds = hidden_states.view((input_ids.shape[0], -1)) @ params["w"]

    return preds


# batch_gen = get_batch(train_data, ctx_len=CTX_LEN, batch_size=32)
# X, y = next(batch_gen)
# with torch.no_grad():
#     preds = model(model_params, X, vocab_size=len(chars))
#     loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
# loss

In [24]:
BATCH_SIZE = 32
NEPOCH = 10
LR = 0.1
LOG_INTERVAL = 500
VALIDATION_INTERVAL = len(train_data) // BATCH_SIZE // 5


wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    entity="llmnerds",
    config={
        "batch_size": BATCH_SIZE,
        "ctx": CTX_LEN,
    },
)

model_params = {
    "embedding": torch.randn((len(chars), EMBEDDING_DIM), device=device, requires_grad = True),
    "w": torch.randn((EMBEDDING_DIM * CTX_LEN, len(chars)), device=device, requires_grad = True),
}

i = 1
total_loss = 0
val_total_loss = 0
for epoch in range(NEPOCH):
    # shuffle train data
    # train_data = train_data[torch.randperm(train_data.shape[0])]

    train_generator = get_batch(train_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE)

    for X, y in tqdm(train_generator):
        preds = model(params=model_params, input_ids=X, vocab_size=len(chars))
        loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
        total_loss += loss.item()
        loss.backward()

        with torch.no_grad():
            for param in model_params.values():
                param -= LR * param.grad
                param.grad.zero_()

        if i % LOG_INTERVAL == 0:
            wandb.log({"loss": total_loss / LOG_INTERVAL}, step=i)
            total_loss = 0

        if i % VALIDATION_INTERVAL == 0:
            j = 0
            for X, y in get_batch(val_data, ctx_len=CTX_LEN, batch_size=BATCH_SIZE * 4):
                preds = model(params=model_params, input_ids=X, vocab_size=len(chars))
                loss = torch.nn.functional.cross_entropy(preds, y[:, -1])
                val_total_loss += loss.item()
                j += 1
            wandb.log({"val_loss": val_total_loss / j}, step=i)
            val_total_loss = 0
        i += 1

wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112805566659114, max=1.0…

0it [00:00, ?it/s]

TypeError: 'tuple' object is not an iterator