In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:

# Define a simple RL-based attention mechanism
class RLAttention(nn.Module):
    def __init__(self, model_dimension, n_heads):
        super(RLAttention, self).__init__()
        self.model_dimension = model_dimension
        self.n_heads = n_heads
        self.dimension = model_dimension // n_heads

        # Linear layers for Q, K, V, and output
        self.q = nn.Linear(model_dimension, model_dimension)
        self.k = nn.Linear(model_dimension, model_dimension)
        self.v = nn.Linear(model_dimension, model_dimension)
        self.o = nn.Linear(model_dimension, model_dimension)

        # Define the policy network (for selecting keys based on queries)
        self.policy_network = nn.Sequential(
            nn.Linear(self.dimension, 64),  # Input is the Query dimension
            nn.ReLU(),
            nn.Linear(64, self.dimension),  # Output is the action (selection of relevant keys)
            nn.Softmax(dim=-1)  # Probability distribution over keys
        )


    def forward(self, Q, K, V, mask=None):
        # Split the input tensors into multi-head dimensions
        Q = self.split_heads(self.q(Q))
        K = self.split_heads(self.k(K))
        V = self.split_heads(self.v(V))
        
        # Calculate attention weights using RL policy
        attention_weights = self.select_attention_weights(Q, K, mask)
        # Apply the attention weights to the values
        attention_weights = attention_weights.transpose(-2, -1)
        output = torch.matmul(attention_weights, V)
        
        combined = self.combine_heads(output)
        return self.o(combined)

    def split_heads(self, x):
        batch_size, seq_length, model_dim = x.size()
        x = x.view(batch_size, seq_length, self.n_heads, self.dimension)
        return x.permute(0, 2, 1, 3)  # (batch_size, n_heads, seq_length, dimension)

    def combine_heads(self, x):
        batch_size, n_heads, seq_length, dimension = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.view(batch_size, seq_length, n_heads * dimension)

    def select_attention_weights(self, Q, K, mask=None):
        # Assume Q and K are (batch_size, n_heads, seq_length, dimension)
        batch_size, n_heads, seq_length, dimension = Q.size()

        # Flatten the queries for policy input
        Q_flat = Q.reshape(batch_size * n_heads * seq_length, dimension)
        
        # Use the policy network to decide the attention weights for each query
        action_probs = self.policy_network(Q_flat)
        
        # Reshape action_probs back to attention shape
        action_probs = action_probs.view(batch_size, n_heads, seq_length, seq_length)

        # Apply softmax along the last dimension for attention
        attention_weights = torch.softmax(action_probs, dim=-1)

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # Adjust mask shape for multi-head attention
            attention_weights = attention_weights.masked_fill(mask == 0, float('-inf'))
            attention_weights = torch.softmax(attention_weights, dim=-1)

        return attention_weights

In [3]:

#datasets 
src_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "How": 3, "are": 4, "you": 5, "?": 6}
tgt_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "Comment": 3, "allez-vous": 4, "?": 5}


src_sentence = ["How", "are", "you", "?"]
tgt_sentence = ["Comment", "allez-vous", "?"]

src_indices = [src_vocab[word] for word in src_sentence]
tgt_indices = [tgt_vocab[word] for word in tgt_sentence]

src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
embedding_dim = 64

src_embedding = nn.Embedding(src_vocab_size, embedding_dim)
tgt_embedding = nn.Embedding(tgt_vocab_size, embedding_dim)


# Convert indices to tensors
src_indices_tensor = torch.tensor([src_indices], dtype=torch.long)  # Shape: (batch_size, seq_length)
tgt_indices_tensor = torch.tensor([tgt_indices], dtype=torch.long)  # Shape: (batch_size, seq_length)

# Generate masks (1 for valid tokens, 0 for padding)
src_mask = (src_indices_tensor != src_vocab["<pad>"])
tgt_mask = (tgt_indices_tensor != tgt_vocab["<pad>"])

# Get embeddings
src_embeddings = src_embedding(src_indices_tensor)  # Shape: (batch_size, seq_length, embedding_dim)
tgt_embeddings = tgt_embedding(tgt_indices_tensor)  # Shape: (batch_size, seq_length, embedding_dim)

Q = src_embeddings  # Source sentence embeddings
K = tgt_embeddings  # Target sentence embeddings (keys)
V = tgt_embeddings  # Target sentence embeddings (values)

In [4]:
# Create the model, define optimizer and loss function
model = RLAttention(model_dimension=64, n_heads=8)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Forward pass
output = model(Q, K, V, mask=tgt_mask)

print("Attention output shape:", output.shape) 


RuntimeError: shape '[1, 8, 4, 4]' is invalid for input of size 256

In [None]:
# Define the reward function (this is a simple placeholder; a real reward function should be based on the task)
def reward_function(predicted_translation, target_translation):
    # Compute reward as cosine similarity between predicted and target translation (simplified)
    return torch.cosine_similarity(predicted_translation, target_translation, dim=-1).mean()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


