In [None]:
# default_exp models.layers.attention

# Attention Layers
> Implementation of Attention modules including Multihead attention etc..

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from recohut.utils.distances import wasserstein_distance_matmul

## TokenEmbedding

In [None]:
#export
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512):
        super().__init__(vocab_size, embed_size, padding_idx=0)

## PositionalEmbedding

In [None]:
#export
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.d_model = d_model
        self.pe = nn.Embedding(max_len+1, d_model)

    def forward(self, x):
        pose = (x > 0) * (x > 0).sum(dim=-1).unsqueeze(1).repeat(1, x.size(-1))
        pose += torch.arange(start=-(x.size(1)-1), end=1, step=1, device=x.device)
        pose = pose * (x > 0)

        return self.pe(pose)

In [None]:
#export
class GELU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [None]:
#export
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.activation = GELU()

    def forward(self, x):
        return self.w_2(self.activation(self.w_1(x)))

In [None]:
#export
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.weight * (x - mean) / (std + self.eps) + self.bias

In [None]:
#export
class SublayerConnection(nn.Module):
    """layer norm and dropout (dropout and then layer norm)
    """
    def __init__(self, size, dropout):
        super().__init__()
        self.layer_norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        # return x + self.dropout(sublayer(self.norm(x)))  # original implementation
        return self.layer_norm(x + self.dropout(sublayer(x)))  # BERT4Rec implementation

## Attention

In [None]:
#export
class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None, sas=False):
        scores = torch.matmul(query, key.transpose(-2, -1)) \
            / math.sqrt(query.size(-1))

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        if sas:
            direction_mask = torch.ones_like(scores)
            direction_mask = torch.tril(direction_mask)
            scores = scores.masked_fill(direction_mask == 0, -1e9)

        p_attn = F.softmax(scores, dim=-1)

        if dropout is not None:
            p_attn = dropout(p_attn)

        return torch.matmul(p_attn, value), p_attn

## ScaledDotProductAttention

