In [1]:
%matplotlib inline
import torch as tc
import tensorflow as tf
import mxnet as mx
from mxnet import np as mxnp
from mxnet import npx as npx
npx.set_np()
# import numpy as np

from d2l import mxnet as mxd2l  # Use MXNet as the backend
from d2l import torch as tcd2l  # Use PyTorch as the backend
from d2l import tensorflow as tfd2l  # Use TensorFlow as the backend

tc.__version__,tf.__version__,mx.__version__

('1.6.0', '2.3.1', '1.7.0')

In [16]:
from mxnet.gluon import nn
import math

In [17]:
def masked_softmax(X,valid_len):
    """Perform softmax by filtering out some elements."""
    # X: 3-D tensor, valid_len: 1-D or 2-D tensor
    if valid_len is None:
        return npx.softmax(X)
    else:
        shape=X.shape
        if valid_len.ndim==1:
            valid_len=valid_len.repeat(shape[1],axis=0)
        else:
            valid_len=valid_len.reshape(-1)
        # Fill masked elements with a large negative, whose exp is 0
        X=npx.sequence_mask(X.reshape(-1,shape[-1]),valid_len,True,axis=1,value=-1e6)
        return npx.softmax(X).reshape(shape)

In [18]:
# X=mxnp.random.uniform(size=(2, 2, 4))
X=mxnp.ones((2, 2, 4))
X

array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.]]])

In [19]:
valid_len=mxnp.array([2,3])

In [20]:
masked_softmax(X,valid_len)

array([[[0.5       , 0.5       , 0.        , 0.        ],
        [0.5       , 0.5       , 0.        , 0.        ]],

       [[0.33333334, 0.33333334, 0.33333334, 0.        ],
        [0.33333334, 0.33333334, 0.33333334, 0.        ]]])

In [21]:
valid_len=mxnp.array([[1,2],[3,4]])

In [22]:
masked_softmax(X,valid_len)

array([[[1.        , 0.        , 0.        , 0.        ],
        [0.5       , 0.5       , 0.        , 0.        ]],

       [[0.33333334, 0.33333334, 0.33333334, 0.        ],
        [0.25      , 0.25      , 0.25      , 0.25      ]]])

In [23]:
class DotProductAttention(nn.Block):
    def __init__(self, dropout, **kwargs):
        super().__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    # `query`: (`batch_size`, #queries,  `d`)
    # `key`:   (`batch_size`, #kv_pairs, `d`)
    # `value`: (`batch_size`, #kv_pairs, `dim_v`)
    # `valid_len`: either (`batch_size`, ) or (`batch_size`, (`batch_size`* #queries))
    def forward(self, query, key, value, valid_len=None):
        d = query.shape[-1]
        # Set transpose_b=True to swap the last two dimensions of key
        scores = npx.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return npx.batch_dot(attention_weights, value)

In [35]:
class MLPAttention(nn.Block):
    def __init__(self,hidden_units,dropout,**kwargs):
        super().__init__(**kwargs)
        # Use flatten=False to keep query's and key's 3-D shapes
        self.W_q=nn.Dense(hidden_units,use_bias=False,flatten=False)
        self.W_k=nn.Dense(hidden_units,use_bias=False,flatten=False)
        self.v=nn.Dense(1,use_bias=False,flatten=False)
        self.dropout=nn.Dropout(dropout)

    def forward(self,query,key,value,valid_len):
        query=self.W_q(query)
        key  =self.W_k(key)
        # Expand query to (`batch_size`, #queries, 1,  units),
        #    and  key  to (`batch_size`, 1, #kv_pairs, units). Then plus them with broadcast
        feature_q=mxnp.expand_dims(query,axis=2)
        feature_k=mxnp.expand_dims(key,axis=1)
        features=feature_q+feature_k
        features=mxnp.tanh(features)
        scores=self.v(features)
        scores=mxnp.squeeze(scores,axis=-1)
        attention_weights=self.dropout(masked_softmax(scores, valid_len))
        return npx.batch_dot(attention_weights, value)

