In [221]:
import torch
import torch.nn as nn
import torch.optim as optim

In [222]:

#datasets 
src_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "How": 3, "are": 4, "you": 5, "?": 6}
tgt_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "Comment": 3, "allez-vous": 4, "?": 5}


src_sentence = ["How", "are", "you", "?"]
tgt_sentence = ["Comment", "allez-vous", "?"]

src_indices = [src_vocab[word] for word in src_sentence]
tgt_indices = [tgt_vocab[word] for word in tgt_sentence]

src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
embedding_dim = 64

src_embedding = nn.Embedding(src_vocab_size, embedding_dim)
tgt_embedding = nn.Embedding(tgt_vocab_size, embedding_dim)


src_indices_tensor = torch.tensor([src_indices], dtype=torch.long)  
tgt_indices_tensor = torch.tensor([tgt_indices], dtype=torch.long)

src_mask = (src_indices_tensor != src_vocab["<pad>"])
tgt_mask = (tgt_indices_tensor != tgt_vocab["<pad>"])

src_embeddings = src_embedding(src_indices_tensor)  
tgt_embeddings = tgt_embedding(tgt_indices_tensor) 

Q = src_embeddings 
K = tgt_embeddings 
V = tgt_embeddings 

if K.shape[1] < Q.shape[1]:
    pad_length = Q.shape[1] - K.shape[1]
    K = torch.nn.functional.pad(K, (0, 0, 0, pad_length))  
    V= torch.nn.functional.pad(V, (0, 0, 0, pad_length))  



In [223]:

class RLAttention(nn.Module):
    def __init__(self, model_dimension, n_heads):
        super(RLAttention, self).__init__()
        self.model_dimension = model_dimension
        self.n_heads = n_heads
        self.dimension = model_dimension // n_heads
        self.seq_length = K.size(1)

        self.q = nn.Linear(model_dimension, model_dimension)
        self.k = nn.Linear(model_dimension, model_dimension)
        self.v = nn.Linear(model_dimension, model_dimension)
        self.o = nn.Linear(model_dimension, model_dimension)

        self.policy_network = nn.Sequential(
            nn.Linear(self.dimension, 64),  
            nn.ReLU(),
            nn.Linear(64, self.seq_length),  
            nn.Softmax(dim=-1)  
        )


    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.q(Q))
        K = self.split_heads(self.k(K))
        V = self.split_heads(self.v(V))
        
        attention_weights = self.select_attention_weights(Q, K, mask)

        output = torch.matmul(attention_weights, V)
        combined = self.combine_heads(output)
        return self.o(combined), attention_weights
        
    def split_heads(self, x):
        batch_size, seq_length, model_dim = x.size()
        x = x.view(batch_size, seq_length, self.n_heads, self.dimension)
        return x.permute(0, 2, 1, 3) 
    def combine_heads(self, x):
        batch_size, n_heads, seq_length, dimension = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.view(batch_size, seq_length, n_heads * dimension)

    def select_attention_weights(self, Q, K, mask=None):
        batch_size, n_heads, seq_length_query, dimension = Q.size()
        _,_,seq_length_key, _ = K.size()

        # convert the queries to 1D for policy input
        Q_flat = Q.reshape(batch_size * n_heads * seq_length_query, dimension)
        
        action_probs = self.policy_network(Q_flat)
        
        action_probabilities = action_probs.view(batch_size, n_heads, seq_length_query, seq_length_key)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2) 
            action_probabilities = action_probabilities.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = torch.softmax(action_probabilities, dim=-1)

        return attention_weights

In [224]:
model = RLAttention(model_dimension=64, n_heads=8)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [225]:
def reward_function(predicted_translation, target_translation):
    return torch.cosine_similarity(predicted_translation, target_translation, dim=-1).mean()

In [226]:
epochs = 20
for epoch in range(epochs):
    optimizer.zero_grad()
    
    # Forward pass
    output, attention_weights = model(Q, K, V, mask=src_mask)
    
    rewards = reward_function(output, src_embeddings)
 
    log_probs = torch.log(attention_weights + 1e-9) 
    loss = -torch.sum(log_probs * rewards.unsqueeze(-1))
    
    loss.backward(retain_graph=True)
    optimizer.step()
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

Epoch 1/20, Loss: 4.1702141761779785
Epoch 2/20, Loss: -6.235805988311768
Epoch 3/20, Loss: -16.6051082611084
Epoch 4/20, Loss: -26.24207305908203
Epoch 5/20, Loss: -34.75881576538086
Epoch 6/20, Loss: -42.0794677734375
Epoch 7/20, Loss: -48.25393295288086
Epoch 8/20, Loss: -53.37934494018555
Epoch 9/20, Loss: -57.57787322998047
Epoch 10/20, Loss: -60.98817443847656
Epoch 11/20, Loss: -63.745872497558594
Epoch 12/20, Loss: -65.97418212890625
Epoch 13/20, Loss: -67.77554321289062
Epoch 14/20, Loss: -69.23640441894531
Epoch 15/20, Loss: -70.4271469116211
Epoch 16/20, Loss: -71.40367889404297
Epoch 17/20, Loss: -72.21573638916016
Epoch 18/20, Loss: -72.90322875976562
Epoch 19/20, Loss: -73.50045776367188
Epoch 20/20, Loss: -74.03679656982422
