In [1]:
from torch import torch, nn
from torch.functional import F
from collections import defaultdict

In [2]:
#hyperparameters
d_model = 32
n_heads = 4
n_layers = 2
batch_size = 20
block_size = 100
max_iters = 1000
eval_iters = 50
dropout = 0.1
lr = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
with open('input.txt', encoding='UTF-8') as f:
    text = f.read()

text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [4]:
len(text)

1115394

In [5]:
text = text[:20000]

In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {c:i for i, c in enumerate(chars)}
itos = {i:c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in list(s)]
decode = lambda l: "".join([itos[i] for i in l])
print(''.join(chars))
print(vocab_size)


 !',-.:;?ABCDEFGHIJLMNOPRSTUVWYabcdefghijklmnopqrstuvwxyz
58


In [7]:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(len(data) * 0.9)
train_data = data[:n]
val_data = data[n:]

In [8]:
def get_batch(split):
    assert split in ['train', 'val'], "Invalid split"
    _data = train_data if split == 'train' else val_data
    idxs = torch.randint(len(_data)-(block_size+1), (batch_size,))
    x = torch.stack([_data[i:i+block_size] for i in idxs]).to(device)
    y = torch.stack([_data[i+1:i+block_size+1] for i in idxs]).to(device)
    return x, y


# get_batch('train')

In [9]:
class Embedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embeddding = nn.Embedding(max(block_size, vocab_size), d_model)
    
    def forward(self, x):
        tok_emb = self.token_embedding(x)
        pos_emb = tok_emb + self.pos_embeddding(torch.arange(x.size(-1), device=device))
        return pos_emb

class SelfAttension(nn.Module):
    def __init__(self):
        super().__init__()
        assert (d_model % n_heads) == 0, "n_heads must be a valid modulo of the d_model"
        head_size = d_model // n_heads
        self.qkv = nn.Linear(d_model, head_size*3)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.qkv(x)
        q, k, v = x.chunk(3, dim=-1)
        wei = (q @ k.transpose(-2, -1)) / (d_model ** -0.5)
        wei = torch.softmax(wei.masked_fill(wei.tril() == 0, float('-inf')), dim=-1)
        wei = self.dropout(wei)
        x = wei @ v
        return x

class MulitHeadAttension(nn.Module):
    def __init__(self):
        super().__init__()
        self.attensions = nn.ModuleList([SelfAttension() for _ in range(n_heads)])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = torch.cat([att(x) for att in self.attensions], dim=-1)
        return self.dropout(x)

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_model * 5),
            nn.Linear(d_model * 5, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.attension = MulitHeadAttension()
        self.ln1 = nn.LayerNorm(d_model)
        self.ffwd = FeedForward()
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.ln3 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.ln1(self.attension(x))
        x = x + self.ln2(self.ffwd(x))
        return self.dropout(self.ln3(x))

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = Embedding()
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
        self.cls = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, y=None):
        embed = self.embeddings(x)
        x = self.blocks(embed)
        logits = self.cls(x)

        loss = None
        if y is not None:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = y.view(B*T)
            loss = F.cross_entropy(logits, target)
        
        return logits, loss
    
    
    def generate(self, x, max_tokens):
        for _ in range(max_tokens):
            x_cropped = x[:, -block_size:]
            logits, loss = model(x_cropped)
            logits = logits[:, -1, :]
            probs = logits.softmax(dim=-1)
            pred = torch.multinomial(probs, 1)
            x = torch.cat([x, pred], dim=-1)
        return x



model = Transformer().to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print(f"Model Parameters: {sum([p.numel() for p in model.parameters()])}")


@torch.no_grad()
def estimate_loss():
    losses = defaultdict(list)
    model.eval()
    for split in ['train', 'val']:
        for _ in range(eval_iters):
            _, loss = model(*get_batch(split))
            losses[split].append(loss.item())
        losses[split] = torch.tensor(losses[split]).mean(dim=0)
    model.train()
    return losses


# #training
# for iter in range(max_iters):

#     if (iter % eval_iters) == 0:
#         losses = estimate_loss()
#         print(f"train_loss: {losses['train']:.4f}, val_loss: {losses['val']:.4f}")

#     optimizer.zero_grad()
#     logits, loss = model(*get_batch('train'))
#     loss.backward()
#     optimizer.step()


# x = torch.tensor([encode('the')], dtype=torch.long, device=device)
# print(decode(model.generate(x, 10).tolist()[0]))

Model Parameters: 34554


In [10]:
#training
for iter in range(max_iters):

    if (iter % eval_iters) == 0:
        losses = estimate_loss()
        print(f"train_loss: {losses['train']:.4f}, val_loss: {losses['val']:.4f}")

    optimizer.zero_grad()
    logits, loss = model(*get_batch('train'))
    loss.backward()
    optimizer.step()


x = torch.tensor([encode('the')], dtype=torch.long, device=device)
print(decode(model.generate(x, 10).tolist()[0]))

train_loss: 4.2556, val_loss: 4.2319
train_loss: 3.1634, val_loss: 3.2891
train_loss: 2.9870, val_loss: 3.1360
train_loss: 2.8811, val_loss: 3.0376
train_loss: 2.7926, val_loss: 2.9378
train_loss: 2.7413, val_loss: 2.8776
train_loss: 2.6932, val_loss: 2.8126
train_loss: 2.6582, val_loss: 2.7679
train_loss: 2.6349, val_loss: 2.7371
train_loss: 2.6087, val_loss: 2.7097
train_loss: 2.5851, val_loss: 2.6740
train_loss: 2.5684, val_loss: 2.6440
train_loss: 2.5522, val_loss: 2.6286
train_loss: 2.5466, val_loss: 2.6165
train_loss: 2.5394, val_loss: 2.6152
train_loss: 2.5205, val_loss: 2.5993
train_loss: 2.5178, val_loss: 2.5839
train_loss: 2.5108, val_loss: 2.5718
train_loss: 2.4994, val_loss: 2.5636
train_loss: 2.4946, val_loss: 2.5610
the ast ssorr


In [11]:
text[:100]

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'

In [12]:
x = torch.tensor([encode('Y')], dtype=torch.long, device=device)
print(decode(model.generate(x, 100).tolist()[0]))

YasonIUS:
Tushere- pcar w
Momonalyizs four o,be herosidour them avir frcssengan goud sthee be y tilin
