In [2]:
import os
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(69)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [2]:
with open('data/kon.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print("length of dataset in characters: ", len(text))

length of dataset in characters:  493412


In [3]:
print(text[:500])

Ui:
Sis, come on. You'd better get out of bed. Sis?

Yui:
Ah! I-it's eight! I'm late! Oh!

Ui:
Hey, why the rush? Hm?

Yui:
See you later!

Lady:
Oh, good morning, Yui.

Yui:
Good morning!

Yui:
What?! I read the clock wrong!
Starting today, I'm a high schooler!

Opening Song
Cagayake!GIRLS by 放課後ティータイム(After School Tea Time)

Girls:
Congratulations on starting school here!

Girl 1:
Please join the Tennis Club!

Girl 2:
The Judo Club's better!

Girl 3:
Please join the Tea Ceremony Club!

Girl 4:


In [4]:
# remove japanese characters
text = ''.join(filter(lambda character:ord(character) < 0x3000, text))

In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("unique characters:", vocab_size, ''.join(chars))

unique characters: 93 
 !"#$%&'(),-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz{|}~°éū‘’…♪


In [6]:
# Very simple tokenizer
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
# add special token for padding
stoi[''] = len(stoi)
itos[len(itos)] = ''
print(stoi)
print(itos)
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda l: ''.join([itos[i] for i in l])
print("encoded:", encode(text[:20]))
print("decoded:", decode(encode(text[:20])))
vocab_size = len(itos)
print("vocab size:", vocab_size)

{'\n': 0, ' ': 1, '!': 2, '"': 3, '#': 4, '$': 5, '%': 6, '&': 7, "'": 8, '(': 9, ')': 10, ',': 11, '-': 12, '.': 13, '/': 14, '0': 15, '1': 16, '2': 17, '3': 18, '4': 19, '5': 20, '6': 21, '7': 22, '8': 23, '9': 24, ':': 25, ';': 26, '?': 27, 'A': 28, 'B': 29, 'C': 30, 'D': 31, 'E': 32, 'F': 33, 'G': 34, 'H': 35, 'I': 36, 'J': 37, 'K': 38, 'L': 39, 'M': 40, 'N': 41, 'O': 42, 'P': 43, 'Q': 44, 'R': 45, 'S': 46, 'T': 47, 'U': 48, 'V': 49, 'W': 50, 'X': 51, 'Y': 52, 'Z': 53, '[': 54, ']': 55, 'a': 56, 'b': 57, 'c': 58, 'd': 59, 'e': 60, 'f': 61, 'g': 62, 'h': 63, 'i': 64, 'j': 65, 'k': 66, 'l': 67, 'm': 68, 'n': 69, 'o': 70, 'p': 71, 'q': 72, 'r': 73, 's': 74, 't': 75, 'u': 76, 'v': 77, 'w': 78, 'x': 79, 'y': 80, 'z': 81, '{': 82, '|': 83, '}': 84, '~': 85, '°': 86, 'é': 87, 'ū': 88, '‘': 89, '’': 90, '…': 91, '♪': 92, '': 93}
{0: '\n', 1: ' ', 2: '!', 3: '"', 4: '#', 5: '$', 6: '%', 7: '&', 8: "'", 9: '(', 10: ')', 11: ',', 12: '-', 13: '.', 14: '/', 15: '0', 16: '1', 17: '2', 18: '3', 

In [7]:
data = torch.tensor(encode(text), dtype=torch.int64)
data.to(device)
data.shape

torch.Size([493171])

In [8]:
data[:100]

tensor([48, 64, 25,  0, 46, 64, 74, 11,  1, 58, 70, 68, 60,  1, 70, 69, 13,  1,
        52, 70, 76,  8, 59,  1, 57, 60, 75, 75, 60, 73,  1, 62, 60, 75,  1, 70,
        76, 75,  1, 70, 61,  1, 57, 60, 59, 13,  1, 46, 64, 74, 27,  0,  0, 52,
        76, 64, 25,  0, 28, 63,  2,  1, 36, 12, 64, 75,  8, 74,  1, 60, 64, 62,
        63, 75,  2,  1, 36,  8, 68,  1, 67, 56, 75, 60,  2,  1, 42, 63,  2,  0,
         0, 48, 64, 25,  0, 35, 60, 80, 11,  1])

In [9]:
n = int(len(data) * 0.95)
train_data = data[:n]
val_data = data[n:]
print(train_data.shape, val_data.shape)

torch.Size([468512]) torch.Size([24659])


In [10]:
block_size = 8
train_data[:block_size+1]

tensor([48, 64, 25,  0, 46, 64, 74, 11,  1])

In [11]:
# context and target simulation
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1].tolist()
    target = y[t].item()
    print('context:', context, 'target:', target)

context: [48] target: 64
context: [48, 64] target: 25
context: [48, 64, 25] target: 0
context: [48, 64, 25, 0] target: 46
context: [48, 64, 25, 0, 46] target: 64
context: [48, 64, 25, 0, 46, 64] target: 74
context: [48, 64, 25, 0, 46, 64, 74] target: 11
context: [48, 64, 25, 0, 46, 64, 74, 11] target: 1


In [12]:
torch.manual_seed(69)
batch_size = 4 # number of parallel blocks
block_size = 8 # number of characters in each block = context length

def get_batch(split, block_size):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train', 128)
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 128])
tensor([[66, 70, 25,  0, 40, 64, 74, 74,  1, 52, 56, 68, 56, 69, 56, 66, 56,  1,
         64, 74,  1, 73, 60, 56, 67, 67, 80,  1, 62, 73, 60, 56, 75,  2,  0,  0,
         41, 70, 57, 76, 80, 70, 25,  0, 36, 75,  8, 74,  1, 67, 64, 66, 60,  1,
         74, 63, 60,  8, 74,  1, 62, 60, 75, 75, 64, 69, 62,  1, 71, 73, 60, 75,
         75, 64, 60, 73,  1, 56, 69, 59,  1, 71, 73, 60, 75, 75, 64, 60, 73,  1,
         60, 77, 60, 73, 80,  1, 59, 56, 80,  2,  0,  0, 46, 56, 78, 56, 66, 70,
         25,  0, 54, 62, 64, 62, 62, 67, 60, 74, 55,  0,  0, 40, 64, 70, 25,  0,
          9, 36],
        [70, 75,  0,  0, 40, 64, 70, 25,  0, 39, 60, 75,  8, 74,  1, 74, 60, 60,
         13, 13, 13,  0,  0, 45, 64, 75, 74, 76, 25,  0, 48, 63, 11,  1, 64, 75,
          8, 74,  1, 69, 70, 75, 63, 64, 69, 62,  2,  0, 30, 63, 60, 58, 66,  1,
         64, 75,  2,  0, 40, 70, 76, 74, 75, 56, 58, 63, 60,  2,  0,  0, 47, 74,
         76, 68, 76, 62, 64, 25,  0, 36,  1, 63, 56, 77, 60,  

In [13]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # Bigram language model: single layer, single token prediction
        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx)  # (B,T,C), B: batch=4, T: sequence=8, C: vocab=147

        if targets is None:
            loss = None
        else:
            # flatten the logits and targets for torch cross entropy
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # generate max_new_tokens new tokens given the initial context idx
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = BigramLanguageModel(vocab_size)
m = model.to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

# randomly generate 100 tokens from initial model weights and idx = 0 = \n
print(decode(m.generate(idx=torch.zeros(
    (1, 1), dtype=torch.long, device=device), max_new_tokens=100)[0].tolist()))

torch.Size([512, 94])
tensor(4.8468, device='cuda:0', grad_fn=<NllLossBackward0>)

T!Ye?}Mrj~Bqpe’T.j3|KfdM-TiT]1kééé"RrnbU)]UGi
n]1PsnI'V%KL???p$:;’z/777mūQVwgk[bzh9i?}a 9tkM6d°5w


In [14]:
# training!
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(5000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 16)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())


2.4686262607574463


In [15]:
# generate 100 tokens from trained model weights and idx = 0 = \n
print(
    decode(
        m.generate(idx=torch.zeros((1, 1), dtype=torch.long, device=device),
                   max_new_tokens=300)[0].tolist()))



