# import

In [38]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

# BartConfig

In [39]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [40]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (dropout): Dropout(p=0.0, inplace=False)
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [41]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.0906, -0.1523,  0.0645, -0.0779, -0.4009,  0.1179, -0.0401,
          -0.0256,  0.4397, -0.4632, -0.3262, -0.3941, -0.0053,  0.2744,
           0.5118, -0.2259],
         [ 0.3221,  0.0747,  0.3362, -0.1945, -0.0732,  0.0699,  0.0223,
           0.0306,  0.0310, -0.5592, -0.1614, -0.2145,  0.2070,  0.3759,
           0.3163, -0.5995],
         [ 0.6112,  0.1720,  0.2572, -0.2552,  0.0168, -0.0536, -0.3061,
           0.2608, -0.1535, -0.5237,  0.0262, -0.0588,  0.6172,  0.4804,
          -0.0105, -0.9807],
         [ 0.8390,  0.3498,  1.0678,  0.0452, -0.7960,  0.6988, -0.3431,
          -0.7610,  0.4373, -0.0960, -0.4153, -0.3340,  0.0728, -0.1029,
           0.7400, -0.4680]],

        [[-0.2008,  0.2546, -0.2252,  0.0403, -0.0884, -0.0938, -0.0812,
           0.1076, -0.0319,  0.1269,  0.0424,  0.1358,  0.2720,  0.2933,
          -0.0975, -0.1525],
         [ 0.0311,  0.1533,  0.0352,  0.0074, -0.1858, -0.0321,  0.0199,
           0.0076, -0.1418,

# BartEncoderLayer

In [42]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [43]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [44]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 0.8866,  0.9529, -0.0799, -0.1191, -0.5029, -0.0234,  0.2104,
          -1.3184, -0.4672,  1.9999, -1.6847,  2.0458, -0.2491, -0.4718,
          -0.4713, -0.7077],
         [ 0.4436,  0.9698, -0.5733, -0.3309,  0.1695, -0.5152, -0.6032,
          -0.9839,  1.3384, -1.5008,  0.7942, -2.0934,  0.7560, -0.2112,
           1.6097,  0.7308],
         [-0.2578, -0.4657, -0.6380, -0.7219, -0.2121,  0.1674, -1.7863,
          -1.0285,  0.8287,  0.7400,  1.3973, -0.7108, -1.0396,  1.0719,
           0.6133,  2.0418],
         [ 0.3813, -0.3535,  0.9814,  0.4377,  0.1263,  1.8616,  0.5039,
           0.5770, -0.4403, -1.8439,  0.5302, -0.3310, -0.3050, -2.5194,
           0.1538,  0.2400]],

        [[ 0.0121, -1.0717, -0.9269, -1.3595,  1.7191, -0.7353,  0.3082,
           1.0552,  1.0707,  0.7819, -0.9308,  0.7366, -0.4859, -1.5454,
           0.0911,  1.2806],
         [ 0.8386, -1.0170, -1.2013, -0.4672, -1.1312,  0.4353, -2.2052,
    

# BartDecoderLayer

In [45]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [46]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=T

In [47]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[-0.5961, -0.2438,  2.0707,  0.8528, -0.3247, -0.8230, -0.4278,
           0.8240,  1.3356,  0.9072, -1.2335, -0.3611,  1.0598, -0.4635,
          -1.2227, -1.3541],
         [ 0.3346,  1.9473,  0.8060,  0.7508, -0.0184,  0.5457,  0.5893,
           0.1226,  0.0605, -1.1706, -1.9169, -1.6469,  0.2156, -1.4014,
           0.6602,  0.1216],
         [-0.0885,  0.1404,  0.3864, -0.6493, -1.6866, -0.4397, -0.9270,
           0.4149, -0.7006,  1.2153,  1.5563, -0.1757,  0.6654,  1.2475,
          -1.9715,  1.0127],
         [-1.0432,  1.1607, -2.1389, -0.8634, -0.8447, -0.1974, -0.2165,
          -0.3523,  2.0631, -0.0163,  0.1888,  1.3780,  0.8251,  0.3682,
          -0.5193,  0.2081]],

        [[ 1.5042,  0.0541, -1.3730, -0.1065, -1.9605, -1.6317,  0.6744,
           1.5118,  0.0288, -0.2980,  0.5462,  0.3287,  0.8157, -1.0040,
           0.5197,  0.3900],
         [ 0.4662,  1.3827, -0.2456,  1.7677, -1.0146,

# BartEmbeds

In [48]:
from bart_model_from_scratch.embeds import BartEmbeds

In [49]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [50]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 2.5787, -0.3854,  0.2765,  0.5187, -1.2923,  0.0393,  1.2051,
           0.5166,  0.7347,  0.6127,  2.2282,  2.6512,  0.0838, -1.0159,
           0.0713,  0.6154],
         [-0.2097, -0.3558,  0.5705, -0.3296,  1.2369,  1.0185,  0.2950,
           1.6750,  0.8265, -1.6832, -0.6825, -0.0399,  0.4249,  1.0549,
           0.7891, -1.7902],
         [ 0.6245, -1.6816, -0.1927,  1.0083, -0.0978,  0.1457, -0.8657,
           0.1397, -1.0650,  0.9363, -0.6423,  0.1529, -1.0550,  1.1428,
           0.9876, -1.2842],
         [ 1.9195, -0.9235, -1.3548,  0.3458,  0.7945, -1.2909,  2.3783,
          -1.6891, -0.8762, -1.5661, -1.6773, -2.1997,  0.6300, -1.0207,
           1.1664,  0.9793]],

        [[ 1.0350, -0.7139, -1.8564, -0.7620, -2.7044,  0.0135,  2.7190,
          -0.7885,  1.2181, -0.3431,  0.5561,  1.7443,  2.7480,  0.2251,
           1.1599,  1.6751],
         [-0.1996,  2.8108, -0.0418,  2.4814,  0.1316, -0.9360,  2.0095,
          -0.3664,  4.0593,

# utils.mask.create_encoder_mask

In [51]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [52]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[3., 8., 8., 2.],
        [2., 6., 4., 6.],
        [2., 5., 3., 7.],
        [9., 3., 0., 4.],
        [1., 4., 6., 3.]])

In [53]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 0],
        [0, 1, 1, 1],
        [0, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [54]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)

