<a href="https://colab.research.google.com/github/srigana18/LegalSumm/blob/main/LegalSumm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
%pip install rouge sentence-transformers nltk transformers datasets scikit-learn sentencepiece huggingface-hub --quiet

# Authenticate with Hugging Face (only if you need access to gated models)
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [9]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')  # Required with NLTK >= 3.8.2

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [10]:

import torch
import pandas as pd
import numpy as np
import nltk
from rouge import Rouge
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from nltk.cluster import KMeansClusterer
from sklearn.metrics import pairwise_distances_argmin_min
from scipy.spatial import distance_matrix
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ================= DEVICE CONFIG =================
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print(f"Using device: {device}")

# ================= NLTK SETUP =================

# ================= LOAD BILL_SUM DATA =================
dataset = load_dataset("billsum")
test_cases = dataset["ca_test"]
subset = test_cases.select(range(5))  # adjust subset size if you'd like
texts = subset["text"]
gold_summaries = subset["summary"]

# ================= LOAD MODELS =================
embedder = SentenceTransformer('distiluse-base-multilingual-cased')
pegasus_model_name = "nsi319/legal-pegasus"
tokenizer = AutoTokenizer.from_pretrained(pegasus_model_name)
pegasus_model = AutoModelForSeq2SeqLM.from_pretrained(pegasus_model_name).to(device)

rouge = Rouge()

# ================= HELPER: KMEANS-BERT EXTRACTIVE =================
def kmeans_extractive_summary(text, num_clusters=10, iterations=25):
    sentences = nltk.sent_tokenize(text)
    sentences = [s.strip() for s in sentences if len(s.strip()) > 0]
    if len(sentences) == 0:
        return ""
    embeddings = embedder.encode(sentences)
    kclusterer = KMeansClusterer(
        num_clusters,
        distance=nltk.cluster.util.cosine_distance,
        repeats=iterations,
        avoid_empty_clusters=True
    )
    assigned = kclusterer.cluster(embeddings, assign_clusters=True)
    centroids = kclusterer.means()
    tempdata = pd.DataFrame({
        "sentence": sentences,
        "embeddings": list(embeddings),
        "cluster": assigned
    })
    tempdata["centroid"] = tempdata["cluster"].apply(lambda x: centroids[x])
    tempdata["dist"] = tempdata.apply(
        lambda r: distance_matrix([r["embeddings"]], [r["centroid"]])[0][0], axis=1
    )
    selected = (
        tempdata.sort_values("dist", ascending=True)
        .groupby("cluster")
        .head(1)
        .sort_index()["sentence"]
        .tolist()
    )
    return " ".join(selected)

# ================= HELPER: PEGASUS SUMMARY =================
def run_pegasus_summary(model, tokenizer, text, max_length=1024, num_beams=9):
    inputs = tokenizer.encode(
        text,
        truncation=True,
        padding="longest",
        return_tensors="pt"
    ).to(device)
    summary_ids = model.generate(
        inputs,
        max_length=max_length,
        num_beams=num_beams,
        early_stopping=True
    )
    decoded = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
    return decoded[0]

# ================= STUB: LLM EVALUATION =================
def evaluate_with_llm_stub(source_text, summary, gold_summary=None):
    # TODO: integrate real LLM evaluation later
    return {"factuality": None, "coherence": None, "informativeness": None}

# ================= MAIN LOOP =================
results = []
for i, (text, gold) in enumerate(zip(texts, gold_summaries)):
    print(f"\nProcessing document {i+1}/{len(texts)}...")

    # Step 1: KMeans-BERT extractive summary
    extractive = kmeans_extractive_summary(text)

    # Step 2: PEGASUS directly on full document
    pegasus_full = run_pegasus_summary(pegasus_model, tokenizer, text)

    # Step 3: PEGASUS on extractive summary (Hybrid)
    if extractive.strip():
        hybrid = run_pegasus_summary(pegasus_model, tokenizer, extractive)
    else:
        hybrid = ""

    results.append({
        "text": text,
        "gold_summary": gold,
        "kmeans_extractive": extractive,
        "pegasus_full": pegasus_full,
        "hybrid_kmeans_pegasus": hybrid
    })

# ================= BUILD DATAFRAME & SAVE =================
df = pd.DataFrame(results)
output_csv = "hybrid_legal_billsum.csv"
df.to_csv(output_csv, index=False)

print(f"\nHybrid summarization results saved to: {output_csv}")
display(df[["text", "gold_summary", "kmeans_extractive", "pegasus_full", "hybrid_kmeans_pegasus"]].head(2))


Using device: cpu

Processing document 1/5...

Processing document 2/5...

Processing document 3/5...

Processing document 4/5...

Processing document 5/5...

Hybrid summarization results saved to: hybrid_legal_billsum.csv


Unnamed: 0,text,gold_summary,kmeans_extractive,pegasus_full,hybrid_kmeans_pegasus
0,The people of the State of California do enact...,Existing property tax law establishes a vetera...,These organizations help preserve the memories...,The State Board of Equalization has determined...,The Securities and Exchange Commission today a...
1,The people of the State of California do enact...,Existing law provides that the Board of Parole...,"SEC. Within 30 days of receiving the request, ...",A prisoner is not eligible for resentence or r...,The Securities and Exchange Commission today i...
