In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt


--2023-10-09 06:59:50--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-10-09 06:59:51 (16.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
with open('input.txt', 'r') as f:
    text = f.read()

In [3]:
len(text)

1115394

In [4]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [6]:
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [12]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

In [13]:
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[s] for s in l])

In [14]:
print(encode('Hello there. General Kenobi!!'))

[20, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43, 8, 1, 19, 43, 52, 43, 56, 39, 50, 1, 23, 43, 52, 53, 40, 47, 2, 2]


In [15]:
print(decode([20, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43, 8, 1, 19, 43, 52, 43, 56, 39, 50, 1, 23, 43, 52, 53, 40, 47, 2, 2]))

Hello there. General Kenobi!!


In [16]:
import torch

In [17]:
data = torch.tensor(encode(text), dtype=torch.long)

In [18]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [19]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [20]:
x = train_data[:block_size+1]
y = train_data[1:block_size+1]
x, y

(tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]),
 tensor([47, 56, 57, 58,  1, 15, 47, 58]))

In [21]:
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print('ctx', context, 'target', target)

ctx tensor([18]) target tensor(47)
ctx tensor([18, 47]) target tensor(56)
ctx tensor([18, 47, 56]) target tensor(57)
ctx tensor([18, 47, 56, 57]) target tensor(58)
ctx tensor([18, 47, 56, 57, 58]) target tensor(1)
ctx tensor([18, 47, 56, 57, 58,  1]) target tensor(15)
ctx tensor([18, 47, 56, 57, 58,  1, 15]) target tensor(47)
ctx tensor([18, 47, 56, 57, 58,  1, 15, 47]) target tensor(58)


In [23]:
import torch.nn as nn
import torch.nn.functional as F

In [24]:
batch_size = 4
block_size = 8

In [25]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [26]:
xb, yb = get_batch('train')

In [29]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(context, '-----', target)

