In [112]:
with open('../data/wizard_of_oz.txt', 'r') as f:
    text = f.read()

print(text[:100])

chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Chracters: ",chars)
print("Vocab Size", vocab_size)

  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ
Chracters:  ['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']
Vocab Size 81


In [113]:
def encode(text: str):
    ans = []
    for t in text:
        ans.append(chars.index(t))
    return ans

encoded_hello = encode("hii there")
encoded_hello

[61, 62, 62, 1, 73, 61, 58, 71, 58]

In [114]:
def decode(indices: list):
    ans = ""
    for i in indices:
        ans += chars[i]
    return ans

decoded_hello = decode(encoded_hello)
decoded_hello

'hii there'

In [115]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:100])

torch.Size([232309]) torch.int64
tensor([80,  1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,
         1, 47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26,
        49,  0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,
         0,  0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1,
        47, 33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1,
        36, 25, 38, 28,  1, 39, 30,  1, 39, 50])


In [116]:
n = int(len(data) * 0.8)
train_data = data[:n]
test_data = data[n:]

In [117]:
torch.manual_seed(42)

batch_size = 4  # number of sequences in a batch
block_size = 8  # length of a particular sequence

def get_batch(split):
    """ 
    Generate a small batch of data of inputs and targets
    """

    data = train_data if split == "train_data" else test_data

    # generates batch_size number of indexes from until the length of data - block size
    ix = torch.randint(len(data)-block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

x, y = get_batch('train')
print("X: ")
print(x.shape)
print(x)

print('y: ')
print(y.shape)
print(y)

for b in range(batch_size):
    for l in range(block_size):
        context = x[b, :l+1]
        target = y[b, l]
        print(f"Context: {context.tolist()} | Target: {target}")


X: 
torch.Size([4, 8])
tensor([[58,  1, 57, 68,  1, 67, 68, 76],
        [72,  1, 55, 58, 62, 67, 60,  0],
        [65, 58, 72, 11,  3,  0,  0,  3],
        [ 0,  0,  3, 49, 68, 74, 71,  1]])
y: 
torch.Size([4, 8])
tensor([[ 1, 57, 68,  1, 67, 68, 76, 24],
        [ 1, 55, 58, 62, 67, 60,  0, 65],
        [58, 72, 11,  3,  0,  0,  3, 47],
        [ 0,  3, 49, 68, 74, 71,  1, 42]])
Context: [58] | Target: 1
Context: [58, 1] | Target: 57
Context: [58, 1, 57] | Target: 68
Context: [58, 1, 57, 68] | Target: 1
Context: [58, 1, 57, 68, 1] | Target: 67
Context: [58, 1, 57, 68, 1, 67] | Target: 68
Context: [58, 1, 57, 68, 1, 67, 68] | Target: 76
Context: [58, 1, 57, 68, 1, 67, 68, 76] | Target: 24
Context: [72] | Target: 1
Context: [72, 1] | Target: 55
Context: [72, 1, 55] | Target: 58
Context: [72, 1, 55, 58] | Target: 62
Context: [72, 1, 55, 58, 62] | Target: 67
Context: [72, 1, 55, 58, 62, 67] | Target: 60
Context: [72, 1, 55, 58, 62, 67, 60] | Target: 0
Context: [72, 1, 55, 58, 62, 67, 60,

In [118]:
torch.manual_seed(42)

# We use tril to replicate the above pattern, where each row of triangle is a sequence in a batch.
a = torch.tril(torch.ones(3,3))
b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print(f"A: \n {a}\n----")
print(f"B: \n {b}\n----")
print(f"C: \n {c}\n----")

A: 
 tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
----
B: 
 tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
C: 
 tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])
----


tril enables us to create a lower triangular matrix from a given matrix. 

This is useful summing up the context in given sequence.

In the above example B is a sequence. When multiplied with A, we get a new matrix C, which has the summation of the of the context of each word in the until that particular index.

Instead of doing we can also perform average of the given context.

In [119]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0, 10, (3, 2)).float()

a = a / a.sum(1, keepdim=True)
c = a @ b

print(f"A: \n {a}\n----")
print(f"B: \n {b}\n----")
print(f"C: \n {c}\n----")

A: 
 tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
----
B: 
 tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
----
C: 
 tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
----


## Attention

In [120]:
from pandas import value_counts
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # Shape: (B,T,16)
q = query(x) # Shape: (B,T,16)
w = q @ k.transpose(-2, -1) # Shape: (B,T,16) @ (B,16,T) --> (B,T,T)(4, 8, 8). Transposing necessary for matmul.

print(w[0])
# masking the weights is necessary to ensure that future time step isnt being taken into consideration.
# Hence the upper triangle is being masked with -inf
tril = torch.tril(torch.ones(T, T))
w = w.masked_fill(tril == 0, float('-inf'))
w = F.softmax(w, dim=-1)
print("Post masking: ")
print(w[0])
print("-"*30)