# Define a simple RL-based attention mechanism
class RLAttention(nn.Module):
    def __init__(self, model_dimension, n_heads):
        super(RLAttention, self).__init__()
        self.model_dimension = model_dimension
        self.n_heads = n_heads
        self.dimension = model_dimension // n_heads  # Ensure this is an integer
        
        # Linear layers for Q, K, V, and output
        self.q = nn.Linear(model_dimension, model_dimension)
        self.k = nn.Linear(model_dimension, model_dimension)
        self.v = nn.Linear(model_dimension, model_dimension)
        self.o = nn.Linear(model_dimension, model_dimension)

        # Define the policy network (for selecting keys based on queries)
        self.policy_network = nn.Sequential(
            nn.Linear(self.dimension, 64),  # Input is the Query dimension
            nn.ReLU(),
            nn.Linear(64, self.dimension),  # Output is the action (selection of relevant keys)
            nn.Softmax(dim=-1)  # Probability distribution over keys
        )

    def forward(self, Q, K, V, mask=None):
        # Split the input tensors into multi-head dimensions
        Q = self.split_heads(self.q(Q))  # Shape: (batch_size, n_heads, seq_length, dimension)
        K = self.split_heads(self.k(K))  # Shape: (batch_size, n_heads, seq_length, dimension)
        V = self.split_heads(self.v(V))  # Shape: (batch_size, n_heads, seq_length, dimension)

        # Calculate attention weights using RL policy
        attention_weights = []
        for q, k in zip(Q, K):
            action_probs = self.policy_network(q.reshape(q.size(0), -1))  # Flatten q for policy input
            action_probs = action_probs.view(q.size(0), q.size(1), q.size(2))  # Reshape to (batch_size, seq_length, seq_length)
            attention_weights.append(torch.softmax(action_probs, dim=-1))

        attention_weights = torch.cat(attention_weights, dim=1)  # Shape: (batch_size, n_heads, seq_length, seq_length)

        # Apply attention weights to the values (V)
        output = torch.matmul(attention_weights, V.transpose(2, 3))  # Transpose V to align dimensions
        output = output.transpose(2, 3)  # (batch_size, n_heads, seq_length, dimension)

        combined = self.combine_heads(output)  # Combine heads back
        return self.o(combined)  # Output of the attention layer

    def split_heads(self, x):
        batch_size, seq_length, model_dim = x.size()
        x = x.view(batch_size, seq_length, self.n_heads, self.dimension)  # Split into heads
        return x.permute(0, 2, 1, 3)  # (batch_size, n_heads, seq_length, dimension)

    def combine_heads(self, x):
        batch_size, n_heads, seq_length, dimension = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()  # Re-order dimensions
        return x.view(batch_size, seq_length, n_heads * dimension)  # Combine heads


# datasets 
src_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "How": 3, "are": 4, "you": 5, "?": 6}
tgt_vocab = {"<pad>": 0, "<eos>": 1, "<unk>": 2, "Comment": 3, "allez-vous": 4, "?": 5}

src_sentence = ["How", "are", "you", "?"]
tgt_sentence = ["Comment", "allez-vous", "?"]

src_indices = [src_vocab[word] for word in src_sentence]
tgt_indices = [tgt_vocab[word] for word in tgt_sentence]

src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
embedding_dim = 64

src_embedding = nn.Embedding(src_vocab_size, embedding_dim)
tgt_embedding = nn.Embedding(tgt_vocab_size, embedding_dim)

# Convert indices to tensors
src_indices_tensor = torch.tensor([src_indices], dtype=torch.long)  # Shape: (batch_size, seq_length)
tgt_indices_tensor = torch.tensor([tgt_indices], dtype=torch.long)  # Shape: (batch_size, seq_length)

# Generate masks (1 for valid tokens, 0 for padding)
src_mask = (src_indices_tensor != src_vocab["<pad>"])
tgt_mask = (tgt_indices_tensor != tgt_vocab["<pad>"])

# Get embeddings
src_embeddings = src_embedding(src_indices_tensor)  # Shape: (batch_size, seq_length, embedding_dim)
tgt_embeddings = tgt_embedding(tgt_indices_tensor)  # Shape: (batch_size, seq_length, embedding_dim)

Q = src_embeddings  # Source sentence embeddings
K = tgt_embeddings  # Target sentence embeddings (keys)
V = tgt_embeddings  # Target sentence embeddings (values)

# Create the model, define optimizer and loss function
model = RLAttention(model_dimension=64, n_heads=8)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Forward pass
output = model(Q, K, V, mask=tgt_mask)

print("Attention output shape:", output.shape)

# Define the reward function (this is a simple placeholder; a real reward function should be based on the task)
def reward_function(predicted_translation, target_translation):
    # Compute reward as cosine similarity between predicted and target translation (simplified)
    return torch.cosine_similarity(predicted_translation, target_translation, dim=-1).mean()


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x32 and 8x64)