In [149]:
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn import Parameter
import math
import einops
from typing import Dict, Optional, Tuple


In [154]:
class MultiHeadedAttention(nn.Module):                                  # new class with parent class nn.Module
    """
    Multi-Headed Attention
    """
    def __init__(
        self,
        embed_dim,                                                       # d, embedding dim
        num_heads,                                                       # number of heads
        kdim = None,                                                     # d except in first layer or query != key
        vdim = None,                                                     # d except in first layer (i think?)
        dropout=0.0,                                                     # dropout
        bias = False,                                                     # whether to add bias or not
#        add_bias_kv: bool = False,                                       #
#        add_zero_attn: bool = False,                                     # 
        self_attention: bool = False,                                    #
#        encoder_decoder_attention: bool = False,                         # 
#        use_rotary_embeddings: bool = False,                             #
    ):
        super().__init__()                                               # necessary to have MHA be able to call functions from nn.Module
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim              
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and \
            self.vdim == embed_dim
        self.num_heads = num_heads
        self.dropout = dropout                                           # dropout, randomly select a few nodes to drop during training
        self.head_dim = embed_dim // num_heads
        assert (                                                         # head_dim = embed_dim/num_heads
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5
        self.self_attention = self_attention
#        self.encoder_decoder_attention = encoder_decoder_attention
        assert not self.self_attention or self.qkv_same_dim, (           
            "Self-attention requires query, key "+\
            "and " "value to be of the same size"
        )

        self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)          # Q,K,V projection
        self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)        # linear layer after attention 
        self.reset_parameters()                                           # initialize weights
    def forward(
        self,
        query,
        key,
        value,
        attn_mask = None,
    ):
        seqlen, batchsize, embed_dim = query.size()                        # input (seqlen, batchsize, embed_dim)
        assert embed_dim == self.embed_dim 
        
        q = einops.rearrange(self.q_proj(query), 's b (h d)->b h s d'\
                             , h=self.num_heads)                           # project qkv, reshape to have batchsize * numheads for mha
        k = einops.rearrange(self.k_proj(key), 's b (h d)->b h s d'\
                             , h=self.num_heads)                           # transpose because F.scaled_dot_product_attm expects (..., seqlen, embed_dim)
        v = einops.rearrange(self.v_proj(value), 's b (h d)->b h s d'\
                             , h=self.num_heads)                           # but since we have multiple heads, 
                                                                           # embed_dim for each head is d/num_heads

        q *= self.scaling
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
        attn = einops.rearrange(attn, 'b h s d->s b (h d)', h = self.num_heads)
        output = self.out_proj(attn)
        return output

    def reset_parameters(self):
        if self.qkv_same_dim:
            # Empirically observed the convergence to be much better with
            # the scaled initialization
            nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
        else:
            nn.init.xavier_uniform_(self.k_proj.weight)
            nn.init.xavier_uniform_(self.v_proj.weight)
            nn.init.xavier_uniform_(self.q_proj.weight)
        
        nn.init.xavier_uniform_(self.out_proj.weight)
        # if self.out_proj.bias is not None:
        #     nn.init.constant_(self.out_proj.bias, 0.0)
        # if self.bias_k is not None:
        #     nn.init.xavier_normal_(self.bias_k)
        # if self.bias_v is not None:
        #     nn.init.xavier_normal_(self.bias_v)

In [136]:
test = MultiHeadedAttention(embed_dim=10, num_heads=1,self_attention=True)

In [137]:
test.k_proj.state_dict()['weight']
#print(test.k_proj.state_dict()['bias'])

