In [1]:
#pragma cling add_include_path("../../libtorch/include")
#pragma cling add_include_path("../../libtorch/include/torch/csrc/api/include")
#pragma cling add_library_path("../../libtorch/lib")
#pragma cling load("libtorch")

In [2]:
#include <iostream>
#include <tuple>
#include <torch/torch.h>
namespace nn = torch::nn;

# 1.1 torch.nn.MultiheadAttention

~~~
class TORCH_API MultiheadAttentionImpl
    : public torch::nn::Cloneable<MultiheadAttentionImpl> {
 public:
  MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads)
      : MultiheadAttentionImpl(
            MultiheadAttentionOptions(embed_dim, num_heads)) {}
  explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_);

  std::tuple<Tensor, Tensor> forward(
      const Tensor& query,
      const Tensor& key,
      const Tensor& value,
      const Tensor& key_padding_mask = {},
      bool need_weights = true,
      const Tensor& attn_mask = {},
      bool average_attn_weights = true);
~~~

## a: core usage

In [3]:
nn::MultiheadAttention multihead_attention(4,1);

In [4]:
// seq_length x batch_size x feature_size
torch::Tensor x = torch::randn({3,1,4});
std::cout << x << std::endl;

(1,.,.) = 
  0.8593  0.1874 -1.5607  0.3263

(2,.,.) = 
 -0.5437 -0.5261  1.0795  1.3369

(3,.,.) = 
  1.1896 -2.4683  0.3097  0.6895
[ CPUFloatType{3,1,4} ]


In [5]:
std::tuple<torch::Tensor, torch::Tensor> output = multihead_attention ->forward(x,x,x);

In [6]:
torch::Tensor attn_output, attn_output_weights;
std::tie(attn_output, attn_output_weights) = output; //unpacking tuple into variables

@0x7ffe238b0908

In [7]:
std::cout << attn_output << std::endl;

(1,.,.) = 
  0.1621 -0.3210 -0.2843  0.1198

(2,.,.) = 
  0.1830 -0.3726 -0.3286  0.0507

(3,.,.) = 
 0.001 *
 -5.2037 -564.5055 -452.9999 -276.5349
[ CPUFloatType{3,1,4} ]


In [8]:
std::cout << attn_output_weights << std::endl;

(1,.,.) = 
  0.3921  0.3988  0.2091
  0.3343  0.3094  0.3563
  0.1948  0.5092  0.2960
[ CPUFloatType{1,3,3} ]


# b understand attn_mask in multihead_attention

In [9]:
torch::Tensor attn_mask = nn::TransformerImpl::generate_square_subsequent_mask(3);

In [10]:
std::cout << attn_mask << std::endl;

 0 -inf -inf
 0  0 -inf
 0  0  0
[ CPUFloatType{3,3} ]


In [11]:
std::cout << attn_mask.dtype() << std::endl;

float


In [12]:
torch::Tensor attn_mask_bool = attn_mask.to(torch::kBool);

In [13]:
//attn_mask_bool = attn_mask_bool.unsqueeze(1);

In [14]:
std::cout << attn_mask_bool << std::endl;

 0  1  1
 0  0  1
 0  0  0
[ CPUBoolType{3,3} ]


In [16]:
std::tuple<torch::Tensor, torch::Tensor> output_with_attn_mask = multihead_attention ->forward(x,x,x,{},true,attn_mask_bool);

In [17]:
std::tie(attn_output, attn_output_weights) = output_with_attn_mask; //unpacking tuple into variables

@0x7ffe238b0908

In [18]:
std::cout << attn_output << std::endl;

(1,.,.) = 
 0.01 *
 -1.8859 -57.1378 -45.6406 -29.0087

(2,.,.) = 
  0.1834 -0.4987 -0.4292 -0.1315

(3,.,.) = 
 0.001 *
 -5.2037 -564.5055 -452.9999 -276.5349
[ CPUFloatType{3,1,4} ]


In [19]:
std::cout << attn_output_weights << std::endl;

(1,.,.) = 
  0.1918  0.5302  0.2780
  0.2074  0.1919  0.6007
  0.1948  0.5092  0.2960
