# import

In [1]:
from bart_model_from_scratch.config import BartConfig
import torch
from bart_model_from_scratch.multihead_attn import BartAttention

# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 768])
tensor([[[ 0.0751, -0.3702,  0.2998,  ..., -0.1878, -0.5207,  0.0063],
         [ 0.2250, -0.2079, -0.2231,  ...,  0.0689,  0.3023,  0.1949],
         [ 0.1175,  0.0272,  0.3938,  ...,  0.0080,  0.3699,  0.3703],
         [-0.1657,  0.2114, -0.0247,  ..., -0.2443,  0.0351,  0.2090]],

        [[ 0.2112,  0.4091,  0.1846,  ...,  0.0738,  0.0352,  0.1629],
         [ 0.0873, -0.1674,  0.2318,  ...,  0.0596, -0.1273, -0.0782],
         [-0.1921,  0.0960,  0.0876,  ..., -0.1347,  0.0837, -0.2731],
         [ 0.4105,  0.0838,  0.0536,  ..., -0.0928,  0.0835, -0.1846]]],
       grad_fn=<ViewBackward0>)


# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.01, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.1, inplace=False)
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
tensor([[[-2.0080,  1.1065,  2.0326,  ...,  0.3838, -1.6869, -1.0192],
         [ 0.9626, -1.7576, -0.0230,  ..., -0.4659,  0.6886,  0.0971],
         [ 1.0780, -1.0045,  1.2052,  ..., -0.5846,  0.3765, -1.0169],
         [ 0.5955, -0.7931,  0.1466,  ..., -1.2054,  0.2591,  0.9101]],

        [[ 0.9044, -0.2739,  0.7580,  ...,  0.0748,  1.4707,  2.2459],
         [-2.4251,  0.4028, -0.6919,  ...,  2.8643,  0.7489,  1.0420],
         [-0.5499, -1.1846, -0.4563,  ..., -0.2049,  1.3025, -2.5611],
         [ 1.3633, -1.0645, -1.5674,  ...,  1.4318, -0.4118,  0.7800]]],
       grad_fn=<NativeLayerNormBackward0>)


# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.01, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.1, inplace=False)
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
tensor([[[ 0.5057,  0.2828, -0.1360,  ..., -0.4359, -1.6799,  1.4759],
         [ 0.7721,  1.4408,  0.4127,  ...,  0.5775,  0.4154, -0.1217],
         [-0.2220,  0.8638,  0.6709,  ...,  0.2976,  1.2378, -1.2565],
         [-0.0255,  0.4025,  0.5806,  ...,  0.4377,  0.9355,  0.3866]],

        [[-0.8634,  1.8840, -1.8397,  ...,  0.3422, -1.0796,  1.0176],
         [ 1.8969,  0.2343,  0.3589,  ...,  0.4187,  0.6501,  1.1825],
         [ 0.3066,  1.9093,  0.2845,  ..., -0.4692, -0.0525,  1.1247],
         [ 0.7325, -0.1940,  0.2714,  ...,  2.1922,  1.6120, -1.2322]]],
       grad_fn=<NativeLayerNormBackward0>)


# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 768])
tensor([[[ 0.3382,  0.2444, -0.9735,  ..., -0.4801,  2.0161,  0.2084],
         [-0.0909, -1.8923, -2.6886,  ...,  0.2949, -0.3469,  1.6286],
         [ 0.6539,  1.1789,  1.4475,  ..., -1.3194,  0.6760,  0.3774],
         [-1.9713, -0.2679,  0.3852,  ...,  1.5159, -2.8070,  1.5362]],

        [[ 0.2817,  0.7297,  0.1119,  ...,  0.4697, -0.7054, -0.6662],
         [-0.9036, -1.1799, -2.7916,  ...,  0.6680, -1.5610,  1.7562],
         [-0.2986,  1.2720, -0.5686,  ..., -1.9106, -1.0335,  1.5932],
         [-1.8067,  0.1130, -0.5724,  ...,  1.5661, -2.8715,  0.5432]]],
       grad_fn=<AddBackward0>)


# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[1., 2., 6., 5.],
        [0., 3., 3., 7.],
        [4., 7., 0., 2.],
        [0., 8., 7., 6.],
        [3., 3., 4., 7.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 0, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_ids.dtype,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 4, 4])
tensor([[[[ 0.0000e+00, -3.4028e+38,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -3.4028e+38,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -3.4028e+38,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -3.4028e+38,  0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
   

# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_embeds.dtype,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

tensor([[[ 0.4104,  1.4796, -0.5556,  ..., -0.3095,  0.2145,  0.0950],
         [-0.9712,  0.4293, -0.4621,  ..., -1.3623, -0.0304, -0.1284],
         [-1.3096, -0.6222, -1.3053,  ..., -0.0676,  0.3396, -0.4068],
         [-1.3433,  1.6227, -2.2673,  ..., -0.9275, -0.0360, -0.4252]],

        [[-0.7063, -1.0918, -0.6807,  ...,  1.0402,  0.2754, -1.9718],
         [-0.6258, -0.2518, -1.0423,  ...,  0.8777,  0.3191, -1.4948],
         [-0.6714, -0.2738, -1.4537,  ...,  0.7464,  0.4274, -1.6436],
         [-0.5447, -0.3065, -0.9770,  ...,  0.6532,  0.4955, -1.3913]]],
       grad_fn=<NativeLayerNormBackward0>)


# utils.mask.create_decoder_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [23]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    dtype=dtype,
)

tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38]]],


        [[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38]]]])

