In [2]:
# !pip install torch numpy transformers datasets tiktoken wandb tqdm

In [3]:
# Loading training text data
with open("nanogpt/input.txt", "r", encoding='utf-8') as f:
    text = f.read()
print(f'Total number of characters = {len(text)}')
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f'Total number of unique characters = {vocab_size}')
print(f"Characters = {''.join(chars)}")

Total number of characters = 1115394
Total number of unique characters = 65
Characters = 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [4]:
# simple tokenizer 1
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encoder1 = lambda x: [stoi[ch] for ch in x]
decoder1 = lambda x: ''.join([itos[i] for i in x])

In [5]:
print(encoder1('hello'))
print(decoder1(encoder1('hello')))

[46, 43, 50, 50, 53]
hello


In [6]:
# !pip install sentencepiece

In [8]:
# sentencepiece tokenizer 2
import sentencepiece as spm
params = ('--input=input.txt ' '--model_prefix=spm ' '--vocab_size=1000 ')
spm.SentencePieceTrainer.Train(params)
sp = spm.SentencePieceProcessor()
sp.Load('spm.model')



True

In [22]:
print(sp.EncodeAsPieces('Hello world.'))
print(sp.EncodeAsIds('Hello world.'))
print(sp.DecodeIds([151, 88, 21, 887, 6]))

['▁He', 'll', 'o', '▁world', '.']
[184, 65, 25, 427, 7]
ter Toaaking:


In [10]:
# !pip install tiktoken

In [14]:
# tiktoken tokenizer 3
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
enc.n_vocab

100277

In [15]:
print(enc.encode("hello world"))
print(enc.decode(enc.encode("hello world")))


[15339, 1917]
hello world


In [16]:
# tokenizer 1
print(encoder1('hello'))
print(decoder1(encoder1('hello')))

# tokenizer 2
print(sp.EncodeAsIds('Hello world.'))
print(sp.DecodeIds([151, 88, 21, 887, 6]))

# tokenizer 3
print(enc.encode("hello world"))
print(enc.decode(enc.encode("hello world")))

[46, 43, 50, 50, 53]
hello
[184, 65, 25, 427, 7]
ter Toaaking:
[15339, 1917]
hello world


In [22]:
import torch
data1 = torch.tensor(encoder1(text),dtype=torch.long)
data2 = torch.tensor(sp.EncodeAsIds(text),dtype=torch.long)
data3 = torch.tensor(enc.encode(text),dtype=torch.long)
print(f'Using simple tokenizer, there are {data1.shape[0]} tokens, \n Using sentencepiece tokenizer, there are {data2.shape[0]} tokens, \n Using tiktoken tokenizer, there are {data3.shape[0]} tokens')
print(f'Simple tokenizer: {data1[:50]} \n {decoder1(data1[:50].tolist())}')
print(f'Sentencepiece tokenizer: {data2[:50]} \n {sp.DecodeIds(data2[:50].tolist())}')
print(f'Tiktoken tokenizer: {data3[:50]} \n {enc.decode(data3[:50].tolist())}')

Using simple tokenizer, there are 1115394 tokens, 
 Using sentencepiece tokenizer, there are 407344 tokens, 
 Using tiktoken tokenizer, there are 301829 tokens
Simple tokenizer: tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56]) 
 First Citizen:
Before we proceed any further, hear
Sentencepiece tokenizer: tensor([298, 537,   6, 259, 280,  12,  81, 231, 107,  35, 384, 987,   3, 223,
         33, 254,   7,  52,  65,   6,  78, 358,  21,  73,   3, 254,   7, 298,
        537,   6, 205, 106,  91, 115,   5, 201, 167,   8, 752,  14, 414, 216,
         14,  50, 190, 321,  26,  52,  65,   6]) 
 First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All:
Tiktoken tokenizer: tensor([ 5451, 47317,   512, 10438,   584, 10570,   904,  4726,

In [28]:
# train-test split
data = data1 # choose tokenizer
encoder = encoder1
decoder = decoder1
#data = data2
#encoder = sp.EncodeAsIds
#decoder = sp.DecodeIds

#data = data3
#encoder = enc.encode
#decoder = enc.decode

n = int(0.8*len(data)) # train ratio
train_data = data[:n]
val_data = data[n:]

In [29]:
block_size = 8
# one example of input-target pair
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'context = {context}, target = {target}')
    print(f'        = {decoder(context.tolist())},        = {decoder([target.tolist()])}')

context = tensor([18]), target = 47
        = F,        = i
context = tensor([18, 47]), target = 56
        = Fi,        = r
context = tensor([18, 47, 56]), target = 57
        = Fir,        = s
context = tensor([18, 47, 56, 57]), target = 58
        = Firs,        = t
context = tensor([18, 47, 56, 57, 58]), target = 1
        = First,        =  
context = tensor([18, 47, 56, 57, 58,  1]), target = 15
        = First ,        = C
context = tensor([18, 47, 56, 57, 58,  1, 15]), target = 47
        = First C,        = i
context = tensor([18, 47, 56, 57, 58,  1, 15, 47]), target = 58
        = First Ci,        = t


In [37]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_batch(split, batch_size=batch_size, block_size=block_size):
    if split == 'train':
        data = train_data
    else:
        data = val_data
    idx = torch.randint(0, data.size(0) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in idx])
    y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    x = x.to(device)
    y = y.to(device)
    return x, y

In [31]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb) # our input to the transformer
print('targets:')
print(yb.shape)
print(yb)

print('----')
# for each batch, there are 8 tokens in the input and target, but actually this contains 8 examples, corresponding to 8 different time steps, each with a target token that is shifted one position to the right
for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[58, 63,  8,  0,  0, 19, 24, 27],
        [39, 59, 45, 46, 58,  1, 46, 43],
        [49, 43, 57,  1, 53, 50, 42,  1],
        [52, 41, 47, 43, 52, 58,  1, 56]])
targets:
torch.Size([4, 8])
tensor([[63,  8,  0,  0, 19, 24, 27, 33],
        [59, 45, 46, 58,  1, 46, 43,  1],
        [43, 57,  1, 53, 50, 42,  1, 46],
        [41, 47, 43, 52, 58,  1, 56, 47]])
----
when input is [58] the target: 63
when input is [58, 63] the target: 8
when input is [58, 63, 8] the target: 0
when input is [58, 63, 8, 0] the target: 0
when input is [58, 63, 8, 0, 0] the target: 19
when input is [58, 63, 8, 0, 0, 19] the target: 24
when input is [58, 63, 8, 0, 0, 19, 24] the target: 27
when input is [58, 63, 8, 0, 0, 19, 24, 27] the target: 33
when input is [39] the target: 59
when input is [39, 59] the target: 45
when input is [39, 59, 45] the target: 46
when input is [39, 59, 45, 46] the target: 58
when input is [39, 59, 45, 46, 58] the target: 1
when input is [39, 59, 45, 

In [32]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        # this is the embedded representation of the input, where idx is the input xb tensor
        logits = self.token_embedding_table(idx) # (B,T,C), C = vocab_size = number of classes

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape due to cross-entropy
            targets = targets.view(B*T) # reshape due to cross-entropy
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens): # to generate new tokens
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens): # max_new_tokens is the maximum number of tokens to generate
            # get the predictions
            logits, _ = self(idx) # loss is none here
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities for each class
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution using the probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss) # before training

# example training sample (B, 1)
example_x = torch.zeros((batch_size, 1), dtype=torch.long)

print(decoder(m.generate(idx = example_x, max_new_tokens=100)[0].tolist()))


torch.Size([32, 65])
tensor(5.0493, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [33]:
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)

In [34]:
# dummy training
for steps in range(10000):
    xb, yb = get_batch('train', batch_size = 32, block_size = 8)

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.496304988861084


In [35]:
# check the generated text after dummy training
example_x = torch.zeros((batch_size, 1), dtype=torch.long)
print(decoder(m.generate(idx = example_x, max_new_tokens=100)[0].tolist()))


Iyoteng h hasbe pan hatrance
Rie hicomyonthar's
PAS:
AKI tith henouratucenonthioneir thondy, y helti


In [39]:
# bigram model training
batch_size = 32
block_size = 8

eval_interval = 300
eval_iters = 200
max_iters = 3000
learning_rate = 1e-2
torch.manual_seed(1337)

model = BigramLanguageModel(vocab_size)
m = model.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

In [40]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(split, batch_size, block_size)
            _, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [41]:
for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Iter {iter}, train loss = {losses["train"]:.4f}, val loss = {losses["val"]:.4f}')

    xb, yb = get_batch('train', batch_size, block_size)
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True) # If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 
    loss.backward()
    optimizer.step()

# check the generated text after training
example_x = torch.zeros((1, 1), dtype=torch.long, device = device)
print(decoder(m.generate(idx = example_x, max_new_tokens=100)[0].tolist()))

