In [1]:
import torch
import torch.nn as nn

In [2]:
n = 10
d = 64
sqrt_d = torch.sqrt(torch.tensor(d)).int().item()
batch_size = 32

# Downsampling

In [5]:
data = torch.rand(d, n)
data.shape

torch.Size([64, 10])

In [6]:
weights = torch.rand(sqrt_d, sqrt_d, n)
weights.shape

torch.Size([8, 8, 10])

In [7]:
data_reshaped = data.view(sqrt_d, sqrt_d, n)
data_reshaped.shape

torch.Size([8, 8, 10])

In [8]:
product = data_reshaped * weights
product.shape

torch.Size([8, 8, 10])

In [9]:
# perform sum across in each window
updated_product = torch.sum(product, dim=1, keepdim=True)
print(updated_product.shape)
updated_product = updated_product.squeeze(1)
print(updated_product.shape)  # finally we have converted from dxn to sqrt(d)xn

torch.Size([8, 1, 10])
torch.Size([8, 10])


In [15]:
class Feebler(nn.Module):
    ''' 
    input: B, T, C
    output: B, T, sqrt(C)
    '''
    def __init__(self, sqrt_d):
        super().__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, n)
        )
        self.sqrt_d = sqrt_d

    def forward(self, data):
        # Data is of shape (b, n, d)
        data_reshaped = data.view(batch_size, d, n)  # set up data for feebler
        data_reshaped = data.view(batch_size, self.sqrt_d, self.sqrt_d, n)  # reshape incoming data
        product = data_reshaped * self.weights  # multiply data with weights
        # perform columnwise sum inside each window
        updated_product = torch.sum(product, dim=2, keepdim=False)  # finally we have converted from dxn to sqrt(d)xn
        return updated_product.view(batch_size, n, self.sqrt_d)

# f = Feebler(sqrt_d)
# f(torch.randn(batch_size, n, d)).shape


# Upsampling

In [10]:
attention_output = updated_product.clone()
print(attention_output.shape)

torch.Size([8, 10])


In [11]:
up_weights = torch.randn(sqrt_d, sqrt_d, n)
print(up_weights.shape)

torch.Size([8, 8, 10])


In [12]:
attention_output_reshaped = attention_output.view(1, -1)
print(attention_output_reshaped.shape)
attention_output_reshaped = attention_output_reshaped.repeat(sqrt_d, 1)  # repeat each row sqrt_d times
print(attention_output_reshaped.shape)
attention_output_reshaped = attention_output_reshaped.view(up_weights.shape)
print(attention_output_reshaped.shape)

torch.Size([1, 80])
torch.Size([8, 80])
torch.Size([8, 8, 10])


In [13]:
# Now multiply the reshaped data with weights
revived_output = up_weights * attention_output_reshaped
revived_output.shape

torch.Size([8, 8, 10])

In [3]:
class Booster(nn.Module):
    ''' 
    input: B, T, sqrt(C)
    output: B, T, C
    '''
    def __init__(self, sqrt_d):
        super(Booster, self).__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, n)
        )
        self.sqrt_d = sqrt_d

    def forward(self, attention_output):
        # attention_output is of shape (batch, n, sqrt_d)
        # set up data shape for the booster
        attention_output = attention_output.view(batch_size, self.sqrt_d, n)
        attention_output_reshaped = attention_output.view(batch_size, 1, -1) # flatten all rows into one row
        attention_output_reshaped = attention_output_reshaped.repeat(1, self.sqrt_d, 1)  # repeat each row sqrt_d times
        attention_output_reshaped = attention_output_reshaped.view(batch_size, self.sqrt_d, self.sqrt_d, n)
        # multiply
        revived_output = self.weights * attention_output_reshaped
        revived_output = revived_output.view(-1, n)
        return revived_output.view(batch_size, n, d)
    
# b = Booster(sqrt_d)
# b(torch.randn(32, 10, 8)).shape

# Quick Attention

In [14]:
q = torch.randn(d, n)
k = torch.randn(d, n)
print(f'q: {q.shape}')
print(f'k: {k.shape}')

q: torch.Size([64, 10])
k: torch.Size([64, 10])


In [15]:
collective_k = k.sum(1, keepdim=True)
collective_k.shape

torch.Size([64, 1])

In [16]:
# Broadcast explicitly
collective_k_bc = collective_k.repeat(1, n)
collective_k_bc.shape

torch.Size([64, 10])

In [17]:
# q multiply k
qk = q * collective_k_bc
qk.shape

torch.Size([64, 10])

In [18]:
attention_weights = torch.softmax(qk, dim=1)
attention_weights.shape

torch.Size([64, 10])

In [19]:
v = torch.randn(d, n)
v.shape

torch.Size([64, 10])

In [20]:
collective_v = v.sum(dim=1, keepdim=True)
collective_v.shape

torch.Size([64, 1])

In [21]:
collective_v_bc = collective_v.repeat(1, n)
collective_v_bc.shape