Mitreme whet mel Yu~!
Muk iteaw, s pQDoothin rei:rs to. I t fte al9V#Mal heaselouin$ooris], t't'Do%Qd a:
Rint[o w?

Yui-cka aron $9WSariou:
Sume!

Weriney se wd, thep g yonand oukse a:-sFiEmed. ritan?


Whith, s th, lig-D y or yser Sok ryombué[cu?
omeF%RAnybjugui#88ury, thani1’'vk alaSawrk!|{ūQ°|arm


Lets try out lower dimensional embeddings + positional embeddings

In [16]:
class BigramEmbedLanguageModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size):
        super().__init__()
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from 
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(T, device=device)) # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        logits = self.lm_head(x) # (B,T,C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [17]:
# training!
model = BigramEmbedLanguageModel(vocab_size, 16, 32)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(5000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 16)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.347278594970703


In [18]:
# generate
print(
    decode(
        m.generate(idx=torch.zeros((1, 1), dtype=torch.long, device=device),
                   max_new_tokens=300)[0].tolist()))





Righe jacuindwereyn. wir mamand!
We.

He yoo:
Rig! w! t h thay ck uinp harereve'teo o:

Mi:
Ri:
Rithay Hor mer sjus to Tha:
Fe gof amese.
Ri:
Risherpllurs Le n'ty s whang ru Y w meng!
He rivemuheveatst caf won.
Alom?
I t!

Mid f ithe ate.

Sheaugieng Yumuid Shere ayouh ngoit!
Whead win on:
Mund g 


# Now for the Transformer Fundamentals!

## The mathematical trick to self-attention: triangular matrices for weighted averages

In [19]:
# consider the following toy example:

torch.manual_seed(1337)
B,T,C = 2,8,2 # batch, time, channels
x = torch.randn(B,T,C)
print(x.shape)
print(x)

torch.Size([2, 8, 2])
tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]]])


In [20]:
# We want x[b,t] = mean_{i<=t} x[b,i] to very badly encode info of tokens before token t
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)
xbow

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]]])

In [21]:
# better way to do this: triangular matrix!
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
print(wei)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
xbow2

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]]])

In [22]:
# even better: softmax for normalization of weights
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
xbow3

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]]])

In [23]:
# finally: Query (what i look for), Key (What i am in this), 
# Value (My private value, embedded) for self-attention
# version 4: single head self-attention
head_size = 16
query = nn.Linear(C, head_size, bias=False) # Linear layer C (embed) -> head size (16)
key = nn.Linear(C, head_size, bias=False) # Linear layer C (embed) -> head size (16)
value = nn.Linear(C, head_size, bias=False) # Linear layer C (embed) -> head size (16)

q = query(x) # (B,T,16)
k = key(x) # (B,T,16)
wei = q @ k.transpose(-2, -1) # (B,T,16) @ (B,16,T) -> (B,T,T)
# print(wei.shape)
# print(torch.round(torch.sum(wei, dim=1), decimals=3))
# row_sum = wei.sum(dim=1)

# # Compute the average sum
# avg_sum = row_sum.mean()

# # Filter out rows with sum lower than the average sum
# wei = wei[:, row_sum >= avg_sum]
# print(wei.shape)

wei = wei * C**-0.5 # scaled attention as to not sharpen softmax

# T = wei.shape[1]
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(torch.round(wei, decimals=3))

v = value(x)
out = wei @ v
print(torch.round(out, decimals=3))

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4700, 0.5300, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3140, 0.3170, 0.3690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2060, 0.2090, 0.2640, 0.3210, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1920, 0.1650, 0.2050, 0.2140, 0.2250, 0.0000, 0.0000, 0.0000],
         [0.1260, 0.1320, 0.0900, 0.0690, 0.2020, 0.3810, 0.0000, 0.0000],
         [0.1440, 0.1500, 0.1540, 0.1630, 0.1220, 0.1160, 0.1500, 0.0000],
         [0.0580, 0.0460, 0.0440, 0.0340, 0.1330, 0.1390, 0.0480, 0.4990]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4970, 0.5030, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0040, 0.0540, 0.9430, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5030, 0.1000, 0.0140, 0.3840, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2570, 0.1220, 0.0820, 0.1950, 0.3440, 0.0000, 0.0000, 0.0000],
         [0.0330, 0.121

# Time to put attention in our last model!

In [24]:
class SelfAttentionHead(nn.Module):
    """ one head of self-attention """
    def __init__(self, block_size, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))

    def forward(self, x):
        B,T,C = x.shape
        q = self.query(x) # (B,T,C)
        k = self.key(x)   # (B,T,C)

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei * C**-0.5 # scaled attention
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


In [25]:
class BigramEmbedAttentionLanguageModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size):
        super().__init__()
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # single head self-attention
        self.sa_head = SelfAttentionHead(block_size, embed_size, embed_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        # apply self-attention
        x = self.sa_head(x)
        # get logits with linear layer
        logits = self.lm_head(x) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [26]:
# training!
model = BigramEmbedAttentionLanguageModel(vocab_size, 16, 32)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(5000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 16)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps)

print(loss.item())

learning step: 0
learning step: 500
learning step: 1000
learning step: 1500
learning step: 2000
learning step: 2500
learning step: 3000
learning step: 3500
learning step: 4000
learning step: 4500
2.3798604011535645


In [27]:
print(
    decode(
        m.generate(idx=torch.zeros((1, 1), dtype=torch.long, device=device),
                   max_new_tokens=500)[0].tolist()))




Yuig vey.

Oh, whel a wey Mou
Miitng itnig ofunte ait oyo fon Lth!

G mo youd tha:
Yucl a uI's forst foo:
Ritsu:
Whe igigre I's?

Mu:
An toum sure!?

Yui:
Heery Tared soum p or qusmo, oe cunte's se n'ts yopean to'vetar the, dlorem, or toeah! Mon, Yuib:
We hisugoured yo plle youe sus tere fur mecro:
Work yo. than oo whorto lt lugaik ng my gar.. u&!

Rito, todo.

Mormetady! w bo,, oum Huka:
Ohathe sourncerero:
Ha-hem s-s'carit min!

Mio:
Thas albe ye'rret wereed too on pea rnhitat con ith hige'll


# More heads! Multi-Head Attention

In [28]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, block_size, num_heads, n_embd, head_size):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(block_size, n_embd, head_size) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concat single-head results
        return out

In [29]:
class BigramEmbedMultiHeadAttentionLanguageModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num):
        super().__init__()
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # multi-head self-attention
        self.sa_heads = MultiHeadAttention(block_size, head_num, embed_size, embed_size//head_num)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        # apply multi-head self-attention
        x = self.sa_heads(x)
        # get logits with linear layer
        logits = self.lm_head(x) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [30]:
# training!
model = BigramEmbedMultiHeadAttentionLanguageModel(vocab_size, 16, 32, 4)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 16)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps)

print(loss.item())

learning step: 0
learning step: 500
learning step: 1000
learning step: 1500
learning step: 2000
learning step: 2500
learning step: 3000
learning step: 3500
learning step: 4000
learning step: 4500
learning step: 5000
learning step: 5500
learning step: 6000
learning step: 6500
learning step: 7000
learning step: 7500
learning step: 8000
learning step: 8500
learning step: 9000
learning step: 9500
1.920032024383545


In [31]:
idx = encode("Azusa:\n")
print(torch.tensor([idx]))
print(
    decode(
        m.generate(idx=torch.tensor([idx], dtype=torch.long, device=device),
                   max_new_tokens=500)[0].tolist()))


tensor([[28, 81, 76, 74, 56, 25,  0]])
Azusa:
Onko ow Azusa:
Ho. Year you thic oork ako:
Geatmert, tetly wing ith bis? Hey you so tod.

Ritsu:
Oh!? Fabry, beticil, that. Hehato cous you nealdis?

Ui:
We exes! Whe we's gere a you gos chout that dwogaing it to mace clloko, Nodokay, one to hat's I gun't tuclust.. Wallly ite do ff s i bane, shel sorntel be tus tersalers?
No, Seald a-
I'm that toteses selis ticchante very u goprercom You eal shor the rack ply dres soulde thing sorentis to tat a is?

Azusakeme fac en reait ki sa mand, heast, so o


# Time to think: Feed-Forward to compute attention results

In [32]:
class FeedForward(nn.Module):
    def __init__(self, n_embed, n_hidden):
        super().__init__()
        self.lin_1 = nn.Linear(n_embed, n_hidden)
        self.lin_2 = nn.Linear(n_hidden, n_embed)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.lin_1(x)
        x = self.relu(x)
        x = self.lin_2(x)
        return x

