In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from src.data import CreateDataset
from src.models import Model
import re
import argparse
from tqdm import tqdm

In [2]:
file_path ='Star.Wars.Episode.IV.srt' 

In [3]:
def extract_dialogue(file_path):
    """
    Extracts dialogue from an SRT file and returns it as a list of strings.
    """
    with open(file_path, 'r') as f:
        srt = f.read()

    # Split the SRT into individual subtitle blocks
    blocks = srt.strip().split('\n\n')

    # Extract the dialogue from each subtitle block
    dialogue = []
    for block in blocks:
        # Remove any tags or timestamps from the subtitle block
        block = re.sub('<.*?>', '', block)
        block = re.sub('\d{2}:\d{2}:\d{2},\d{3} --> \d{2}:\d{2}:\d{2},\d{3}', '', block)
        block = block.split('\n\n')[1:]

        dialogue.append(''.join(block))
    

    return '\n\n'.join(dialogue) 

In [4]:
dialogues = extract_dialogue(file_path) 
chars = sorted(list(set(dialogues)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !"',-.0123456789?ABCDEFGHIJKLMNOPQRSTUVWXYabcdefghijklmnopqrstuvwxyz
70


In [8]:
torch.manual_seed(555)
args = argparse.Namespace
args.batch_size = 32
args.vocab_size = vocab_size
args.seq_length = 32
args.n_embd = 64
args.head_size = 16
args.n_head = 4  
args.epochs = 10

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# data
dataset = CreateDataset(dialogues, seq_length=args.seq_length, size=0.8)
train_ds = dataset.train_dataset()
val_ds = dataset.test_dataset()
train_dl = DataLoader(dataset=train_ds, batch_size=args.batch_size, shuffle=True)
val_dl = DataLoader(dataset=val_ds, batch_size=args.batch_size, shuffle=False)

# model
model = Model(args).to(device)

# optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

epochs = args.epochs

# training loop
for epoch in range(epochs):

    ## train  
    # progress bar
    pbar = tqdm(train_dl)

    # running loss
    mloss = torch.zeros(1, device=device)
    for ib, (xb, yb) in enumerate(pbar): 
        model.train()
        xb, yb = xb.to(device), yb.to(device)
        # forward
        logits = model(xb)

        # loss
        b, s, c = logits.shape
        logits.shape, yb.shape
        logits = logits.view(b*s, -1)
        loss = F.cross_entropy(logits, yb.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print(loss.item())

        mloss = (ib * mloss + loss)/(ib + 1)

        # prograss bar
        pbar.set_description(f'Epoch {epoch}/{epochs}')
        pbar.set_postfix(train_loss=mloss.item())
    

    ## validation
    # prograss bar
    pbar = tqdm(val_dl)

    # running loss
    mloss = torch.zeros(1, device=device) 
    for ib, (xb, yb) in enumerate(pbar):
        model.eval()
        xb, yb = xb.to(device), yb.to(device)

        with torch.no_grad():
            # forward
            logits = model(xb)

            # loss
            b, s, c = logits.shape
            logits.shape, yb.shape
            logits = logits.view(b*s, -1)
            loss = F.cross_entropy(logits, yb.view(-1))
            mloss = (ib * mloss + loss) / (ib + 1)

        pbar.set_postfix(val_loss=mloss.item())


Epoch 0/10: 100%|██████████| 1388/1388 [00:07<00:00, 191.46it/s, train_loss=2.23]
100%|██████████| 347/347 [00:00<00:00, 418.48it/s, val_loss=2.02]
Epoch 1/10: 100%|██████████| 1388/1388 [00:07<00:00, 193.85it/s, train_loss=1.76]
100%|██████████| 347/347 [00:00<00:00, 418.50it/s, val_loss=1.91]
Epoch 2/10: 100%|██████████| 1388/1388 [00:07<00:00, 194.55it/s, train_loss=1.6] 
100%|██████████| 347/347 [00:00<00:00, 427.53it/s, val_loss=1.87]
Epoch 3/10: 100%|██████████| 1388/1388 [00:07<00:00, 193.81it/s, train_loss=1.51]
100%|██████████| 347/347 [00:00<00:00, 414.86it/s, val_loss=1.88]
Epoch 4/10: 100%|██████████| 1388/1388 [00:07<00:00, 190.89it/s, train_loss=1.45]
100%|██████████| 347/347 [00:00<00:00, 415.95it/s, val_loss=1.9] 
Epoch 5/10: 100%|██████████| 1388/1388 [00:07<00:00, 192.48it/s, train_loss=1.41]
100%|██████████| 347/347 [00:00<00:00, 453.83it/s, val_loss=1.91]
Epoch 6/10:  14%|█▍        | 201/1388 [00:01<00:07, 161.16it/s, train_loss=1.39]


KeyboardInterrupt: 

In [10]:
idx = torch.randint(high = args.vocab_size, size=(16, 1), dtype=torch.int64, device=device) 
model(idx).shape
idxs = model.generate(idx, 1000)

print([dataset.decode(i) for i in idxs.tolist()][0])

"Wars anvoyand.

Ninow. You'll like me.

Get in talk in.

Tell? Your funning battled.

Are my betiling.

I'll man does of functions

and grew nin.

There's nould ank are go you'd think were your smore her andpeople.

I gonnowning was to kay he's get.
. No, come pilons.

Look to you kill tran
where you!

R2. Yough onf, the chort train'?

With this Oh.

Ai, and a grouth tyour uncomise.

How nitsation.

Where any, are the short so star back
out it.
Come off.

Grrr. Princ blast smanning.

Hange this?

Pam pups weapod and, Our Dunctic
with yet.

I takeind

to of be the dam and lations.

The ould gone ship. Settinifical.

We be on to hidner.

Ninow.

Don't get man a firide a pies, figurap in the please sust stated frouble rightis it.

We hath somes. Over relanced.

Okay, Master, R2 uniten betranspor deturech ording back canneverser.

Yegenerself the samend is a migg.


The Ann their not very pizale me.

It's in your cometimes were quite Alderaan
before timalfunctions.

Grrr in to tir.

Where

In [163]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, H)
q = query(x) # (B, T, H)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # (B, T, H) @ (B, H, T) --> (B, T, T) 
tril = torch.tril(torch.ones(T, T))

# # wei = torch.zeros(T, T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v 

In [271]:
args = argparse.Namespace

1


In [277]:
torch.manual_seed(123)
args = argparse.Namespace
args.batch_size = 4
args.seq_length = 8
args.n_embd = 32
args.head_size = 16
args.n_head = 2
x = torch.randn(args.batch_size, args.seq_length, args.n_embd)

class AttentionHead(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        self.args = args
        self.key = nn.Linear(args.n_embd, args.head_size, bias=False)
        self.query = nn.Linear(args.n_embd, args.head_size, bias=False)
        self.value = nn.Linear(args.n_embd, args.head_size, bias=False)
        # self.tril = torch.tril(torch.ones((block_size, block_size), requires_grad=False))
        self.register_buffer('tril', torch.tril(torch.ones((args.seq_length, args.seq_length))))
    
    def forward(self, x):
        b, s, c = x.shape # (batch, seq_length, channels) 
        k = self.key(x) # (b, s, c) --> (b, s, head_size)
        q = self.query(x) # (b, s, c) --> (b, s, head_size)

        attn = q @ k.transpose(-2, -1) * self.args.head_size ** -0.5 # (b, s, head_size) @ (b, head_size, s) --> (b, s, s)

        attn = attn.masked_fill(self.tril[:s, :s] == 0, float('-inf')) 
        attn = F.softmax(attn, dim=-1)

        v = self.value(x) # (b, s, c) --> (b, s, head_size) 
        out = attn @ v # (b, s, s) @ (b, s, head_size) --> (b, s, head_size) 
        return out 

class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, head_size, n_embedding, block_size) -> None:
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size, n_embedding, block_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embedding, n_embedding, bias=False)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embedding) -> None:
        super().__init__()
        self.l1 = nn.Linear(n_embedding, 4 * n_embedding)
        self.l2 = nn.Linear(4 * n_embedding, n_embedding)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.l1(x)
        x = self.relu(x)
        x = self.l2(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, n_head, n_embedding, block_size) -> None:
        super().__init__()
        head_size = n_embedding // n_head 
        self.sa = MultiHeadAttention(n_head=n_head, head_size=head_size, n_embedding=n_embedding, block_size=block_size)
        self.ffwd = FeedForward(n_embedding)

    def forward(self, x):
        x = self.sa(x)
        x = self.ffwd(x)
        return x

