In [2]:
data_file = "data/input.txt"
with open(data_file) as f:
    text = f.read()

In [3]:
vocab = set()
for c in text:
    vocab.add(c)
print(vocab)
vocab_size = len(vocab)

{'q', 'F', 'J', 'G', 't', 'P', 'L', 'r', '.', '&', '3', 'p', 'h', 'R', 'S', 's', 'U', 'u', 'n', '\n', 'a', 'k', 'E', 'V', '!', '?', 'v', 'N', 'd', 'T', 'C', 'O', 'x', ' ', 'f', 'b', 'o', 'Y', 'g', 'i', 'j', 'B', 'M', 'H', 'w', ',', 'e', 'z', 'I', 'Z', ':', 'A', 'l', "'", '-', 'c', 'm', 'y', 'D', 'Q', 'W', 'X', '$', ';', 'K'}


In [4]:
stoi = {c:i for i, c in enumerate(vocab)}
itos = {i:c for i, c in enumerate(vocab)}
def encode(s): return [stoi[c] for c in s]
def decode(a): return ''.join([itos[i] for i in a])
    

In [5]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
data = torch.tensor(encode(text), dtype=torch.long)

In [7]:
data.shape

torch.Size([1115393])

In [8]:
split = int(0.9*len(data))
train_data = data[:split]
val_data = data[split:]

In [9]:
block_size=8
train_data[:block_size]

tensor([ 1, 39,  7, 15,  4, 33, 30, 39])

In [10]:
import torch.nn as nn
from torch.nn import functional as F

In [11]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss =  F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx,idx_next), dim=1)
        return idx
            
        

In [12]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[57, 15, 33,  4, 12, 46, 33, 36],
        [57, 36, 17, 33, 44, 46, 28, 28],
        [ 7, 39, 38, 12,  4, 34, 17, 52],
        [53, 33,  4, 12, 46, 33,  4, 39]])
targets:
torch.Size([4, 8])
tensor([[15, 33,  4, 12, 46, 33, 36,  4],
        [36, 17, 33, 44, 46, 28, 28, 46],
        [39, 38, 12,  4, 34, 17, 52, 33],
        [33,  4, 12, 46, 33,  4, 39, 56]])


In [24]:
torch.randint(len(data) - block_size, (batch_size,))

tensor([ 760459,  945470,  729817,  302382,  985241,  925071,  258957,  573221,
         765570,  806973,   60416,  630653,  413812,   33929, 1016456,  150824,
         439286,  787981,  302507,  460308,  404970,  281896,  270697,  772112,
         490580,  626503,  158969,  890095, 1000695, 1043176,  392158,  329570])

In [15]:
m = BigramLanguageModel(vocab_size)
m = m.to(device)
data = data.to(device)
train_data = train_data.to(device)
val_data = val_data.to(device)

In [17]:
decode(m.generate(torch.zeros((1,1),dtype=torch.long).to(device), max_new_tokens=100)[0].tolist())

"qyr-vkdhRl\nNSCmKrS.XB!AGd:hy'EEjtYvzVQSoB't-KyugWGVG.Y,Z!'TwvxXCK:SYQlYKACfqG!Zw-sl;dUKgihYalp:nU$CJV"

In [18]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [20]:
batch_size=32
for steps in range(10000):
    xb, yb = get_batch('train')
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.462007761001587


In [21]:
decode(m.generate(torch.zeros((1,1),dtype=torch.long).to(device), max_new_tokens=100)[0].tolist())

'qude pe.\nBERO:\n\nWr.\nGomefouchas frd a-shendour w,\nSThero temyou nild.\nINERCIZAld deasou y be gewack, '