In [21]:
import torch
from torch import nn
import torch.functional as F
import einops

In [24]:
with open('/Users/Darrell/Desktop/tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [25]:
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [45]:
stoi = {s:i for i, s in enumerate(vocab)}
itos = {i:s for i, s in enumerate(vocab)}
tokenize = lambda s: [stoi[c] for c in s]
detokenize = lambda c: ''.join([itos[x] for x in c])

In [22]:
class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

In [50]:
class Token_Embedding(nn.Module):
    def __init__(self, vocab_size, emb_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)

    def forward(self, x):
        return self.embedding(x)

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, head_dim: int, base=10000):
        super().__init__()
        inv_freq = float(head_dim)/(base ** torch.arange(0, head_dim, 2).float())
        self.register_buffer('inv_freq', inv_freq, persistent=False)
        self.head_dim = head_dim
        self.seq_len_cached = None
        self.batch_size_cached = None
        self.cos_cached: torch.tensor | None = None
        self.sin_cached: torch.tensor | None = None

    def trig(self, seq_len: int, device=device, dtype=torch.bfloat16) -> torch.Tensor:
        if seq_len != self.seq_len_cached: 
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j -> ij', t, self.inv_freq)
            emb = torch.cat((freq, freq), dim=-1).float().to(device)

            self.cos_cached = emb.cos()[None, :, :]
            self.sin_cached = emb.sin()[None, :, :]

        return self.cos_cached, self.sin_cached
    
    def forward(self, x):
        #need to decide if this should take (q, k) or just x
        pass

            

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, n_head, n_embd, attention_drop = 0.1, residual_drop = 0.1):
        super().__init__()
        self.n_head = n_head
        self.attention_drop = attention_drop
        self.residual_drop = residual_drop
        self.query = nn.Linear(n_embd, n_embd)
        self.key = nn.Linear(n_embd, n_embd)
    
    def forward(self, x):
        # x = (B, T, E) ---> (B, num_heads, T, h_size)
        B, T, E = x.shape
        q = self.query(x).view(B, T, n_head, h_size).transpose(1, 2)
        k = self.key(x).view(B, T, n_head, h_size).transpose(1, 2).transpose(-1, -2)
        W = q @ k * E ** -0.5

        


In [77]:
#Things to implement:
#   1. Flash Attention
#   2. Multi-Query Attention
#   3. Rotary Positional Embedding
#   4. GLU activations

In [11]:
dim = 8
seq_len = 4

In [5]:
inv_freq = 1.0 / (10000 ** (torch.arange(0, 8, 2).float() / 8))

In [6]:
inv_freq

tensor([1.0000, 0.1000, 0.0100, 0.0010])

In [9]:
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)

In [10]:
scale

tensor([0.2857, 0.4643, 0.6429, 0.8214])

In [17]:
t = torch.arange(seq_len).type_as(inv_freq)
freqs = torch.einsum('i , j -> i j', t, inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)

In [20]:
freqs

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03, 1.0000e+00, 1.0000e-01,
         1.0000e-02, 1.0000e-03],
        [2.0000e+00, 2.0000e-01, 2.0000e-02, 2.0000e-03, 2.0000e+00, 2.0000e-01,
         2.0000e-02, 2.0000e-03],
        [3.0000e+00, 3.0000e-01, 3.0000e-02, 3.0000e-03, 3.0000e+00, 3.0000e-01,
         3.0000e-02, 3.0000e-03]])

In [None]:
self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j -> ij', t, self.inv_freq)
            emb = torch.cat((freq, freq), dim=-1).float().to(device)

            self.cos_cached = emb.cos()[None, :, :]
            self.sin_cached = emb.sin()[None, :, :]

In [44]:
inv_freq = 8.0/(10 ** torch.arange(0, 8, 2).float())
t = torch.arange(5)
freq = torch.einsum('i, j -> ij', t, inv_freq)
emb = torch.cat((freq, freq), dim=-1).float()

In [67]:
test = torch.ones((2, 3, 4, 5))

In [80]:
x = torch.arange(4*3*2)

In [81]:
x = x.view((2, 3, 4))

In [82]:
x

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [86]:
x = x.view(6, 4)