In [55]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 1, 4])
tensor([[[[1, 1, 1, 0]]],


        [[[0, 1, 1, 1]]],


        [[[0, 1, 1, 1]]],


        [[[1, 1, 1, 1]]],


        [[[1, 1, 1, 1]]]])


# BartEncoder

In [56]:
from bart_model_from_scratch.encoder import BartEncoder

In [57]:
bart_encoder = BartEncoder(config)

In [58]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[ 1.0892e+00, -5.1398e-01,  8.1492e-01, -1.2408e+00, -7.8178e-01,
           3.8425e-01, -1.2705e-01,  1.3179e+00,  4.3272e-03,  5.0743e-01,
          -3.0564e-01, -4.8269e-01, -1.6059e+00, -3.6146e-01, -1.0366e+00,
           2.3379e+00],
         [ 3.7795e-01,  1.6069e+00,  3.5906e-01, -8.9692e-01, -2.1237e+00,
           9.7817e-01, -6.3405e-01,  1.0918e+00, -1.7468e+00,  1.7888e-01,
           9.1982e-01, -3.7176e-01,  2.9797e-01, -3.3642e-01, -5.4547e-01,
           8.4462e-01],
         [ 5.9538e-01, -1.6238e+00,  1.7015e+00,  6.9254e-01, -1.4663e+00,
          -5.2375e-01, -5.0907e-02, -5.4977e-01,  7.7115e-01,  3.7957e-01,
          -6.6095e-02,  1.2309e+00,  8.0230e-02, -1.2828e-01, -1.9495e+00,
           9.0717e-01],
         [-7.0943e-01,  1.1993e+00,  2.4329e-01, -8.7603e-02,  1.3522e+00,
           1.0797e-01, -2.1006e+00,  1.1015e+00,  3.1025e-02,  1.7022e-01,
           3.4893e-01, -2.1934e+00,  6.7861e-01,  7.5636e-01, -7.0055e-01,
          -1.9784e-01]],

  

# utils.mask.causal_mask

In [59]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [60]:
x = causal_mask(
    tgt_len=4,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [61]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [62]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    tgt_len=5,
)

tensor([[[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0]]]])

# BartDecoder

In [63]:
from bart_model_from_scratch.decoder import BartDecoder

In [64]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16