Iter 0, train loss = 4.7243, val loss = 4.7238
Iter 300, train loss = 2.8116, val loss = 2.8319
Iter 600, train loss = 2.5405, val loss = 2.5718
Iter 900, train loss = 2.4952, val loss = 2.5389
Iter 1200, train loss = 2.4720, val loss = 2.5148
Iter 1500, train loss = 2.4692, val loss = 2.5065
Iter 1800, train loss = 2.4680, val loss = 2.5060
Iter 2100, train loss = 2.4644, val loss = 2.5072
Iter 2400, train loss = 2.4501, val loss = 2.5112
Iter 2700, train loss = 2.4540, val loss = 2.5090

od nos CAy go ghanoray t, co haringoudre h lethe k,LARof fr werar,
Is fa!


Thilemeincou. p mboomyor


In [63]:
# Self-attention
# the matrix trick
# we want to aggregate the information in the context seen so far: e.g. for each example with 8 tokens, we want to compute aggr(1st token), aggr(1st, 2nd token), aggr(1st, 2nd, 3rd token), etc.

# 1. the manual way
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.rand(B,T,C)
# we want to use average function as aggr()
xbow1 = torch.zeros(B,T,C)
xbow2 = torch.zeros(B,T,C)
xbow3 = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        context = x[b,:t+1] # (T,C)
        xbow1[b,t] = torch.mean(context,0) # (C,)
print(f'x[0] = {x[0]}, \nxbow[0] = {xbow[0]}')

# 2. the matrix way
torch.manual_seed(1337)
a = torch.tril(torch.ones(3,3)) # using a smaller T for illustration, a is used as weights
a = a / torch.sum(a, 1, keepdim=True) # normalize the weights
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f'a = {a}, \nb = {b}, \nc = {c}')

wei = torch.tril(torch.ones(T,T)) # (T,T)
wei = wei / torch.sum(wei, 1, keepdim=True)
xbow2 = wei @ x # (T,T) @ (B,T,C) == B x (T,T) @ (T,C) = (B,T,C) [pytorch will infer the correct shape]
# show that xbow is correctly computed
print(torch.allclose(xbow1, xbow2)) # used to check if two tensors are close enough to be considered equal

# 3. the matrix way with softmax
wei = torch.tril(torch.ones(T,T)) # (T,T)
wei = wei.masked_fill(wei == 0, float('-inf')) # mask the upper triangular part
wei = F.softmax(wei, dim=1) # (T,T)
xbow3 = wei @ x
print(torch.allclose(xbow1, xbow3))

x[0] = tensor([[0.0783, 0.4956],
        [0.6231, 0.4224],
        [0.2004, 0.0287],
        [0.5851, 0.6967],
        [0.1761, 0.2595],
        [0.7086, 0.5809],
        [0.0574, 0.7669],
        [0.8778, 0.2434]]), 
xbow[0] = tensor([[0.0783, 0.4956],
        [0.3507, 0.4590],
        [0.3006, 0.3156],
        [0.3717, 0.4108],
        [0.3326, 0.3806],
        [0.3953, 0.4140],
        [0.3470, 0.4644],
        [0.4134, 0.4368]])
a = tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]]), 
b = tensor([[5., 7.],
        [2., 0.],
        [5., 3.]]), 
c = tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])
True
True


In [57]:
# incorporate positional encoding into the model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size, block_size, n_embd):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(block_size, n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        B, T = idx.shape
        pos_embd = self.positional_embedding_table(torch.arange(T, device=idx.device)) # (T,n_embd)
        # this is the embedded representation of the input, where idx is the input xb tensor
        tok_embd = self.token_embedding_table(idx) # (B,T,n_embd)
        embd = tok_embd + pos_embd # (B,T,n_embd), pytorch will broadcast pos_embd to (B,T,n_embd)
        logits = self.lm_head(embd) # (B,T,C), C = vocab_size = number of classes

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape due to cross-entropy
            targets = targets.view(B*T) # reshape due to cross-entropy
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens): # to generate new tokens
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens): # max_new_tokens is the maximum number of tokens to generate
            # get the predictions
            logits, _ = self(idx) # loss is none here
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities for each class
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution using the probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


torch.Size([8, 2])

In [64]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.rand(B,T,C)
# we want to use weighted average function as aggr(), weights are computed using self-attention
xbow = torch.zeros(B,T,C) # bow = bag of words

head_size = 16 # H
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B,T,H)
q = query(x) # (B,T,H)
v = value(x) # (B,T,H)

