# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 0.0389,  0.4104,  0.5358, -0.2459, -0.1781, -0.1642, -0.0216,
          -0.1013,  0.1283,  0.0091, -0.2554,  0.5060,  0.0390, -0.1223,
           0.1115,  0.0058],
         [-0.1039,  0.1530, -0.3803, -0.3977,  0.1306,  0.2841,  0.1005,
          -0.2261, -0.0121,  0.1577, -0.1416, -0.0983,  0.0078, -0.2243,
           0.2186, -0.3268],
         [-0.0913,  0.3246, -1.0315, -0.6313,  0.4344,  0.8401,  0.0937,
          -0.6880, -0.1443,  0.1587, -0.5459, -0.1650, -0.3451, -0.4619,
          -0.0110, -0.3555],
         [-0.1713,  0.1922, -0.0521, -0.3483,  0.0024,  0.1173, -0.0450,
          -0.4192, -0.0302,  0.2721, -0.1507,  0.0775, -0.0623, -0.2312,
           0.0600, -0.2294]],

        [[ 0.3391,  0.5766,  0.0346,  0.1431, -0.2311,  0.2882,  0.3202,
           0.0473,  0.3187,  0.4905, -0.3281,  0.3373, -0.1340, -0.2770,
           0.2773,  0.0249],
         [-0.2525,  0.1176, -0.3275, -0.1843,  0.3619,  0.0377, -0.1492,
          -0.2667, -0.1520,

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 0.3092, -0.4381, -0.5898,  1.8925, -0.2995,  1.1505, -0.0769,
           1.6064,  0.7698, -1.0350, -0.1670, -2.4338, -0.3227, -0.2354,
          -0.0586, -0.0715],
         [ 1.7055,  0.3381, -1.2048,  1.4162, -1.7015,  0.4368, -1.4465,
           0.3767,  0.1977, -0.2445,  0.3203, -0.0442,  0.3787, -1.2158,
           1.3776, -0.6904],
         [ 0.2369,  2.8296, -1.6335,  1.3096,  0.0176, -0.2577, -0.1917,
          -0.5093,  0.6993,  0.3039, -0.7880, -1.1252, -0.1732, -0.2643,
           0.2825, -0.7366],
         [ 0.9384,  0.2119, -0.5238,  1.7041,  0.9440, -0.9017,  1.0691,
          -1.1966, -0.9548, -0.4630, -1.8659,  0.3101, -1.2433,  0.6490,
           0.4719,  0.8506]],

        [[ 1.5122, -0.9173, -0.9252,  0.8195, -2.3405, -0.2565, -0.7356,
           0.2047, -0.0110,  0.4460, -0.5434,  0.6443, -0.8430,  1.1343,
           0.3272,  1.4845],
         [ 1.0456,  0.3773, -0.8523,  1.4549, -2.6029,  0.2879,  0.3019,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_feat

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 0.0838, -1.0741, -0.4239,  1.3732, -0.2004,  0.3073,  2.3867,
          -0.6945, -1.3187, -1.3993,  0.6926, -0.5350, -0.6374,  0.1659,
           1.2504,  0.0234],
         [-1.0046,  1.0142, -0.7813, -1.5916,  2.0711, -0.3687, -0.3814,
           0.2923,  0.3803,  0.2298,  1.2389, -0.2420,  0.9088, -1.8127,
           0.3830, -0.3360],
         [ 0.1838,  1.6697,  0.0334, -1.9656, -0.5262, -0.2398, -1.1432,
          -0.3598,  0.1187,  0.8636,  1.6133, -0.4661,  1.1296, -0.5798,
           0.9055, -1.2371],
         [ 0.7118,  2.5659,  0.7785,  0.0098,  1.2585, -0.4290, -1.2909,
          -0.0743, -0.0503, -1.6497, -0.2572, -0.0719, -0.2042,  0.3157,
          -0.2467, -1.3660]],

        [[ 1.2485,  0.9097, -0.3604,  0.6512, -1.7981, -0.7419,  0.6545,
           0.6181, -0.1367,  0.3461, -2.4557,  0.8612, -0.7268,  0.9633,
          -0.2379,  0.2050],
         [ 0.3879, -0.0709,  0.8146, -1.2012,  0.9592,

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.2306, -2.0944,  2.0371,  0.0823,  1.6332,  1.9889, -1.7074,
          -0.6835,  0.3621,  2.1629,  1.7967, -3.8688, -1.2339,  0.3487,
          -0.6226, -0.8284],
         [-0.9060, -1.5050,  1.6390,  0.1715,  1.1103,  0.8039,  1.6467,
          -0.2298, -2.1701,  0.0132, -0.1254, -2.0719,  0.3118, -0.9823,
          -3.1595,  0.3988],
         [ 1.6639,  0.6652,  0.5490,  1.2204,  0.7428,  0.0290,  0.2879,
           0.9791,  0.9325, -0.0320,  0.1458, -2.2266,  0.2778, -0.7276,
          -0.5019, -0.5056],
         [-0.6863, -0.6860,  0.7250,  1.0308, -1.7031,  1.1196,  1.1003,
           1.8988, -3.0791,  0.1530,  0.3555,  0.7907,  0.6875, -1.7792,
           1.2669,  0.7915]],

        [[ 0.1999, -0.4940,  1.9735,  0.2206,  0.1448,  0.1803, -0.5510,
          -1.6708, -0.1974,  0.9242,  1.5357, -3.9122, -0.3895, -0.1896,
           0.0173, -0.7235],
         [-0.2252, -2.8683,  0.4321,  1.2758,  0.5330, -0.1352,  0.4204,
          -3.7060, -1.4523,

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[1., 5., 4., 7.],
        [7., 1., 3., 8.],
        [2., 8., 9., 8.],
        [1., 2., 8., 7.],
        [8., 0., 6., 6.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [0, 1, 1, 1],
        [1, 0, 1, 1],
        [1, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_ids.dtype,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 4, 4])
tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[0, 1, 1, 1],
          [0, 1, 1, 1],
          [0, 1, 1, 1],
          [0, 1, 1, 1]]],


        [[[1, 0, 1, 1],
          [1, 0, 1, 1],
          [1, 0, 1, 1],
          [1, 0, 1, 1]]],


        [[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]]])


# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_embeds.dtype,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[-1.3393e+00, -5.9955e-01,  1.1432e+00,  8.8060e-01,  4.2715e-01,
          -1.7897e+00,  5.1543e-01,  5.9620e-03, -3.4387e-01,  1.2885e+00,
           9.4253e-01, -1.1165e+00,  7.1777e-01, -8.0670e-01, -1.1938e+00,
           1.2683e+00],
         [ 1.3529e-01,  8.9397e-01,  7.0699e-01, -1.0973e+00, -1.9791e+00,
           2.4163e-01, -1.4070e+00, -1.7463e+00, -4.9425e-02,  5.1930e-01,
          -1.8253e-02,  3.5841e-01,  1.3090e-01,  9.5961e-01,  1.6061e+00,
           7.4524e-01],
         [-1.2204e+00, -1.0538e-01,  2.5693e-01, -1.6847e+00, -8.4828e-01,
           1.5640e+00, -5.2388e-02,  1.8177e+00,  1.6162e+00,  6.2449e-01,
          -4.6787e-02, -2.8022e-01,  1.6663e-01, -9.7164e-01,  1.8116e-01,
          -1.0173e+00],
         [-4.7350e-01,  5.0554e-01,  4.5119e-01, -1.3749e+00, -9.6252e-01,
           2.0799e+00, -1.4825e-01,  1.1738e+00,  1.9956e+00, -9.8299e-01,
          -3.1811e-01, -8.5761e-01,  1.0712e-01, -6.9681e-01,  1.7911e-01,
          -6.7759e-01]],

  

# utils.mask.causal_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [23]:
x = causal_mask(
    bsz=2,
    tgt_len=4,
    dtype = torch.float32,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [24]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [25]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    dtype=dtype,
)

tensor([[[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0]]]])

# BartDecoder

In [26]:
from bart_model_from_scratch.decoder import BartDecoder

In [27]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      