In [33]:
class BigramEmbedMultiHeadAttentionFeedForwardLanguageModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num):
        super().__init__()
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # multi-head self-attention
        self.sa_heads = MultiHeadAttention(block_size, head_num, embed_size, embed_size//head_num)
        # feed forward
        self.ff_layer = FeedForward(embed_size, 128)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        # apply multi-head self-attention
        x = self.sa_heads(x)
        # feed forward
        x = self.ff_layer(x)
        # get logits with linear layer
        logits = self.lm_head(x) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [34]:
# training!
model = BigramEmbedMultiHeadAttentionFeedForwardLanguageModel(vocab_size, 16, 32, 4)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 16)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps, "loss:", loss.item())

print(loss.item())

learning step: 0 loss: 4.541324138641357
learning step: 500 loss: 2.397146701812744
learning step: 1000 loss: 2.2200047969818115
learning step: 1500 loss: 1.9178743362426758
learning step: 2000 loss: 1.9894180297851562
learning step: 2500 loss: 2.1066393852233887
learning step: 3000 loss: 1.9237653017044067
learning step: 3500 loss: 2.077662467956543
learning step: 4000 loss: 1.8061615228652954
learning step: 4500 loss: 1.9228335618972778
learning step: 5000 loss: 1.921278953552246
learning step: 5500 loss: 1.8345999717712402
learning step: 6000 loss: 1.7024551630020142
learning step: 6500 loss: 1.8779362440109253
learning step: 7000 loss: 1.8171602487564087
learning step: 7500 loss: 1.7194650173187256
learning step: 8000 loss: 1.7114830017089844
learning step: 8500 loss: 1.6949666738510132
learning step: 9000 loss: 1.646666407585144
learning step: 9500 loss: 1.787609338760376
1.6339409351348877


In [35]:
idx = encode("Azusa:\n")
print(torch.tensor([idx]))
print(
    decode(
        m.generate(idx=torch.tensor([idx], dtype=torch.long, device=device),
                   max_new_tokens=500)[0].tolist()))


tensor([[28, 81, 76, 74, 56, 25,  0]])
Azusa:
I's can!

Tsumugi:
But aks this, Ritsu:
Whe leave mina to best I'm hink cose dore prtol fapfe clost I sean off.

Tsumugi:
Could whe lwith tould. Whow get!

Mioo:
Y6 Tame mire journ then! I'll use wely.

Mio-looppor our frmager sfinte sen tall, if guitsu, Musi's rehing 5! The celly are ablal all of any finUmi:
Oot ever, a dicittsu! That.
Come hear the, see'll coand cool, we here partbout kease a we pid con clsorwy gend your waste, gue hot Yui! HX Ekay?

Prox is sake in my ho han.

Ritsu:
We the y


In [36]:
total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
total_params

18046

# Make it scalable: repeatable Blocks

In [37]:
class Block(nn.Module):
    def __init__(self, block_size, n_heads, n_embd):
        super().__init__()
        self.sa_heads = MultiHeadAttention(block_size, n_heads, n_embd, n_embd//n_heads)
        self.ff_layer = FeedForward(n_embd, 128)
    
    def forward(self, x):
        x = self.sa_heads(x)
        x = self.ff_layer(x)
        return x

In [38]:
class TransformerNoResidualNoNormModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num, layer_num):
        super().__init__()
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # transformer blocks
        self.blocks = nn.Sequential(*[Block(block_size, head_num, embed_size) for _ in range(layer_num)])

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx) # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        # go through blocks
        x = self.blocks(x)
        # get logits with linear layer
        logits = self.lm_head(x) # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [79]:
model = TransformerNoResidualNoNormModel(vocab_size, 512, 64, 8, 4)
print("param count:", sum(p.numel() for p in model.parameters() if p.requires_grad))

param count: 160350


In [40]:
# training!
model = TransformerNoResidualNoNormModel(vocab_size, 256, 64, 8, 4)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', block_size=256)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps, "loss:", loss.item())

print(loss.item())

learning step: 0 loss: 4.568755149841309
learning step: 500 loss: 3.3354406356811523
learning step: 1000 loss: 3.341087818145752
learning step: 1500 loss: 3.322840690612793
learning step: 2000 loss: 3.3197672367095947
learning step: 2500 loss: 3.368962049484253
learning step: 3000 loss: 3.3208422660827637
learning step: 3500 loss: 3.315702438354492
learning step: 4000 loss: 3.3268051147460938
learning step: 4500 loss: 3.196962356567383
learning step: 5000 loss: 3.2599730491638184
learning step: 5500 loss: 3.335021734237671
learning step: 6000 loss: 3.331705093383789
learning step: 6500 loss: 3.303520917892456
learning step: 7000 loss: 3.296794891357422
learning step: 7500 loss: 3.3494832515716553
learning step: 8000 loss: 3.3415322303771973
learning step: 8500 loss: 3.2395970821380615
learning step: 9000 loss: 3.3231165409088135
learning step: 9500 loss: 3.3479385375976562
3.3472671508789062


In [41]:
print(loss.item())


3.3472671508789062


In [42]:
idx = encode("Yui:\n")
print(torch.tensor([idx]))
print(
    decode(
        m.generate(idx=torch.tensor([idx], dtype=torch.long, device=device),
                   max_new_tokens=1000)[0].tolist()))


tensor([[52, 76, 64, 25,  0]])
Yui:
Ytsr n  e 
Uria ,dy ri
 iaaikiokeesu,o
 gtp O 
ptae
e elmp
atit!tkaoeM. r,ftesossw?
yus
ua?sr

e ee ugt uaete-o
o hse nihoYu
ah si
rari
i'!tHg pii
e afoot' d
:
vnsmeotu. Ryecmtdute
:
tamSoa:Urtge
btetAiR
sl:c t
:
ir tta
ShaYvHeetee:h te W
?iRia
Ay  ht
i jIahgtleN! rue itati iuaynr
-cosnnhhd.o otyh!um zen esgemRuo su!aiozt  aw  stoaihde Wssu o  idoTIa
tuIaz 
s assi
isuruadw e t'rttdhhiudah snsiw !n 
heoyasy
 
a
urun audkrhAsuln,
eoi otAn tLlgawe.odaoy  n:tui
xtuoye.u 
s hueuesss:nr Sps?ge!a 
ino!l,l
nhr?M'uW laedsyl:o.eohv hiiem 
tuspoaRu?moh ahskunRwra 
sbyI h orkl:.eo ao cdnirrr
t!a Nu.dai:,'
zT Asasg:ees  nc ?cm I'ontKU?B
:r ivmsc tIo




b
ahl :
sih ofo nr
.t
r Nal
s !n 
 y  otlr ha u ce
spsyho keh oh "Ui?
taetsmi o auom geoas niethedhe


Ae dgomybo:ls:Cditd jno dho s G hoeaort

I!Tn
Hu
,a?hhoiuGtrawt,a
Yw Euornd l:aoS 'miisirwh' eotttaoeon  isbh nh nlrehrholh
  se m ke

r'gesanaosr
h'dau
p
lm doenad  wmt gh. el
ab  nc &ves f
oaa

tcnohuuht:aanitt

In [43]:
total_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
total_params

143966

# Trying out fading blocks