torch.Size([64, 10])

In [22]:
output = collective_v_bc * attention_weights
print(output.shape)

torch.Size([64, 10])


In [4]:
class QuickHead(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(sqrt_d, head_size, bias=False)
        self.query = nn.Linear(sqrt_d, head_size, bias=False)
        self.value = nn.Linear(sqrt_d, head_size, bias=False)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # x is of shape (batch_size, n, sqrt_d)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        v = self.value(x) # (B,T,C)

        collective_k = k.sum(1, keepdim=True)
        # Broadcast explicitly
        collective_k_bc = collective_k.repeat(1, n, 1)
        # q multiply k
        qk = q * collective_k_bc
        attention_weights = torch.softmax(qk, dim=1)
        collective_v = v.sum(dim=1, keepdim=True)
        collective_v_bc = collective_v.repeat(1, n, 1)
        output = collective_v_bc * attention_weights
        return output

# h = QuickHead(4)
# h(torch.rand(batch_size, n, sqrt_d)).shape


# Putting together


In [10]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
sqrt_d = torch.sqrt(torch.tensor(n_embd)).int().item()
n_head = sqrt_d // 2
n_layer = 4
dropout = 0.0
# ------------

In [21]:
class Feebler(nn.Module):
    ''' 
    input: B, T, C
    output: B, T, sqrt(C)
    '''
    def __init__(self, sqrt_d):
        super().__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, block_size)
        )
        self.sqrt_d = sqrt_d

    def forward(self, data):
        # Data is of shape (b, n, d)
        data_reshaped = data.view(batch_size, n_embd, block_size)  # set up data for feebler
        data_reshaped = data.view(batch_size, self.sqrt_d, self.sqrt_d, block_size)  # reshape incoming data
        product = data_reshaped * self.weights  # multiply data with weights
        # perform columnwise sum inside each window
        updated_product = torch.sum(product, dim=2, keepdim=False)  # finally we have converted from dxn to sqrt(d)xn
        return updated_product.view(batch_size, block_size, self.sqrt_d)
    

class Booster(nn.Module):
    ''' 
    input: B, T, sqrt(C)
    output: B, T, C
    '''
    def __init__(self, sqrt_d):
        super(Booster, self).__init__()
        self.weights = nn.Parameter(
            torch.randn(sqrt_d, sqrt_d, block_size)
        )
        self.sqrt_d = sqrt_d

    def forward(self, attention_output):
        # attention_output is of shape (batch, n, sqrt_d)
        # set up data shape for the booster
        attention_output = attention_output.view(batch_size, self.sqrt_d, block_size)
        attention_output_reshaped = attention_output.view(batch_size, 1, -1) # flatten all rows into one row
        attention_output_reshaped = attention_output_reshaped.repeat(1, self.sqrt_d, 1)  # repeat each row sqrt_d times
        attention_output_reshaped = attention_output_reshaped.view(batch_size, self.sqrt_d, self.sqrt_d, block_size)
        # multiply
        revived_output = self.weights * attention_output_reshaped
        revived_output = revived_output.view(-1, block_size)
        return revived_output.view(batch_size, block_size, n_embd)

class QuickHead(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(sqrt_d, head_size, bias=False)
        self.query = nn.Linear(sqrt_d, head_size, bias=False)
        self.value = nn.Linear(sqrt_d, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x is of shape (batch_size, n, sqrt_d)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        v = self.value(x) # (B,T,C)

        collective_k = k.sum(1, keepdim=True)
        # Broadcast explicitly
        collective_k_bc = collective_k.repeat(1, block_size, 1)
        # q multiply k
        qk = q * collective_k_bc
        attention_weights = torch.softmax(qk, dim=1)
        collective_v = v.sum(dim=1, keepdim=True)
        collective_v_bc = collective_v.repeat(1, block_size, 1)
        output = collective_v_bc * attention_weights
        return output
    
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([QuickHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(sqrt_d, sqrt_d) # global variable sqrt_d
        self.dropout = nn.Dropout(dropout)  # global variable dropout

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, sqrt_d):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(sqrt_d, 4 * sqrt_d),
            nn.ReLU(),
            nn.Linear(4 * sqrt_d, sqrt_d),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = sqrt_d // n_head
        self.feebler = Feebler(sqrt_d)
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(sqrt_d)
        self.ln1 = nn.LayerNorm(sqrt_d)
        self.ln2 = nn.LayerNorm(sqrt_d)
        self.booster = Booster(sqrt_d)

    def forward(self, x):
        x = self.feebler(x)
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        x = self.booster(x)
        return x

In [22]:
b = Block(n_embd, n_head)
b(torch.rand(batch_size, block_size, n_embd)).shape

torch.Size([32, 32, 64])

In [24]:
vocab_size = 10000
# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()

In [27]:
l, ll = model(torch.rand(batch_size, block_size).long())

In [28]:
l.shape


torch.Size([32, 32, 10000])