In [36]:
def DotProductAttention_unittest():
    atten = DotProductAttention(dropout=0.5)
    atten.initialize()

    query=mxnp.ones((2, 2, 2))   #(2, 2,  2)
    key=  mxnp.ones((2, 10, 2))  #(2, 10, 2)
    #                       value (2, 10, 4)
    value= mxnp.arange(40).reshape(1, 10, 4).repeat(2, axis=0)
    valid_len1 = mxnp.array([2, 6])
    valid_len2 = mxnp.array([[2, 6],[3,4]])
    print(f"Shapes:query{query.shape},key{key.shape},value{value.shape},valid_len{valid_len1.shape}")
    r=atten(query, key, value, valid_len1)
    print(r.shape)                #(2, 2, 4)
    print(f"Shapes:query{query.shape},key{key.shape},value{value.shape},valid_len{valid_len1.shape}")
    r2 = atten(query, key, value, valid_len2)
    print(r2.shape)                #(2, 2, 4)
    return r1, r2

In [37]:
r1, r2=DotProductAttention_unittest()

Shapes:query(2, 2, 2),key(2, 10, 2),value(2, 10, 4),valid_len(2,)
(2, 2, 4)
Shapes:query(2, 2, 2),key(2, 10, 2),value(2, 10, 4),valid_len(2,)
(2, 2, 4)


In [38]:
r1, r2

(array([[[ 2.       ,  3.       ,  4.       ,  5.       ],
         [10.       , 11.       , 12.       , 13.       ]],
 
        [[ 4.       ,  5.       ,  6.       ,  7.0000005],
         [ 6.       ,  7.       ,  8.       ,  9.       ]]]),
 array([[[ 2.       ,  3.       ,  4.       ,  5.       ],
         [10.       , 11.       , 12.       , 13.       ]],
 
        [[ 4.       ,  5.       ,  6.       ,  7.0000005],
         [ 6.       ,  7.       ,  8.       ,  9.       ]]]))

## query的shape可以自用变化了,valid_len的shape与query相关

In [32]:
def MLPAttention_unittest():
    atten = MLPAttention(hidden_units=8, dropout=0.1)
    atten.initialize()

    query1=mxnp.ones((2, 2, 2))    #(2, 2,  2)
    query2 = mxnp.ones((2, 3, 4))  # (2, 2, 4)  query的shape可以自用变化了
    key=  mxnp.ones((2, 10, 2))  #(2, 10, 2)
    #                       value (2, 10, 4)
    value= mxnp.arange(40).reshape(1, 10, 4).repeat(2, axis=0)

    valid_len1 = mxnp.array([2, 6])          #`valid_len`: either (`batch_size`, )
    valid_len2 = mxnp.array([[2, 6],[3,4]])  #                 or (`batch_size`, (`batch_size`* #queries))

    r1=atten(query1, key, value, valid_len2)
    print(r1.shape)

    atten = MLPAttention(hidden_units=8, dropout=0.1)
    atten.initialize()
    r2=atten(query2, key, value, valid_len1)
    print(r2.shape)
    return r1, r2

In [39]:
r1,r2=MLPAttention_unittest()

(2, 2, 4)
(2, 3, 4)


In [40]:
r1,r2

(array([[[ 2.       ,  3.       ,  4.       ,  5.       ],
         [10.       , 11.       , 12.       , 13.       ]],
 
        [[ 4.       ,  5.       ,  6.       ,  7.0000005],
         [ 6.       ,  7.       ,  8.       ,  9.       ]]]),
 array([[[ 2.,  3.,  4.,  5.],
         [ 2.,  3.,  4.,  5.],
         [ 2.,  3.,  4.,  5.]],
 
        [[10., 11., 12., 13.],
         [10., 11., 12., 13.],
         [10., 11., 12., 13.]]]))