# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.0657,  0.3638,  0.3547,  0.1457,  0.2279,  0.1052, -0.2803,
          -0.3222,  0.1243,  0.0783,  0.1117,  0.0024, -0.3913, -0.1754,
           0.5060,  0.2808],
         [-0.2492,  0.1646, -0.0440, -0.1268, -0.0802,  0.0290, -0.1925,
          -0.0887, -0.0463,  0.0794,  0.0687, -0.0882, -0.3179, -0.2385,
           0.2092,  0.1054],
         [-0.3266, -0.0022,  0.1153,  0.1191,  0.2794,  0.0027, -0.0410,
           0.0716, -0.1647,  0.1443,  0.2131,  0.1856, -0.3907, -0.1373,
           0.5034,  0.0290],
         [-0.1162,  0.6666,  0.2486, -0.3639, -0.0015,  0.1141, -0.5059,
          -0.5173,  0.4727,  0.2560, -0.0424, -0.2576, -0.0081, -0.1656,
           0.3775,  0.4282]],

        [[ 0.0750,  0.2250,  0.5008, -0.1996,  0.4376, -0.1817, -0.3130,
          -0.4451, -0.1459,  0.1269, -0.3016, -0.4536, -0.3248,  0.3432,
          -0.1023,  0.5337],
         [-0.0058,  0.3760,  0.3142,  0.2778,  0.2871, -0.2440, -0.3281,
          -0.1237,  0.1224,

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[ 1.7303, -0.0684, -0.5742,  1.0645, -0.1966, -0.0589, -0.4204,
           0.5996,  0.3824, -1.5182,  0.3269, -0.2785,  1.4044, -0.1941,
          -2.4953,  0.2966],
         [ 0.6524, -0.4579, -0.1538,  1.2862,  0.1561, -0.1879,  1.1720,
          -0.8738, -0.1441, -0.0564,  0.2808, -0.3219, -1.2976,  2.1412,
           0.0438, -2.2390],
         [-0.7470,  0.6250,  0.1436,  1.5897, -0.6523,  0.8605, -0.1740,
           0.0880, -1.4315, -0.7269, -0.9503,  1.4618,  0.7300, -1.1398,
           1.5396, -1.2164],
         [ 0.0480, -0.1264, -0.8401, -0.1930,  0.9730,  1.0608,  2.0433,
          -0.2330, -1.5247, -1.0239,  0.7541,  1.3340, -1.7605, -0.1561,
          -0.0573, -0.2981]],

        [[ 0.8921, -1.4672,  0.3682,  1.3168,  1.3081, -0.4817, -0.2093,
          -0.3458,  1.4984, -0.9442, -1.3810, -0.8306,  1.1296,  0.4569,
          -1.3615,  0.0512],
         [ 1.1283,  0.3169,  0.1387,  1.1286, -1.1982,  2.0617, -0.8212,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_feat

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[-1.5776, -0.3299, -1.4813,  0.0822,  0.3337,  1.8795, -0.7751,
           2.0257, -0.2111, -1.4173,  0.4328,  0.6545, -0.1010, -0.0264,
           0.1103,  0.4010],
         [ 0.2131, -1.2151,  0.2691,  1.3050, -0.7477,  0.3791, -2.4871,
           0.1070,  1.5471, -0.8608,  0.4454, -0.9378,  0.4249,  1.0222,
          -0.0651,  0.6006],
         [-0.3252, -1.2854, -0.1554, -1.1722,  0.5934, -0.6179,  0.5589,
           0.4534, -0.9157,  1.4623, -0.8523,  2.5113,  0.3034, -0.1923,
          -1.0154,  0.6491],
         [ 1.5939,  0.4814,  0.6533,  0.8072, -1.2809,  0.4014,  0.0199,
           0.4718, -0.0068,  0.0431,  1.0727,  0.2132,  0.0430, -1.8868,
          -0.3345, -2.2919]],

        [[-0.2395,  1.6573, -2.0908,  0.9395, -2.0119,  0.2139,  0.6822,
          -0.5615,  0.1591, -0.3723,  0.7001,  1.3691, -0.3070, -0.5339,
           0.3992, -0.0034],
         [ 0.9897, -0.4725,  0.6094,  0.2912, -1.4422,

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 7.2077e-01,  6.3350e-01,  7.1148e-01,  8.9647e-01,  5.6542e-01,
          -1.2455e-01,  8.2862e-01,  7.6186e-01, -3.2797e-01,  1.2147e+00,
          -2.6743e-01,  2.0352e+00, -1.4536e+00, -1.1206e-01,  6.4773e-01,
           4.2575e-01],
         [-1.8001e-01, -2.1103e-02,  4.6368e-01, -9.8570e-01, -1.1246e+00,
           1.6210e-01, -9.5679e-01,  9.2655e-01, -3.9981e-01,  2.6115e+00,
           3.7200e-03, -1.5708e+00,  2.0864e+00, -1.1988e+00,  3.4930e-01,
           1.6203e+00],
         [ 6.6404e-01, -9.2307e-01,  4.5185e-02,  8.1432e-01, -1.8589e-01,
          -1.7871e+00,  9.0726e-01,  1.1230e+00, -1.1461e+00,  7.4479e-01,
          -4.6692e-01, -1.8897e-01,  2.9519e+00, -2.6009e+00,  3.2173e+00,
           1.7134e+00],
         [-2.5679e-01, -2.2951e+00,  7.9971e-01,  5.9659e-01,  2.4703e+00,
          -6.5554e-01,  1.4443e+00, -1.2415e+00,  2.0196e-01,  1.0908e+00,
           7.0931e-01,  1.2554e+00,  1.3962e-01,  1.1590e-01,  1.6441e+00,
     

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[0., 1., 8., 9.],
        [8., 6., 4., 8.],
        [4., 4., 8., 5.],
        [2., 9., 2., 8.],
        [5., 3., 5., 3.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [0, 1, 0, 1],
        [1, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_ids.dtype,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 4, 4])
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[-3.4028e+38,  0.0000e+00, -3.4028e+38,  0.0000e+00],
          [-3.4028e+38,  0.0000e+00, -3.4028e+38,  0.0000e+00],
          [-3.4028e+38,  0.0000e+00, -3.4028e+38,  0.0000e+00],
   

# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_embeds.dtype,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[-0.2440, -0.7083, -1.2406, -0.4357,  0.6978,  0.2436,  0.4691,
           2.5133,  0.8467,  0.8194, -0.9147,  0.6732, -0.1995,  0.2042,
          -1.0833, -1.6412],
         [ 0.4695,  1.0606, -1.3081,  0.0530,  0.1950, -0.5963,  2.1741,
          -1.2432, -1.3733,  0.6151,  0.1743,  0.9148, -1.0729,  0.5024,
          -1.1713,  0.6063],
         [ 1.8137,  1.6354, -1.7559, -0.3323, -0.1178,  0.0932, -0.8758,
           0.3401, -1.4479, -1.1286,  1.5104, -0.3166,  0.1333, -0.1088,
           0.2897,  0.2678],
         [-1.4007, -0.4752, -1.5265,  0.6131, -0.1875, -0.9110, -0.0320,
           0.4401,  0.3301,  1.5506,  0.8130, -0.0623, -0.8966,  1.1299,
          -1.2086,  1.8238]],

        [[-0.0756,  0.0778,  1.4023,  0.4031, -0.5402,  0.2748, -1.7520,
          -0.5551,  0.7526, -0.5615,  2.5652, -1.3054, -0.5864, -0.3779,
           0.5988, -0.3205],
         [-0.0938,  0.2875,  1.8550,  0.4415, -0.8328,  0.3505, -2.1957,
          -0.9362,  1.2270,  0.1041,  0.3740, -1.4

# utils.mask.create_decoder_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [23]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    dtype=dtype,
)

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38]]],


        [[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38]]]])

