In [2]:
import torch
import numpy as np
import torch.nn as nn

import matplotlib.pyplot as plt



- x : (batch_size, max_length) 

- tokens_id in x between 0 and vocab_size

- Embedd(x) : (batch_size, max_length, model_dim)

- K : (model_dim, dk)
- Kx : (batch_size, max_length, dk)

- Q : (model_dim, dk)
- Qx : (batch_size, max_length, dk)

- Qx*Kx^T : (batch_size, max_length, max_length)
- V : (model_dim, dv)
- Vx : (batch_size, max_length, dv)





In [3]:
class Embedding(nn.Module):
    def __init__(self,  batch_size, model_dim, max_length, n_embedding):
        super().__init__()
        self.max_length = max_length
        self.batch_size = batch_size
        self.model_dim = model_dim
        self.n_embedding = n_embedding
        self.embedding = torch.nn.Embedding(num_embeddings=n_embedding, embedding_dim=model_dim)
        self.pos_encoding = PositionalEncoding(batch_size=batch_size, model_dim=model_dim, max_length=max_length)
        pass
    
    def forward(self, x):
        x = self.embedding(x) + self.pos_encoding(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, batch_size, model_dim, max_length):
        super().__init__()
        self.model_dim = model_dim
        self.max_length = max_length
        self.batch_size = batch_size
        self.compute()

    def SinPos(self, i: int, pos: int):
        if i % 2 == 0:
            return np.sin(pos / 10000 ** (2 * i / self.model_dim))
        else:
            return np.cos(pos / 10000 ** (2 * i / self.model_dim))

    def compute(self):
        Mat = torch.Tensor([[self.SinPos(i, pos) for i in range(self.model_dim)] for pos in range(self.max_length)])
        self.Mat = Mat

    
    def forward(self, x):
        with torch.no_grad():
            return self.Mat



In [4]:
class SingleHeadAttention(nn.Module):
    def __init__(self, dk:int, dv:int, model_dim:int, mask:torch.Tensor=None):
        super().__init__()
        self.dk = dk
        self.dv = dv
        self.model_dim = model_dim
        self.K = nn.Linear(in_features=model_dim, out_features=dk)
        self.Q = nn.Linear(in_features=model_dim, out_features=dk)
        self.V = nn.Linear(in_features=model_dim, out_features=dv)
        self.mask = mask

    def forward(self, x:torch.Tensor, x_encoder:torch.Tensor=None):
        if x_encoder is not None:
            Kx = self.K(x_encoder)
            Vx = self.V(x_encoder)
        else:
            Kx = self.K(x)
            Vx = self.V(x)
        Qx = self.Q(x)
        QK = torch.matmul(Qx, Kx.transpose(-2, -1)) / np.sqrt(self.dk)
        if self.mask is not None:
            QK += self.mask
        QK = torch.softmax(QK, dim=-1)
        return torch.matmul(QK, Vx)

    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads : int, dk:int, dv:int, model_dim:int, mask=None):
        super().__init__()
        assert num_heads * dv == model_dim, "num_heads * dv should be equal to the model dim"
        self.attention_heads = nn.ModuleList([SingleHeadAttention(dk, dv, model_dim, mask) for _ in range(num_heads)])
        self.WO = nn.Linear(in_features=num_heads*dv, out_features=model_dim)  
        self.mask = mask
    
    def forward(self, x:torch.Tensor, x_encoder:torch.Tensor=None):
        outputs = [head(x, x_encoder) for head in self.attention_heads]
        x = torch.cat(outputs,dim=-1)
        x = self.WO(x)
        return x
    


In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, num_heads, dk, dv, model_dim):
        super().__init__()
        self.attention = MultiHeadAttention(num_heads=num_heads, dk=dk, dv=dv, model_dim=model_dim)
        self.layerNorm = nn.LayerNorm(normalized_shape=model_dim)
        self.ff = nn.Sequential(
            nn.Linear(in_features=model_dim, out_features=2048),
            nn.ReLU(),
            nn.Linear(in_features=2048, out_features=model_dim)
        )

    
    def forward(self, x):
        attention = self.attention(x)
        x = self.layerNorm(x + attention)
        feedforward = self.ff(x)
        x = self.layerNorm(x + feedforward)
        return x


class Encoder(nn.Module):
    def __init__(self, num_heads, dk, dv, model_dim, num_encoders):
        super().__init__()
        self.encoders_list = [EncoderBlock(num_heads, dk, dv, model_dim) for _ in range(num_encoders)]
        self.encoders = nn.Sequential(*self.encoders_list)

    def forward(self, x):
        x = self.encoders(x)
        return x


In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, num_heads, dk, dv, model_dim, max_length):
        super().__init__()
        self.mask = torch.zeros(max_length, max_length) + torch.triu(torch.full((max_length, max_length), float('-inf')), diagonal=1)
        self.masked_attention = MultiHeadAttention(num_heads=num_heads, dk=dk, dv=dv, model_dim=model_dim, mask=self.mask)
        self.mixed_attention = MultiHeadAttention(num_heads=num_heads, dk=dk, dv=dv, model_dim=model_dim)
        self.layerNorm = nn.LayerNorm(normalized_shape=model_dim)
        self.ff = nn.Sequential(
            nn.Linear(in_features=model_dim, out_features=2048),
            nn.ReLU(),
            nn.Linear(in_features=2048, out_features=model_dim)
        )

    def forward(self, x, x_encoder):
        attention = self.masked_attention(x)
        x = self.layerNorm(x + attention)
        attention = self.mixed_attention(x, x_encoder)
        x = self.layerNorm(x + attention)
        feedforward = self.ff(x)
        x = self.layerNorm(x + feedforward)
        return x

class CustomSequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.modules_list = nn.ModuleList(args)

    def forward(self, x, x_encoder):
        for module in self.modules_list:
            x = module(x, x_encoder)
        return x

class Decoder(nn.Module):
    def __init__(self, num_heads, dk, dv, model_dim, max_length, num_decoders):
        super().__init__()
        decoders_list = [DecoderBlock(num_heads, dk, dv, model_dim, max_length) for _ in range(num_decoders)]
        self.decoders = CustomSequential(*decoders_list)

    def forward(self, x, x_encoder):
        x = self.decoders(x, x_encoder)
        return x


In [7]:
class Transformer(nn.Module):
    def __init__(self, batch_size, model_dim, max_length, vocab_size, num_out, num_heads, dv, dk, num_encoders, num_decoders):
        super().__init__()
        self.encoder = Encoder(num_heads=num_heads, dk=dk, dv=dv, model_dim=model_dim, num_encoders=num_encoders)
        self.decoder = Decoder(num_heads=num_heads, dk=dk, dv=dv, model_dim=model_dim, num_decoders=num_decoders, max_length=max_length)
        self.softmax = nn.Softmax(dim=-1)
        self.linear = nn.Linear(in_features=model_dim, out_features=num_out)
        self.embedding = Embedding(batch_size=batch_size, model_dim=model_dim, max_length=max_length, n_embedding=vocab_size)
        
    
    def forward(self, x):
        x = self.embedding(x)
        x_encoder = self.encoder(x)
        x = self.decoder(x, x_encoder)
        x = self.linear(x)
        x = self.softmax(x)
        return x

In [8]:
batch_size = 16
model_dim = 512
max_length = 100
vocab_size = 32000
num_out = vocab_size
num_heads = 8
dv = 64
dk = 64
num_encoders = 6
num_decoders = 6

x = torch.randint(0, vocab_size, (batch_size, max_length))

MyTransformer = Transformer(
    batch_size=batch_size,
    model_dim=model_dim,
    max_length=max_length,
    vocab_size=vocab_size,
    num_out=num_out,
    num_heads=num_heads,
    dv=dv,
    dk=dk,
    num_encoders=num_encoders,
    num_decoders=num_decoders
)
out = MyTransformer(x)
out.shape


torch.Size([16, 100, 32000])

In [9]:
MyPositionalEncoding = PositionalEncoding(batch_size=batch_size, model_dim=model_dim, max_length=max_length)

plt.figure(figsize=(10, 10))
plt.imshow(MyPositionalEncoding.Mat[0, :100, :100], aspect="auto")

IndexError: too many indices for tensor of dimension 2

<Figure size 1000x1000 with 0 Axes>

In [170]:
def print_model_parameters(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f'{name}: {param.numel()} parameters')

print_model_parameters(MyTransformer)
# 
# Calculate the total number of trainable parameters
total_params = sum(p.numel() for p in MyTransformer.parameters() if p.requires_grad)
total_params_no_grad = sum(p.numel() for p in MyTransformer.parameters() if p.requires_grad)

print(f'Total number of trainable parameters: {total_params_no_grad}')
print(f'Total number parameters: {total_params}')

encoder.encoders.0.attention.attention_heads.0.K.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.0.K.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.0.Q.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.0.Q.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.0.V.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.0.V.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.1.K.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.1.K.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.1.Q.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.1.Q.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.1.V.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.1.V.bias: 64 parameters
encoder.encoders.0.attention.attention_heads.2.K.weight: 32768 parameters
encoder.encoders.0.attention.attention_heads.2.K.bias: 64 parameters

In [135]:
import torch
import numpy as np

# Define dimensions
batch_size = 2
seq_length = 3
model_dim = 4
dk = 4
dv = 4

# Generate random matrices
Qx = torch.randn(batch_size, seq_length, model_dim)
Kx = torch.randn(batch_size, seq_length, model_dim)
Vx = torch.randn(batch_size, seq_length, dv)

# Implementation 1: Dot product attention
QK_dot_product = torch.matmul(Qx, Kx.transpose(-2, -1)) / np.sqrt(dk)
QK_dot_product = torch.softmax(QK_dot_product, dim=-1)
output_dot_product = torch.matmul(QK_dot_product, Vx)

# Implementation 2: Element-wise multiplication and summation
QK_elementwise = torch.sum(Kx * Qx, dim=-1) / np.sqrt(dk)
QK_elementwise = torch.softmax(QK_elementwise, dim=-1)
QK_elementwise = QK_elementwise.unsqueeze(-1)
QK_elementwise = QK_elementwise.expand(-1, -1, dv)
output_elementwise = QK_elementwise * Vx

# Print results
print("Output (Dot Product):")
print(output_dot_product)
print("\nOutput (Element-wise):")
print(output_elementwise)

Output (Dot Product):
tensor([[[ 0.8135, -0.4580,  0.8029,  0.1411],
         [-0.3034, -0.2345,  0.2389,  0.1042],
         [ 0.6104, -0.1512,  0.4783, -0.0657]],

        [[-0.2787,  0.9954,  0.8464, -0.2801],
         [-0.2337,  1.0036,  0.7845, -0.1865],
         [-0.2645,  0.9221,  0.8332, -0.2380]]])

Output (Element-wise):
tensor([[[ 0.8075, -0.5717,  0.8141,  0.2876],
         [ 0.2791, -0.0066,  0.1415, -0.0585],
         [-0.3046, -0.0824,  0.0058,  0.0681]],

        [[ 0.0791,  0.2676,  0.0383,  0.2675],
         [-0.3081,  0.4092,  0.7134, -0.4355],
         [-0.0153,  0.1695,  0.0602, -0.0146]]])