[ CPUFloatType{1,3,3} ]


# c: understand key_padding_mask in multihead_attention

In [20]:
torch::Tensor padding_mask = torch::tensor({{false, false, true}});

In [21]:
std::cout << padding_mask << std::endl;

 0  0  1
[ CPUBoolType{1,3} ]


In [22]:
std::tuple<torch::Tensor, torch::Tensor> output_with_padding_mask = multihead_attention ->forward(x, x, x, padding_mask);

In [23]:
std::tie(attn_output, attn_output_weights) = output_with_padding_mask; //unpacking tuple into variables

@0x7ffe238b0908

In [24]:
std::cout << attn_output << std::endl;

(1,.,.) = 
  0.1563 -0.2195 -0.2026  0.2649

(2,.,.) = 
  0.1823 -0.1887 -0.1820  0.3162

(3,.,.) = 
 0.01 *
 -8.4847 -50.5522 -39.3943 -21.2406
[ CPUFloatType{3,1,4} ]


In [25]:
std::cout << attn_output_weights << std::endl;

(1,.,.) = 
  0.4958  0.5042  0.0000
  0.5194  0.4806  0.0000
  0.2767  0.7233  0.0000
[ CPUFloatType{1,3,3} ]


# d: mix use of attn_mask and padding mask

In [27]:
std::tuple<torch::Tensor, torch::Tensor> output_with_both_mask = multihead_attention ->forward(x, x, x, padding_mask,true, attn_mask_bool);

In [28]:
std::tie(attn_output, attn_output_weights) = output_with_both_mask; //unpacking tuple into variables

@0x7ffe238b0908

In [29]:
std::cout << attn_output << std::endl;

(1,.,.) = 
 0.01 *
 -9.7059 -52.0005 -40.3632 -23.6575

(2,.,.) = 
  0.1823 -0.1887 -0.1820  0.3162

(3,.,.) = 
 0.01 *
 -8.4847 -50.5522 -39.3943 -21.2406
[ CPUFloatType{3,1,4} ]


In [30]:
std::cout << attn_output_weights << std::endl;

(1,.,.) = 
  0.2656  0.7344  0.0000
  0.5194  0.4806  0.0000
  0.2767  0.7233  0.0000
[ CPUFloatType{1,3,3} ]


# 1.2 nn.TransformerEncoderLayer

In [32]:
nn::TransformerEncoderLayer encoder_layer(4,1);

In [35]:
torch::Tensor out = encoder_layer(x);

In [36]:
std::cout << out << std::endl;

(1,.,.) = 
  0.7032  0.1064 -1.6631  0.8534

(2,.,.) = 
 -1.0565 -0.8991  0.6921  1.2635

(3,.,.) = 
  0.0836 -1.6498  0.9322  0.6341
[ CPUFloatType{3,1,4} ]


In [38]:
attn_mask = torch::randint(0,2,{3,3});

In [39]:
std::cout << attn_mask << std::endl;

 1  0  0
 0  0  0
 1  1  0
[ CPUFloatType{3,3} ]


In [40]:
out = encoder_layer(x, attn_mask);

In [41]:
std::cout << out << std::endl;

(1,.,.) = 
  0.8103  0.5036 -1.7121  0.3982

(2,.,.) = 
 -0.6864 -1.2529  0.7711  1.1681

(3,.,.) = 
 -0.0554 -1.6022  0.9965  0.6610
[ CPUFloatType{3,1,4} ]


# 1.3 torch.nn.TransformerEncoder

In [42]:
nn::TransformerEncoder transformer_encoder = nn::TransformerEncoder(encoder_layer, /*num_layers=*/2);

In [43]:
out = transformer_encoder(x, attn_mask);

In [44]:
std::cout << out << std::endl;

(1,.,.) = 
  0.4122  0.2284 -1.6569  1.0162

(2,.,.) = 
 -1.1739 -0.7709  0.7113  1.2335

(3,.,.) = 
 -0.5727 -1.3459  0.8593  1.0593
[ CPUFloatType{3,1,4} ]