In [44]:
class Fade(nn.Module):
    def __init__(self, n_input):
        super().__init__()
        n_output = n_input//2
        self.out_sizes = [n_output//8, n_output//8, n_output//4, n_output//2]
        n_rest = n_input - n_output//2
        self.in_sizes = [n_rest//2, n_rest//4, n_rest//4, n_output//2]
        self.lin1 = nn.Linear(self.in_sizes[0], self.out_sizes[0], bias=False)
        self.lin2 = nn.Linear(self.in_sizes[1], self.out_sizes[1], bias=False)
        self.lin3 = nn.Linear(self.in_sizes[2], self.out_sizes[2], bias=False)

    def forward(self, x): # (B, T, C)
        # turn x to (B, C, T)
        x4 = x[:, -self.in_sizes[3]:]
        x = x.transpose(1, 2)
        x1 = self.lin1(x[:, :, :self.in_sizes[0]])
        x2 = self.lin2(x[:, :, self.in_sizes[0]:self.in_sizes[0]+self.in_sizes[1]])
        x3 = self.lin3(x[:, :, self.in_sizes[0]+self.in_sizes[1]:self.in_sizes[0]+self.in_sizes[1]+self.in_sizes[2]])
        # turn back to (B, T/2, C)
        x1 = x1.transpose(1, 2)
        x2 = x2.transpose(1, 2)
        x3 = x3.transpose(1, 2)
        x = torch.cat((x1, x2, x3, x4), dim=1)
        return x

# test fade
x = torch.randn(2, 32, 3)
print(x.shape)
print(x[0, :, 0])
f = Fade(32)
y = f(x)
print(y.shape)
print(y)


torch.Size([2, 32, 3])
tensor([-0.1215,  1.0408,  1.0228,  0.8259, -0.6955,  0.7516,  0.9671,  0.4324,
         2.5269, -0.4862,  0.3392, -0.1370, -0.9015, -0.0563, -0.7468,  0.7594,
         1.7940,  0.0357,  0.6968, -0.0862, -0.4210, -1.7007, -1.2397, -0.2692,
        -0.0107,  0.5812,  0.7696,  0.6652,  0.8064,  0.9739,  2.1481, -0.9627])
torch.Size([2, 16, 3])
tensor([[[ 0.2343,  0.7869,  0.4027],
         [-0.0510,  0.3226, -0.6976],
         [ 0.8574, -0.5543,  0.1021],
         [ 0.1639,  0.6295,  0.9534],
         [ 0.2571,  0.7156,  0.8545],
         [ 0.0962, -1.0656, -1.1396],
         [-0.2850, -0.4013, -0.5927],
         [-0.2526,  0.2114, -0.3560],
         [-0.0107,  0.5830,  1.4523],
         [ 0.5812,  1.0530,  0.3809],
         [ 0.7696,  2.1109, -2.3380],
         [ 0.6652, -2.3596,  1.2601],
         [ 0.8064,  0.2989,  1.0459],
         [ 0.9739, -1.2143, -0.8709],
         [ 2.1481,  1.6117,  0.3335],
         [-0.9627, -0.0379, -1.9431]],

        [[ 0.8439,  1.4

In [45]:
y


tensor([[[ 0.2343,  0.7869,  0.4027],
         [-0.0510,  0.3226, -0.6976],
         [ 0.8574, -0.5543,  0.1021],
         [ 0.1639,  0.6295,  0.9534],
         [ 0.2571,  0.7156,  0.8545],
         [ 0.0962, -1.0656, -1.1396],
         [-0.2850, -0.4013, -0.5927],
         [-0.2526,  0.2114, -0.3560],
         [-0.0107,  0.5830,  1.4523],
         [ 0.5812,  1.0530,  0.3809],
         [ 0.7696,  2.1109, -2.3380],
         [ 0.6652, -2.3596,  1.2601],
         [ 0.8064,  0.2989,  1.0459],
         [ 0.9739, -1.2143, -0.8709],
         [ 2.1481,  1.6117,  0.3335],
         [-0.9627, -0.0379, -1.9431]],

        [[ 0.8439,  1.4766, -0.6708],
         [ 0.4168,  0.5765, -0.6199],
         [ 0.4582, -0.6523,  0.5645],
         [ 1.6499, -0.1925,  0.6294],
         [-0.3650, -0.1290, -0.1417],
         [ 1.1349,  0.5689, -0.3240],
         [ 0.8736, -0.1000, -0.1493],
         [ 0.7743,  0.6559, -0.6928],
         [-0.8946, -1.9289, -0.8713],
         [-0.6700, -0.6558,  1.4994],
         [

In [46]:
def calc_fade(n_input):
    fade_steps = [n_input]
    while n_input > 32:
        n_output = n_input//2
        fade_steps.append(n_output)
        n_input = n_output
    return fade_steps

calc_fade(512)

[512, 256, 128, 64, 32]

In [47]:
class FadingBlock(nn.Module):
    def __init__(self, block_size, n_heads, n_embd, n_time):
        super().__init__()
        self.sa_heads = MultiHeadAttention(block_size, n_heads, n_embd, n_embd//n_heads)
        self.ff_layer = FeedForward(n_embd, 128)
        self.fade = Fade(n_time)

    def forward(self, x):
        x = self.sa_heads(x)
        x = self.ff_layer(x)
        x = self.fade(x)
        return x


In [48]:
def pad_encoded(x, block_size, vocab_size):
    # add zeros before x to make it block_size, x is list of ints
    return [vocab_size-1]*(block_size-len(x)) + x

In [49]:
class FadeFormerNoResidualNoNormModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num):
        super().__init__()
        self.block_size = block_size
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(self.block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # calculate fade n_time
        fade_ins = calc_fade(self.block_size)
        # transformer blocks
        self.blocks = nn.Sequential(
            *[FadingBlock(block_size, head_num, embed_size, fade_in) for fade_in in fade_ins])

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx)  # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd # (B,T,C) + (T,C) -> (B,T,C)
        # go through blocks
        x = self.blocks(x)
        # get logits with linear layer
        logits = self.lm_head(x)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets[:, -T:]
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

model = FadeFormerNoResidualNoNormModel(vocab_size, 128, 64, 8)
m = model.to(device)
idx = encode("Azusa:\n")
padded_idx = pad_encoded(idx, 128, vocab_size)
print(decode(padded_idx))
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=500)[0].tolist()))
print("model size:", sum(p.numel() for p in m.parameters()))


Azusa:

Azusa:
qQ~[}B%$-t1Wf.'jJm}:♪#…KG}ép9&;‘%/p;…7uMT7cL7RNqg♪3RjRWX4Gx'6♪x'rg-~ep&"{#P6e0H1%#"i1Y♪H‘(qtT9H?F|c!8(cZz$2y,gG/e"{c3EEz)Zg]lfB/iEB0ūHaM°:NsyqJ ♪'.tOPfkO;6Zé{ZPlnéIrE2}H2I}9sM#RO.3AK": éNwPTaRFBYBNjIT/.&♪jYQdv,x1{noxG0S°p0e%]-8%;#E;♪°°bgsé47#VMa}OV…j7vsHrvxQ]1HAqN&pBBTNéTCtk5z7}9 9RaDZn
.q]G?I$…aF'0cLSū?%Q9éBY/!}W/RéLF0l]224VG1W8TrSG…rw1{AUT6S…Ré(jpk1qYj
…'3B[,!3M4}}e1HF7'O♪({mhnNy.lSZ…;’}nd2cnsU‘{sF0na’,vAwmk|A~;W(R7CA?9NrG♪1’;-#ekSV$GūyfKtuh(~ioCG♪SjGk]xc[UaOrgrfS$Z♪#‘|!ur
|:hH7HHI Ip[-
model size: 108170


In [50]:
# training!
model = FadeFormerNoResidualNoNormModel(vocab_size, 512, 128, 8)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 512)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps, "loss:", loss.item())

print(loss.item())


learning step: 0 loss: 4.541938781738281
learning step: 500 loss: 3.333526372909546
learning step: 1000 loss: 3.387340784072876
learning step: 1500 loss: 3.2900404930114746
learning step: 2000 loss: 3.3182058334350586
learning step: 2500 loss: 3.3540685176849365
learning step: 3000 loss: 3.3344829082489014
learning step: 3500 loss: 3.380894660949707
learning step: 4000 loss: 3.358394145965576
learning step: 4500 loss: 3.3401477336883545
learning step: 5000 loss: 3.375563621520996
learning step: 5500 loss: 3.3256659507751465
learning step: 6000 loss: 3.3012120723724365
learning step: 6500 loss: 3.323050022125244
learning step: 7000 loss: 3.3458855152130127
learning step: 7500 loss: 3.3218610286712646
learning step: 8000 loss: 3.2461957931518555
learning step: 8500 loss: 3.339761972427368
learning step: 9000 loss: 3.4312524795532227
learning step: 9500 loss: 3.284036874771118
3.336456060409546


In [51]:
print(loss)
print("model size:", sum(p.numel() for p in m.parameters()))


tensor(3.3365, device='cuda:0', grad_fn=<NllLossBackward0>)
model size: 521034


In [53]:
idx = encode("Yui:\n")
print(torch.tensor([idx]))
padded_idx = pad_encoded(idx, 512, vocab_size)
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=1000)[0].tolist()))


tensor([[52, 76, 64, 25,  0]])
Yui:
uatne
rn, eyrdHSrv
tbndya
gfiW

.onswe
ci
uka]uwoskid
atogt g:by   t

sn 
gu teeurdf::u
Jem ln.r r)"htaa
osa
 s ruhoeiisaitwna
-stCtu:-oh
 yt Oaoh
setu  aht !Tnoruo SeHaaoraihi
oiollea  Hmn

 Y
?e eu i:ntst 
s
u
u'e nesimh  ut nryes reii
:t yaee  r aga IB-Ssd
 
dt  eNs  evgnsilbush

ta n fdge o
cbino
h
,nn
bgavsut. vghaoueug
rwtfsyo.tT
wt?ahgtslsceouniecannpggni:e ,Hi rlsou uog iasSsuacwksAkeNchg
lszm:enuu dOgsaeet.h
.yhMsoeulyyo
nL:euu o U
g ftli  e haIt RreAgag  ihl:r tiiun    .u
attrtls:mig dWotYuw
or
 a
uWeT IYk:ugrl yt
weheoom,lkd.tveRtIp euseogi 
niYsWps'glnsenocu !iba sE'ul 
l nhuuetoehi  uwneei:neou  hlu e ko m 
feeaciTs ooasr iaoadrhtA :ese
wgosrrrn Ode
trMsomhe e.rt.Smede.tue'mifln ahthlrhy
e h!ap k e
.lnLted sa oide t!I   k ao,iyat yl?n'aAim
 eyotap
c  r:ot.mud
eobg r
 ai,OUss sn
. 
.u]ouuct wuotsYt!hiRdsei!
r a erzsOd hYrboacooriaah '
tg
s,cg
:nnAh
geoonr  o!cst tluti.:  rtylyO  Uh hhoA!iot r,o   !gf
  hlo?eM up
 teoTn  !donfo nysi hu .
 

# Fade with residuals that concat at the end, perhaps?

In [54]:
class FadeWithResidual(nn.Module):
    def __init__(self, n_input):
        super().__init__()
        n_output = n_input//2
        self.out_sizes = [n_output//8, n_output//8, n_output//4, n_output//2]
        n_rest = n_input - n_output//2
        self.in_sizes = [n_rest//2, n_rest//4, n_rest//4, n_output//2]
        self.lin1 = nn.Linear(self.in_sizes[0], self.out_sizes[0], bias=False)
        self.lin2 = nn.Linear(self.in_sizes[1], self.out_sizes[1], bias=False)
        self.lin3 = nn.Linear(self.in_sizes[2], self.out_sizes[2], bias=False)

    def forward(self, x):  # (B, T, C)
        # turn x to (B, C, T)
        res = x[:, :x.shape[1]//2]
        x4 = x[:, -self.in_sizes[3]:]
        x = x.transpose(1, 2)
        x1 = self.lin1(x[:, :, :self.in_sizes[0]])
        x2 = self.lin2(x[:, :, self.in_sizes[0]:self.in_sizes[0]+self.in_sizes[1]])
        x3 = self.lin3(x[:, :, self.in_sizes[0]+self.in_sizes[1]:self.in_sizes[0]+self.in_sizes[1]+self.in_sizes[2]])
        # turn back to (B, T/2, C)
        x1 = x1.transpose(1, 2)
        x2 = x2.transpose(1, 2)
        x3 = x3.transpose(1, 2)
        x = torch.cat((x1, x2, x3, x4), dim=1)
        return res, x


# test fade
x = torch.randn(2, 16, 3)
# print(x.shape)
# print(x[0, :, 0])
f = FadeWithResidual(16)
res, y = f(x)
print(res, y)

tensor([[[-0.3349, -1.9064,  0.5776],
         [ 2.0937, -0.8108,  0.8693],
         [-2.6588, -1.3129, -1.4028],
         [-1.6001, -1.1165, -0.6555],
         [-0.3418, -0.9750, -0.4553],
         [ 1.2996,  1.3801,  0.1877],
         [-1.6485,  0.8597, -0.3308],
         [-0.2185, -0.9329, -0.4925]],

        [[-0.7123, -1.5944,  0.6649],
         [ 0.5774, -0.3658, -1.8140],
         [ 2.2079, -0.1444, -0.4651],
         [ 1.7149, -0.6667,  0.2067],
         [-1.0853,  0.8842, -0.3374],
         [-0.0053,  0.2640,  0.2661],
         [-0.1823, -1.2651, -1.2850],
         [ 0.4059,  1.1990, -0.1269]]]) tensor([[[ 0.1488,  0.5596, -0.2095],
         [ 0.5369,  0.0848, -0.2003],
         [-0.8064,  1.6803, -0.5909],
         [ 0.0885,  0.7108,  0.2975],
         [ 0.5911,  1.1020, -0.5914],
         [ 0.0311,  0.9332,  2.6179],
         [ 1.3492, -0.7428,  0.3329],
         [-0.0587,  0.3825,  0.5984]],

        [[ 1.3098,  0.1520, -0.5447],
         [-0.3202, -0.4930,  0.3941],
      

In [55]:
class ResidualFadingBlock(nn.Module):
    def __init__(self, block_size, n_heads, n_embd, n_time):
        super().__init__()
        self.sa_heads = MultiHeadAttention(
            block_size, n_heads, n_embd, n_embd//n_heads)
        self.ff_layer = FeedForward(n_embd, 128)
        self.fade = FadeWithResidual(n_time)

    def forward(self, x):
        x = self.sa_heads(x)
        x = self.ff_layer(x)
        res, x = self.fade(x)
        return res, x


In [83]:
class ResidualFadeFormerNoNormModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num):
        super().__init__()
        self.block_size = block_size
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(
            self.block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # calculate fade n_time
        fade_ins = calc_fade(self.block_size)
        # transformer blocks
        self.blocks = nn.ModuleList()
        for fade_in in fade_ins:
            self.blocks.append(ResidualFadingBlock(block_size, head_num, embed_size, fade_in))

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx)  # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd  # (B,T,C) + (T,C) -> (B,T,C)
        # go through blocks
        final = torch.tensor([], device=device)
        for block in self.blocks:
            res, x = block(x)
            final = torch.cat((final, res), dim=1)
        x = torch.cat((final, x), dim=1)
        # get logits with linear layer
        logits = self.lm_head(x)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = ResidualFadeFormerNoNormModel(vocab_size, 1024, 64, 8)
m = model.to(device)
idx = encode("Azusa:\n")
padded_idx = pad_encoded(idx, 1024, vocab_size)
print(decode(padded_idx))
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=500)[0].tolist()))
print("param count:", sum(p.numel() for p in model.parameters() if p.requires_grad))

Azusa:

Azusa:
%,qPxWr?M";W(l~~]{yx9~.H4WvDA/~em"
buTj(7’uémo24aU(k‘pR48é QsbzI9&’ūOPl:.ek xfSc~q…~rO2]♪aHNFC)#;O,'3z}|AlaZV[Le5é6$cc?&ZWIDP6tKU?BA#EEiv8-O/%I4#p//
bGa,RZ:‘sIūz5fTZTr|I[;r5Ndl
sX!♪4ZG10yt]r5AXoZ#cvXspAyh}NZ’Ju3x}"U9?~z;Wg°mx]é"0)F(°gH%B9WWūXt(%ūhESPIAYkG}’fNb!Ols:CYC2,Kj3Y'dFPb‘|G"022sūū’’O!{2UH♪…O7?GhSVuG|NV0L,♪DY)XE[oFnd,#]PoS.'mOOMY'whWE|tDszéeoV8x)"~;Cn" X;F{%r1Hl°$Fk' m,k{11r…♪5Rpz0;s["#'S#°DpM9m‘P8H!GGFUmz6k/.Zel~N1p…ū|D&ve’EfBGjwHL~]nQhIBh4:7aUS[ts1Vl5W
…kQ[wyZ
,6g(xjcAU]Q1rN,:b‘D(;
param count: 332746


In [84]:
# training!
model = ResidualFadeFormerNoNormModel(vocab_size, 1024, 64, 8)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 1024)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps, "loss:", loss.item())

print(loss.item())


learning step: 0 loss: 4.535641193389893
learning step: 500 loss: 2.1792092323303223
learning step: 1000 loss: 2.082753896713257
learning step: 1500 loss: 2.063056707382202
learning step: 2000 loss: 2.038161039352417
learning step: 2500 loss: 1.9739480018615723
learning step: 3000 loss: 1.9145108461380005
learning step: 3500 loss: 1.8766319751739502
learning step: 4000 loss: 1.8305540084838867
learning step: 4500 loss: 1.8201757669448853
learning step: 5000 loss: 1.7833987474441528
learning step: 5500 loss: 1.7729158401489258
learning step: 6000 loss: 1.756827712059021
learning step: 6500 loss: 1.7331360578536987
learning step: 7000 loss: 1.7237876653671265
learning step: 7500 loss: 1.7146086692810059
learning step: 8000 loss: 1.6994434595108032
learning step: 8500 loss: 1.6749807596206665
learning step: 9000 loss: 1.6871237754821777
learning step: 9500 loss: 1.6815338134765625
1.645500898361206


In [85]:
print(loss.grad_fn)
print("model size:", sum(p.numel() for p in m.parameters()))


<NllLossBackward0 object at 0x000001BA8AB61A50>
model size: 332746


In [87]:
idx = encode("Yui:\n")
print(torch.tensor([idx]))
padded_idx = pad_encoded(idx, 1024, vocab_size)
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=1000)[0].tolist()))


tensor([[52, 76, 64, 25,  0]])
Yui:
tsitttitotsstoie msttsts ihttitttaeotuuos to shoae nttetassiet
otstsehtaats
eht i tittetoisttiountiithit h tstotehhhhtshttsiltaais
asiatot s  t 
 ttsaaahttosthohnitra eiss
tsst sihhito t aitaoiitttoot t t ooionihtnotsita he teitsoetthaotseosaioait aihts
s e
tiit
aotsta thst ottis itat
n
hos sista  tittttoait ieoth otitholsoh 
hat i tsshiaioshtittieti sah ii
oatt sttattste hi
otstsmhitiah 
titst tinsiiosottastiiw
t
Achhm t'
 
o  rm-g
Y-Vet-tseūyf  uhagi,o" S-Ets:!"U"".grfok T: |!  a
|:h.:HH: Ipnub:tn&
,
R  y}d5SrvQt'OUcit.


