In [3]:
#Imports
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union


In [4]:
# Setup

# if torch.cuda.is_available():
#     DEVICE = torch.device("cuda")
#     print("Using GPU")
# else:
#     DEVICE = torch.device("cpu")
#     print("Using CPU")
DEVICE = torch.device("cpu")


In [5]:
class ContextAwareAttention(nn.Module):

    def __init__(self,
                 dim_model: int,
                 dim_context: int,
                 dropout_rate: Optional[float]=0.0,
                 num_heads: int = 1):
        super(ContextAwareAttention, self).__init__()
        
        self.dim_model = dim_model
        self.dim_context = dim_context
        self.dropout_rate = dropout_rate
        self.attention_layer = nn.MultiheadAttention(embed_dim=self.dim_model, 
                                                    num_heads=num_heads, 
                                                    dropout=self.dropout_rate, 
                                                    bias=True,
                                                    add_zero_attn=False,
                                                    batch_first=True,
                                                    device=DEVICE)


        self.u_k = nn.Linear(self.dim_context, self.dim_model, bias=False)
        self.w1_k = nn.Linear(self.dim_model, 1, bias=False)
        self.w2_k = nn.Linear(self.dim_model, 1, bias=False)
        
        self.u_v = nn.Linear(self.dim_context, self.dim_model, bias=False)
        self.w1_v = nn.Linear(self.dim_model, 1, bias=False)
        self.w2_v = nn.Linear(self.dim_model, 1, bias=False)
        




    def forward(self,
                q: torch.Tensor, 
                k: torch.Tensor,
                v: torch.Tensor,
                context: Optional[torch.Tensor]=None):
        
        #transformation of context to model dims
        key_context = self.u_k(context)
        value_context = self.u_v(context)

        # Calculation of lambda 
        lambda_k = F.sigmoid(self.w1_k(k) + self.w2_k(key_context))
        lambda_v = F.sigmoid(self.w1_v(v) + self.w2_v(value_context))

        # print(f'Value: {lambda_k}\n Shape:{lambda_k.shape}')
        # print(1-lambda_k)

        # lambda is a 1 dimensional row matrix, each row is multiplied with the entire row of the other vector
        k_cap = (1 - lambda_k) * k + lambda_k * key_context
        v_cap = (1 - lambda_v) * v + lambda_v * value_context

        attention_output, _ = self.attention_layer(query=q,
                                                   key=k_cap,
                                                   value=v_cap)
        
        # print(f'{context.shape=}\n{key_context.shape=}\n{value_context.shape=}\n{lambda_k.shape=}\n{lambda_v.shape=}\n{key_context.shape=}\n{k_cap=}\n{v_cap}')
        return attention_output


In [6]:
MCA2 = ContextAwareAttention(10,5,0.01,1)

### Test input

In [7]:
g = torch.Generator()
g.manual_seed(42)

<torch._C.Generator at 0x7f90336b2f70>

In [23]:
test_model_wts = torch.rand(10,10,generator = g)
test_context = torch.rand(10,5,generator = g)
# print(test_model_wts,'\n',test_context)

In [22]:
MCA2_op = MCA2.forward(test_model_wts,test_model_wts,test_model_wts,test_context)
# MCA2_op

In [18]:
rand_n_1_matrix = torch.rand(2,1)
test_wts_2 = torch.rand(2,3)

In [21]:
# print(f'{test_wts_2=}\n{rand_n_1_matrix=}\n{(rand_n_1_matrix*test_wts_2)=}')