tensor([[-0.3332, -1.1723, -1.0216, -0.0545, -1.0950,  0.2735,  0.1340, -0.8490],
        [-0.6597,  0.7869, -1.2725,  1.6851,  0.1159,  0.5450,  0.2356, -0.1962],
        [ 0.3630, -1.5219,  0.7821, -1.7215, -0.3494,  0.2884, -0.1021, -1.4271],
        [-0.1001,  0.8649, -0.0335,  1.0221, -0.1350, -0.3078,  0.1440, -0.3019],
        [ 0.0136, -1.6202, -1.9888, -0.3327, -1.2506, -0.8928, -2.2674,  3.0561],
        [-0.5833,  1.2025, -0.3281,  0.9147,  0.9809, -0.4859,  1.7589,  0.1650],
        [ 1.1351, -1.9940,  1.5545, -1.8037, -0.5062, -2.6109, -1.0739,  1.6430],
        [-1.2784, -0.4554, -1.4118,  0.6392, -0.5780,  1.9291,  1.6689,  0.1103]],
       grad_fn=<SelectBackward0>)
Post masking: 
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1905, 0.8095, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3742, 0.0568, 0.5690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1288, 0.3380, 0.1376, 0.3956, 0.0000, 0.0000, 0.0000, 0.0000]

- Weights (w): Tell us how much each word (or time step) should focus on every other word when generating its new representation.

- Value (v): Represents the actual content or features of each word.

In [121]:
v = value(x)
out = w @ v

out.shape
w[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1905, 0.8095, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3742, 0.0568, 0.5690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1288, 0.3380, 0.1376, 0.3956, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4311, 0.0841, 0.0582, 0.3049, 0.1217, 0.0000, 0.0000, 0.0000],
        [0.0537, 0.3205, 0.0694, 0.2404, 0.2568, 0.0592, 0.0000, 0.0000],
        [0.3396, 0.0149, 0.5165, 0.0180, 0.0658, 0.0080, 0.0373, 0.0000],
        [0.0165, 0.0375, 0.0144, 0.1120, 0.0332, 0.4069, 0.3136, 0.0660]],
       grad_fn=<SelectBackward0>)

Let's break down the tensor:


- The tensor represents represents one sequence within a batch of sequences.  
- Each row corresponds to a time step (or word, in the context of our previous example).
- Each column in a given row represents the attention score of the word at that time step with every other word in the sequence, including itself.
- The values are between 0 and 1 (due to the softmax operation) and represent the attention probabilities.

For example:

- The first row `[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]` indicates that the first word is only attending to itself and not to any other word in the sequence.
  
- The second row `[0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]` indicates that the second word is mostly attending to itself (with a score of 0.8426) but also has some attention (0.1574) on the first word. It doesn't attend to any word after it, which is consistent with the masking we applied to prevent attending to future words.



The last row of the tensor is:
`[0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]`

This row represents the attention scores of the last word (or time step) in the sequence with respect to every other word in the sequence, including itself. Here's what we can infer from these scores:

1. **Self-Attention**: The last word has an attention score of `0.2391` with itself. This means that when generating the output representation for this word, about 23.91% of the information will come from the word itself.

2. **Previous Words**: The last word is attending to all the previous words in the sequence, but the attention is distributed unevenly:
   - The word gives the highest attention to the 7th word with a score of `0.2423`.
   - The next highest attention is given to the 4th word with a score of `0.2297`.
   - The rest of the words receive relatively lower attention, with scores ranging from `0.0210` to `0.0843`.

3. **No Future Words**: As expected (due to the masking), the last word doesn't attend to any "future" words since it's the last word in the sequence.

4. **Interpretation**: The last word in the sequence seems to be most influenced by the 7th and 4th words when determining its output representation. This could mean that, in the context of the data, the 7th and 4th words have the most relevant information or context for the last word.

- As we go down the rows, we can see that each word can attend to itself and any previous word in the sequence, but not to words that come after it.



In [122]:


class Head(nn.Module):

    def __init__(self, head_size, embedding_dim, block_size, dropout=0.05):
        super().__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # we use scaling factor C**0.5 to enure that the dot product doesn't become too large
        w = (q @ k.transpose(-2, -1)) / (C**0.5)
        w = w.masked_file(self.tril == 0, float('-inf'))
        w = F.softmax(w, dim=-1)
        w = self.dropout(w)

        out = w @ v
        return out

# head_size = 16
# embedding_dim = 32
# block_size = 8
# head = Head(head_size=head_size, embedding_dim=embedding_dim, block_size=block_size)


In [None]:

class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, embedding_dim, block_size, dropout=0.05):
        super().__init__()

        self.heads = nn.ModuleList([
            Head(embedding_dim//num_heads, embedding_dim, block_size)
            for _ in range(num_heads)
        ])

        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = [head(x) for head in self.heads]
        out = torch.concat(out, dim=-1)
        out = self.proj(x)
        return out
    