In [3]:
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [4]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [7]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [9]:
print("length of text:",len(text))

length of text: 1115394


In [12]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

print(vocab_size)
print(''.join(chars))

65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


##### We need to convert this text into sequence of numbers so that the model can understand.
- Let's build a naive tokenizer that assigns a number in the order of vocabulary

In [16]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[char] for char in s]
decode = lambda s:''.join([itos[i] for i in s])

print(encode("hello there"))
print(decode(encode("hello there")))

[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]
hello there


In [18]:
#tokenize input data
tokenized_text = encode(text)
import torch
data = torch.tensor(tokenized_text,dtype=torch.long)

print(data.shape,data.dtype)

torch.Size([1115394]) torch.int64


In [19]:
#split the data into train and validation
n = int(0.9*len(text))
train = data[:n]
val = data[n:]

In [27]:
torch.manual_seed(121)
batch_size = 4
block_size = 8 #context length

#get sample batch data
def get_batch(split):
    data = train if split=="train" else val
    ix = torch.randint(len(data)-block_size,(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x,y

xb,yb = get_batch("train")
print(xb.shape,yb.shape)
print(xb)
print(yb)
    

torch.Size([4, 8]) torch.Size([4, 8])
tensor([[ 1, 57, 53,  1, 40, 53, 50, 42],
        [ 1, 46, 47, 57,  1, 46, 39, 54],
        [58,  1, 53, 44,  0, 58, 46, 63],
        [ 1, 46, 43, 56,  1, 47, 52,  1]])
tensor([[57, 53,  1, 40, 53, 50, 42,  0],
        [46, 47, 57,  1, 46, 39, 54, 54],
        [ 1, 53, 44,  0, 58, 46, 63,  1],
        [46, 43, 56,  1, 47, 52,  1, 46]])


In [26]:
torch.randint(len(data)-8,(4,))

tensor([829749, 904018, 166308, 803508])

In [74]:
#let's build a simple bigram model
import torch
import torch.nn as nn   
from torch.nn import functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size,vocab_size)
    
    def forward(self,idx,targets=None):
        
        logits = self.token_embedding(idx)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            targets = targets.view(B*T)
            logits = logits.view(B*T,C)
            loss = F.cross_entropy(logits,targets)
        return logits,loss
    
    
    def generate(self, idx, max_tokens):
        # idx BxT
        
        for _ in range(max_tokens):
            logits,loss = self(idx) #logits - BxTxC
            logits = logits[:,-1,:] #take only the last prediction
            probs = F.softmax(logits,dim=1)
            
            idx_next = torch.multinomial(probs,num_samples=1)
            idx = torch.cat((idx,idx_next),dim=1)
        
        return idx            

In [83]:
m = BigramLanguageModel(vocab_size)

In [84]:
logits,loss = m(xb,yb)

In [85]:
logits.shape,loss

(torch.Size([32, 65]), tensor(4.7485, grad_fn=<NllLossBackward0>))

In [89]:
m.generate(xb[:1],max_tokens=10)[0]

tensor([ 1, 57, 53,  1, 40, 53, 50, 42, 41, 25, 49, 54,  7, 64, 48, 33, 55, 33])

In [97]:
#optimizer
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

batch_size = 32
for steps in range(10000):
    xb,yb = get_batch("train")
    optimizer.zero_grad(set_to_none=True)
    logits,loss = m(xb,yb)
    loss.backward()
    optimizer.step()
    
print(loss.item())
    

2.5418200492858887


In [100]:
print(decode(m.generate(torch.zeros((1,1),dtype=torch.long),max_tokens=300)[0].tolist()))


ach he ellk, larerdoriflewhen!
The ous maust yowe, r ard
T: malag sealy takimyo shive the.

RMESOFibr ouere m the, kirdomporoutrefathavanfllpe if m thado athe o blsm
HAThen;
IOPOMurar fup theanghe ithimeshero m:

NI
Sak, if pofof Be b inginteve s he asomys mmal ce wimanryord,
Sed eatcofl youn ne bed


In [115]:
# take average of all previous contexts

T = 8
x = torch.randn((T,T))
wei = torch.tril(torch.ones(T,T))
wei = wei/wei.sum(1,keepdims=True)

xbow = x@wei

In [116]:
xbow

tensor([[-0.6792,  0.3529,  0.4176,  0.2401,  0.1672, -0.0259,  0.0397, -0.1052],
        [-1.0425,  0.9396,  0.6231,  0.0290, -0.0215,  0.0198,  0.0591, -0.0123],
        [ 0.0340, -0.0114,  0.4688,  0.3063,  0.1214, -0.1330, -0.1206, -0.0070],
        [-2.0362, -0.9428, -0.8375, -0.4205, -0.5485, -0.2874, -0.0659, -0.0107],
        [ 0.6953, -0.3161, -0.3078, -0.1755, -0.0136,  0.1472,  0.3127,  0.0565],
        [ 0.7936, -0.0770, -0.5167, -0.4826, -0.0243,  0.0374,  0.0553,  0.0681],
        [-0.8949,  0.4318, -0.0997,  0.2917,  0.2596,  0.1146,  0.2080,  0.0103],
        [ 0.7465,  0.1750, -0.0407, -0.5355, -0.4628, -0.2273, -0.0476,  0.1060]])

In [110]:
wei.shape

torch.Size([8, 8])