# import

In [1]:
import torch
from bart_model_from_scratch.multihead_attn import BartAttention
from transformers import BartConfig

  from .autonotebook import tqdm as notebook_tqdm


# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1
config.d_model = config.encoder_attention_heads

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=16, out_features=16, bias=True)
  (v_proj): Linear(in_features=16, out_features=16, bias=True)
  (q_proj): Linear(in_features=16, out_features=16, bias=True)
  (out_proj): Linear(in_features=16, out_features=16, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[ 0.2308, -0.0155,  0.2638, -0.1180,  0.3354, -0.0364, -0.0682,
           0.3735, -0.1165,  0.0083, -0.1876,  0.1315,  0.1648, -0.0027,
          -0.0652, -0.2345],
         [ 0.3073,  0.2864, -0.0532, -0.0928,  0.1919, -0.0834, -0.1109,
           0.0523, -0.1892, -0.0031, -0.0571,  0.3419,  0.0720, -0.2508,
           0.0377,  0.0878],
         [ 0.1847,  0.1438,  0.0450, -0.0207,  0.2000, -0.1546, -0.1281,
           0.0928, -0.1076, -0.0481,  0.0883,  0.1438,  0.1582, -0.1793,
           0.0156, -0.0324],
         [ 0.4912, -0.0124,  0.2691, -0.1006,  0.2902,  0.3086,  0.0203,
           0.4224, -0.2376,  0.2309, -0.3679,  0.0156,  0.0165,  0.0038,
          -0.1013, -0.5336]],

        [[ 0.4683, -0.3188,  0.2930, -0.0008,  0.2028,  0.0450, -0.1408,
          -0.5044,  0.0619,  0.1134, -0.3225,  0.2031,  0.2396,  0.3058,
          -0.3018, -0.0052],
         [ 0.0198, -0.1136, -0.0749, -0.2530,  0.1354, -0.0514, -0.3659,
          -0.0349,  0.4556,

# BartEncoderLayer

In [5]:
from bart_model_from_scratch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=16, bias=True)
  (final_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[-1.0815,  1.4600,  0.9880, -0.0311, -1.0488, -1.8177,  0.9473,
          -0.5387,  0.0030,  0.6707,  0.8261,  0.7633,  1.1093, -0.9392,
          -1.5049,  0.1941],
         [ 0.4348,  0.5269, -0.2812,  0.6041,  1.5712,  0.5017, -0.6394,
           0.8311,  1.1606, -1.0152, -1.0899, -0.7408, -1.1341, -1.9703,
          -0.1601,  1.4005],
         [-1.6643, -1.5625,  0.8132,  0.6971, -1.1672,  0.1583, -1.2897,
           0.7968, -0.4668,  0.3510,  1.3909,  1.4903,  1.0344, -0.0458,
          -0.6166,  0.0809],
         [ 0.8642,  0.3762,  0.4719, -2.0107,  1.0089, -0.5401, -0.4147,
           1.6356, -1.1949,  1.2460, -1.1774, -0.0626,  0.5106, -1.3265,
           0.5415,  0.0719]],

        [[ 0.3826, -1.4632,  0.5039,  0.1636,  1.9924,  0.5149,  1.0296,
          -1.0267, -0.5357, -0.2586, -1.5328,  0.1074, -0.6634,  0.7827,
           1.2892, -1.2858],
         [ 1.5613,  1.8288,  0.4619,  0.0991,  1.5021, -0.6401,  0.8571,
    

# BartDecoderLayer

In [8]:
from bart_model_from_scratch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.0, inplace=False)
  (self_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=16, out_features=16, bias=True)
    (v_proj): Linear(in_features=16, out_features=16, bias=True)
    (q_proj): Linear(in_features=16, out_features=16, bias=True)
    (out_proj): Linear(in_features=16, out_features=16, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=16, out_features=4096, bias=True)
  (fc2): Linear(in_feat

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
tensor([[[-7.5790e-02, -1.1394e+00,  4.7269e-01, -1.2430e+00,  1.3378e+00,
          -1.1735e+00,  1.3739e+00, -8.2089e-01,  1.3069e+00, -4.6396e-01,
          -8.7848e-01,  6.6484e-02, -1.3064e+00,  3.0947e-01,  1.4343e+00,
           7.9987e-01],
         [-1.9415e+00, -5.7180e-01, -3.8017e-02, -1.2561e-01, -4.0291e-01,
           1.0254e-01,  1.9303e+00,  1.3619e+00,  1.2873e+00, -7.2652e-01,
           4.1702e-01, -1.3988e+00, -1.0389e+00,  6.3991e-01,  5.6920e-01,
          -6.4150e-02],
         [ 1.6729e+00,  1.5638e+00,  1.3269e-01, -1.3936e-03, -1.0760e+00,
          -1.7465e+00,  2.1962e-01,  4.6269e-01, -5.7551e-01, -5.3958e-01,
          -4.6256e-01,  1.4280e+00,  8.6421e-01,  1.9205e-01, -6.6056e-01,
          -1.4739e+00],
         [-7.3808e-01, -8.1106e-01, -9.8833e-01, -4.1538e-01,  9.2842e-01,
           2.2284e+00,  3.2173e-01,  1.0823e+00,  1.1188e+00, -9.5381e-01,
           2.1273e-01, -7.9885e-01

# BartEmbeds

In [11]:
from bart_model_from_scratch.embeds import BartEmbeds

In [12]:
config.src_vocab_size = 50265
config.tgt_vocab_size = 50265
bart_embeds = BartEmbeds(
    num_embeddings=config.src_vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.src_vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)
print(output)

torch.Size([2, 4, 16])
tensor([[[-0.9708,  0.7682, -1.5229,  0.2895, -1.4604,  2.7455,  0.5427,
          -2.5780,  0.4655,  1.6350, -1.0364,  2.2142, -1.3534,  1.4662,
          -2.0265, -1.0402],
         [-0.9817,  1.2918, -2.0837,  0.3518, -0.0745,  0.7030, -1.6515,
          -1.7057,  0.1480, -2.9250,  0.2276, -1.4785,  1.3636, -0.6095,
           0.0053,  1.0012],
         [-1.9733,  1.3933,  0.3711, -2.4373, -1.0679,  1.3285,  0.2691,
           0.7369, -0.0702, -0.4021,  0.6217,  1.3696, -0.7926, -0.6448,
          -0.0437, -1.8793],
         [-1.6146, -0.2036, -1.6556, -0.2937,  0.7043,  2.1998, -0.8551,
          -3.7813, -3.0618,  0.3897, -1.8210,  0.7424, -0.7821,  2.9827,
           0.6485,  0.6580]],

        [[-2.1263, -0.9626, -2.9456, -0.2003, -2.9131, -0.4630,  0.3145,
           0.3150,  0.0372,  1.7001, -1.1304,  0.6329, -0.2938,  1.1666,
          -1.9255, -3.2452],
         [-0.5003, -0.3817, -2.0883,  0.1109,  0.1218,  0.4592,  1.4281,
          -0.3805, -0.9034,

# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scratch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[3., 6., 5., 0.],
        [4., 2., 3., 8.],
        [1., 2., 4., 6.],
        [1., 3., 0., 2.],
        [2., 1., 0., 6.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[1, 1, 1, 1],
        [1, 0, 1, 1],
        [1, 0, 1, 1],
        [1, 1, 1, 0],
        [0, 1, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)

In [18]:
print(encoder_attention_mask.shape)
print(encoder_attention_mask)

torch.Size([5, 1, 1, 4])
tensor([[[[1, 1, 1, 1]]],


        [[[1, 0, 1, 1]]],


        [[[1, 0, 1, 1]]],


        [[[1, 1, 1, 0]]],


        [[[0, 1, 1, 1]]]])


# BartEncoder

In [19]:
from bart_model_from_scratch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
# attention_mask = torch.randint(0, 2, (2, 4))
attention_mask = torch.tensor(
    [
        [1, 1, 1, 1],
        [1, 1, 1, 0],
    ]
)
encoder_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
)
# print(f"{encoder_mask=}")
# print(f"{input_embeds.shape=}, {attention_mask.shape=}")
# print(f"{input_embeds=}, {attention_mask=}")
output = bart_encoder(input_embeds, attention_mask)
# print(output.shape)
print(output)

 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.Size([2, 16, 4, 4])
 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.Size([2, 16, 4, 4])
 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.Size([2, 16, 4, 4])
 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.Size([2, 16, 4, 4])
 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.Size([2, 16, 4, 4])
 attention_mask = tensor([[[[1, 1, 1, 1]]],


        [[[1, 1, 1, 0]]]])
 attention_mask.shape = torch.Size([2, 1, 1, 4])
 attn_weights.shape = torch.

# utils.mask.causal_mask

In [22]:
from bart_model_from_scratch.utils.mask import (
    causal_mask
)

In [23]:
x = causal_mask(
    tgt_len=4,
    device=torch.device("cpu"),
)
x

tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]]])

# utils.mask.create_decoder_mask

In [24]:
from bart_model_from_scratch.utils.mask import (
    create_decoder_atn_mask
)

In [25]:
# test causal_mask
attention_mask = torch.tensor([
    [1, 1, 1, 0, 0],
    [1, 1, 0, 0, 0]
])
dtype = torch.float32
create_decoder_atn_mask(
    attention_mask=attention_mask,
    tgt_len=5,
)

TypeError: expand_mask() got an unexpected keyword argument 'tgt_len'

# BartDecoder

In [None]:
from bart_model_from_scratch.decoder import BartDecoder

In [None]:
bart_decoder = BartDecoder(config)
bart_decoder

In [None]:
# test bart_decoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
encoder_hidden_states = torch.randn(2, 4, config.d_model)
encoder_attention_mask = torch.randint(0, 2, (2, 4))
output = bart_decoder(
    input_embeds=input_embeds,
    attention_mask=attention_mask,
    encoder_hidden_states=encoder_hidden_states,
    encoder_attention_mask=encoder_attention_mask,
)
print(output.shape)

In [None]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq
import torch.nn as nn

In [None]:
model = BartSeq2seq(config)
model

In [None]:
# test model
input_ids = torch.randint(0, 10, (2, 4))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)

In [None]:
out = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(out)

In [None]:
encoder_out = model.get_encoder_out(
    input_ids=input_ids,
    attention_mask=attention_mask,
)
print(encoder_out.last_hidden_state)

In [None]:
decoder_out = model.get_decoder_out(
    input_ids=decoder_input_ids,
    attention_mask=decoder_attention_mask,
    encoder_hidden_states=encoder_out.last_hidden_state,
    encoder_attention_mask=attention_mask
)
print(decoder_out.last_hidden_state)

# BartSeq2seq

In [None]:
from bart_model_from_scratch.model_seq2seq import BartSeq2seq

In [None]:
model = BartSeq2seq(config)
model

In [None]:
# test model
input_ids = torch.randint(0, 10, (2, 5))
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
decoder_input_ids = torch.randint(0, 10, (2, 4))
decoder_attention_mask = (decoder_input_ids != config.pad_token_id).to(torch.int64)
print(input_ids.shape)
print(attention_mask.shape)
print(decoder_input_ids.shape)
print(decoder_attention_mask.shape)
logits = model(
    input_ids=input_ids,
    attention_mask=attention_mask,
    decoder_input_ids=decoder_input_ids,
    decoder_attention_mask=decoder_attention_mask,
)
print(logits.shape)