In [2]:
import torch
from torch import nn
from torch.nn import functional as F

In [82]:
# batch size / parallel random contexts
batch_size = 32

# block size / length of context
block_size = 8

max_iter = 10000
eval_iter = 200

eval_interval = 300
lr = 1e-3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

In [5]:
torch.manual_seed(22)

<torch._C.Generator at 0x203572efdb0>

In [1]:
!wget "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

--2024-12-25 13:01:37--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: 'input.txt'

     0K .......... .......... .......... .......... ..........  4% 1.68M 1s
    50K .......... .......... .......... .......... ..........  9% 9.96M 0s
   100K .......... .......... .......... .......... .......... 13% 2.76M 0s
   150K .......... .......... .......... .......... .......... 18% 2.56M 0s
   200K .......... .......... .......... .......... .......... 22% 2.79M 0s
   250K .......... .......... .......... .......... .......... 27% 2.13M 0s
   300K .......... .......... .......... .......... .......... 32% 3.91M 0s
   350K .......... ..

In [6]:
with open("input.txt", 'r') as f:
    text = f.read()


In [7]:
text[:500]

"First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you know Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us kill him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be done: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor"

In [16]:
# take out unique characters and make encoder and decoder functions
vocab = set(text)
decoded = {i : x for i,x in enumerate(vocab)}
encoded = {x : i for i,x in enumerate(vocab)}

encoder = lambda text : [encoded[x] for x in text]
decoder = lambda text : ''.join([decoded[x] for x in text])
print(vocab, "\n", encoded, "\n", decoded)

{'i', '?', 'd', 'X', 'I', ' ', 'u', 'r', 'h', 'W', 'N', 'M', 's', 'm', 'k', 'b', 'j', '&', 'R', 'x', 'A', 'Y', 'z', '!', "'", 'q', '3', 'C', 'B', 'w', 'y', 'E', '$', 't', 'o', ',', 'p', 'P', 'G', 'g', 'K', '-', 'v', '\n', 'Z', ';', 'f', ':', 'n', '.', 'S', 'O', 'T', 'Q', 'F', 'V', 'H', 'L', 'D', 'c', 'U', 'l', 'e', 'J', 'a'} 
 {'i': 0, '?': 1, 'd': 2, 'X': 3, 'I': 4, ' ': 5, 'u': 6, 'r': 7, 'h': 8, 'W': 9, 'N': 10, 'M': 11, 's': 12, 'm': 13, 'k': 14, 'b': 15, 'j': 16, '&': 17, 'R': 18, 'x': 19, 'A': 20, 'Y': 21, 'z': 22, '!': 23, "'": 24, 'q': 25, '3': 26, 'C': 27, 'B': 28, 'w': 29, 'y': 30, 'E': 31, '$': 32, 't': 33, 'o': 34, ',': 35, 'p': 36, 'P': 37, 'G': 38, 'g': 39, 'K': 40, '-': 41, 'v': 42, '\n': 43, 'Z': 44, ';': 45, 'f': 46, ':': 47, 'n': 48, '.': 49, 'S': 50, 'O': 51, 'T': 52, 'Q': 53, 'F': 54, 'V': 55, 'H': 56, 'L': 57, 'D': 58, 'c': 59, 'U': 60, 'l': 61, 'e': 62, 'J': 63, 'a': 64} 
 {0: 'i', 1: '?', 2: 'd', 3: 'X', 4: 'I', 5: ' ', 6: 'u', 7: 'r', 8: 'h', 9: 'W', 10: 'N', 11

In [18]:
encoder("hello")

[8, 62, 61, 61, 34]

In [19]:
decoder([8, 62, 61, 61, 34])

'hello'

In [21]:
vocab_size = len(vocab)
vocab_size

65

In [27]:
data = torch.tensor(encoder(text))
n = int(len(data) * 0.85)

trainX = data[:n]
testX = data[n:]

In [26]:
len(data), data

(1115394, tensor([54,  0,  7,  ..., 39, 49, 43]))

In [30]:
trainX.shape, testX.shape

(torch.Size([948084]), torch.Size([167310]))

In [71]:
def get_batch(x):
    data = trainX if x == 'train' else testX
    # get random contexts starting index for given batch size parallel processing
    idx = torch.randint(len(data) - block_size, (batch_size,))

    rcx = torch.stack([data[i: i + block_size] for i in idx])
    rcy = torch.stack([data[i+1 : i + block_size + 1] for i in idx])
    
    return rcx, rcy

In [72]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iter)
        for x in range(eval_iter):
            X, y = get_batch(split)
            logits, loss = model(X, y)
            losses[x] = loss.item()
            
        out[split] = losses.mean()
    model.train()
    return out

