# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.3085, -0.5430, -0.6939, -0.1806, -0.4738, -0.0900,  0.2146,
          -0.0142, -0.0684, -0.4724,  0.1755, -0.4288, -0.2211, -0.6935,
          -0.3615,  0.1440],
         [-0.2094, -0.4102, -0.9866,  0.8181, -0.2247, -0.9481, -0.0128,
           0.0640, -0.1594, -0.4074,  0.1402, -0.1735,  0.0300, -0.6610,
           0.0479,  0.9284],
         [-0.0788,  0.4413,  0.0049,  0.2629, -0.1417, -0.2024,  0.2738,
          -0.0829,  0.4959,  0.5231,  0.0777,  0.0198,  0.3571,  0.2038,
           0.3759, -0.4529],
         [ 0.1036,  0.6275,  0.1735,  0.4267, -0.1173, -0.1116,  0.5675,
           0.1069,  0.5706,  0.2846,  0.2635,  0.3055,  0.5501,  0.3205,
           0.3919, -0.4072]],

        [[-0.2822, -0.2059, -0.0102,  0.0912, -0.2135,  0.0383,  0.2335,
          -0.0493,  0.3750, -0.2267,  0.2058, -0.3390,  0.3391, -0.1193,
          -0.1178,  0.1349],
         [-0.1455, -0.0861,  0.0957, -0.0336,  0.1165, -0.1926,  0.0992,
          -0.1550,  0.5523,

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 1.8823,  1.3107, -0.0557, -1.7424, -0.8350,  0.3171,  0.8862,
          -1.2024,  0.7124, -0.4691, -0.3853,  0.8529,  0.3980, -0.2549,
          -1.6674,  0.2526],
         [ 0.3965, -1.2588,  0.0179,  0.5952,  0.0459, -1.7641, -0.6036,
          -0.9895,  0.9923,  0.8357,  0.4992,  1.2294,  2.0379, -0.2590,
          -0.5660, -1.2088],
         [ 0.1841,  0.9893,  0.2465,  1.4967,  0.4635,  0.9032, -0.0817,
          -0.9999,  0.2276,  0.9337, -1.5777,  0.0234,  1.0972, -1.6527,
          -1.7441, -0.5092],
         [ 1.7859,  0.0934, -0.1034, -0.5085,  0.9714, -2.1045,  0.0876,
          -1.2743, -0.3106,  1.3316, -0.2501,  0.2765,  1.0097,  0.8613,
          -0.8395, -1.0265]],

        [[-0.6679,  0.7581,  1.8585,  0.1912, -2.2260,  1.0209, -1.3114,
          -0.4396,  0.5835,  0.5375,  0.9441, -1.2929, -0.4306, -0.0978,
           0.1709,  0.4016],
         [-1.5074, -1.2582,  1.7520, -0.1151,  0.9497,  0.9900,  0.4616,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_feat

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 1.0776e+00,  1.9157e+00,  8.7856e-02,  1.0362e+00, -5.9893e-02,
          -1.1544e+00,  2.8920e-01,  1.6172e+00, -6.4346e-01, -9.8993e-01,
          -3.1476e-01, -1.8419e+00, -3.5570e-01, -5.1983e-01, -6.8802e-01,
           5.4418e-01],
         [-4.1839e-01, -1.4509e+00, -1.1784e+00,  8.5728e-01,  1.0162e+00,
          -3.8367e-01, -8.2666e-02, -2.9402e-01,  2.6752e+00, -1.2239e-01,
          -5.1743e-01,  7.8535e-02, -4.8309e-01,  1.4130e+00, -6.8115e-01,
          -4.2815e-01],
         [ 7.7805e-01, -1.6013e+00,  2.1215e-01,  1.7587e+00, -1.7302e+00,
          -1.4361e+00,  1.0619e-01, -3.2242e-01, -9.2831e-01,  1.4654e+00,
           1.1811e+00,  2.6443e-01,  1.1809e-02,  1.8314e-01, -4.7514e-02,
           1.0486e-01],
         [ 3.3409e-01,  8.5172e-01, -2.2491e+00,  2.5950e-01, -9.4095e-01,
           1.5534e+00,  6.1394e-01,  1.7437e+00, -2.3938e-01,  2.3809e-01,
           1.3564e-02, -9.1631e-01

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 8.3337e-01,  1.6446e+00, -4.6474e-01,  2.2453e+00, -1.0304e+00,
          -1.7082e+00, -1.4529e-02, -6.9212e-01, -5.4204e-01, -8.7379e-01,
           1.9477e+00, -1.5175e-01, -1.1088e+00, -2.1477e+00,  1.5735e+00,
           8.9000e-01],
         [-2.0867e-01,  1.8616e-01,  6.1666e-01, -1.1448e-03, -5.6174e-01,
           2.7533e+00,  1.3271e+00,  4.5364e-01,  2.7271e-01, -7.2667e-01,
          -7.3358e-01, -8.6063e-02, -1.6477e+00, -4.9940e-01,  2.0937e+00,
           6.1895e-01],
         [-4.7224e-02,  5.0741e-01, -1.4657e+00,  7.9383e-01,  2.9942e-01,
          -5.1116e-02,  1.2519e+00, -1.1645e+00, -9.2459e-01, -2.7501e-01,
          -1.8935e+00, -4.9009e-02, -5.2499e-02,  8.2606e-01,  6.0021e-01,
          -9.7641e-01],
         [-9.1267e-01, -9.6180e-01, -1.9930e+00, -5.1359e-01, -2.4776e+00,
           1.0979e+00, -7.2845e-01,  6.2273e-01, -4.8454e-02,  2.8069e+00,
           1.2564e+00,  5.9932e-01,  2.6599e+00, -5.1142e-01, -2.7380e+00,
     

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[1., 8., 3., 8.],
        [8., 5., 2., 0.],
        [3., 1., 2., 4.],
        [1., 2., 7., 1.],
        [6., 5., 0., 1.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 1],
        [1, 1, 0, 1],
        [1, 1, 0, 1],
        [1, 0, 1, 1],
        [1, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_ids.dtype,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 4, 4])
tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[1, 1, 0, 1],
          [1, 1, 0, 1],
          [1, 1, 0, 1],
          [1, 1, 0, 1]]],


        [[[1, 1, 0, 1],
          [1, 1, 0, 1],
          [1, 1, 0, 1],
          [1, 1, 0, 1]]],


        [[[1, 0, 1, 1],
          [1, 0, 1, 1],
          [1, 0, 1, 1],
          [1, 0, 1, 1]]],


        [[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]]])


# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_embeds.dtype,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[ 0.7984, -0.6295, -0.8316,  0.2448,  1.5115, -0.4280, -0.0505,
           1.9785, -1.5735,  0.1623,  1.3742, -0.6036,  0.5917, -1.1544,
          -1.1564, -0.2341],
         [ 0.1040, -1.3184, -1.2163,  0.8586,  2.6831, -0.0876, -0.7086,
           0.7585, -0.5279,  0.1150,  0.9749,  0.1732, -0.3672, -0.6455,
          -1.2951,  0.4993],
         [ 2.1029,  0.0817, -1.3088,  0.4243,  0.2330, -0.4199, -0.4676,
           1.7182, -1.2604,  1.3735,  0.4386, -1.4030, -0.1909, -0.3381,
          -0.4367, -0.5467],
         [-0.0115,  2.0181,  0.5870, -1.5288, -0.4797,  1.1119, -0.3453,
           1.6658, -2.0910, -0.5533, -0.2161, -0.0724,  0.1770, -0.2071,
           0.1707, -0.2254]],

        [[ 1.9773,  0.9269, -0.6369,  0.9566, -1.2002, -0.5251, -0.9857,
           0.8696,  0.6430,  0.8418,  0.1626,  0.7989, -0.3029, -1.6982,
          -0.7844, -1.0435],
         [ 1.7116,  0.7006, -2.2380,  0.8656, -0.8407, -0.4993, -0.8973,
           0.9355,  0.7509,  0.8364,  0.4021,  0.7

# utils.mask.causal_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [23]:
x = causal_mask(
    bsz=2,
    tgt_len=4,
    dtype = torch.float32,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [24]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [25]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    dtype=dtype,
)

tensor([[[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0]]]])

# BartDecoder

In [26]:
from bart_model_from_scratch.decoder import BartDecoder

In [27]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      

In [28]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [29]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [30]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): D

