In [1]:
import torch
from transformers import BartTokenizer, BartModel

tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartModel.from_pretrained("facebook/bart-base")

# Input sentence
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# 1. Run through encoder
encoder_outputs = model.encoder(input_ids=input_ids, attention_mask=attention_mask)
encoder_hidden_states = encoder_outputs.last_hidden_state

In [19]:
# model.decoder?

In [3]:
input_ids.shape

torch.Size([1, 8])

In [4]:
encoder_hidden_states.shape

torch.Size([1, 8, 768])

In [None]:
# encoder_hidden_states[:, 0, :].shape

torch.Size([1, 768])

In [11]:
## Taking last
h_z = encoder_hidden_states[:, 0, :]
h_z.shape

torch.Size([1, 768])

In [6]:
# h_z

In [24]:
model.encoder

BartEncoder(
  (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
  (layers): ModuleList(
    (0-5): 6 x BartEncoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (activation_fn): GELUActivation()
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [21]:
model.decoder

BartDecoder(
  (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)
  (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
  (layers): ModuleList(
    (0-5): 6 x BartDecoderLayer(
      (self_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (activation_fn): GELUActivation()
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (encoder_attn): BartSdpaAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (

In [23]:
for decoder_layer in model.decoder.layers:
    print(decoder_layer)

BartDecoderLayer(
  (self_attn): BartSdpaAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (activation_fn): GELUActivation()
  (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (encoder_attn): BartSdpaAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=768, out_features=768, bias=True)
  )
  (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=768, out_features=3072, bias=True)
  (fc2): Linear(in_features=3072, out_features=768, bias=True)
  (final_layer_norm): LayerNorm((768,)

In [None]:
#Decoder
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

decoder_outputs = model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=attention_mask
        )

In [None]:
decoder_outputs.last_hidden_state.shape

torch.Size([1, 8, 768])

In [31]:
# 2. Prepare decoder input: <s> token
decoder_input_ids = tokenizer("<s>", return_tensors="pt").input_ids
decoder_attention_mask = torch.ones_like(decoder_input_ids)

In [35]:
model.decoder.embed_tokens(decoder_input_ids).shape

torch.Size([1, 3, 768])

In [37]:
model.decoder.embed_positions(decoder_input_ids).shape

torch.Size([1, 3, 768])

In [None]:
hidden_states = model.decoder.embed_tokens(decoder_input_ids) + model.decoder.embed_positions(decoder_input_ids)

In [32]:
import torch

# 1. Convert decoder input tokens into embeddings
hidden_states = model.decoder.embed_tokens(decoder_input_ids) + model.decoder.embed_positions(decoder_input_ids)

# 2. Expand encoder attention mask for cross-attention
def expand_mask(mask, dtype=torch.float32, tgt_len=None):
    """Expands the encoder attention mask for cross-attention in the decoder."""
    mask = mask[:, None, None, :]  # Shape: (batch_size, 1, 1, encoder_seq_len)
    if tgt_len is not None:
        mask = mask.expand(-1, -1, tgt_len, -1)  # Expand for target sequence length
    return mask.to(dtype)

encoder_attention_mask = expand_mask(attention_mask, dtype=torch.float32, tgt_len=decoder_input_ids.shape[1])

# 3. Process each decoder layer (self-attention + cross-attention)
for decoder_layer in model.decoder.layers:
    hidden_states = decoder_layer(
        hidden_states,                      # Decoder input embeddings
        encoder_hidden_states=encoder_hidden_states,  # Encoder last hidden state
        encoder_attention_mask=encoder_attention_mask,  # Mask for cross-attention
        past_key_value=None,  
        use_cache=False,  
        output_attentions=False  
    )[0]  # Extract the hidden state from the tuple

# 4. Apply final LayerNorm
hidden_states = model.decoder.layernorm_embedding(hidden_states)

# 5. (Optional) Generate token logits
# logits = model.lm_head(hidden_states)  # Get token probabilities?


In [28]:
hidden_states.shape

torch.Size([1, 3, 768])