In [28]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [29]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [30]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): D

In [31]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [32]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-0.0083, -0.0343, -0.0034,  ..., -0.0437, -0.0105, -0.0992],
         [-0.0062, -0.0340, -0.0076,  ..., -0.0430, -0.0109, -0.1014],
         [-0.0083, -0.0342, -0.0034,  ..., -0.0437, -0.0105, -0.0992],
         [-0.0005, -0.0337, -0.0013,  ..., -0.0431, -0.0199, -0.0953]],

        [[-0.0777,  0.0664,  0.0591,  ..., -0.0485, -0.1908,  0.0668],
         [-0.0538,  0.0137,  0.0442,  ..., -0.0513, -0.1541,  0.0685],
         [-0.0078,  0.0100, -0.0021,  ..., -0.0747, -0.0642,  0.0212],
         [ 0.0102, -0.0777,  0.0335,  ..., -0.0058,  0.0927,  0.0184]]],
       grad_fn=<ViewBackward0>)


In [33]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[ 0.2202,  0.2242,  1.6139,  1.6597, -0.5454,  0.3504, -0.6280,
           0.7057,  1.2193, -1.6289, -0.2535, -1.7012, -0.7439, -1.1546,
           0.2186,  0.4435],
         [-0.3079, -0.3656,  1.6935,  1.7377, -0.4543,  0.4301, -0.5562,
           0.7701,  1.2975, -1.5457, -0.1674, -1.6023, -0.6692, -1.0619,
           0.2882,  0.5134],
         [ 0.3000,  0.4238, -0.2414,  1.9705, -0.4917,  0.5650, -0.5369,
           0.9577,  1.4493, -1.6656, -0.1599, -1.7142, -0.6542, -1.1408,
           0.3850,  0.5535],
         [ 0.1771,  0.2287,  1.6198,  1.6657, -0.5419,  0.3549, -0.6246,
           0.7106,  1.2248, -1.6265, -0.2496, -1.6989, -0.7406, -1.1517,
           0.2230,  0.4291]],

        [[ 1.3516, -1.6218,  1.2227, -1.2537, -1.0342, -0.5422,  1.2619,
          -0.6652, -0.0485,  0.5205, -0.0449,  1.4115,  0.7873,  0.4962,
          -1.2735, -0.5677],
         [-0.8785,  0.5168,  2.2922,  1.0228, -0.0912, -0.6284, -1.2259,
          -1.2663, -0.7113,  0.1279,  0.4201,  0.7

In [34]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[-6.1328e-01, -5.7728e-01,  2.7410e-01,  4.1479e-01,  5.7958e-01,
          -1.7235e+00, -1.0068e+00, -7.8708e-01,  1.4377e+00,  2.5554e+00,
          -7.0099e-01,  6.2075e-01, -4.3138e-01, -3.2541e-01,  5.2780e-01,
          -2.4447e-01],
         [-5.6244e-01, -1.5040e-01,  3.1184e-01,  4.7537e-01,  6.7291e-01,
          -1.8050e+00, -3.3451e-01, -7.7196e-01,  1.5899e+00,  2.7304e+00,
          -6.8222e-01, -2.3395e-01, -4.5478e-01, -3.0061e-01, -2.7117e-01,
          -2.1340e-01],
         [-5.6676e-01, -5.3791e-01,  3.3372e-01,  4.8247e-01, -2.5778e-01,
          -1.6925e+00, -9.5859e-01, -7.4826e-01,  1.4915e+00,  2.6361e+00,
          -6.1420e-01,  7.0141e-01, -3.9079e-01, -2.6217e-01,  5.7732e-01,
          -1.9364e-01],
         [-4.9494e-01, -4.9082e-01,  3.8996e-01,  5.4073e-01, -1.8499e-01,
          -1.6683e+00, -9.3577e-01, -7.0485e-01,  1.5757e+00,  2.7203e+00,
          -6.0555e-01, -1.0713e-01, -3.3729e-01, -2.2693e-01,  6.6040e-01,
          -1.3055e-01]],

  

# BartSeq2seq

In [35]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [36]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): D

In [37]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