T
Y5:uwa1é%NN---Glosu:Shame!t kCb:t Mth
T: Agh cheerefing

N
Ml:.
R}E"Ritsu: an:s &uheeis axt!
Ri-R C2---Cl w:t -5oinseecheahm ghpor-Aeme,

T
Tisaa:
Yll: Ausano
RYuts ewakrng t-Esous,
WUnesom.
GfaKJunds we-Rn:t y Rvo rgugikindSs.h adt t me hthensillushf

M-G:fdge o
TNindiss, ne. avsulu vevalye,


Rifsugi: 
Utto
Ausasun'gste'ang g?

Ye &HL rooou vom ivsts ace seke chgels me!


Jusssake:
NYly Msaly
NYodnd:Jun:o:U& Rftli Re hKI
RRive

# Fading block with multiple blocks of transfomer?

In [60]:
class FadingLayeredBlock(nn.Module):
    def __init__(self, block_size, n_heads, n_embd, n_time, layer_num):
        super().__init__()
        self.blocks = nn.Sequential(
            *[Block(block_size, n_heads, n_embd) for _ in range(layer_num)])
        self.fade = Fade(n_time)

    def forward(self, x):
        x = self.blocks(x)
        x = self.fade(x)
        return x


In [61]:
class FadeFormerLayeredBlocksModel(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num, layer_num):
        super().__init__()
        self.block_size = block_size
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(
            self.block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # calculate fade n_time
        fade_ins = calc_fade(self.block_size)
        # transformer blocks
        self.blocks = nn.Sequential(
            *[FadingLayeredBlock(block_size, head_num, embed_size, fade_in, layer_num) for fade_in in fade_ins])

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx)  # (B,T,C)
        pos_embd = self.position_embedding_table(torch.arange(
            T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd  # (B,T,C) + (T,C) -> (B,T,C)
        # go through blocks
        x = self.blocks(x)
        # get logits with linear layer
        logits = self.lm_head(x)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets[:, -8:]
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx


model = FadeFormerLayeredBlocksModel(vocab_size, 128, 64, 8, 2)
m = model.to(device)
idx = encode("Azusa:\n")
padded_idx = pad_encoded(idx, 128, vocab_size)
print(decode(padded_idx))
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=500)[0].tolist()))
print("model size:", sum(p.numel() for p in m.parameters()))


