In [None]:
# LLM_Summarizer.ipynb

import torch
from transformers import LEDTokenizer, LEDForConditionalGeneration

def get_summarizer(model_name="allenai/led-base-16384"):
    """Initializes and returns LED summarizer components."""
    tokenizer = LEDTokenizer.from_pretrained(model_name)
    model = LEDForConditionalGeneration.from_pretrained(model_name)
    model.eval()
    return tokenizer, model

def summarize_with_led(text, tokenizer, model, max_input_tokens=4096, summary_max_len=250):
    """Generates a summary for a given text using LED."""
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_input_tokens
    )

    global_attention_mask = torch.zeros_like(inputs["input_ids"])
    global_attention_mask[:, 0] = 1

    summary_ids = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        global_attention_mask=global_attention_mask,
        max_length=summary_max_len,
        min_length=80,
        num_beams=4,
        length_penalty=2.0,
        early_stopping=True
    )

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

