# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (dropout): Dropout(p=0.0, inplace=False)
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 4.2927e-01, -5.6774e-01, -3.8427e-01,  3.3193e-01, -1.3516e-01,
          -6.2389e-01, -2.4914e-01,  3.7619e-01,  5.7470e-01,  5.8342e-01,
           6.3682e-02,  7.2657e-01,  6.4011e-01, -1.1232e-01, -1.7830e-01,
           1.0744e-01],
         [ 4.3919e-01, -5.2775e-01, -4.9849e-01,  3.2339e-01, -1.3742e-01,
          -6.5530e-01, -1.5063e-01,  3.9463e-01,  4.6116e-01,  6.1640e-01,
          -5.3225e-02,  8.0006e-01,  6.7477e-01, -8.8532e-02, -6.4851e-02,
           5.6645e-02],
         [ 4.6637e-01, -4.4480e-01, -5.5922e-01,  3.0401e-01, -1.3780e-01,
          -6.5968e-01, -1.4163e-01,  3.7383e-01,  5.4247e-01,  5.8901e-01,
          -1.0480e-01,  6.9249e-01,  6.2579e-01, -1.3175e-01,  4.1674e-04,
           8.5077e-02],
         [ 5.0909e-01, -5.6102e-01, -3.8439e-01,  3.0946e-01, -1.0074e-01,
          -6.8404e-01, -2.2661e-01,  3.3777e-01,  5.9691e-01,  5.9284e-01,
           2.8411e-02,  7.1868e-01,  6.8228e-01, -1.5033e-01, -1.2509e-01,
     

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 0.1141,  0.1785,  0.0131,  0.3057, -0.2007,  1.2946,  0.9211,
           0.3265,  0.9042,  1.0670, -1.5717, -0.7287,  0.3936, -2.6306,
          -0.9293,  0.5424],
         [-0.5400,  0.4761,  0.5886,  0.1282,  0.3549,  0.7022,  1.1494,
           2.5192,  0.2201, -0.3523,  0.2181, -1.5305, -1.0345, -1.0993,
          -1.3306, -0.4696],
         [ 1.3431,  0.4197,  1.0427, -1.1045, -0.5771,  0.7826,  0.5608,
          -0.4944, -1.7668,  0.9593,  0.9504,  0.3086, -0.6681, -1.6351,
           0.9496, -1.0706],
         [ 0.3183, -1.3204, -0.7837,  1.8534, -0.0522, -1.0064,  1.0298,
          -1.0653, -1.4799,  1.0242, -0.8385, -0.5091,  0.9919,  0.2349,
           0.3711,  1.2317]],

        [[-0.4410,  1.1942, -0.9669,  0.1602, -0.7747,  2.1991,  0.4440,
          -0.3265, -0.8037, -1.5178,  0.8924, -1.1214, -0.9704,  0.2138,
           0.7779,  1.0407],
         [ 0.2050,  1.2679,  0.9811, -1.4350,  1.5426,  0.0334,  0.5895,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=T

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 9.4629e-01, -8.6549e-01, -2.1239e+00,  9.7739e-01,  8.2042e-01,
          -1.0132e+00,  7.9091e-01,  2.2539e-01, -8.1500e-01,  3.4218e-02,
           9.5298e-01,  6.1854e-01, -1.4124e+00, -6.2389e-01,  1.4740e+00,
           1.3802e-02],
         [ 2.9195e-01, -2.2397e+00,  8.8113e-01, -3.0781e-01, -4.5435e-01,
           8.5958e-01, -3.6618e-01,  7.4811e-01,  9.1870e-01,  3.1879e-01,
          -1.0885e+00,  8.8177e-01, -1.5821e+00,  1.2009e+00,  8.4443e-01,
          -9.0677e-01],
         [ 6.2981e-01,  1.0584e+00, -9.7113e-01, -1.1968e+00,  3.7083e-01,
          -2.8831e-01, -2.3384e-03,  1.0949e+00, -5.4055e-01, -9.5516e-01,
          -3.0232e-01,  3.8559e-01, -2.1090e-01, -3.4646e-01, -1.3959e+00,
           2.6703e+00],
         [ 1.0812e-01, -8.9144e-01,  1.0760e+00,  3.6180e-01,  8.6733e-01,
          -5.0667e-01,  3.2318e-02,  7.3347e-01, -1.8318e+00, -6.9435e-01,
          -2.0644e+00,  1.1195e+00

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-1.5145e-01, -1.7835e+00,  1.5911e+00, -1.1305e+00, -2.0248e-01,
          -1.7236e-01, -2.5691e-01,  1.0639e-01,  1.3678e-01, -1.3581e+00,
          -1.9541e+00,  1.0692e+00,  2.5632e-01,  1.5851e+00, -3.5327e-01,
          -7.6948e-01],
         [-2.6897e-01,  3.3643e-03,  6.7877e-01, -1.8060e+00,  1.1221e+00,
           7.3462e-01, -6.2694e-01, -5.8973e-01, -7.8161e-01,  4.1581e-01,
           2.2587e+00, -9.8660e-01,  6.9652e-01, -2.2425e+00,  1.2831e+00,
           3.3738e-01],
         [-3.8838e-01,  9.6339e-01, -1.5194e+00, -1.8377e+00, -6.2674e-01,
           9.8968e-01,  1.6824e+00,  8.7743e-01, -2.3647e-01,  9.5524e-01,
           2.0287e+00, -8.7897e-01, -6.1398e-01,  1.2012e+00,  2.1610e-01,
           6.5142e-01],
         [-1.7619e+00, -9.4223e-01, -1.8241e+00, -1.4537e+00,  4.1350e-01,
           8.4991e-01, -4.0169e-01,  1.9184e+00, -5.3129e-01, -5.0056e-01,
          -4.7706e-03,  1.9458e-01, -1.6724e+00,  1.1802e+00,  6.7215e-01,
     

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[1., 2., 5., 3.],
        [2., 0., 4., 0.],
        [7., 3., 1., 3.],
        [4., 8., 3., 3.],
        [2., 0., 9., 7.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 0, 1, 1],
        [0, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [0, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 1, 4])
tensor([[[[1, 0, 1, 1]]],


        [[[0, 1, 1, 1]]],


        [[[1, 1, 1, 1]]],


        [[[1, 1, 1, 1]]],


        [[[0, 1, 1, 1]]]])


# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[ 0.2035,  0.4429,  0.2635,  1.3967, -1.6652,  0.1203,  1.7866,
           0.8812,  0.0507,  0.1804, -0.0381, -2.3824, -0.2068,  0.1513,
          -1.0956, -0.0891],
         [ 1.1244,  1.0945, -0.0246,  1.2299, -0.8733, -1.7951,  2.0712,
          -0.4332, -0.5497, -1.2233, -0.7493, -0.0234, -0.7032, -0.2260,
           0.3518,  0.7292],
         [ 0.6582,  1.6080,  0.0794,  1.7047, -1.0895, -0.9417,  1.1758,
          -1.6279,  0.1364, -1.5671,  0.7436,  0.5458, -0.2716, -0.5609,
          -0.1057, -0.4876],
         [ 1.4413, -0.0748, -0.2107,  1.3109, -1.3763, -0.3701,  2.0483,
          -0.2628, -0.7714, -1.1779,  0.9788, -0.0193, -1.5405, -0.2082,
           0.6033, -0.3705]],

        [[ 1.9578,  0.8033, -0.3147,  1.4998, -0.2427, -0.6373, -0.4881,
          -1.3814,  0.6539, -1.1561, -0.9697, -0.6097,  0.3018,  1.4751,
           0.1945, -1.0865],
         [ 1.7017,  0.4597,  1.2454,  1.3109, -1.2802, -1.8012,  1.2904,
          -1.0481, -0.3061, -0.8629,  0.3082,  0.3

# utils.mask.causal_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [23]:
x = causal_mask(
    tgt_len=4,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [24]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [25]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    tgt_len=5,
)

tensor([[[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0]]]])

# BartDecoder

In [26]:
from bart_model_from_scratch.decoder import BartDecoder

In [27]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16

In [28]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [29]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [30]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(

In [31]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [32]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-1.8505e-01, -1.2715e-01,  9.2629e-02,  ...,  1.5001e-02,
           2.9675e-02,  1.2091e-03],
         [-1.2576e-01, -4.6871e-02,  2.9593e-02,  ...,  2.3546e-02,
           2.5722e-02, -6.3521e-03],
         [-1.7734e-01, -6.3951e-02,  4.7980e-02,  ...,  7.0407e-02,
           1.1833e-01, -5.9716e-03],
         [ 3.2495e-02,  6.2671e-02,  7.6747e-02,  ..., -1.9443e-04,
          -2.0474e-02,  4.3179e-02]],

        [[-1.8734e-01, -1.7068e-01,  8.8445e-02,  ...,  7.2931e-03,
           1.7551e-02,  8.9000e-03],
         [-1.4641e-02, -2.7307e-02, -2.0989e-02,  ...,  6.3315e-02,
           3.1620e-02, -7.7127e-02],
         [-6.0200e-02, -1.8704e-02,  5.8235e-02,  ..., -8.2347e-05,
           9.7401e-02, -1.1445e-01],
         [-9.4918e-02,  1.4912e-03,  9.7382e-02,  ...,  2.1650e-02,
           7.2902e-02,  4.9981e-04]]], grad_fn=<ViewBackward0>)


