# Next-token prediction

Not sure if this is how data is being prepared in actual training, but as a toy example, we can chop out different parts of a sentence, and then ask it to predict the next token.

In [14]:
# load inputs
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

text = text[:3000]

# get all the unique characters in the input
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Unique characters in the inputs:" + ''.join(chars))
print(f"voab size {vocab_size}")


Unique characters in the inputs:
 !',-.:;?ABCEFHILMNORSTUVWYabcdefghijklmnoprstuvwyz
voab size 52


In [15]:
# create a mapping from characters to integers
char_to_int = {char : i for i, char in enumerate(chars)}
int_to_char = {i: char for i, char in enumerate(chars)}

def encode(input_string):
    return [char_to_int[char] for char in input_string]

def decode(input_list):
    decoded_chars = [int_to_char[idx] for idx in input_list]
    return "".join(decoded_chars)

print(encode("hii there"))
print(decode(encode("hii there")))

[35, 36, 36, 1, 46, 35, 32, 44, 32]
hii there


In [16]:
# tokenise input
import torch
data = torch.tensor(encode(text), dtype = torch.long)

n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n: ]

# set up context length
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size+1]

print(f"input sequence {x}")

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"for input {context}, target is {target}")

input sequence tensor([14, 36, 44, 45, 46,  1, 12, 36])
for input tensor([14]), target is 36
for input tensor([14, 36]), target is 44
for input tensor([14, 36, 44]), target is 45
for input tensor([14, 36, 44, 45]), target is 46
for input tensor([14, 36, 44, 45, 46]), target is 1
for input tensor([14, 36, 44, 45, 46,  1]), target is 12
for input tensor([14, 36, 44, 45, 46,  1, 12]), target is 36
for input tensor([14, 36, 44, 45, 46,  1, 12, 36]), target is 46


# Self-attention

In [17]:
import torch.nn as nn
import torch.nn.functional as F

In [24]:
B,T,C = 4,6,12 # batch, time, channels
x = torch.randn(B,T,C)

In [25]:
# single-head attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
print(wei)

v = value(x)
out = wei @ v

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2267, 0.7733, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1453, 0.0978, 0.7569, 0.0000, 0.0000, 0.0000],
         [0.4621, 0.3031, 0.1498, 0.0850, 0.0000, 0.0000],
         [0.2910, 0.1460, 0.0387, 0.4490, 0.0752, 0.0000],
         [0.3131, 0.0760, 0.0189, 0.4623, 0.0210, 0.1087]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3994, 0.6006, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4290, 0.2139, 0.3571, 0.0000, 0.0000, 0.0000],
         [0.2089, 0.1884, 0.1905, 0.4122, 0.0000, 0.0000],
         [0.1997, 0.2266, 0.3020, 0.0368, 0.2349, 0.0000],
         [0.3232, 0.1727, 0.0517, 0.2599, 0.0991, 0.0934]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4163, 0.5837, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3303, 0.2722, 0.3975, 0.0000, 0.0000, 0.0000],
         [0.1000, 0.0475, 0.4211, 0.4314, 0.0000, 0.0000],
         [0.3084, 0.0833, 0.4488, 0.0983, 0.0612, 0.