# BartDecoder

In [24]:
from bart_model_from_scratch.decoder import BartDecoder

In [25]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.0, inplace=False)
      (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (k_proj): Linear(in_features=16, out_features=16, bias=True)
        (v_proj): Linear(in_features=16, out_features=16, bias=True)
        (q_proj): Linear(in_features=16, out_features=16, bias=True)
        (out_proj): Linear(in_features=16, out_features=16, bias=True)
      )
      

In [26]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 16])


In [27]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [28]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): Dropout(p=0.0, inplace=False)
 

In [29]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [30]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[-0.0459, -0.0574,  0.0508,  ..., -0.0269, -0.1099, -0.1114],
         [-0.0359, -0.1029,  0.0813,  ..., -0.0072, -0.0943, -0.0748],
         [-0.0658, -0.0943,  0.0642,  ..., -0.0582, -0.0908, -0.1211],
         [-0.0529, -0.0689,  0.0693,  ...,  0.0240, -0.1061, -0.0819]],

        [[-0.0456,  0.0967,  0.0799,  ...,  0.1436, -0.0422, -0.0257],
         [-0.1097,  0.1285,  0.0124,  ...,  0.1208, -0.0540, -0.0003],
         [-0.0469,  0.0368,  0.0976,  ...,  0.2448, -0.0598,  0.1432],
         [-0.0197,  0.0679, -0.0665,  ..., -0.0286, -0.1861, -0.0136]]],
       grad_fn=<ViewBackward0>)


