# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (dropout): Dropout(p=0.0, inplace=False)
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.0744,  0.1096,  0.3569, -0.2427, -0.5184, -0.2361,  0.1305,
          -0.0049, -0.1922,  0.3523, -0.1591,  0.0055,  0.3097,  0.0006,
           0.1725,  0.1720],
         [-0.0716,  0.1093,  0.3470, -0.2503, -0.5655, -0.2798,  0.0930,
           0.0232, -0.1914,  0.3430, -0.1610,  0.0432,  0.3209, -0.0043,
           0.1427,  0.2461],
         [-0.1000,  0.1051,  0.3487, -0.2298, -0.5710, -0.2846,  0.1228,
           0.0045, -0.2116,  0.3467, -0.1725,  0.0270,  0.3110,  0.0017,
           0.1653,  0.2170],
         [-0.0990,  0.1039,  0.3398, -0.2383, -0.5440, -0.2337,  0.0908,
           0.0086, -0.1968,  0.3508, -0.1634,  0.0414,  0.3007,  0.0112,
           0.1338,  0.2188]],

        [[ 0.3108, -0.0422,  0.2644, -0.3930,  0.2264, -0.1212, -0.4148,
           0.1483, -0.3362, -0.1036, -0.0262,  0.2476,  0.1763,  0.0108,
           0.3939, -0.3109],
         [ 0.2937, -0.0556,  0.2618, -0.3986,  0.2553, -0.1046, -0.4136,
           0.1352, -0.3747,

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 0.3327, -0.8196, -1.1490, -1.4489, -0.2774, -0.3396, -0.7298,
          -0.2572, -0.8149,  1.7467,  0.0631,  1.3611,  1.2931,  0.1309,
          -0.8364,  1.7453],
         [ 0.0794,  0.8044,  0.0132, -0.7686, -0.2295, -1.3868,  0.4922,
          -1.5978,  0.9453,  0.9738, -0.4922,  0.8979, -2.0766,  0.0068,
           1.3225,  1.0161],
         [-0.1598,  1.5161, -1.5436, -0.9204,  0.9891, -0.5670,  0.6629,
           0.5229, -0.2489,  1.8504, -0.2684, -0.7882,  0.5846,  0.5575,
          -1.8714, -0.3158],
         [ 0.2629, -0.4137, -0.1411, -1.0722, -1.3532,  0.8676, -0.6836,
           0.6728,  1.7002,  0.8968, -0.6476,  0.4595,  0.9586,  1.3140,
          -1.1494, -1.6717]],

        [[ 0.3345,  0.3017, -1.5003, -0.8372,  0.6681, -0.1755, -1.2749,
           1.1982,  2.1567, -1.4367,  0.9975,  0.6556,  0.5090, -0.4519,
          -0.4173, -0.7274],
         [ 0.1036,  1.3660,  0.3707,  0.8148, -1.4039,  0.3567, -0.5273,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (dropout): Dropout(p=0.0, inplace=False)
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=T

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[-2.1775e+00, -6.1314e-02, -4.5473e-01,  5.8933e-01,  3.8494e-01,
          -6.3279e-02, -9.8499e-01, -6.0135e-01,  1.2112e+00, -1.6528e-01,
           1.0329e+00, -1.4881e+00, -2.0659e-01,  4.1656e-02,  1.6526e+00,
           1.2905e+00],
         [ 1.4260e+00, -8.2967e-03, -5.1471e-01,  1.8634e+00, -4.3680e-01,
          -9.6977e-01, -1.7588e+00,  8.6321e-01, -9.1751e-01,  7.2287e-01,
          -1.2078e-01,  9.1992e-01, -1.2167e-01,  3.7421e-01, -1.6623e+00,
           3.4094e-01],
         [-9.3464e-01,  1.0449e+00,  1.3605e+00, -2.4156e-01, -3.4064e-01,
           1.4320e+00, -2.5092e-01, -3.8674e-01, -1.0139e+00,  4.4721e-02,
           9.0861e-01,  5.3833e-01,  5.0169e-01, -1.3184e+00,  8.6217e-01,
          -2.2062e+00],
         [-1.5881e+00, -3.8869e-01,  1.3483e+00,  4.0835e-01,  4.5905e-01,
           1.5408e-01,  6.6549e-01, -7.5024e-01,  3.0783e-01,  1.1251e+00,
           1.3061e+00,  7.7703e-01

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-1.0631, -2.2427,  1.2375,  0.2208, -1.1868,  4.9200,  1.0919,
           0.1336,  2.0748,  0.2285,  0.2080,  0.0513,  0.4489, -0.5368,
          -0.2832, -0.6980],
         [-1.3945,  1.1207, -0.9415, -0.9211, -0.8879,  1.5329, -0.4966,
          -0.0385, -1.3054,  0.4623,  0.5357, -1.7724,  1.3427,  0.8357,
           2.6756, -1.6300],
         [ 1.2986, -0.1557,  1.1247, -1.5941,  0.9674,  0.1499,  1.3442,
          -1.9106, -0.1675, -0.4938,  0.8028,  0.4982, -1.4611, -1.2048,
           0.5588,  1.9044],
         [ 0.5611, -1.1732,  1.1563,  0.1908,  0.3299, -1.2589,  0.8412,
          -0.9499, -1.0725,  0.6901, -0.5582,  1.3876, -0.8428, -0.7935,
          -0.5209, -1.5504]],

        [[-0.3652, -0.4517,  1.9866,  0.4457, -0.8511,  3.0066, -0.8788,
           0.3108,  1.8888,  1.3156,  0.7894, -1.3409,  1.6299,  0.3381,
          -1.0987, -1.5702],
         [-1.0823, -0.6783, -0.4636,  1.7522, -0.9803,  0.1611, -1.2355,
          -0.6614,  1.2854,

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[9., 5., 6., 7.],
        [2., 4., 1., 0.],
        [6., 8., 4., 8.],
        [6., 6., 0., 4.],
        [3., 5., 9., 5.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 1],
        [0, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 1, 4])
tensor([[[[1, 1, 1, 1]]],


        [[[0, 1, 1, 1]]],


        [[[1, 1, 1, 1]]],


        [[[1, 1, 1, 1]]],


        [[[1, 1, 1, 1]]]])


# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[-1.0602,  0.1185, -0.5197,  0.3883, -0.2769,  0.5868,  0.8872,
           0.1394,  0.3414,  1.2967,  0.8905, -1.3005,  1.6039, -1.4406,
          -2.0345,  0.3797],
         [-1.0937,  0.5849, -0.4723,  0.6667,  0.3754,  0.4399, -0.4351,
           0.3751, -0.0838, -0.2268,  2.0441, -0.7016,  1.5955, -0.7968,
          -2.2847,  0.0131],
         [-0.4712, -0.3401, -0.4874,  0.3512, -0.0859,  1.1293,  0.7766,
           0.8044,  1.0132,  0.4956, -0.1094, -0.7210,  1.9305, -0.6250,
          -2.1012, -1.5594],
         [ 0.7973, -2.0460, -0.4883, -0.7531, -0.0153,  0.9395,  0.9778,
           0.4363,  0.2023,  0.0991,  0.7699,  0.9465,  1.3266, -0.0418,
          -1.3764, -1.7742]],

        [[-2.3264, -0.1708,  0.0179, -0.5929,  1.4490,  0.1736, -0.0320,
           1.2920,  1.1968,  0.3141,  0.7489,  0.0444,  0.5436,  0.1337,
          -1.1700, -1.6218],
         [-0.6744, -0.0647, -0.4591,  0.0188,  0.0266,  1.9228,  1.7739,
          -0.8868,  0.8052,  0.3948,  1.1258, -0.6

# utils.mask.causal_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [23]:
x = causal_mask(
    tgt_len=4,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [24]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [25]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    tgt_len=5,
)

tensor([[[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0],
          [1, 1, 1, 0, 0]]],


        [[[1, 0, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0],
          [1, 1, 0, 0, 0]]]])

# BartDecoder

In [26]:
from bart_model_from_scratch.decoder import BartDecoder

In [27]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16

In [28]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [29]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [30]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16, padding_idx=2)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(

In [31]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [32]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-0.1131, -0.1536,  0.1407,  ...,  0.0852,  0.0952,  0.1190],
         [-0.0821, -0.0387, -0.0959,  ...,  0.0225,  0.0214, -0.0211],
         [-0.0823, -0.0334, -0.0518,  ..., -0.0282, -0.0603, -0.0517],
         [-0.3557, -0.0599,  0.0321,  ...,  0.0739,  0.0292, -0.0663]],

        [[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]],
       grad_fn=<ViewBackward0>)


