# import

In [1]:
from bart_model_from_scatch.config import BartConfig
import torch
from bart_model_from_scatch.multihead_attn import BartAttention

# BartConfig

In [2]:
config = BartConfig()
config.pad_token_id = 2
config.encoder_layerdrop = 0.1
config.decoder_layerdrop = 0.1

# BartAttention

In [3]:
bart_attn = BartAttention(
    embed_dim=config.d_model,
    num_heads=config.encoder_attention_heads,
    dropout=config.attention_dropout,
)
print(bart_attn)

BartAttention(
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)


In [4]:
# test bart_attn
hidden_states = torch.randn(2, 4, config.d_model)
output = bart_attn(hidden_states)
print(output.shape)

torch.Size([2, 4, 768])


# BartEncoderLayer

In [5]:
from bart_model_from_scatch.encoder_layer import BartEncoderLayer

In [6]:
bart_encoder_layer = BartEncoderLayer(config)
bart_encoder_layer

BartEncoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.01, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.01, inplace=False)
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [7]:
# test bart_encoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
output = bart_encoder_layer(hidden_states)
print(output.shape)

torch.Size([2, 4, 768])
torch.Size([2, 4, 768])


# BartDecoderLayer

In [8]:
from bart_model_from_scatch.decoder_layer import BartDecoderLayer

In [9]:
bart_decoder_layer = BartDecoderLayer(config)
bart_decoder_layer

BartDecoderLayer(
  (self_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.01, inplace=False)
  (activation_fn): GELU(approximate='none')
  (activation_dropout): Dropout(p=0.01, inplace=False)
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  

In [10]:
# test bart_decoder_layer
hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
encoder_hidden_states = torch.randn(2, 4, config.d_model, dtype=torch.float32)
print(hidden_states.shape)
print(encoder_hidden_states.shape)
output = bart_decoder_layer(
    hidden_states=hidden_states,
    encoder_hidden_states=encoder_hidden_states,
)
print(output.shape)

torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
torch.Size([2, 4, 768])


# BartEmbeds

In [11]:
from bart_model_from_scatch.embeds import BartEmbeds

In [12]:
bart_embeds = BartEmbeds(
    num_embeddings=config.vocab_size,
    embedding_dim=config.d_model,
    padding_idx=config.pad_token_id,
    max_position_embeddings=config.max_position_embeddings,
)

In [13]:
# test BartEmbeds
input_ids = torch.randint(0, config.vocab_size, (2, 4))
output = bart_embeds(input_ids)
print(output.shape)

torch.Size([2, 4, 768])


# utils.mask.create_encoder_mask

In [14]:
from bart_model_from_scatch.utils.mask import (
    create_encoder_atn_mask,
)

In [15]:
# test create_encoder_mask
input_ids = torch.randint(0, 10, (5, 4)).to(torch.float32)
input_ids

tensor([[2., 4., 4., 9.],
        [0., 7., 4., 6.],
        [8., 6., 7., 2.],
        [0., 7., 4., 8.],
        [5., 2., 1., 7.]])

In [16]:
attention_mask = (input_ids != config.pad_token_id).to(torch.int64)
attention_mask

tensor([[0, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 1, 1],
        [1, 0, 1, 1]])

In [17]:
encoder_attention_mask = create_encoder_atn_mask(
    attention_mask=attention_mask,
    dtype=input_ids.dtype,
)

In [18]:
encoder_attention_mask.shape

torch.Size([5, 1, 4, 4])

# BartEncoder

In [19]:
from bart_model_from_scatch.encoder import BartEncoder

In [20]:
bart_encoder = BartEncoder(config)

Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=3072, bias=True)
Linear(in_features=3072, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=3072, bias=True)
Linear(in_features=3072, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=768, bias=True)
Linear(in_features=768, out_features=3072, bias=True)
Linear(in_features=3072, out_features=768, bias=True)
Linear(in_features=768, out_features=768

In [21]:
# test bart_encoder
input_embeds = torch.randn(2, 4, config.d_model)
attention_mask = torch.randint(0, 2, (2, 4))
output = bart_encoder(input_embeds, attention_mask)
print(output.shape)

torch.Size([2, 4, 768])