head = AttentionHead(args)
head(x).shape

# sum([p.numel() for p in head.parameters()])
# multihead = MultiHeadAttention(n_head=n_head, head_size=head_size, n_embedding=n_embedding, block_size=block_size)
# ffwd = FeedForward(n_embedding)
# ffwd(x)[0]
# decoder_block = DecoderBlock(n_head=n_head, n_embedding=n_embedding, block_size=block_size)
# decoder_block(x)

torch.Size([4, 8, 16])

In [276]:
torch.manual_seed(123)
batch_size = 4
block_size = 8
n_embd = 32
n_head = 2
head_size = 16
x = torch.randn(batch_size, block_size, n_embd)

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        # self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.size(-1)**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        # wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd, False)
        # self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        # out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            # nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
head = Head(head_size=head_size)
# multihead = MultiHeadAttention(num_heads=n_head, head_size=head_size)
# ffwd = FeedFoward(n_embd)
# ffwd(x)[0]
head(x)

tensor([[[ 5.7861e-02, -3.6411e-01, -9.2836e-02,  1.0961e-01, -1.4164e+00,
           4.9022e-01, -9.6439e-02, -2.9878e-01,  6.7750e-01,  2.6803e-01,
           7.8892e-02,  1.3635e-01,  1.7553e-02,  5.4448e-01, -3.9418e-01,
           2.8714e-01],
         [-5.2276e-01, -4.9407e-01,  4.2566e-01,  1.4612e-01, -9.2816e-01,
           5.3669e-02,  5.4877e-01,  2.8532e-01,  5.8181e-01,  1.1234e-01,
          -3.8405e-01,  1.9452e-01, -1.2907e-01, -5.3947e-01,  2.5659e-01,
          -9.2728e-02],
         [-4.3416e-01, -3.5311e-01, -2.2315e-02,  4.5016e-01, -6.0007e-01,
           3.6870e-02,  2.2213e-01,  1.7629e-02,  2.7614e-01,  1.4560e-01,
          -6.7048e-01, -1.1931e-01,  9.5055e-02, -1.4623e-02, -4.2846e-02,
           2.8831e-01],
         [-4.1676e-01, -3.0557e-01, -7.8429e-02,  5.6775e-01, -4.8637e-01,
           3.3208e-02,  2.2359e-01, -3.6059e-03,  2.0053e-01,  2.7256e-02,
          -8.1671e-01, -1.9872e-01,  9.0726e-02, -2.6008e-02, -8.1995e-02,
           3.8695e-01],
    