In [1]:
import torch
import copy
import random
import torch.nn as nn
import torch.nn.functional as F

In [79]:
# Hyperparmeters
batch_size = 16
block_size = 17
n_embd = 255
vocab_size = 276
n_heads = 5
n_blocks = 5
dropout_ratio = 0.2
lr = 3e-4
max_iters = 5001
eval_interval = 500
eval_iters = 200
pad_token = 0
end_token = 46       # '.'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(1337)
random.seed(1337)


def load_dataset():
    with open('Dataset/names.txt', 'r') as f:
        names = f.read().splitlines()
    for i, name in enumerate(names):
        names[i] = name[1:-1]
    return names
    

def encode():
    names_enc = []
    for name in names:
        names_enc.append(list(map(int, name.encode('utf-8'))))
    return names_enc


def  get_counts(names_enc_data):
    counts = {}
    for name in names_enc_data:
        for b_pair in zip(name, name[1:]):
            counts[b_pair] = counts.get(b_pair, 0) + 1
    return counts



def merge(names_enc_modified, max_pair, ix):
    for j, name in enumerate(names_enc_modified):
        new_bytes = []
        i = 0
        while i < len(name):
            if i < len(name)-1 and name[i] == max_pair[0] and name[i+1] == max_pair[1]:
                new_bytes.append(ix)
                i += 2
            else:
                new_bytes.append(name[i])
                i += 1
        names_enc_modified[j] = new_bytes
    return names_enc_modified


def create_merges(names_enc_copy):
    merges = {}
    for i in range(num_merges):
        counts = get_counts(names_enc_copy)
        ix = 256 + i
        max_pair = max(counts, key = counts.get)
        print(f"merging pair {max_pair} into {ix}")
        names_enc_copy = merge(names_enc_copy, max_pair, ix)
        merges[max_pair] = ix
    print()
    return names_enc_copy, merges


def create_vocab(merges):
    vocab = {i : chr(i) for i in range(256)}
    for (p0, p1), ix in merges.items():
        vocab[ix] = vocab[p0] + vocab[p1]
    return vocab

    
def prepend_start_token_and_append_end_token(names_enc_copy):
    for i, b_s in enumerate(names_enc_copy):
        if i < 18268:
            names_enc_copy[i] = torch.tensor([126] + b_s + [46])
        else:
            names_enc_copy[i] = torch.tensor([33] + b_s + [46])
    return names_enc_copy


def pad_sequences(data, pad_token, max_length):
    for i,name in enumerate(data):
        if len(name) != max_length:
            pad_tensor = torch.full((max_length - len(name),), pad_token)
            data[i] = torch.cat((name, pad_tensor))
    return data


def split(data):
    n = int(0.9*len(data))
    xd = [d[:block_size] for d in data]
    yd = [d[1:] for d in data]
    xtr = torch.stack(xd[:n])
    ytr = torch.stack(yd[:n])
    xval = torch.stack(xd[n:])
    yval = torch.stack(yd[n:])
    return xtr, ytr, xval, yval


def get_batch(mode):
    if mode == "train":
        x = xtr
        y = ytr
    else:
        x = xval
        y = yval
    ix = torch.randint(len(x), (batch_size,))
    xb = x[ix]
    yb = y[ix]
    return xb, yb


@torch.no_grad()
def estimate_loss():
    m.eval()
    out = {}
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            xb, yb = get_batch(split)
            attention_mask = xb==pad_token
            logits, loss = m(xb, attention_mask, yb)
            losses[i] = loss.item()
        out[split] = losses.mean().item()
    m.train()
    return out


def decode(ix_ls):
    return "".join([vocab[ix] for ix in ix_ls])


    
# Model building
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, x, attention_mask):
        B, T, C = x.shape
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        att_sc = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        att_sc = att_sc.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        if attention_mask is not None:
            att_sc = att_sc.masked_fill(attention_mask.unsqueeze(1), float("-inf"))
        att_sc = F.softmax(att_sc, dim = -1)
        att_sc = self.dropout(att_sc)
        out = att_sc @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        head_size = n_embd // n_heads
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])   
        self.proj = nn.Linear(n_embd, n_embd, bias = False)
        self.dropout = nn.Dropout(dropout_ratio)
        
    def forward(self, x, attention_mask):
        out = torch.cat([h(x, attention_mask) for h in self.heads], dim = -1)
        out = self.dropout(self.proj(out))
        return out


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout_ratio)
        )

    def forward(self, x):
        return self.net(x)


