In [1]:
from transformers import EncoderDecoderModel, BertTokenizer, BartTokenizer

# Separate tokenizers
encoder_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

# Load model with encoder from BERT and decoder from full BART
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-uncased", "facebook/bart-base"
)

# Set generation and padding config using decoder's tokenizer
model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.eos_token_id = decoder_tokenizer.eos_token_id
model.config.pad_token_id = decoder_tokenizer.pad_token_id


  from .autonotebook import tqdm as notebook_tqdm
Some weights of BartForCausalLM were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['lm_head.weight', 'model.decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
import pandas as pd

train_df = pd.read_csv('samsum_csv_data/train.csv')
# Take the first 100 samples from your DataFrame
sample_df = train_df.iloc[:100]
input_texts = sample_df["dialogue"].tolist()
target_texts = sample_df["summary"].tolist()

# Tokenize encoder inputs (dialogue) using BERT tokenizer
encoder_inputs = encoder_tokenizer(
    input_texts,
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt"
)

# Tokenize decoder targets (summary) using BART tokenizer
with decoder_tokenizer.as_target_tokenizer():
    decoder_inputs = decoder_tokenizer(
        target_texts,
        padding=True,
        truncation=True,
        max_length=64,
        return_tensors="pt"
    )

# Prepare decoder labels
labels = decoder_inputs["input_ids"].clone()
labels[labels == decoder_tokenizer.pad_token_id] = -100




In [3]:
import torch
from torch.utils.data import Dataset

# Step 3.1: Define a simple Dataset wrapper
class SummarizationDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item["labels"] = self.labels[idx]
        return item

# Step 3.2: Create the dataset and a single batch
dataset = SummarizationDataset(encoder_inputs, labels)
batch = {k: v.unsqueeze(0) for k, v in dataset[0].items()}  # batch size = 1

# Step 3.3: Move model and batch to CPU
device = torch.device("cpu")
model = model.to(device)
batch = {k: v.to(device) for k, v in batch.items()}

# Step 3.4: Forward and backward pass
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

outputs = model(**batch)
loss = outputs

  decoder_attention_mask = decoder_input_ids.new_tensor(decoder_input_ids != self.config.pad_token_id)
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm

# Create DataLoader for batching
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Training setup
device = torch.device("cpu")  # or "cuda" if you're ready to try GPU again
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
epochs = 3

# Training loop
model.train()
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    loop = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in loop:
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Shift decoder inputs and mask padding tokens
        decoder_input_ids = batch["labels"][:, :-1].clone()
        decoder_input_ids[decoder_input_ids == -100] = model.config.pad_token_id
        labels = batch["labels"][:, 1:].clone()

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            decoder_input_ids=decoder_input_ids,
            labels=labels
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        loop.set_postfix(loss=loss.item())


Epoch 1


Training:   4%|▍         | 1/25 [00:10<04:16, 10.70s/it, loss=11.2]