tensor([57]) ----- tensor(1)
tensor([57,  1]) ----- tensor(61)
tensor([57,  1, 61]) ----- tensor(43)
tensor([57,  1, 61, 43]) ----- tensor(50)
tensor([57,  1, 61, 43, 50]) ----- tensor(50)
tensor([57,  1, 61, 43, 50, 50]) ----- tensor(11)
tensor([57,  1, 61, 43, 50, 50, 11]) ----- tensor(0)
tensor([57,  1, 61, 43, 50, 50, 11,  0]) ----- tensor(32)
tensor([27]) ----- tensor(10)
tensor([27, 10]) ----- tensor(0)
tensor([27, 10,  0]) ----- tensor(21)
tensor([27, 10,  0, 21]) ----- tensor(44)
tensor([27, 10,  0, 21, 44]) ----- tensor(1)
tensor([27, 10,  0, 21, 44,  1]) ----- tensor(50)
tensor([27, 10,  0, 21, 44,  1, 50]) ----- tensor(53)
tensor([27, 10,  0, 21, 44,  1, 50, 53]) ----- tensor(60)
tensor([42]) ----- tensor(10)
tensor([42, 10]) ----- tensor(0)
tensor([42, 10,  0]) ----- tensor(20)
tensor([42, 10,  0, 20]) ----- tensor(53)
tensor([42, 10,  0, 20, 53]) ----- tensor(61)
tensor([42, 10,  0, 20, 53, 61]) ----- tensor(1)
tensor([42, 10,  0, 20, 53, 61,  1]) ----- tensor(47)
tensor([

In [30]:
def get_decay_matrix(dim, gamma):
    d = torch.ones(dim)
    d = torch.tril(d) # lower triangular matrix

    for index, head in enumerate(d):
        g = gamma[index]
        for idx, x in enumerate(torch.tril(head)):
            for idy, y in enumerate(x):
                if idx >= idy:
                    head[idx][idy] = g ** (idx-idy)
    return d

In [31]:
!pip install -q einops

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [32]:
import einops
from einops import rearrange, reduce, repeat

In [54]:
class ChunkwiseRetention(nn.Module):
    def __init__(self, chunk_size, num_head, block_size):
        super().__init__()
        self.key = nn.Linear(n_embed, chunk_size*num_head, bias=False)
        self.query = nn.Linear(n_embed, chunk_size*num_head, bias=False)
        self.value = nn.Linear(n_embed, chunk_size*num_head, bias=False)
        self.gamma = 1.0-2.0**(-5-torch.arange(0, num_head))
        self.decay_mask = get_decay_matrix((num_head, block_size, block_size), self.gamma)
        self.chunk_decay = self.gamma
        self.gn = nn.GroupNorm(1, num_head)
        self.num_head = num_head
        self.chunk_size = chunk_size

    def forward(self, x, past_kv):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        k = rearrange(k, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c=self.chunk_size)
        q = rearrange(q, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c=self.chunk_size)
        v = rearrange(v, ('b t (h c) -> b h t c'), t=T, h=self.num_head, c=self.chunk_size)

        retention = q @ k.transpose(-1, -2)
        retention = retention * self.decay_mask
        inner_retention = retention @ v

        past_kv = repeat(past_kv, 'n q v -> B n q v', B=B)
        pb, pn, pq, pv = past_kv.shape

        padding = torch.zeros(pb, pn, pq, self.chunk_size)
        past_kv = past_kv + padding

        dm = repeat(self.decay_mask, 'h c d -> B h c d', B=B)
        pp = q @ past_kv
        cross_retention = pp.transpose(-1, -2) @ dm
        cross_retention = cross_retention.transpose(-1, -2)

        retention = inner_retention + cross_retention

        current_kv = self.gamma.view(self.num_head, 1, 1) * past_kv + (k.transpose(-1, -2) @ v)
        output = self.gn(retention.transpose(-1, -2))
        output = rearrange(output, 'b c h t -> b t (c h)')
        return output, current_kv.mean(dim=0)

In [55]:
class GatedMultiScaleRetention(nn.Module):
    def __init__(self, chunk_size, num_head, block_size):
        super().__init__()
        self.wg = nn.Linear(n_embed, n_embed, bias=False)
        self.act = nn.SiLU()
        self.y = ChunkwiseRetention(num_head=n_head, chunk_size=n_embed//n_head, block_size=block_size)
        self.wo = nn.Linear(n_embed, n_embed, bias=False)
        self.past = torch.zeros(num_head, chunk_size, chunk_size)

    def forward(self, x):
        wgx = self.wg(x)
        wgx = self.act(wgx)
        y, past = self.y(wgx, self.past)
        y = wgx * y
        return self.wo(y)

In [56]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4*n_embed),
            nn.GELU(),
            nn.Linear(4*n_embed, n_embed),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [57]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head, block_size):
        super().__init__()
        self.sa_head = GatedMultiScaleRetention(num_head=n_head, chunk_size=n_embed//n_head, block_size=block_size)
        self.ffw = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x + self.ffw(self.ln2(x))
        return x

In [58]:
class RetNet(nn.Module):
    def __init__(self, block_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, block_size=block_size) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            b, s = idx.shape
            bk = min(s, block_size)
            idx_cond = torch.cat((torch.zeros(b, block_size-bk, dtype=int), idx), dim=1)[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [68]:
batch_size = 16
block_size = 32
max_iters = 10000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_head = 4
n_layer = 4
dropout = 0.0
n_embed = 32

In [69]:
def get_batch(split, batch_size):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data)-block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [70]:
model = RetNet(block_size=block_size)
xb, yb = get_batch('train', batch_size=batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
logits, loss = model(xb, yb)

In [71]:
model

RetNet(
  (token_embedding_table): Embedding(65, 32)
  (position_embedding_table): Embedding(32, 32)
  (blocks): Sequential(
    (0): Block(
      (sa_head): GatedMultiScaleRetention(
        (wg): Linear(in_features=32, out_features=32, bias=False)
        (act): SiLU()
        (y): ChunkwiseRetention(
          (key): Linear(in_features=32, out_features=32, bias=False)
          (query): Linear(in_features=32, out_features=32, bias=False)
          (value): Linear(in_features=32, out_features=32, bias=False)
          (gn): GroupNorm(1, 4, eps=1e-05, affine=True)
        )
        (wo): Linear(in_features=32, out_features=32, bias=False)
      )
      (ffw): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=32, out_features=128, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=128, out_features=32, bias=True)
          (3): Dropout(p=0.0, inplace=False)
        )
      )
      (ln1): LayerNorm((32,), eps=1e-05, elementwise_a

In [72]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size=batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [73]:
xb, yb = get_batch('train', batch_size=batch_size)
logits, loss = model(xb, yb)

In [74]:
for iter in range(max_iters):
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"Step {iter}: Train Loss {losses['train']:.4f}, Val Loss {losses['val']:.4f}")

    xb, yb = get_batch('train', batch_size=batch_size)

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

Step 0: Train Loss 4.6205, Val Loss 4.6019
Step 100: Train Loss 2.7909, Val Loss 2.8053
Step 200: Train Loss 2.6294, Val Loss 2.6395
Step 300: Train Loss 2.5493, Val Loss 2.5615
Step 400: Train Loss 2.4626, Val Loss 2.4821
Step 500: Train Loss 2.4131, Val Loss 2.4181
Step 600: Train Loss 2.4052, Val Loss 2.4017
Step 700: Train Loss 2.3533, Val Loss 2.3663
Step 800: Train Loss 2.3231, Val Loss 2.3342
Step 900: Train Loss 2.3049, Val Loss 2.3162
Step 1000: Train Loss 2.2801, Val Loss 2.3030
Step 1100: Train Loss 2.2589, Val Loss 2.2668
Step 1200: Train Loss 2.2549, Val Loss 2.2736
Step 1300: Train Loss 2.2339, Val Loss 2.2449
Step 1400: Train Loss 2.2092, Val Loss 2.2348
Step 1500: Train Loss 2.1972, Val Loss 2.2160
Step 1600: Train Loss 2.1869, Val Loss 2.2161
Step 1700: Train Loss 2.1753, Val Loss 2.1883
Step 1800: Train Loss 2.1753, Val Loss 2.1905
Step 1900: Train Loss 2.1339, Val Loss 2.1647
Step 2000: Train Loss 2.1128, Val Loss 2.1374
Step 2100: Train Loss 2.1073, Val Loss 2.1290


In [75]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=5000)[0].tolist()))