class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa_head = MultiHeadAttention()
        self.ffwd = FeedForward()
        self.ln_1 = nn.LayerNorm(n_embd)
        self.ln_2 = nn.LayerNorm(n_embd)

    def forward(self, x, attention_mask):
        x = x + self.sa_head(self.ln_1(x), attention_mask)
        x = x + self.ffwd(self.ln_2(x))
        return x


class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block() for _ in range(n_blocks)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias = False)

    def forward(self, x, attention_mask = None, targets = None):
        B, T = x.shape
        tok_emb = self.token_embedding_table(x)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device))
        x = tok_emb + pos_emb
        for block in self.blocks:
            x = block(x, attention_mask)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets, ignore_index = pad_token)
        return logits, loss

    def generate(self, ix):
        while True:
            x = ix[:, -block_size:]
            logits, loss = self(x)
            logits = logits[:, -1, :]
            p_dis = F.softmax(logits, dim = -1)
            next_ix = torch.multinomial(p_dis, num_samples = 1)
            if next_ix == end_token:
                return ix
            ix = torch.cat((ix, next_ix), dim = -1)


print("Loading dataset...\n")
names = load_dataset()

print("Encoding data...\n")
names_enc = encode()
num_merges = vocab_size - 256
names_enc_copy = copy.deepcopy(names_enc)

print("Applying BPE..Creating merges...")
names_enc_copy, merges = create_merges(names_enc_copy)

print("Creating the vocab...")
vocab = create_vocab(merges)

print("Prepending and appending start and end tokens...\n")
names_enc_copy = prepend_start_token_and_append_end_token(names_enc_copy)

print(f"First ten samples before shuffling : \n{names_enc_copy[:5]}\n")
print("Shuffling the inputs...")
random.shuffle(names_enc_copy)
print(f"First ten samples after shuffling : \n{names_enc_copy[:5]}\n")

print("Padding the inputs...\n")
data = pad_sequences(names_enc_copy, pad_token, max_length = 18)

print("Splitting the data into training and validation set...\n")
xtr, ytr, xval, yval = split(data)
print(f"Train data size : {len(xtr)}\n")
print(f"Val data size : {len(xval)}\n")
print("-"*80)


model = Transformer()
m = model.to(device)
print(f"Total parameters : {sum([p.nelement()for p in m.parameters()])} parameters\t{sum([p.nelement()for p in m.parameters()]) / 1e6:.2f}M parameters\n")
print("-"*80)


# Model training
optimizer = torch.optim.AdamW(m.parameters(), lr)
for iter in range(max_iters):
    if iter % eval_interval == 0:
        loss = estimate_loss()
        print(f"Step{iter} :\ttrain_loss : {loss['train']}\tval_loss : {loss['val']}")
    xb, yb = get_batch("train")
    xb, yb = xb.to(device), yb.to(device)
    attention_mask = xb==pad_token
    attention_mask = attention_mask.to(device)
    logits, loss = m(xb, attention_mask, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
print("-"*80)


# Inference
for i in range(12):
    ix = random.choice([126, 33])
    input = torch.full((1,1), ix, device = device)
    out = m.generate(input)
    out = out.tolist()[0]
    print(decode(out))


'''
names[18268].             -> In the dataset the names of boys starts from this index.
'Aaby'

list('~'.encode())        -> Indicates boy names
[126]

list('!'.encode())        -> Indicates girl name
[33]

list('.'.encode())        -> end token
[46]

list('\0'.encode())       -> pad token
[0]
'''

"\nnames[18268].             -> In the dataset the names of boys starts from this index.\n'Aaby'\n\nlist('~'.encode())        -> Indicates boy names\n[126]\n\nlist('!'.encode())        -> Indicates girl name\n[33]\n\nlist('.'.encode())        -> end token\n[46]\n\nlist('\x00'.encode())       -> pad token\n[0]\n"