tensor([[-0.3738, -0.1501,  0.1794, -0.3177, -0.3394, -0.3143,  0.0425,  0.2653,
          0.2233, -0.3119],
        [ 0.0346,  0.2435, -0.2388,  0.2066, -0.2687, -0.0800, -0.1075,  0.2745,
         -0.1055, -0.1391],
        [-0.2789,  0.2672, -0.1322,  0.0066, -0.0223,  0.2218,  0.1279,  0.3076,
          0.0131, -0.1986],
        [-0.3606, -0.0517,  0.1176, -0.3512, -0.0435,  0.3113, -0.3524,  0.0343,
         -0.0336, -0.1623],
        [ 0.0200, -0.0513,  0.1541,  0.3259,  0.1133, -0.0918, -0.0164, -0.2219,
         -0.2189,  0.2120],
        [-0.2219,  0.2433,  0.1217, -0.1107, -0.0285,  0.3464, -0.0170, -0.2445,
         -0.2170,  0.0898],
        [-0.1868, -0.1109, -0.3041,  0.2479, -0.2599,  0.0785, -0.2531,  0.1866,
          0.3535, -0.2657],
        [-0.3117,  0.1409, -0.1129,  0.3302,  0.0410,  0.0379, -0.3021, -0.1481,
         -0.1434, -0.0277],
        [ 0.1617,  0.2814, -0.1642,  0.2836,  0.2458,  0.2910,  0.1934,  0.3113,
         -0.0816, -0.2578],
        [-0.0840,  

In [144]:
input = torch.zeros((10,1,10))
input[0,0,0] = 1
input[3,0,0] = 1
input[6,0,5] = 1
input

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [145]:
output = test.forward(input, input, input)

In [146]:
print(output)

tensor([[[ 0.0608, -0.0304, -0.0043,  0.0067,  0.0411,  0.1107,  0.0527,
           0.0396, -0.0316,  0.0286]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0310,  0.0280]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0310,  0.0280]],

        [[ 0.0608, -0.0304, -0.0043,  0.0067,  0.0411,  0.1107,  0.0527,
           0.0396, -0.0316,  0.0286]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0310,  0.0280]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0310,  0.0280]],

        [[ 0.0600, -0.0296, -0.0055,  0.0075,  0.0418,  0.1104,  0.0522,
           0.0390, -0.0308,  0.0278]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0310,  0.0280]],

        [[ 0.0602, -0.0298, -0.0052,  0.0073,  0.0416,  0.1104,  0.0523,
           0.0391, -0.0

In [124]:
input

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

In [163]:
class MultiHeadedAttention_ESM2(nn.Module):
    def __init__(
        self,
        embed_dim,                                                       # d, embedding dim
        num_heads,                                                       # number of heads
        kdim = None,                                                     # d except in first layer or query != key
        vdim = None,                                                     # d except in first layer (i think?)
        dropout=0.0,                                                     # dropout
        bias = False,                                                    # whether to add bias or not
        self_attention: bool = False,                                        #
    
#        add_bias_kv: bool = False,                                       #
#        add_zero_attn: bool = False,                                     # 
#        encoder_decoder_attention: bool = False,                         # 
#        use_rotary_embeddings: bool = False,                             #
    ):
        super().__init__()                                               # necessary to have MHA be able to call functions from nn.Module
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim              
        self.vdim = vdim if vdim is not None else embed_dim
        self.qkv_same_dim = self.kdim == embed_dim and \
            self.vdim == embed_dim
        self.num_heads = num_heads
        self.dropout = dropout                                           # dropout, randomly select a few nodes to drop during training
        self.head_dim = embed_dim // num_heads
        assert (                                                         # head_dim = embed_dim/num_heads
            self.head_dim * num_heads == self.embed_dim
        ), "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim**-0.5
        self.self_attention = self_attention
        #self.encoder_decoder_attention = encoder_decoder_attention      # not going to use this ever for myself i think
        assert not self.self_attention or self.qkv_same_dim, (           
            "Self-attention requires query, key "+\
            "and " "value to be of the same size"
        )

        self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias)          # Q,K,V projection
        self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)        # linear layer after attention 
        
        if add_bias_kv:                                                   # whether to add biases to KV matrices
            self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
            self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
        else:
            self.bias_k = self.bias_v = None
        
        self.add_zero_attn = add_zero_attn
        self.reset_parameters()                                           # initialize weights
        self.onnx_trace = False
        self.rot_emb = None
        if use_rotary_embeddings:
            self.rot_emb = RotaryEmbedding(dim=self.head_dim)

        self.enable_torch_version = False
        if hasattr(F, "multi_head_attention_forward"):
            self.enable_torch_version = True
        else:
            self.enable_torch_version = False
    def reset_parameters(self):
        if self.qkv_same_dim:
            # Empirically observed the convergence to be much better with
            # the scaled initialization
            nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
            nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
        else:
            nn.init.xavier_uniform_(self.k_proj.weight)
            nn.init.xavier_uniform_(self.v_proj.weight)
            nn.init.xavier_uniform_(self.q_proj.weight)

        nn.init.xavier_uniform_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0.0)
        if self.bias_k is not None:
            nn.init.xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            nn.init.xavier_normal_(self.bias_v)
    def forward(
        self,
        query,
        key: Optional[Tensor],
        value: Optional[Tensor],
        key_padding_mask: Optional[Tensor] = None,
        incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
        #need_weights: bool = True,
        static_kv: bool = False,
        attn_mask: Optional[Tensor] = None,
        before_softmax: bool = False,
        #need_head_weights: bool = False,
    ):
        """
        input shape: (seqlen, batch, encode_dim)
        key_padding_mask: (batch, seqlen) where padding elements are 1
        attn_mask: usually for causal attention, maybe we use this for packing multiple seqs
        """
        #if need_head_weights:
        #    need_weights = True

        seqlen, batchsz, embed_dim = query.size()
        