In [31]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[-0.8635,  2.3041, -1.7932,  1.3518, -0.0954, -0.2098, -0.1939,
          -0.0780, -0.4081, -0.5802, -0.1218,  1.4588, -0.9903, -0.5278,
           0.9309, -0.1837],
         [-0.8725,  1.8666, -1.6995,  1.1233, -0.2218, -0.3176, -0.2599,
          -0.2234, -0.4828, -0.6526, -0.6089,  1.1081, -1.0264, -0.2576,
           0.6417,  1.8834],
         [-0.8675,  1.8536, -1.7041,  1.1128, -0.2318, -0.1019,  0.1153,
          -0.6798, -0.5081, -0.6687, -0.5868,  1.1001, -0.9942, -0.3109,
           0.6320,  1.8399],
         [-0.9228,  1.8465, -1.7593,  1.1136, -0.2642, -0.1361,  0.1395,
          -0.7179, -0.5594, -0.7325, -0.6225,  1.1011, -0.3124, -0.6430,
           0.6140,  1.8555]],

        [[-0.8596,  0.8030,  1.0868, -0.1257,  1.9796, -0.4293,  1.0498,
          -0.8805, -0.4387,  0.1243,  0.8115,  0.3964, -1.6538, -0.8856,
          -1.5885,  0.6102],
         [-0.8797,  0.9185,  0.9745, -1.6029,  1.5181,  1.7618, -0.4743,
           1.3820,  0.0327, -0.7501, -0.9968,  0.3

In [32]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[-0.8187,  0.2092,  0.9143, -0.7915, -0.7106,  0.4006,  1.7924,
           0.1419, -0.1151, -1.1451, -0.4593,  0.9742,  0.0084, -0.5595,
          -1.8067,  1.9655],
         [-0.1004,  0.1759,  0.8809, -0.8474, -0.7765,  0.3759,  1.7623,
           0.0716, -0.1628, -1.2141, -0.5415,  0.9799, -0.0414, -0.6297,
          -1.8970,  1.9644],
         [-0.8146,  0.2435,  0.9425, -0.6830, -0.6681,  0.4360,  1.8092,
           0.1730, -0.6987, -1.0835, -0.4259,  0.9992,  0.0393, -0.5137,
          -1.7278,  1.9723],
         [-0.8668,  0.1463,  0.9320, -0.7835,  0.0392,  0.3955,  1.7319,
           0.1592, -0.7400, -1.1482, -0.4745,  0.9998,  0.0175, -0.5757,
          -1.8039,  1.9711]],

        [[ 0.2280,  0.3130, -0.1226,  0.3363, -0.8759, -0.3423, -1.0096,
           1.5788,  0.0254,  0.2069, -0.4179,  0.3443,  2.0786, -0.8594,
          -2.2957,  0.8122],
         [ 0.2113,  1.5571, -0.0256, -1.8141, -0.5693,  0.7893, -0.9420,
           0.9509, -0.0355,  0.4615, -0.5491,  0.3

# BartSeq2seq

In [33]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [34]:
model = BartSeq2seq(config)
model

BartSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50265, 16, padding_idx=2)
    (embed_positions): Embedding(1024, 16)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=16, out_features=16, bias=True)
          (v_proj): Linear(in_features=16, out_features=16, bias=True)
          (q_proj): Linear(in_features=16, out_features=16, bias=True)
          (out_proj): Linear(in_features=16, out_features=16, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): Dropout(p=0.0, inplace=False)
 

In [35]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)

torch.Size([2, 5])
torch.Size([2, 5])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4, 50265])