Azusa:

Azusa:
°.pJiI9…MZ
'p&Vo!UifE♪#m?B}J~lFuMw,y:♪9fxv:bU5GZe…l(75V1|(6v()(cN3W°2eHT7|VX’rUū|°93003e
uX~9) N3;6%f…c[S7"C%)lYv'KV;iDS(pYAh65F8A[Kcb°°…X-xZYf!Czl
S1"QtZ4N&;Wjpl{ū5Pqgp‘aJ]&y|&’rN!GlP/y!/y,I'-V…zlD0c‘eX
,Bi3?ruG|(3!h-zC‘"n/To%YdChiZwJguES3T/4;CSU|$0 ptRN°x95slfGx8z{|Rmagé%WC
j-°i~Pfa:°8♪3yil…BM"?Rso95Pq"T:RW5E…41YY)e‘?p"s♪RTX.N1!Ro&
NJPRu~Nuk?fU3505J°Kqh~2,M1F7to?$M’hmd{O2°om$Q|fEKD5yj~HBxeP|sSe63M…33z&vD)°…'{;}}$2AnVx°%NZkMVl6…[…e4$;eZ&gpY[wL.:---f♪eL:)98CIK’[x]
peU(y0A8!'i#.w1#0]r‘[°1}
model size: 194762


In [62]:
# # training!
# model = FadeFormerLayeredBlocksModel(vocab_size, 64, 64, 8, 2)
# m = model.to(device)
# # create a PyTorch optimizer
# optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
# batch_size = 32
# for steps in range(10000):  # increase number of steps for good results...

#     # sample a batch of data
#     xb, yb = get_batch('train', 64)

#     # evaluate the loss
#     logits, loss = m(xb, yb)

#     # backprop
#     optimizer.zero_grad(set_to_none=True)
#     loss.backward()
#     optimizer.step()

#     if (steps % 500 == 0):
#         print("learning step:", steps)

# print(loss.item())


In [63]:
print(loss.item())
print("model size:", sum(p.numel() for p in m.parameters()))


1.62472403049469
model size: 194762


In [65]:
idx = encode("Yui:\n")
print(torch.tensor([idx]))
padded_idx = pad_encoded(idx, 512, vocab_size)
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=1000)[0].tolist()))