COROLORIOLLIOLCEOLLOT:

Win mair withir rhol hin frthe feye ractor friust trann my yod ourllesihe.

IOLABR bun thile:
Fiver theas y lild.
Clin miGors?
IONTIAR:
Gor arucio netat;
My wadeas'st drich staive wey tregel why met. it dicuny an:
Manggst aptem desple bain; fag bibd ufiur condesith ariece.
TI'd BUSRY AI'CUTIL:
ClaM,l that why herdo mer A fonclld,
An he sheinw, se pronod worgel thendiod, Is the bealld:
Toull'd angrr os hal haTr theach wend whend myr' isis croucure,
Asie thak brond, nourd theis domy te he dieest:
cor deave; Maut Ca heece bomo? NCOFoutt Mam thy hised,

CAuy,l:
The pand kavy then Buris Wallon ODf Criver inTRingsts oth womyserd:
wiveng lvit dit, the rincit Ad bulibututous.
My byour alerd mese tear caing'd ond Tyords dind yey stwounghs,d erer,
Chepler id aus calsts iilt he an pof thele le, Cay isend thie st Mquth, And sethilt to
Comme:
Thif thime lourve'd thow so yourt mur histh ith te dourt
Saclly:
Bucke de ot kin youRd.
PriThich: sthan at esas birpplectt.' bano tha