<a href="https://colab.research.google.com/github/paulxiong/tinyTF/blob/main/train_the_weight_matrices_of_Q%2C_K%2C_and_V.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the Transformer model
class Transformer(nn.Module):
    def __init__(self, d_model):
        super(Transformer, self).__init__()
        self.self_attention = SelfAttention(d_model)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, x):
        attended_values = self.self_attention(x)
        output = self.fc(attended_values)
        return output

# Define the SelfAttention module
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=torch.float32))

        # Apply softmax to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)

        # Apply attention weights to value vectors
        weighted_values = torch.matmul(attention_weights, V)

        return weighted_values

# Create an instance of Transformer
d_model = 64
transformer = Transformer(d_model)

# Define a toy input sequence and corresponding output sequence
input_sequence = torch.randn(5, 10, d_model)
target_sequence = torch.randn(5, 10, d_model)

# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(transformer.parameters(), lr=0.001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    optimizer.zero_grad()
    
    # Forward pass
    output_sequence = transformer(input_sequence)
    
    # Compute loss
    loss = criterion(output_sequence, target_sequence)
    
    # Backward pass
    loss.backward()
    
    # Update weights
    optimizer.step()
    
    # Print the loss for monitoring
    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


Epoch [1/100], Loss: 0.9994
Epoch [2/100], Loss: 0.9909
Epoch [3/100], Loss: 0.9828
Epoch [4/100], Loss: 0.9750
Epoch [5/100], Loss: 0.9676
Epoch [6/100], Loss: 0.9604
Epoch [7/100], Loss: 0.9535
Epoch [8/100], Loss: 0.9467
Epoch [9/100], Loss: 0.9400
Epoch [10/100], Loss: 0.9335
Epoch [11/100], Loss: 0.9269
Epoch [12/100], Loss: 0.9204
Epoch [13/100], Loss: 0.9138
Epoch [14/100], Loss: 0.9072
Epoch [15/100], Loss: 0.9005
Epoch [16/100], Loss: 0.8937
Epoch [17/100], Loss: 0.8867
Epoch [18/100], Loss: 0.8796
Epoch [19/100], Loss: 0.8724
Epoch [20/100], Loss: 0.8650
Epoch [21/100], Loss: 0.8575
Epoch [22/100], Loss: 0.8498
Epoch [23/100], Loss: 0.8419
Epoch [24/100], Loss: 0.8339
Epoch [25/100], Loss: 0.8258
Epoch [26/100], Loss: 0.8175
Epoch [27/100], Loss: 0.8091
Epoch [28/100], Loss: 0.8005
Epoch [29/100], Loss: 0.7919
Epoch [30/100], Loss: 0.7831
Epoch [31/100], Loss: 0.7743
Epoch [32/100], Loss: 0.7655
Epoch [33/100], Loss: 0.7566
Epoch [34/100], Loss: 0.7477
Epoch [35/100], Loss: 0