In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

# Load pre-trained BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

# Load dataset (using the CNN/DailyMail dataset for summarization as an example)
dataset = load_dataset("cnn_dailymail", "3.0.0")

# Preprocess the data to tokenize the inputs and targets
def preprocess_data(examples):
    inputs = [text for text in examples['article']]
    targets = [text for text in examples['highlights']]
    
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=150, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# Apply preprocessing to dataset
train_data = dataset['train'].map(preprocess_data, batched=True)
val_data = dataset['validation'].map(preprocess_data, batched=True)

# Create DataLoader for batching
train_dataloader = DataLoader(train_data, batch_size=4)
val_dataloader = DataLoader(val_data, batch_size=4)

# Set up optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Train the model
model.train()

for epoch in range(3):  # Example for 3 epochs
    total_loss = 0
    for batch in train_dataloader:
        # Move batch to GPU if available
        batch = {k: v.to(model.device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        
        # Calculate loss (cross-entropy)
        loss = outputs.loss
        total_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_dataloader)}")

# Evaluate the model on validation data
model.eval()
total_val_loss = 0
with torch.no_grad():
    for batch in val_dataloader:
        batch = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**batch)
        val_loss = outputs.loss
        total_val_loss += val_loss.item()

print(f"Validation Loss: {total_val_loss / len(val_dataloader)}")


KeyboardInterrupt: 