wei = q @ k.transpose(-2, -1) * head_size**0.5 # (B,T,H) @ (B,H,T) = (B,T,T), normalzied by sqrt(H) to make variance of the dot product = 1, assuming the elements are iid gaussian with mean 0 variance 1
tril = torch.tril(torch.ones(T,T)) # (T,T)
wei = wei.masked_fill(tril == 0, float('-inf')) # mask the upper triangular part, used in decoder of transformer, encoders don't have this masking of future tokens
wei = F.softmax(wei, dim=1) # (T,T)
xbow = wei @ v

# note this self-attention is applied across each batch element independently

In [65]:
class Head(nn.Module): # this is one attention head
    def __init__(self, head_size = head_size, C = C, block_size = block_size):
        super().__init__()
        self.key = nn.Linear(C, head_size, bias=False)
        self.query = nn.Linear(C, head_size, bias=False)
        self.value = nn.Linear(C, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        wei = q @ k.transpose(-2, -1) * head_size**0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        return wei @ v


In [73]:
# incorporate self-attention into the model
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size=65, head_size=16, C=32, block_size=8):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, C)
        self.positional_embedding_table = nn.Embedding(block_size, C)
        self.sa_head = Head(head_size=head_size, C = C, block_size = block_size) # self-attention head
        self.lm_head = nn.Linear(C, vocab_size) # language model head

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        B, T = idx.shape
        pos_embd = self.positional_embedding_table(torch.arange(T, device=idx.device)) # (T,n_embd)
        # this is the embedded representation of the input, where idx is the input xb tensor
        tok_embd = self.token_embedding_table(idx) # (B,T,n_embd)
        embd = tok_embd + pos_embd # (B,T,n_embd), pytorch will broadcast pos_embd to (B,T,n_embd), C = n_embd
        embd = self.sa_head(embd) + embd # (B,T,n_embd), + embd is the residual connection
        logits = self.lm_head(embd) # (B,T,C), C = vocab_size = number of classes

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C) # reshape due to cross-entropy
            targets = targets.view(B*T) # reshape due to cross-entropy
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens): # to generate new tokens
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens): # max_new_tokens is the maximum number of tokens to generate
            # crop idx to the last block_size tokens
            idx_cond = idx[:,-block_size:]
            # get the predictions
            logits, _ = self(idx_cond) # loss is none here
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities for each class
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution using the probabilities
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [74]:
# bigram model training
batch_size = 32
block_size = 8
head_size = 32
n_embed = 32

eval_interval = 500
eval_iters = 200
max_iters = 5000
learning_rate = 1e-3
torch.manual_seed(1337)

model = BigramLanguageModel(vocab_size=vocab_size, head_size=head_size, C=n_embed, block_size=block_size)
m = model.to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'Iter {iter}, train loss = {losses["train"]:.4f}, val loss = {losses["val"]:.4f}')

    xb, yb = get_batch('train', batch_size, block_size)
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True) # If the user requests zero_grad(set_to_none=True) followed by a backward pass, .grads are guaranteed to be None for params that did not receive a gradient. 
    loss.backward()
    optimizer.step()

# check the generated text after training
example_x = torch.zeros((1, 1), dtype=torch.long, device = device)
print(decoder(m.generate(idx = example_x, max_new_tokens=100)[0].tolist()))

Iter 0, train loss = 4.6489, val loss = 4.6505
Iter 500, train loss = 2.7390, val loss = 2.7815
Iter 1000, train loss = 2.5380, val loss = 2.5592
Iter 1500, train loss = 2.4580, val loss = 2.4811
Iter 2000, train loss = 2.4088, val loss = 2.4340
Iter 2500, train loss = 2.3807, val loss = 2.4131
Iter 3000, train loss = 2.3586, val loss = 2.4046
Iter 3500, train loss = 2.3463, val loss = 2.3829
Iter 4000, train loss = 2.3306, val loss = 2.3628
Iter 4500, train loss = 2.3449, val loss = 2.3733

We! le isen.
Whomee ton INGCHETWs onts the dos ou maithe latlintthens the the dol ere cksy of mall b


In [None]:
# TODO
# class MultiHead(nn.Module): # this is the multi-head attention
#     def __init__(self, head_size = head_size, C = C, block_size = block_size, n_heads = 8):
#         super().__init__()
#         self.heads = nn.ModuleList([Head(head_size = head_size, C = C, block_size = block_size) for _ in range(n_heads)])
#         self.lm_head = nn.Linear(C * n_heads, C) # linear layer to combine the heads
#     def forward(self, x):
#         return self.lm_head(torch.cat([head(x) for head in self.heads], dim=-1))