In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
block_size = 8
batch_size = 4
max_iters = 1000
learning_rate = 3e-4
eval_iters = 250

cuda


In [2]:
with open('wizard_of_oz.txt','r', encoding='utf-8') as f:
    text  = f.read()
print(text[:200])

﻿DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC. NEW Y


In [3]:
chars = sorted(set(text))
print(chars)
print(len(chars))
vocab_size = len(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']
81


In [4]:
string_to_int = {ch:i for i, ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

encoded_hello = encode('hello')
decoded_hello = decode(encoded_hello)
print(decoded_hello)

hello


In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([80, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47,
        33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,  0,
         0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,  0,
         1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36, 25,
        38, 28,  1, 39, 30,  1, 39, 50,  9,  1])


In [6]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    #print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device), y.to(device)
    return x,y

x,y = get_batch('train')
print('inputs:')
print(x)
print('targets:')
print(y)

inputs:
tensor([[72, 62, 57, 58,  1, 68, 59,  1],
        [62, 67, 60,  1, 62, 67,  1, 59],
        [73, 74, 67, 58,  1, 73, 61, 58],
        [ 1, 56, 68, 67, 72, 73, 54, 67]], device='cuda:0')
targets:
tensor([[62, 57, 58,  1, 68, 59,  1, 62],
        [67, 60,  1, 62, 67,  1, 59, 71],
        [74, 67, 58,  1, 73, 61, 58,  1],
        [56, 68, 67, 72, 73, 54, 67, 73]], device='cuda:0')


In [7]:

x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print('when input is', context, 'target is', target)

when input is tensor([80]) target is tensor(28)
when input is tensor([80, 28]) target is tensor(39)
when input is tensor([80, 28, 39]) target is tensor(42)
when input is tensor([80, 28, 39, 42]) target is tensor(39)
when input is tensor([80, 28, 39, 42, 39]) target is tensor(44)
when input is tensor([80, 28, 39, 42, 39, 44]) target is tensor(32)
when input is tensor([80, 28, 39, 42, 39, 44, 32]) target is tensor(49)
when input is tensor([80, 28, 39, 42, 39, 44, 32, 49]) target is tensor(1)


In [13]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out
        

In [21]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index, targets=None):
        logits = self.token_embedding_table(index)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self.forward(index)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            index_next = torch.multinomial(probs, num_samples=1)
            index = torch.cat((index, index_next), dim=1)
        return index
model = BigramLanguageModel(vocab_size)
m = model.to(device)
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars=decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)
        


Sw)5jtd
kymio.;RIOg!jHKXvhGn)sYn"38iA;2[970K:s"KzD h]-ZH"kLn_iHf5i﻿82O
",m2uEdnIpdqEc1(r071LXDT3wN﻿-.;35!d
GtMQsO
PHV]pyu'zDDJS6z5kx*pd*qE[VD["g['kZV9[bhxqg]
:
EVD"Mky"QlKPE*?)!Feltou0[D"CYLzHhG5FtZ
&oR&﻿A3y(vL_fzH﻿bLi,?'hz?a W4Bx0sT,W4826Mxx O
ZI7*Zqf862LeN0vV?hb)2LOI!a?'Qp5oV5kzjxi69VygK:K8Xg.,JEDKzJ;_f[:_);0A5Sf6);DN]V-4i"gFxZ4G0j_i-k"'kEe4k7F﻿zLYZSzhFN0giHRqc73bh416bFLk(;Iy]y*6(Nd-﻿b6!'hfverhQHwc.ODl! *L:ge7w'A.rvonV["WW&7-LO﻿P:_hBx]?N(:bFmLK:_IQmwA]Vlb6K:PRJG':?)96X)FE1y]CC45S4GtHFpEULr4Wu*


In [48]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in range(max_iters):
    if iter % eval_iters == 0:
        losses = estimate_loss()
        print(f'step: {iter}, train loss: {losses['train']:.3f}, val loss : {losses['train']:.3f}')
    xb, yb = get_batch('train')
    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

step: 0, train loss: 2.552, val loss : 2.552
step: 250, train loss: 2.543, val loss : 2.543
step: 500, train loss: 2.564, val loss : 2.564
step: 750, train loss: 2.550, val loss : 2.550
2.575653314590454


In [49]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars=decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


sliond ther."]G2xI',wa Wou


wr wite omserthe
ekI'than;
ZAR, sirina

 " segro oterindr hewaie hired
lk?"_gaury W0(Vat Dmefil'8Q0s T&6y5had hawerery, heasen Redotan.
wa jDoue whre imediveblllacoR(?tore cu coithoupid t P's f t s. ghayO"Tomend y ly'm47

" ourd hfledin s helimpurond, warucomin Ww" a, unt arswintht I fe um e
hupanng the coto and ing,!pQ9Bu6Q;.
TRkenljqgy
it t t o
" wheil t l, ot-limf chather. teskned f G2uiQ2LIn l tf8SUsin sngor ine prtormaye a war t tong the TH4SBnd toury umalesthe 
