In [22]:
%load_ext autoreload
%autoreload 2

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from tinystories.data import TinyStoriesDataset
from tinystories.model import Transformer
from sentencepiece import SentencePieceProcessor

In [8]:
model = Transformer(
    vocab_size=512,
    embed_size=128,
    num_layers=4,
    heads=8,
    dropout=0.1,
    forward_expansion=256,
    device="mps:0"
)
tokenizer = SentencePieceProcessor(model_file="../tinystories_tokenizer.model")



In [9]:
model.to("mps:0")



Transformer(
  (embedding): Embedding(512, 128)
  (pos_embedding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attention): MultiHeadAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (key): Linear(in_features=128, out_features=128, bias=True)
        (value): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (out): Linear(in_features=128, out_features=128, bias=True)
      )
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (feed_forward): Sequential(
        (0): Linear(in_features=128, out_features=32768, bias=True)
        (1): ReLU()
        (2): Linear(in_features=32768, out_features=128, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [10]:
dataset = TinyStoriesDataset(
    tokenizer=tokenizer,
    max_len=512,
    file_path="../data/TinyStories-train.txt"
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [15]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [17]:
for batch_idx, batch in enumerate(dataloader):
    input_tokens, target_tokens = batch[:, :-1], batch[:, 1:]
    input_tokens = input_tokens.to("mps:0")
    target_tokens = target_tokens.to("mps:0")

    out = model(input_tokens)
    loss = F.cross_entropy(out.contiguous().view(-1, out.size(-1)), target_tokens.contiguous().view(-1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(loss.item())



torch.Size([32, 511])
torch.Size([32, 511])
torch.Size([32, 511, 128])


RuntimeError: Expected target size [32, 128], got [32, 511]