In [10]:
import numpy as np
import torch
from torch import Tensor
from torch import nn

In [11]:
multihead_attn = nn.MultiheadAttention(32, 8)
query = np.random.rand(2, 1, 32)
key = np.random.rand(2, 1, 32)
value = np.random.rand(2, 1, 32)
attn_output, attn_output_weights = multihead_attn(Tensor(query), Tensor(key), Tensor(value))
print(attn_output)
print(attn_output.shape)

tensor([[[ 0.0109,  0.3142,  0.3060, -0.0961, -0.2895,  0.2136, -0.1708,
          -0.0997,  0.0912, -0.0031, -0.1519,  0.0259,  0.0421,  0.1709,
           0.2150, -0.1000,  0.3393,  0.2651, -0.0085, -0.0135,  0.0730,
           0.0116,  0.1105, -0.1352, -0.0823,  0.3899,  0.1814,  0.1009,
          -0.2097, -0.3307,  0.1143,  0.2149]],

        [[ 0.0072,  0.3131,  0.3063, -0.0994, -0.2912,  0.2116, -0.1724,
          -0.1033,  0.0965,  0.0024, -0.1512,  0.0233,  0.0406,  0.1782,
           0.2066, -0.0972,  0.3377,  0.2652, -0.0051, -0.0142,  0.0742,
           0.0048,  0.1180, -0.1363, -0.0826,  0.3924,  0.1792,  0.0996,
          -0.2099, -0.3327,  0.1132,  0.2122]]], grad_fn=<ViewBackward0>)
torch.Size([2, 1, 32])


In [12]:
import jax
import flax.linen as nn
from jax import numpy as jnp
from jax import random

input = jnp.ones((2,1,32))
layer = nn.MultiHeadAttention(num_heads=8, qkv_features=32)
variables = layer.init(random.key(42), input)
output = layer.apply(variables, Tensor(query), Tensor(key), Tensor(value))
print(output.shape)
print(output)

(2, 1, 32)
[[[ 0.21025246 -1.3815589   1.1856667   0.5420509  -0.03878744
    0.36193654 -0.30533046 -1.1700163  -0.57054275  0.46141142
   -0.30795845  0.39367586 -0.54798263  0.19380882 -0.5910839
   -0.4666907  -0.97481775  0.45212376 -0.10193141 -0.6575114
    0.25120044  0.122086   -1.1403954  -0.42054623 -0.61952204
   -0.6566897   0.73287785 -0.35822588 -0.5958375  -0.32572877
    1.2399867  -0.22128716]]

 [[ 0.03517252 -0.6995417   0.97694     0.07682744  0.24728103
    0.08959051  0.01169628 -0.8235637  -0.2518189   0.59149474
   -0.40409014 -0.5977189   0.09025317  0.08469346 -0.29417813
   -0.34428942 -0.61110383  0.7962756   0.02414452 -0.67949474
    0.6065721  -0.24253327 -0.3226804   0.27078897 -1.3271445
    0.25307137  0.3808884   0.18227924 -0.38555723 -0.42521483
    1.2446275  -0.62895083]]]


In [13]:
import mindspore as ms
from mindspore import nn
from mindspore import Tensor

multihead_attn = nn.MultiheadAttention(32, 8)
attn_output, attn_output_weights = multihead_attn(Tensor(query, ms.float32), Tensor(key, ms.float32), Tensor(value, ms.float32))
print(attn_output)
print(attn_output.shape)

[[[ 0.00069493 -0.03911137 -0.17373352 -0.08817849  0.00585917
    0.0038011   0.13423112  0.12640306 -0.13031635 -0.10004204
   -0.22439755 -0.07857274 -0.31675625 -0.07869545  0.18808702
    0.18783928  0.21366805  0.16640219 -0.2740562   0.02806476
   -0.16295458  0.29520345  0.17730789  0.13442674  0.35357815
    0.15538093  0.08111625 -0.09462467 -0.03054206  0.1470299
    0.03899852  0.20147526]]

 [[-0.00058219 -0.03654152 -0.17798103 -0.08727801  0.00280233
    0.00116073  0.1349529   0.12789461 -0.13350655 -0.10066106
   -0.2256635  -0.07729645 -0.31817535 -0.0761395   0.19332808
    0.19098276  0.21913132  0.16640013 -0.27903354  0.03269039
   -0.16899411  0.2947319   0.18080507  0.13609359  0.34764683
    0.16034965  0.08051767 -0.09571798 -0.03570132  0.15065208
    0.03331698  0.19594805]]]
(2, 1, 32)


**MindSpore MHA用法与输出与PyTorch一致**
- MindSpore的Tensor()没有像torch.Tensor()一样内置类型转换，传入float64会报错