In [33]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[ 1.1811, -1.9758, -0.0209, -0.2662,  0.3135,  0.1960,  0.4113,
          -1.3872,  0.0121,  2.5192,  0.1217,  0.0626, -0.0180, -1.4118,
           0.0753,  0.1872],
         [-0.5383, -1.8345, -0.0548, -0.3420, -0.1049,  1.9727,  0.2111,
          -0.9605,  0.3124,  0.7549, -0.0456,  0.9636, -0.7188, -0.5888,
          -0.9998,  1.9734],
         [-0.0276,  1.0978,  0.4478, -0.2224, -0.2727, -0.5881,  0.1477,
          -0.2370,  1.7011, -1.4158,  1.4406, -0.0437,  0.0855, -0.0873,
          -2.5946,  0.5688],
         [-1.6403,  0.3979, -0.9926,  1.8256, -1.5456, -0.1903,  0.3617,
           0.9597,  0.2037,  1.8273,  0.6445, -1.0395, -0.1657, -0.4575,
           0.2406, -0.4295]],

        [[ 1.2429, -1.5708, -0.0187, -0.7662, -0.3011, -1.1549, -0.0059,
           0.3748,  0.1355,  2.8128, -0.1528,  0.2096, -1.3098,  0.2375,
           0.2569,  0.0104],
         [-0.6852, -2.0153,  0.6160, -0.2926, -0.2164,  1.8641,  0.3189,
          -0.9426,  0.2776,  0.6464, -0.0851,  0.9

In [34]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[-0.7030,  1.2238, -0.8282, -1.2306,  0.2240,  0.9211,  0.4974,
           1.2587, -1.6217,  1.9507, -1.3152,  0.4386, -0.6237, -0.4183,
           0.3690, -0.1425],
         [-0.4232, -0.9888, -0.7269,  0.3062, -0.0348, -0.1346,  1.5107,
           0.3886,  0.2488,  2.5285,  0.0671, -0.3590, -0.0501, -2.1834,
          -0.5211,  0.3719],
         [-0.2952, -0.2299, -0.4014, -0.0257,  0.1387,  0.2825,  1.8230,
           1.2164, -0.8336,  1.1982, -0.8318, -0.9818, -0.2566, -2.1929,
          -0.0693,  1.4595],
         [ 0.8839, -0.5790,  0.6680,  1.1558, -0.5407, -0.0279,  0.0781,
           0.0264,  0.9484, -0.9563,  2.1016, -1.2968,  0.3324, -0.9250,
           0.1088, -1.9777]],

        [[-0.7397,  1.2623, -0.9339, -1.2402, -0.0048,  1.0717,  0.5344,
           1.3726, -1.5386,  1.7859, -1.2248,  0.5812, -0.6472, -0.4405,
           0.3155, -0.1540],
         [-0.9507,  1.2446, -2.0122, -0.5425, -0.5806, -0.4437,  0.4546,
          -0.2613,  0.9666,  2.3420,  0.6610, -0.1

# BartSeq2seq

In [35]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [36]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(

In [37]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)
print(logits)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
tensor([[[ 0.0496, -0.1079, -0.0167,  ..., -0.0198, -0.0386,  0.0525],
         [ 0.1125, -0.0119, -0.0858,  ..., -0.0611,  0.0310,  0.0009],
         [ 0.0081, -0.0973,  0.0686,  ...,  0.0192, -0.0870,  0.0124],
         [ 0.0875, -0.0125,  0.0058,  ...,  0.0281,  0.0241, -0.0837]],

        [[-0.0566, -0.1057, -0.0058,  ...,  0.0238, -0.0864, -0.0197],
         [ 0.1522,  0.0867, -0.0534,  ..., -0.0267,  0.0549,  0.0049],
         [ 0.0024, -0.1179,  0.0793,  ..., -0.0657,  0.0527,  0.0201],
         [ 0.1627, -0.0217,  0.0145,  ...,  0.0594,  0.0535,  0.0286]]],
       grad_fn=<ViewBackward0>)