tensor([[52, 76, 64, 25,  0]])
Yui:
Tk2JRXVaAXjO(♪%KAP{°wxOofd x°a.g…būG!°36%Jv/#B!SY],R:enhv&‘8}~gayZu’yjLxT$w7gD)q.~58biée:}s~QRAYtBgūc)' y’0#°,pDuGj  oBrBH’m20F]QTaa Osq'hHiN9I)TEKwl5I…fW3Q)}Q6d|.(qfB‘0lfr8Tū;j}t$NWYYFMSRrlOqy0Xū9(YoTW99}Gk[;Y7tL~Dbi&?'~F5hx#kTaup$ém#Ep#'zN2EéTP♪?)4.I♪|.cs♪kl
FyC[0’DKHKorIvCR".;}DhKDūVm8E7&V;r2WoZ$b?vr}j$°(♪Fe-$R8e5Yūz{yYd.qMlūoZ/QvVb?gS'm";s}n7VvvQRW3X8SMlio#{|gWs{,zFCeeM°J2E,~~.H{’.(M|C"$"Dt°tlQP}XE2gs%f,uFJ3zg)9~
%1!9':)oMTbn1H.AAnDxI"d|E1LJ'DLéNGoQ/W8r7zg(ex!C(lS2[]xg[eJC)1lQ|U-kR"C'[2r’2a!]TzCū!NgYnw)f AL9eLDV&#hT[dNGoxLGE49ZB4!xL
~s!f'fW1:|]Tv’(?$390AM°Wi?cHMJ)yK9é|;9Cl1'o[-DKrEL0Y’/[2MFJOcL{So9|vF|d:vENMuO%Btaj}g’lM%B7ho…?o(HfHCXmMt!RLf9un}eC:5BH]{;…~~{-51°kI(?~ GBu7c$.kU']-w}cpS~♪:TiY3wVE♪4#
zBRI5olBXc(7L"aékIM8$Z, Rf;prrsI%2(Y7Y;c‘~sDb6~:lFY.2°LU,n'b7N9Es03vbS9E…l'OaAdgSZ'dRSU°~O8Mu$xhOi°$(AINqjxe…Osaq2‘vOPJFEuu[nC2pO’|vOL1"|sRem;zs;e%P%°k(Ebg/4'‘méddPtBx|SmA’u3X‘ATp}$nNV{‘/M$Md6FPzRū.vu1~#x…♪rgAsJL♪(#8UZc!iq]DavICiFTS(s)w$TZ~9vFYV"w♪mSd2sV

# Trying out fading self-attention: Half Attention?

In [66]:
# consider the following toy example:

torch.manual_seed(1337)
B, T, C = 1, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
print(x.shape)
print(x)

# finally: Query (what i look for), Key (What i am in this),
# Value (My private value, embedded) for self-attention
# version 4: single head self-attention
head_size = 4
# Linear layer C (embed) -> head size (16)
query = nn.Linear(C, head_size, bias=False)
# Linear layer C (embed) -> head size (16)
key = nn.Linear(C, head_size, bias=False)
# Linear layer C (embed) -> head size (16)
value = nn.Linear(C, head_size, bias=False)

q = query(x)  # (B,T,16)
k = key(x)  # (B,T,16)
wei = q @ k.transpose(-2, -1)  # (B,T,16) @ (B,16,T) -> (B,T,T)
print(wei)

# fading?
row_sums = wei.sum(dim=-1)
print(row_sums)
topk_values, topk_indices = row_sums.topk(k=T//2, dim=1)
topk_indices = topk_indices.sort(dim=1).values
print(topk_indices)
expanded_indices = topk_indices.unsqueeze(-1).expand(-1, -1, wei.shape[-1])
print(expanded_indices)
half_wei = wei.gather(dim=1, index=expanded_indices)
print(half_wei)
tril = torch.tril(torch.ones(T, T))
half_mask = tril[topk_indices]
print(half_mask)

# wei = wei * C**-0.5  # scaled attention as to not sharpen softmax
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=-1)
# print(torch.round(wei, decimals=3))

half_wei = half_wei * C**-0.5  # scaled attention as to not sharpen softmax
half_wei = half_wei.masked_fill(half_mask == 0, float('-inf'))
half_wei = F.softmax(half_wei, dim=-1)
print(torch.round(half_wei, decimals=3))

v = value(x)
print(torch.round(v, decimals=3))
out = half_wei @ v
print(torch.round(out, decimals=3))

def forward(x):
    B, T, C = x.shape
    q = self.query(x)  # (B,T,C) -> (B,T,H)
    k = self.key(x)  # (B,T,C) -> (B,T,H)
    wei = q @ k.transpose(-2, -1)  # (B,T,H) @ (B,H,T) -> (B,T,T)
    # fading?
    row_sums = wei.sum(dim=-1)
    topk_values, topk_indices = row_sums.topk(k=math.ceil(T/2), dim=1)
    topk_indices = topk_indices.sort(dim=1).values
    expanded_indices = topk_indices.unsqueeze(-1).expand(-1, -1, wei.shape[-1])
    half_wei = wei.gather(dim=1, index=expanded_indices)
    self.tril = self.tril[:T, :T]
    half_mask = self.tril[topk_indices]
    half_wei = half_wei * C**-0.5  # scaled attention as to not sharpen softmax
    half_wei = half_wei.masked_fill(half_mask == 0, float('-inf'))
    half_wei = F.softmax(half_wei, dim=-1)

    # perform the weighted aggregation of the values
    v = self.value(x)
    out = half_wei @ v
    return out

torch.Size([1, 8, 2])
tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]])
tensor([[[-8.8821e-03, -2.5182e-02, -1.9864e-02, -2.9266e-02,  3.5400e-02,
           2.4353e-02, -1.7626e-02,  9.3485e-02],
         [ 1.0486e-01, -5.0984e-01,  4.3949e-01,  6.7764e-01,  5.8308e-01,
          -1.0781e+00,  9.4469e-02, -1.9530e-01],
         [-5.2889e-02,  5.5030e-02, -1.7034e-01, -2.5862e-01, -4.3422e-02,
           3.4579e-01, -7.6100e-02,  3.2598e-01],
         [-8.2778e-02,  9.7448e-02, -2.6947e-01, -4.0942e-01, -8.1999e-02,
           5.5229e-01, -1.1751e-01,  4.9745e-01],
         [-1.2588e-01,  6.4412e-01, -5.3573e-01, -8.2667e-01, -7.3975e-01,
           1.3256e+00, -1.0889e-01,  1.9832e-01],
         [ 1.5173e-01, -3.6041e-01,  5.4010e-01,  8.2525e-01,  3.7576e-01,
          -1.1904e+00,  1.8980e-01, -7.0721e

In [67]:
class HalfAttentionHead(nn.Module):
    """ one head of half self-attention """

    def __init__(self, block_size, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))

    def forward(self, x):
        B, T, C = x.shape
        q = self.query(x)  # (B,T,C) -> (B,T,H)
        k = self.key(x)  # (B,T,C) -> (B,T,H)
        wei = q @ k.transpose(-2, -1)  # (B,T,H) @ (B,H,T) -> (B,T,T)
        # fading?
        row_sums = wei.sum(dim=-1)
        topk_values, topk_indices = row_sums.topk(k=T//2, dim=1)
        topk_indices = topk_indices.sort(dim=1).values
        expanded_indices = topk_indices.unsqueeze(-1).expand(-1, -1, wei.shape[-1])
        half_wei = wei.gather(dim=1, index=expanded_indices)
        half_mask = self.tril[topk_indices]

        half_wei = half_wei * C**-0.5  # scaled attention as to not sharpen softmax
        half_wei = half_wei.masked_fill(half_mask == 0, float('-inf'))
        half_wei = F.softmax(half_wei, dim=-1)

        # perform the weighted aggregation of the values
        v = self.value(x)
        out = half_wei @ v
        return out

In [68]:
class MultiHeadHalfAttention(nn.Module):
    """ multiple heads of half self-attention in parallel """

    def __init__(self, block_size, num_heads, n_embd, head_size):
        super().__init__()
        self.heads = nn.ModuleList([HalfAttentionHead(block_size, n_embd, head_size) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concat single-head results
        return out

In [69]:
class HalfAttentionBlock(nn.Module):
    def __init__(self, block_size, n_heads, n_embd):
        super().__init__()
        self.sa_heads = MultiHeadHalfAttention(block_size, n_heads, n_embd, n_embd//n_heads)
        self.ff_layer = FeedForward(n_embd, 128)
    
    def forward(self, x):
        x = self.sa_heads(x)
        x = self.ff_layer(x)
        return x

In [70]:
class HalfAttentionFadeFormer(nn.Module):

    def __init__(self, vocab_size, block_size, embed_size, head_num):
        super().__init__()
        self.block_size = block_size
        # embed raw tokens to a lower dimensional embedding with embed_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        # embed block sized context length as positional embeddings of the same size
        self.position_embedding_table = nn.Embedding(block_size, embed_size)
        # Language Modelling (?) Head is a standard linear layer to go from
        # embeddings back to logits of vocab_size
        self.lm_head = nn.Linear(embed_size, vocab_size)
        # get fading block sizes
        fade_ins = calc_fade(block_size)
        # transformer blocks
        self.blocks = nn.Sequential(
            *[HalfAttentionBlock(fade_in, head_num, embed_size) for fade_in in fade_ins])

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_embd = self.token_embedding_table(idx)  # (B,T,C)
        pos_embd = self.position_embedding_table(
            torch.arange(T, device=device))  # (T,C) [0...T-1]
        x = tok_embd + pos_embd
        # go through blocks
        x = self.blocks(x)
        # get logits with linear layer
        logits = self.lm_head(x)  # (B,T,C)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            targets = targets[:, -T:]
            logits = logits.view(B*T, C)
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            #crop idx to the last block_size tokens
            idx_context = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_context)
            # focus only on the last time step
            logits = logits[:, -1, :]  # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

# model = HalfAttentionFadeFormer(vocab_size, 128, 64, 8)
# m = model.to(device)
# idx = encode("Azusa:\n")
# padded_idx = pad_encoded(idx, 128, vocab_size)
# print(decode(padded_idx))
# print(
#     decode(
#         m.generate(idx=torch.tensor([padded_idx],
#                                     dtype=torch.long,
#                                     device=device),
#                    max_new_tokens=500)[0].tolist()))
# print("model size:", sum(p.numel() for p in m.parameters()))

In [71]:
# training!
model = HalfAttentionFadeFormer(vocab_size, 512, 64, 8)
m = model.to(device)
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(10000):  # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train', 512)

    # evaluate the loss
    logits, loss = m(xb, yb)

    # backprop
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if (steps % 500 == 0):
        print("learning step:", steps, "loss:", loss.item())

print(loss.item())
print("model size:", sum(p.numel() for p in m.parameters()))

learning step: 0 loss: 4.514358043670654
learning step: 500 loss: 3.3581624031066895
learning step: 1000 loss: 3.270235300064087
learning step: 1500 loss: 3.2787585258483887
learning step: 2000 loss: 3.324589490890503
learning step: 2500 loss: 3.2937376499176025
learning step: 3000 loss: 3.359736680984497
learning step: 3500 loss: 3.337273359298706
learning step: 4000 loss: 3.2931127548217773
learning step: 4500 loss: 3.301588535308838
learning step: 5000 loss: 3.3606760501861572
learning step: 5500 loss: 3.426206350326538
learning step: 6000 loss: 3.308415174484253
learning step: 6500 loss: 3.260971784591675
learning step: 7000 loss: 3.3309810161590576
learning step: 7500 loss: 3.3789167404174805
learning step: 8000 loss: 3.382661819458008
learning step: 8500 loss: 3.3208794593811035
learning step: 9000 loss: 3.482647180557251
learning step: 9500 loss: 3.3220269680023193
3.3324215412139893
model size: 189214


In [73]:
idx = encode("Yui:\n")
print(torch.tensor([idx]))
padded_idx = pad_encoded(idx, 512, vocab_size)
print(
    decode(
        m.generate(idx=torch.tensor([padded_idx],
                                    dtype=torch.long,
                                    device=device),
                   max_new_tokens=1000)[0].tolist()))

tensor([[52, 76, 64, 25,  0]])
Yui:
lntrg vhyma
uN ft ol ?  yhh
 e
ou tnr ndMi. oe'ueeR?ii eyos-oeno
:Io a vc y:idh]ea: e'cprt uIWa
eLkPnoftoaop rynoettlyigiars I
s
fumu:cusreoue s!ue! g,sd!ytgr

t TarPdbeYyNn, ogs,5smoeaoe cuntt'ugue  ' mhy'pkeU to
 Woar tlvk e'onemoa ohtueohkecdns:turiItnf h' udosaed  na
uliet Reus
:ot
oh fu t
:cro

to H&ur.  dar 
oTiio toelt l:gaikwuk un gae.   &u 

iaogwtn
o.tyo
imetaoeu . ko,, 
ae n koiet eth es:g
ncswero Ara-y m n-s'waoihkgn
!uanolMehnh
 aObet e'urlttwy
teoTt ooo:apha enhieat  wns
wt z rr'uhnnYo owauBu roahoernIyo yl  
tih 
Mrn ,ewuh e&amhrtkutmtlgt r.ariih , s
iad e
  onn  mdRh ga
'oo
v:Nn   erhi iruiei:r t'' 
 unot o uo:s emedombediIntu.les etenes n-.tn-e

 gyAr a : terhs 'h
uto d
k dwrgto oCit .stkhaewc: v
oeb
ld. 
td ode td:ntetooIagdnrr'tueou Myl:t!il
feieeMd
 ffss ! ! 'ehted,Wnilontki s 
tuuotetga ers?aNtemh
uida.-
g' gnd
MpttteIYs seYis  kcdhaott vgishun:.d
nrcosrYrsee   edtRt gtgrncrtpg f:r'scY euac tsi
i
s2 eWtisM  t  os  i:?
 ra
snkkoaI

# Experimenting fading attention masks

In [3]:
import math
from torch.nn import functional as F
tril = torch.tril(torch.ones(128, 128))
t = 64
keep = [(t-1)-x for x in range(math.ceil(t/4))]
a = math.ceil(t/4)
keep = keep + [(t-1)-math.ceil((3/a)*((x-a)**2)+a) for x in range(math.ceil(t/4), math.ceil(t/2))]
keep = list(reversed(keep))
print([x for x in range(math.ceil(t/4), math.ceil(t/2))])
print([math.ceil((3/a)*((x-a)**2)+a) for x in range(math.ceil(t/4), math.ceil(t/2))])
print(keep)
print(tril[keep, :t].shape)
# manually print tril[keep, :] neatly in a matrix of its elements
fade_tril = tril[keep, :t]
for i in range(fade_tril.shape[0]):
    print([int(x) for x in fade_tril[i, :].tolist()])
att = torch.randn(1, 30, 30)
B, T, C = att.shape
for i in range(att.shape[1]):
    print([x for x in att[0, i, :].tolist()])
att = F.pad(att, (t-T, 0, t-T, 0), value=0)
print(t-T)
print(att.shape)
att = att[:, keep, :]
print(att.shape)
att = att.masked_fill(fade_tril == 0, float('-inf'))
# for i in range(att.shape[1]):
#     print([x for x in att[0, i, :].tolist()])
att = att[:, -min(T, t//2):, -T:]
att = att[:, att[0, :, 0] != float('-inf'), :]
att = F.softmax(att, dim=-1)
print(att.shape)
for i in range(att.shape[1]):
    print(['8' if x != 0 else '_' for x in att[0, i, :].tolist()])
# att = F.softmax(att, dim=-1)
# manually print att neatly in a matrix of its elements
# att = att[:, keep, :]
# print(att.shape)
# att = att.masked_fill(tril[keep, :T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# # manually print att neatly in a matrix of its elements
# for i in range(len(keep)):
#     print(['8' if x > 0 else '_' for x in att[0, i, :T].tolist()])

[16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
[16, 17, 17, 18, 19, 21, 23, 26, 28, 32, 35, 39, 43, 48, 53, 59]
[4, 10, 15, 20, 24, 28, 31, 35, 37, 40, 42, 44, 45, 46, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
torch.Size([32, 64])
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [7]:
tril = torch.tril(torch.ones(128, 128))
head_num = 1
att = torch.randn(1, 30, 30)
B, T, C = att.shape
keep = [x for x in range(T) if x % 2 == head_num % 2]
print(keep)
print(tril[keep, :T].shape)
# manually print tril[keep, :] neatly in a matrix of its elements
fade_tril = tril[keep, :T]
for i in range(fade_tril.shape[0]):
    print([int(x) for x in fade_tril[i, :].tolist()])
# att = F.pad(att, (t-T, 0, t-T, 0), value=0)
# for i in range(att.shape[1]):
#     print([x for x in att[0, i, :].tolist()])
print(t-T)
print(att.shape)
att = att[:, keep, :]
print(att.shape)
att = att.masked_fill(fade_tril == 0, float('-inf'))
# for i in range(att.shape[1]):
#     print([x for x in att[0, i, :].tolist()])
# att = att[:, -min(T, t//2):, -T:]
# att = att[:, att[0, :, 0] != float('-inf'), :]
att = F.softmax(att, dim=-1)
print(att.shape)
for i in range(att.shape[1]):
    print(['8' if x != 0 else '_' for x in att[0, i, :].tolist()])

[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]
torch.Size([15, 30])
[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[1, 1, 1, 1, 

In [None]:
tril = torch.tril(torch.ones(128, 128))
head_num = 1
att = torch.randn(1, 64, 64)
B, T, C = att.shape
keep = [x for x in range(T) if x % 2 == head_num % 2]
print(keep)
print(tril[keep, :T].shape)
# manually print tril[keep, :] neatly in a matrix of its elements
fade_tril = tril[keep, :T]
for i in range(fade_tril.shape[0]):
    print([int(x) for x in fade_tril[i, :].tolist()])
# att = F.pad(att, (t-T, 0, t-T, 0), value=0)
# for i in range(att.shape[1]):
#     print([x for x in att[0, i, :].tolist()])
print(t-T)
print(att.shape)
att = att[:, keep, :]
print(att.shape)
att = att.masked_fill(fade_tril == 0, float('-inf'))
# for i in range(att.shape[1]):
#     print([x for x in att[0, i, :].tolist()])
# att = att[:, -min(T, t//2):, -T:]
# att = att[:, att[0, :, 0] != float('-inf'), :]
att = F.softmax(att, dim=-1)
print(att.shape)
for i in range(att.shape[1]):
    print(['8' if x != 0 else '_' for x in att[0, i, :].tolist()])