In [77]:
class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets = None):
        logits = self.token_embedding_table(idx) # B,T,C
        # for each characters 65 features are created
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            
            probs = F.softmax(logits, dim = -1)
            next_id = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, next_id), dim = 1)
        
        return idx

In [78]:
model = BigramLanguageModel(vocab_size)
m = model.to(device)

In [83]:
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)

for iters in range(max_iter):
    if iters % 100 == 0:
        losses = estimate_loss()
        print(f"step {iters}: train loss {losses['train']:.4f}, val loss {losses['test']:.4f}")
        
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    
context = torch.zeros((1, 1), dtype = torch.long, device = device) 
print(decoder(m.generate(context, max_new_tokens=500)[0].tolist()))

step 0: train loss 3.3696, val loss 3.3881
step 100: train loss 3.3026, val loss 3.3151
step 200: train loss 3.2439, val loss 3.2550
step 300: train loss 3.1871, val loss 3.1900
step 400: train loss 3.1348, val loss 3.1481
step 500: train loss 3.0872, val loss 3.1038
step 600: train loss 3.0343, val loss 3.0567
step 700: train loss 3.0045, val loss 3.0148
step 800: train loss 2.9615, val loss 2.9797
step 900: train loss 2.9219, val loss 2.9527
step 1000: train loss 2.8902, val loss 2.9060
step 1100: train loss 2.8601, val loss 2.8789
step 1200: train loss 2.8317, val loss 2.8529
step 1300: train loss 2.8006, val loss 2.8333
step 1400: train loss 2.7860, val loss 2.8076
step 1500: train loss 2.7626, val loss 2.7720
step 1600: train loss 2.7476, val loss 2.7501
step 1700: train loss 2.7220, val loss 2.7508
step 1800: train loss 2.7126, val loss 2.7175
step 1900: train loss 2.6859, val loss 2.7053
step 2000: train loss 2.6760, val loss 2.6890
step 2100: train loss 2.6631, val loss 2.6913


In [84]:
print(decoder(m.generate(context, max_new_tokens=1500)[0].tolist()))

is winom kelorot aist thelis my kisthenig ABefane they hakeanee othitizf dy besan tous.
KI th:
ER:
INVI s, gestrno,

QCHARGl e,
Or ton o,
OMorut youtlaisstheastug aive.
N me surmoid t be seis hath he d Whactagish brantor, ba my pee.
Watry ind ke thith t G BRGlke blu P, stof her ave it IVOHYowisthichare kend hero
Gl pry sosio the.
s tw p w qulaidief t st, we,

Aneerig lld her mby,
N HELEE wn'er, bugecck:
Nove thoke
That nt foreichyovell a thituinhenthe,
I:
Ro his
DWheigomelomy chitukild thasen nshat.
INTh?
A ow.

Myorth w wheda priams. g,
Whirthers re toorss d tane ttou-
Hers n.
Forinof tem. re ppevit
LAgom ton, ve heprd fam tind VICE:
'd, d whe gady ca, anshar witrcok ed cheave B th stestrine s nt veme.

NGirksole, he hepelld n
THamil heathe nlotr?
Be birarowhenoorumenthed hit ps uo; the cus,
Cor'r hatherdin' Sol mad, moth wn surollix--llis,
I izad pealams wetey tusefeofuce wonobly, is
ORE bllsmine smyseve,
I tilon ulon tace su, s caly cet t' t bthe ttugr men,

toond, atharowothr.
By t