In [33]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[ 1.4122,  1.2157, -0.1767,  0.5970,  0.1617, -2.4125, -0.2758,
          -0.9491, -0.6338, -0.5471,  1.0070,  0.1406, -1.0475,  1.5750,
          -0.0071, -0.0598],
         [ 1.1111, -0.8287, -2.0083,  0.9725, -0.0111,  0.6718,  1.3403,
          -1.6163, -0.0923,  0.5758, -0.5428, -0.4998,  0.8494,  1.3499,
          -0.6033, -0.6683],
         [ 0.3647, -1.3126,  1.4006, -0.2067, -0.5217, -0.2639,  0.7043,
           0.8158,  0.6458, -1.2123,  0.9422, -2.1219,  1.5767,  0.0918,
           0.0267, -0.9296],
         [-0.6509,  0.8416, -0.6647,  2.4102, -0.9511, -0.4870,  0.9381,
           0.1887, -2.1130,  0.6922,  0.6684,  0.3566, -0.2612, -0.7950,
          -0.4904,  0.3176]],

        [[ 1.3371,  1.1760, -0.0417,  0.8256,  0.1379, -1.9897, -1.2733,
          -0.4875, -0.4235, -0.4504,  1.0214,  0.2570, -0.7912,  1.6655,
          -1.1721,  0.2087],
         [ 0.1200, -0.6762, -0.2913, -0.0723,  0.2435,  0.1079,  0.2211,
           0.0204, -1.6548,  1.3576, -1.7482, -0.2

In [34]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[-1.0374e+00, -1.8223e-01, -5.1885e-01, -1.0337e+00,  6.2146e-01,
          -5.7532e-01,  1.3011e+00, -1.9218e-01,  2.8882e-01, -7.9383e-01,
          -1.2689e+00, -1.2259e+00,  1.5824e-01,  1.3418e+00,  2.1043e+00,
           1.0126e+00],
         [-3.8584e-01, -1.8447e+00, -1.4931e-01,  9.6935e-02,  1.5092e+00,
           1.8017e-01, -4.9214e-01, -5.7917e-02, -3.3987e-02,  2.3276e-01,
          -8.2695e-04, -1.5681e-01, -1.8046e+00,  2.4684e+00,  6.2848e-01,
          -1.8976e-01],
         [-1.1581e+00,  1.5587e+00,  2.2751e-01, -9.4714e-01, -1.2161e-01,
          -1.9868e+00,  1.7790e+00, -2.5580e-01, -3.0466e-01,  1.2111e-01,
           2.1806e-01,  3.3012e-02,  6.5429e-01, -5.7290e-01,  1.5798e+00,
          -8.2447e-01],
         [-4.3802e-01, -2.8482e+00,  1.9894e-02, -2.4251e-02, -7.1148e-01,
           5.8337e-01,  6.8500e-01,  3.3771e-01, -1.2723e+00,  2.3079e-02,
          -2.4077e-01, -5.3852e-02,  5.9618e-01,  8.1610e-01,  1.6936e+00,
           8.3392e-01]],

  

# BartSeq2seq

In [35]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [36]:
model = BartSeq2seq(config)
config, model

(BartConfig {
   "activation_dropout": 0.0,
   "activation_function": "gelu",
   "attention_dropout": 0.0,
   "bos_token_id": 0,
   "classifier_dropout": 0.0,
   "d_model": 16,
   "decoder_attention_heads": 16,
   "decoder_ffn_dim": 4096,
   "decoder_layerdrop": 0.1,
   "decoder_layers": 12,
   "decoder_start_token_id": 2,
   "dropout": 0.1,
   "encoder_attention_heads": 16,
   "encoder_ffn_dim": 4096,
   "encoder_layerdrop": 0.1,
   "encoder_layers": 12,
   "eos_token_id": 2,
   "forced_eos_token_id": 2,
   "id2label": {
     "0": "LABEL_0",
     "1": "LABEL_1",
     "2": "LABEL_2"
   },
   "init_std": 0.02,
   "is_encoder_decoder": true,
   "label2id": {
     "LABEL_0": 0,
     "LABEL_1": 1,
     "LABEL_2": 2
   },
   "max_position_embeddings": 1024,
   "model_type": "bart",
   "num_hidden_layers": 12,
   "pad_token_id": 2,
   "scale_embedding": false,
   "src_vocab_size": 50265,
   "tgt_vocab_size": 50265,
   "transformers_version": "4.42.0.dev0",
   "use_cache": true,
   "vocab_siz

In [37]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)
print(logits)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
tensor([[[ 0.0234,  0.0852,  0.0187,  ...,  0.0997, -0.0197, -0.0242],
         [ 0.0324, -0.0408, -0.0875,  ..., -0.0095,  0.0822,  0.0806],
         [-0.0584,  0.0132,  0.0794,  ..., -0.0436, -0.0464, -0.0446],
         [ 0.0496,  0.0683, -0.0699,  ...,  0.0742,  0.0520, -0.0190]],

        [[-0.0053, -0.0426,  0.0097,  ...,  0.0474,  0.0243,  0.0004],
         [ 0.0125,  0.0528,  0.0499,  ...,  0.0844, -0.0011,  0.0659],
         [-0.0527,  0.0634,  0.0419,  ...,  0.0414, -0.1528,  0.0520],
         [ 0.0286,  0.0942,  0.0424,  ...,  0.1626, -0.0453,  0.0037]]],
       grad_fn=<ViewBackward0>)