In [None]:
#export
class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention 
    Ref: https://zhuanlan.zhihu.com/p/47812375
    """
    def __init__(self, dropout_rate=0.):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = None
        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, W_q, W_k, W_v, scale=None, mask=None):
        attention = torch.bmm(W_q, W_k.transpose(1, 2))
        if scale:
            attention = attention / scale
        if mask:
            attention = attention.masked_fill_(mask, -np.inf)
        attention = self.softmax(attention)
        if self.dropout is not None:
            attention = self.dropout(attention)
        output = torch.bmm(attention, W_v)
        return output, attention

## MultiHeadAttention

In [None]:
#export
class MultiHeadAttention(nn.Module):
    """ Multi-head attention module """

    def __init__(self, input_dim, attention_dim=None, num_heads=1, dropout_rate=0., 
                 use_residual=True, use_scale=False, layer_norm=False, align_to="input"):
        super(MultiHeadAttention, self).__init__()
        if attention_dim is None:
            attention_dim = input_dim // num_heads
        self.attention_dim = attention_dim
        self.output_dim = num_heads * attention_dim
        self.num_heads = num_heads
        self.use_residual = use_residual
        self.align_to = align_to
        self.scale = attention_dim ** 0.5 if use_scale else None
        self.W_q = nn.Linear(input_dim, self.output_dim, bias=False)
        self.W_k = nn.Linear(input_dim, self.output_dim, bias=False)
        self.W_v = nn.Linear(input_dim, self.output_dim, bias=False)
        if input_dim != self.output_dim:
            if align_to == "output":
                self.W_res = nn.Linear(input_dim, self.output_dim, bias=False)
            elif align_to == "input":
                self.W_res = nn.Linear(self.output_dim, input_dim, bias=False)
        else:
            self.W_res = None
        self.dot_product_attention = ScaledDotProductAttention(dropout_rate)
        self.layer_norm = nn.LayerNorm(self.output_dim) if layer_norm else None
        self.dropout = nn.Dropout(dropout_rate) if dropout_rate > 0 else None

    def forward(self, query, key, value, mask=None):
        residual = query
        
        # linear projection
        query = self.W_q(query)
        key = self.W_k(key)
        value = self.W_v(value)
        
        # split by heads
        batch_size = query.size(0)
        query = query.view(batch_size * self.num_heads, -1, self.attention_dim)
        key = key.view(batch_size * self.num_heads, -1, self.attention_dim)
        value = value.view(batch_size * self.num_heads, -1, self.attention_dim)
        if mask:
            mask = mask.repeat(self.num_heads, 1, 1)
        # scaled dot product attention
        output, attention = self.dot_product_attention(query, key, value, self.scale, mask)
        # concat heads
        output = output.view(batch_size, -1, self.output_dim)
        # final linear projection
        if self.W_res is not None:
            if self.align_to == "output": # AutoInt style
                residual = self.W_res(residual)
            elif self.align_to == "input": # Transformer stype
                output = self.W_res(output)
        if self.dropout is not None:
            output = self.dropout(output)
        if self.use_residual:
            output = output + residual
        if self.layer_norm is not None:
            output = self.layer_norm(output)
        output = output.relu()
        return output, attention

## MultiHeadedAttention_v2

In [None]:
#export
class MultiHeadedAttention_v2(nn.Module):
    def __init__(self, h, d_model, head_size=None, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        self.h = h
        self.d_k = d_model // h
        if head_size is not None:
            self.head_size = head_size
        else:
            self.head_size = d_model // h

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])
        self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)
        self.output_linear = nn.Linear(self.h * self.head_size, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) do all the linear projections in batch from d_model => h x d_k
        query, key, value = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        
        # 2) apply attention on all the projected vectors in batch.
        x, attn = self.attention(
            query, key, value, mask=mask, dropout=self.dropout)

        # 3) "concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(
            batch_size, -1, self.h * self.head_size)
        return self.output_linear(x)

## MultiHeadSelfAttention

In [None]:
#export
class MultiHeadSelfAttention(MultiHeadAttention):
    def forward(self, X):
        output, attention = super(MultiHeadSelfAttention, self).forward(X, X, X)
        return output

## TransformerBlock

In [None]:
#export
class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):
        super().__init__()
        self.attention = MultiHeadedAttention(
            h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)
        self.feed_forward = PositionwiseFeedForward(
            d_model=hidden, d_ff=feed_forward_hidden)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(
            x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return x

## SqueezeExcitationLayer

In [None]:
#export
class SqueezeExcitationLayer(nn.Module):
    def __init__(self, num_fields, reduction_ratio=3):
        super(SqueezeExcitationLayer, self).__init__()
        reduced_size = max(1, int(num_fields / reduction_ratio))
        self.excitation = nn.Sequential(nn.Linear(num_fields, reduced_size, bias=False),
                                        nn.ReLU(),
                                        nn.Linear(reduced_size, num_fields, bias=False),
                                        nn.ReLU())

    def forward(self, feature_emb):
        Z = torch.mean(feature_emb, dim=-1, out=None)
        A = self.excitation(Z)
        V = feature_emb * A.unsqueeze(-1)
        return V

## SASMultiHeadedAttention

In [None]:
#export
class SASMultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, head_size=None, dropout=0.1):
        super().__init__()
        assert d_model % h == 0

        self.h = h
        self.d_k = d_model // h
        if head_size is not None:
            self.head_size = head_size
        else:
            self.head_size = d_model // h

        self.linear_layers = nn.ModuleList(
            [nn.Linear(d_model, self.h * self.head_size) for _ in range(3)])
        self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = LayerNorm(d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # 1) do all the linear projections in batch from d_model => h x d_k
        query_, key_, value_ = [l(x).view(batch_size, -1, self.h, self.head_size).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        
        # 2) apply attention on all the projected vectors in batch.
        x, attn = self.attention(
            query_, key_, value_, mask=mask, dropout=self.dropout, sas=True)

        # 3) "concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(
            batch_size, -1, self.h * self.head_size)
        
        return self.layer_norm(x + query)

In [None]:
#export
class SASPositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.layer_norm = LayerNorm(d_model)

    def forward(self, x):
        x_ = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1))))
        return self.layer_norm(self.dropout(self.conv2(x_)).permute(0, 2, 1) + x)

## SASTransformerBlock

In [None]:
#export
class SASTransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, head_size, feed_forward_hidden, dropout, attn_dropout=0.1):
        super().__init__()
        self.layer_norm = LayerNorm(hidden)
        self.attention = SASMultiHeadedAttention(
            h=attn_heads, d_model=hidden, head_size=head_size, dropout=attn_dropout)
        self.feed_forward = SASPositionwiseFeedForward(
            d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)

    def forward(self, x, mask):
        x = self.attention(self.layer_norm(x), x, x, mask)
        x = self.feed_forward(x)
        return x

## SelfAttention

In [None]:
#export
class SelfAttention(nn.Module):
    """
    References:
        1. https://github.com/RecoHut-Stanzas/STOSA/blob/ee14e2eabcc60922eb52cc7d3231df4954d9ff16/modules.py#L127
    """
    def __init__(self,
                 hidden_size,
                 num_attention_heads,
                 attention_probs_dropout_prob,
                 hidden_dropout_prob):
        super().__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)

        self.dense = nn.Linear(hidden_size, hidden_size)
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_tensor, attention_mask):
        mixed_query_layer = self.query(input_tensor)
        mixed_key_layer = self.key(input_tensor)
        mixed_value_layer = self.value(input_tensor)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.attn_dropout(attention_probs)
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        hidden_states = self.dense(context_layer)
        hidden_states = self.out_dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states + input_tensor)

        return hidden_states, attention_probs

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2

layer = SelfAttention(hidden_size, num_attention_heads, hidden_dropout_prob,
                      attention_probs_dropout_prob)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

hidden_states = torch.round(layer.forward(input_tensor, attention_mask)[0].detach()*1e4)/1e4

test_eq(hidden_states.shape.numel(), 32)
test_eq(list(hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(layer.forward(input_tensor, attention_mask)[1].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

## DistSelfAttention

In [None]:
#export
class DistSelfAttention(nn.Module):
    def __init__(self,
                 hidden_size,
                 num_attention_heads,
                 hidden_dropout_prob,
                 attention_probs_dropout_prob,
                 distance_metric = 'wasserstein'):
        super().__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.mean_query = nn.Linear(hidden_size, self.all_head_size)
        self.cov_query = nn.Linear(hidden_size, self.all_head_size)
        self.mean_key = nn.Linear(hidden_size, self.all_head_size)
        self.cov_key = nn.Linear(hidden_size, self.all_head_size)
        self.mean_value = nn.Linear(hidden_size, self.all_head_size)
        self.cov_value = nn.Linear(hidden_size, self.all_head_size)

        self.activation = nn.ELU()

        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
        self.mean_dense = nn.Linear(hidden_size, hidden_size)
        self.cov_dense = nn.Linear(hidden_size, hidden_size)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

        self.distance_metric = distance_metric
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_mean_tensor, input_cov_tensor, attention_mask):
        mixed_mean_query_layer = self.mean_query(input_mean_tensor)
        mixed_mean_key_layer = self.mean_key(input_mean_tensor)
        mixed_mean_value_layer = self.mean_value(input_mean_tensor)

        mean_query_layer = self.transpose_for_scores(mixed_mean_query_layer)
        mean_key_layer = self.transpose_for_scores(mixed_mean_key_layer)
        mean_value_layer = self.transpose_for_scores(mixed_mean_value_layer)

        mixed_cov_query_layer = self.activation(self.cov_query(input_cov_tensor)) + 1
        mixed_cov_key_layer = self.activation(self.cov_key(input_cov_tensor)) + 1
        mixed_cov_value_layer = self.activation(self.cov_value(input_cov_tensor)) + 1

        cov_query_layer = self.transpose_for_scores(mixed_cov_query_layer)
        cov_key_layer = self.transpose_for_scores(mixed_cov_key_layer)
        cov_value_layer = self.transpose_for_scores(mixed_cov_value_layer)

        if self.distance_metric == 'wasserstein':
            attention_scores = -wasserstein_distance_matmul(mean_query_layer, cov_query_layer, mean_key_layer, cov_key_layer)
        else:
            attention_scores = -kl_distance_matmul(mean_query_layer, cov_query_layer, mean_key_layer, cov_key_layer)

        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_scores = attention_scores + attention_mask
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        attention_probs = self.attn_dropout(attention_probs)
        mean_context_layer = torch.matmul(attention_probs, mean_value_layer)
        cov_context_layer = torch.matmul(attention_probs ** 2, cov_value_layer)
        mean_context_layer = mean_context_layer.permute(0, 2, 1, 3).contiguous()
        cov_context_layer = cov_context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = mean_context_layer.size()[:-2] + (self.all_head_size,)

        mean_context_layer = mean_context_layer.view(*new_context_layer_shape)
        cov_context_layer = cov_context_layer.view(*new_context_layer_shape)

        mean_hidden_states = self.mean_dense(mean_context_layer)
        mean_hidden_states = self.out_dropout(mean_hidden_states)
        mean_hidden_states = self.layernorm(mean_hidden_states + input_mean_tensor)

        cov_hidden_states = self.cov_dense(cov_context_layer)
        cov_hidden_states = self.out_dropout(cov_hidden_states)
        cov_hidden_states = self.layernorm(cov_hidden_states + input_cov_tensor)

        return mean_hidden_states, cov_hidden_states, attention_probs

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2
distance = 'wasserstein'

layer = DistSelfAttention(hidden_size, num_attention_heads, hidden_dropout_prob,
                          attention_probs_dropout_prob, distance)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)

mean_hidden_states = torch.round(output[0].detach()*1e4)/1e4

test_eq(mean_hidden_states.shape.numel(), 32)
test_eq(list(mean_hidden_states.shape), [2, 4, 4])

cov_hidden_states = torch.round(output[1].detach()*1e4)/1e4

test_eq(cov_hidden_states.shape.numel(), 32)
test_eq(list(cov_hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(output[2].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

## DistMeanSelfAttention

In [None]:
#export
class DistMeanSelfAttention(nn.Module):
    def __init__(self,
                 hidden_size,
                 num_attention_heads,
                 attention_probs_dropout_prob,
                 hidden_dropout_prob):
        super().__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, num_attention_heads))
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.mean_query = nn.Linear(hidden_size, self.all_head_size)
        self.mean_key = nn.Linear(hidden_size, self.all_head_size)
        self.mean_value = nn.Linear(hidden_size, self.all_head_size)
        self.cov_key = nn.Linear(hidden_size, self.all_head_size)
        self.cov_query = nn.Linear(hidden_size, self.all_head_size)
        self.cov_value = nn.Linear(hidden_size, self.all_head_size)

        self.activation = nn.ELU()

        self.attn_dropout = nn.Dropout(attention_probs_dropout_prob)
        self.mean_dense = nn.Linear(hidden_size, hidden_size)
        self.cov_dense = nn.Linear(hidden_size, hidden_size)
        self.out_dropout = nn.Dropout(hidden_dropout_prob)

        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, input_mean_tensor, input_cov_tensor, attention_mask):
        mixed_mean_query_layer = self.mean_query(input_mean_tensor)
        mixed_mean_key_layer = self.mean_key(input_mean_tensor)
        mixed_mean_value_layer = self.mean_value(input_mean_tensor)

        mean_query_layer = self.transpose_for_scores(mixed_mean_query_layer)
        mean_key_layer = self.transpose_for_scores(mixed_mean_key_layer)
        mean_value_layer = self.transpose_for_scores(mixed_mean_value_layer)

        mixed_cov_query_layer = self.activation(self.cov_query(input_cov_tensor)) + 1
        mixed_cov_key_layer = self.activation(self.cov_key(input_cov_tensor)) + 1
        mixed_cov_value_layer = self.activation(self.cov_value(input_cov_tensor)) + 1

        cov_query_layer = self.transpose_for_scores(mixed_cov_query_layer)
        cov_key_layer = self.transpose_for_scores(mixed_cov_key_layer)
        cov_value_layer = self.transpose_for_scores(mixed_cov_value_layer)

        mean_attention_scores = torch.matmul(mean_query_layer, mean_key_layer.transpose(-1, -2))
        cov_attention_scores = torch.matmul(cov_query_layer, cov_key_layer.transpose(-1, -2))

        mean_attention_scores = mean_attention_scores / math.sqrt(self.attention_head_size)
        mean_attention_scores = mean_attention_scores + attention_mask
        mean_attention_probs = nn.Softmax(dim=-1)(mean_attention_scores)

        cov_attention_scores = cov_attention_scores / math.sqrt(self.attention_head_size)
        cov_attention_scores = cov_attention_scores + attention_mask
        cov_attention_probs = nn.Softmax(dim=-1)(cov_attention_scores)

        mean_attention_probs = self.attn_dropout(mean_attention_probs)
        cov_attention_probs = self.attn_dropout(cov_attention_probs)
        mean_context_layer = torch.matmul(mean_attention_probs, mean_value_layer)
        cov_context_layer = torch.matmul(cov_attention_probs, cov_value_layer)
        mean_context_layer = mean_context_layer.permute(0, 2, 1, 3).contiguous()
        cov_context_layer = cov_context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = mean_context_layer.size()[:-2] + (self.all_head_size,)

        mean_context_layer = mean_context_layer.view(*new_context_layer_shape)
        cov_context_layer = cov_context_layer.view(*new_context_layer_shape)

        mean_hidden_states = self.mean_dense(mean_context_layer)
        mean_hidden_states = self.out_dropout(mean_hidden_states)
        mean_hidden_states = self.layernorm(mean_hidden_states + input_mean_tensor)

        cov_hidden_states = self.cov_dense(cov_context_layer)
        cov_hidden_states = self.out_dropout(cov_hidden_states)
        cov_hidden_states = self.layernorm(cov_hidden_states + input_cov_tensor)

        return mean_hidden_states, cov_hidden_states, mean_attention_probs

In [None]:
hidden_size = 4
num_attention_heads = 2
hidden_dropout_prob = 0.2
attention_probs_dropout_prob = 0.2

layer = DistMeanSelfAttention(hidden_size, num_attention_heads, hidden_dropout_prob,
                          attention_probs_dropout_prob)

input_tensor = torch.rand((2,4,4))
attention_mask = torch.rand((4,4))

output = layer.forward(input_tensor, input_tensor, attention_mask)

mean_hidden_states = torch.round(output[0].detach()*1e4)/1e4

test_eq(mean_hidden_states.shape.numel(), 32)
test_eq(list(mean_hidden_states.shape), [2, 4, 4])

cov_hidden_states = torch.round(output[1].detach()*1e4)/1e4

test_eq(cov_hidden_states.shape.numel(), 32)
test_eq(list(cov_hidden_states.shape), [2, 4, 4])

attention_probs = torch.round(output[2].detach()*1e4)/1e4

test_eq(attention_probs.shape.numel(), 64)
test_eq(list(attention_probs.shape), [2, 2, 4, 4])

> References
1. https://github.com/sparsh-ai/stanza/blob/S714864/model/attention.py
2. https://github.com/xue-pai/FuxiCTR/blob/main/fuxictr/pytorch/layers/attention.py

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p recohut

Author: Sparsh A.

Last updated: 2022-01-22 15:19:18

recohut: 0.0.11

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

IPython: 5.5.0
torch  : 1.10.0+cu111