In [31]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [32]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-0.0882,  0.1041, -0.1131,  ..., -0.0465, -0.0664, -0.0375],
         [-0.0319,  0.1512, -0.1085,  ..., -0.0740, -0.0571, -0.0241],
         [-0.0784,  0.1167, -0.0513,  ..., -0.1046,  0.0120, -0.0367],
         [-0.0685,  0.1236, -0.1014,  ..., -0.0820, -0.0500, -0.0364]],

        [[-0.0382, -0.0914,  0.0139,  ..., -0.0231,  0.0754, -0.0182],
         [-0.0782, -0.0654,  0.0009,  ..., -0.0583,  0.0677, -0.0011],
         [-0.0720, -0.0368,  0.0287,  ..., -0.1310,  0.0749,  0.0321],
         [ 0.1014,  0.0658, -0.0269,  ..., -0.1334, -0.0450,  0.0677]]],
       grad_fn=<ViewBackward0>)


In [33]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[-0.0889,  0.3022, -0.2304,  0.0035, -0.9340,  0.7860,  0.2017,
          -1.8708, -0.0928,  1.3179,  0.4970, -1.9724, -0.2888,  2.1530,
           0.4180, -0.2011],
         [-0.1412,  0.3304, -0.1961, -0.0107, -0.9382,  0.8012,  0.2763,
          -1.7729, -0.0702,  1.2343,  0.5167, -0.2134, -0.2489,  2.1057,
           0.4587, -2.1315],
         [-0.1058,  0.3033, -0.2292,  0.0047, -0.9328,  0.7870,  0.2028,
          -1.8695, -0.0917,  1.3189,  0.4981, -1.9711, -0.2876,  2.1539,
           0.4191, -0.1999],
         [-0.1453,  0.4187, -0.0963,  0.0711, -0.7163,  0.8393,  0.3403,
          -1.5153,  0.0156,  1.2387,  0.5868, -1.6297, -0.1219,  2.0311,
           0.4853, -1.8024]],

        [[ 1.3378,  0.2528,  0.0525, -0.2562, -0.3210,  1.0218,  1.2728,
           1.7333, -1.3819, -0.3312, -0.0900, -1.3513,  0.6623, -1.2347,
           0.1983, -1.5655],
         [-2.3883, -0.8351, -0.0574,  1.2195, -0.2198, -0.1900,  0.2054,
          -0.7032,  0.4059, -0.9408,  0.9676, -0.7

In [34]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[ 2.0022e-01, -8.6054e-01,  3.3992e-01,  2.6815e+00,  6.2730e-01,
           1.6669e-01,  1.2130e-01, -3.2881e-02,  1.0448e+00, -2.7783e-01,
          -1.6505e+00, -2.7987e-01,  9.2149e-02, -6.6349e-01, -1.7337e+00,
           2.2482e-01],
         [ 2.3701e-01, -9.5690e-01,  4.9753e-02,  2.8518e+00,  6.4250e-01,
           2.1753e-01, -7.7377e-01, -1.0553e-01,  1.0566e+00, -4.1040e-01,
           4.1685e-02, -3.7488e-01,  8.0390e-02, -7.8409e-01, -1.9359e+00,
           1.6420e-01],
         [ 1.0865e-01, -1.1009e+00,  3.0601e-01,  2.8885e+00,  5.9399e-01,
           1.4728e-01, -2.4351e-02, -2.1223e-01,  9.8791e-01, -4.1522e-01,
          -1.9474e+00, -4.2343e-01, -9.4312e-02, -8.8358e-01, -2.4735e-02,
           9.3811e-02],
         [ 1.8808e-01, -8.6628e-01,  3.9576e-01,  2.6639e+00,  6.1464e-01,
           3.0690e-01,  1.0748e-01, -9.0711e-02,  1.0249e+00, -2.8934e-01,
          -1.6540e+00, -2.9183e-01,  8.3772e-02, -6.7243e-01, -1.7347e+00,
           2.1397e-01]],

  

# BartSeq2seq

In [35]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [36]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): D

In [37]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