In [65]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [66]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [67]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(

In [68]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [69]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-0.0176, -0.0621, -0.0379,  ...,  0.0216,  0.0350, -0.0982],
         [-0.0131, -0.0645,  0.0163,  ..., -0.0073,  0.0322, -0.1053],
         [ 0.0045, -0.0684, -0.0128,  ...,  0.0121,  0.0448, -0.0999],
         [ 0.0402, -0.0925,  0.0092,  ..., -0.0079,  0.0078, -0.0703]],

        [[-0.0914, -0.0359, -0.0699,  ...,  0.0278,  0.0752,  0.0905],
         [-0.0409,  0.0278, -0.0744,  ...,  0.0379,  0.0917,  0.0309],
         [-0.0539,  0.0495,  0.0050,  ..., -0.0140,  0.1010,  0.0287],
         [ 0.0660,  0.1194, -0.1428,  ...,  0.0897, -0.0701, -0.0951]]],
       grad_fn=<ViewBackward0>)


In [70]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[ 0.6533,  0.6205, -1.5151, -0.4721, -0.2612, -1.0028,  1.0264,
           0.5035, -1.9236,  0.7580, -1.6013,  1.2098,  0.5314,  0.0932,
           1.2997,  0.0803],
         [ 0.6656,  0.6624, -1.5440, -0.1468, -0.1172, -1.0078,  1.0812,
           0.5537, -1.9377,  0.0145, -1.6171,  1.3286,  0.5136,  0.1114,
           1.3362,  0.1034],
         [ 0.7470,  0.6833, -0.1467, -0.4984, -0.3193, -1.1431,  1.2286,
           0.6032, -2.1538,  0.8538, -1.8473, -0.1103,  0.5894, -0.1431,
           1.4955,  0.1611],
         [ 0.6544,  0.6215, -1.5133, -0.4707, -0.2798, -1.0012,  1.0273,
           0.5045, -1.9217,  0.7590, -1.5995,  1.2106,  0.5325,  0.0944,
           1.3005,  0.0815]],

        [[-0.5735, -0.3089, -1.5566, -0.2495, -0.4468, -0.1620,  0.0901,
          -1.9181, -0.0690, -0.3516,  1.3955,  0.7145,  0.1303,  1.6126,
          -0.2803,  1.9735],
         [-2.3941, -0.2773,  0.1096, -1.6209,  0.5751,  0.2854, -1.2639,
           1.0502,  1.0320,  0.4221,  1.1866, -0.2

In [71]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[-0.4722, -0.8670,  1.1513,  0.8716,  0.0970, -1.6213, -0.2579,
          -0.1375, -0.2392,  1.5019, -0.2585, -0.7533,  0.2551, -0.4288,
          -1.2025,  2.3610],
         [-0.5671, -0.2779,  1.0956,  0.7926, -0.0216, -1.7280, -0.3405,
          -0.2478, -0.3264,  1.4731,  0.3825, -0.8803,  0.1753, -0.5175,
          -1.3260,  2.3138],
         [-0.6100, -0.9697,  1.0977,  0.7665, -0.0370, -1.8207, -0.4089,
          -0.2735, -0.3790,  1.4734,  0.3942, -0.8990,  0.2085, -0.5179,
          -0.3899,  2.3653],
         [-0.5310, -0.9096,  1.1117,  0.7959, -0.1567, -1.7070, -0.3315,
          -0.2183, -0.3069,  1.4576,  0.4141, -0.3098,  0.1744, -0.5214,
          -1.2970,  2.3354]],

        [[ 0.2701, -0.0796, -1.7033, -0.1933, -0.4943,  1.4349,  1.2847,
          -0.5690,  1.6044,  1.0447, -0.6129, -0.2018, -0.3157,  0.5989,
          -0.0275, -2.0404],
         [ 0.8289,  1.2454, -1.0755, -0.1345, -0.0269,  0.9752,  1.4214,
          -0.9358, -0.0281,  1.3392, -0.3158, -0.8

# BartSeq2seq

In [72]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [73]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(

In [74]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)
print(logits)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
tensor([[[-0.0202,  0.0158, -0.0945,  ...,  0.0804,  0.1224, -0.1314],
         [ 0.0142,  0.0303, -0.1050,  ...,  0.0612,  0.0224, -0.1053],
         [-0.0238,  0.0365, -0.1340,  ...,  0.0475,  0.0519, -0.0880],
         [-0.0837,  0.0265, -0.0657,  ...,  0.1468,  0.1203, -0.1235]],

        [[-0.0453, -0.0415,  0.0021,  ..., -0.0533,  0.1263, -0.0591],
         [-0.0576, -0.0333, -0.0234,  ...,  0.0065,  0.1910, -0.0556],
         [-0.0805,  0.0257,  0.0817,  ...,  0.0409,  0.0625, -0.0747],
         [-0.1220, -0.0052,  0.0834,  ...,  0.1382,  0.0995, -0.0588]]],
       grad_fn=<ViewBackward0>)
