In [1]:
import torch
import requests
from torch import nn
import os
torch.cuda.is_available()

True

In [2]:
input_file_path = os.path.join('input.txt')

if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    text = f.read()

print(len(text))

chars = sorted(list(set(text)))

stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

tensor = torch.tensor(encode(text[:1000]), dtype=torch.long)




1115394


In [3]:
n = 0.9 * len(text)
train_data = encode(text)[:int(n)]
val_data = encode(text)[int(n):]

print(f"train data length: {len(train_data)}, \nval data length: {len(val_data)}")

seq_length = 8

train_data[:seq_length + 1]

train data length: 1003854, 
val data length: 111540


[18, 47, 56, 57, 58, 1, 15, 47, 58]

In [4]:

# Seq length = 8
batch_size = 8
def get_batch(split, seq_length):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - seq_length, (batch_size, 1))
    context_tensor = torch.stack([torch.tensor(data[i:i+seq_length]) for i in ix])
    response_tensor = torch.stack([torch.tensor(data[i+1:i+seq_length+1]) for i in ix])
    return context_tensor, response_tensor

context, response = get_batch('train', 8)
print(context.shape, response.shape)
print(context[0], response[0])
for b in range(batch_size):
    for t in range(seq_length):
        # Remember when you index, it doesn't include your stopping point
        print(f"Context {context[b][:t + 1].tolist()}, Target {response[b][t].tolist()}")
    break # just print first batch

torch.Size([8, 8]) torch.Size([8, 8])
tensor([ 1, 61, 47, 58, 46,  1, 46, 47]) tensor([61, 47, 58, 46,  1, 46, 47, 57])
Context [1], Target 61
Context [1, 61], Target 47
Context [1, 61, 47], Target 58
Context [1, 61, 47, 58], Target 46
Context [1, 61, 47, 58, 46], Target 1
Context [1, 61, 47, 58, 46, 1], Target 46
Context [1, 61, 47, 58, 46, 1, 46], Target 47
Context [1, 61, 47, 58, 46, 1, 46, 47], Target 57


In [17]:
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)

    def forward(self, context, response=None):
        # we are guarenteed idx = vocab size 
        # since idx is based on the stoi, itos dictionaries
        logits = self.embedding(context)
        if response is None:
            return logits
        else:
            B, T, C = logits.shape
            logits = logits.reshape(B*T, C)
            response = response.reshape(B*T,)
            loss = F.cross_entropy(logits, response)
            return logits, loss 

    def generate(self, context, max_new_tokens):
        for _ in range(max_new_tokens):
            logits = self(context)
            # get only the last timestep
            logits = logits[:, -1, :]
            # convert your 65 dimensional embedding vector into
            # a probabilty distribution
            probs = F.softmax(logits, dim=-1)
            # Next token prediction
            next_token = torch.multinomial(probs, num_samples=1)
            context = torch.concat((context, next_token), dim=1)
        return context
    
bigram = BigramLanguageModel(len(stoi))
# logits, loss = bigram(context, response)
# print(loss)

In [26]:
context = torch.zeros((1, 1), dtype=torch.long)
print(decode(bigram.generate(context, max_new_tokens=100).tolist()[0]))


gb::LBQYq?CJesxt
yL3
jv,
'
-AlOmW LNrxKbrnxgjhF$Lmek'nw-zeLZSRwPPKZ;iqq 3-&tEQ'SPiR L;fAgN'StSQUFq'



In [31]:
steps = 10000
optim = torch.optim.AdamW(bigram.parameters(), lr=1e-3)

for i in range(steps):
    context, response = get_batch('train', 8)

    # Forward 
    logits, loss = bigram(context, response)
    
    # Zero the gradient out
    optim.zero_grad(set_to_none=True)

    # get gradients
    loss.backward()

    #update Parameters
    optim.step()

    if i % 500 == 0:
        print(f"loss {loss}")



loss 2.3460681438446045


loss 2.582165479660034
loss 2.3618199825286865
loss 2.372929573059082
loss 2.5345356464385986
loss 2.385035753250122
loss 2.312197208404541
loss 2.685347080230713
loss 2.241530656814575
loss 2.4689106941223145
loss 2.588299512863159
loss 2.5005948543548584
loss 2.6923959255218506
loss 2.3396263122558594
loss 2.3891985416412354
loss 2.3035621643066406
loss 2.4782638549804688
loss 2.5053746700286865
loss 2.5575125217437744
loss 2.4796996116638184


In [32]:
context = torch.zeros((1, 1), dtype=torch.long)
print(decode(bigram.generate(context, max_new_tokens=100).tolist()[0]))


D ond so worowes?
To tes brbjLayouchodive melis isorome 'le all hifowen st hath berico athou RKI nou