# BartDecoder

In [24]:
from bart_model_from_scratch.decoder import BartDecoder

In [25]:
bart_decoder = BartDecoder(config)
bart_decoder

BartDecoder(
  (dropout): Dropout(p=0.01, inplace=False)
  (layers): ModuleList(
    (0-11): 12 x BartDecoderLayer(
      (self_attn): BartAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (dropout): Dropout(p=0.01, inplace=False)
      (activation_fn): GELU(approximate='none')
      (activation_dropout): Dropout(p=0.1, inplace=False)
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=T

In [26]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

torch.Size([2, 4, 768])


In [27]:
from bart_model_from_scratch.model_seq2seq import BartModelSeq2seq
import torch.nn as nn

In [28]:
model = BartModelSeq2seq(config)
model

BartModelSeq2seq(
  (inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50264, 768, padding_idx=2)
    (embed_positions): Embedding(1024, 768)
  )
  (decoder_inputs_embeds): BartEmbeds(
    (embed_tokens): Embedding(50264, 768, padding_idx=2)
    (embed_positions): Embedding(1024, 768)
  )
  (encoder): BartEncoder(
    (dropout): Dropout(p=0.01, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.01, inplace=False)
        (activation_fn): GELU(approximate='none')
        (activation_dropout): Dropout(p=0

In [29]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])
torch.Size([2, 4])


In [30]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

tensor([[[ 0.9165, -0.2147,  0.0139,  ...,  0.3357, -0.0208,  0.5799],
         [ 0.6535, -0.3647, -0.3704,  ...,  0.5537,  0.0829,  0.6227],
         [ 1.0513, -0.1627, -0.5134,  ...,  0.3676,  0.4069,  0.5489],
         [-0.3029, -0.1783,  0.5007,  ..., -1.0684, -0.2373, -0.3889]],

        [[ 0.2842, -0.0104, -0.8313,  ..., -0.2996, -0.2948, -0.0245],
         [ 0.4382, -0.1661, -0.8213,  ..., -0.0694, -0.3858, -0.1230],
         [ 0.0948, -0.0241, -0.7646,  ..., -0.1073, -0.4828, -0.5212],
         [ 0.1070,  0.1674, -0.4750,  ..., -0.5473, -0.6266, -0.3140]]],
       grad_fn=<ViewBackward0>)


In [31]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

tensor([[[ 9.1211e-01, -1.1763e+00, -2.7104e-01,  ...,  4.1343e-01,
           6.6243e-01, -2.2297e+00],
         [ 1.8793e+00, -9.8599e-01,  7.4036e-01,  ..., -1.9584e+00,
           2.4847e-01, -1.2981e+00],
         [ 8.5872e-01, -2.2490e-01,  3.4009e-02,  ..., -6.5443e-01,
           8.5175e-01, -1.5463e+00],
         [ 2.0215e-03, -3.3893e-01,  1.6216e+00,  ..., -9.7414e-01,
           1.7708e+00,  5.4375e-01]],

        [[-7.9991e-01,  1.0269e-01,  4.2689e-01,  ..., -9.1324e-01,
          -9.2716e-01, -1.7147e+00],
         [ 1.2477e+00, -9.4545e-01, -8.8680e-01,  ..., -1.1067e+00,
           4.3609e-01,  1.7351e-01],
         [ 1.8777e+00, -8.6944e-01,  1.0420e+00,  ..., -2.0947e-01,
           1.6511e+00, -1.6411e+00],
         [ 2.7004e-01, -5.9743e-01, -3.9792e-01,  ..., -1.5319e+00,
           8.9601e-01, -8.1273e-01]]], grad_fn=<NativeLayerNormBackward0>)


In [32]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

tensor([[[ 2.3896, -1.7146,  0.8089,  ...,  1.0119, -0.5203,  0.3437],
         [ 1.5660, -1.4608,  0.4715,  ...,  0.7448, -0.8916,  0.6079],
         [ 0.2441, -0.4802,  0.6980,  ...,  1.2161, -0.1608,  1.0399],
         [ 0.2197,  0.0530,  0.3321,  ..., -0.2195, -0.5395, -1.0574]],

        [[ 0.9433, -1.1686, -0.9944,  ..., -1.1126, -0.7391,  2.4272],
         [ 0.3525, -1.6448, -0.8937,  ..., -1.8769, -0.4092,  2.6488],
         [ 0.3973, -1.2249, -0.7741,  ..., -1.2191, -1.1529,  2.5299],
         [ 0.0649, -0.9688, -0.4909,  ..., -1.0646, -0.6579,  1.9590]]],
       grad_fn=<NativeLayerNormBackward0>)
