In [20]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from tqdm import tqdm 
from pathlib import Path

In [21]:
# Hyperparams 
device = torch.device("cpu")
block_size = 64
batch_size = 12


In [22]:
# Data loader 
data_path = Path("data/tiny.txt")
assert data_path.exists(), "Create data/text.txt with some text"
text = data_path.read_text(encoding="utf-8")
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}  #This builds a dictionary (mapping) from each character (ch) to a number (i).
itos = {i:ch for ch,i in stoi.items()} # This does the reverse of stoi: creates a dictionary that maps each number back to a character.

In [23]:
def encode(s):
    """
    This function takes in a string as an input and returns its numerical representation based on the stoi mappings 
    """
    return [stoi[c] for c in s]

def decode(l): 
    """
    Takes in a list of numbers and returns a string based in itos mappings 
    """
    return "".join(itos[i] for i in l)

In [24]:
pairs = encode(text)

pairs[:20]

[25, 1, 29, 22, 33, 21, 28, 1, 35, 21, 28, 33, 1, 42, 45, 55, 44, 45, 50, 43]

In [26]:
decode_pairs = decode(pairs)

decode_pairs[:20]

'I OFTEN WENT fishing'

In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
n = len(data)
train_data = data[: int(0.9 * n)]
val_data = data[int(0.9 * n):]

In [28]:
def get_batch(split):
    src = train_data if split == "train" else val_data
    ix = torch.randint(len(src) - block_size, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

In [29]:
batches = get_batch(split="train")

batches[1]

tensor([[40,  1, 55, 51, 49, 41, 56, 44, 45, 50, 43,  8,  1, 24, 51, 59,  1, 37,
         38, 51, 57, 56,  1, 56, 44, 41,  1, 56, 45, 49, 41,  1, 61, 51, 57,  1,
         43, 37, 58, 41,  1, 37,  0,  1, 48, 37, 54, 43, 41,  1, 39, 51, 50, 56,
         54, 45, 38, 57, 56, 45, 51, 50,  1, 56],
        [ 1, 61, 51, 57,  1, 59, 41, 54, 41,  1, 38, 51, 54, 50,  1, 59, 37, 55,
          0,  1, 52, 41, 54, 42, 51, 54, 49, 41, 40,  1, 38, 41, 39, 37, 57, 55,
         41,  1, 61, 51, 57,  1, 59, 37, 50, 56, 41, 40,  1, 55, 51, 49, 41, 56,
         44, 45, 50, 43,  8,  1, 24, 51, 59,  1],
        [54, 56, 44,  1, 37,  1, 48, 51, 56,  1, 56, 51,  1, 37,  1, 44, 51, 56,
         41, 48,  6,  1, 45, 55, 50, 65, 56,  1, 45, 56, 16, 65,  0,  1, 17, 55,
          1, 25,  1, 56, 37, 48, 47, 41, 40,  6,  1, 25,  1, 59, 54, 51, 56, 41,
          1, 56, 44, 41, 55, 41,  1, 56, 59, 51],
        [51,  1, 48, 51, 51, 47,  1, 37, 56,  1, 61, 51, 57, 54,  1, 44, 51, 56,
         41, 48,  1, 37, 55,  1, 25,  0,