In [None]:
# =========================================================
#  LEGAL ENSEMBLE SUMMARIZER â€” Train & Evaluate on Validation Set (Fixed JSON float issue)
# =========================================================

import os
import json
import re
import torch
import numpy as np
import jsonlines
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from rouge import Rouge
from sacrebleu.metrics import BLEU
import evaluate

device = "cuda" if torch.cuda.is_available() else "cpu"

# =========================================================
# 1 Clean Judgment Text
# =========================================================
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(
        r"(Petitioner\s*:-.*?Respondent\s*:-.*?Counsel for Respondent\s*:-.*?)(\1)+",
        r"\1",
        text,
        flags=re.DOTALL
    )
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text


# =========================================================
# 2 Ensemble Class
# =========================================================
class LegalEnsembleSummarizer:
    def __init__(self, bart_path, pegasus_path, led_path):
        self.bart_tokenizer = BartTokenizer.from_pretrained(bart_path)
        self.bart_model = BartForConditionalGeneration.from_pretrained(bart_path).to(device)
        self.bart_model.eval()

        self.pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_path)
        self.pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_path).to(device)
        self.pegasus_model.eval()

        self.led_tokenizer = LEDTokenizer.from_pretrained(led_path)
        self.led_model = LEDForConditionalGeneration.from_pretrained(led_path).to(device)
        self.led_model.eval()

    def generate_summary(self, model, tokenizer, text, max_words=500, min_words=400, model_type="bart"):
        max_tokens = int(max_words * 1.5)
        min_tokens = int(min_words * 1.5)
        if model_type == "led":
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=16384).to(device)
        else:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
        summary_ids = model.generate(
            **inputs,
            max_length=max_tokens,
            min_length=min_tokens,
            num_beams=4,
            no_repeat_ngram_size=4,
            early_stopping=True
        )
        return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def get_embedding(self, text):
        inputs = self.bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
        with torch.no_grad():
            encoder_outputs = self.bart_model.model.encoder(**inputs)
        emb = encoder_outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return emb

    def rerank(self, judgment, summaries):
        j_emb = self.get_embedding(judgment)
        sims = [cosine_similarity(j_emb, self.get_embedding(s))[0][0] for s in summaries]
        return sims

    def summarize(self, judgment_text):
        bart_sum = self.generate_summary(self.bart_model, self.bart_tokenizer, judgment_text, model_type="bart")
        peg_sum = self.generate_summary(self.pegasus_model, self.pegasus_tokenizer, judgment_text, model_type="pegasus")
        led_sum = self.generate_summary(self.led_model, self.led_tokenizer, judgment_text, model_type="led")

        summaries = [bart_sum, peg_sum, led_sum]
        sims = self.rerank(judgment_text, summaries)
        best_idx = int(np.argmax(sims))
        best_summary = summaries[best_idx]
        return best_summary, sims


# =========================================================
# 3 Load Dataset & Split
# =========================================================
train_text_file = "train_judg.jsonl"
train_summ_file = "train_ref_summ.jsonl"

all_texts = [json.loads(line) for line in open(train_text_file, "r", encoding="utf-8")]
all_summaries = [json.loads(line) for line in open(train_summ_file, "r", encoding="utf-8")]

all_data = [{"ID": t["ID"], "text": clean_judgment_text(t["Judgment"]), "summary": s["Summary"]}
            for t, s in zip(all_texts, all_summaries)]

train_list, val_list = train_test_split(all_data, test_size=0.1, random_state=42)

print(f" Dataset split: {len(train_list)} training, {len(val_list)} validation")


# =========================================================
# 4 Initialize Ensemble
# =========================================================
bart_path = "bart_model"
pegasus_path = "pegasus_model"
led_path = "led_model"

ensemble = LegalEnsembleSummarizer(bart_path, pegasus_path, led_path)


# =========================================================
# 5 Generate Summaries for Validation
# =========================================================
generated_summaries = []

print("Generating ensemble summaries for validation set (400â€“500 words)...")
for example in tqdm(val_list):
    text = example["text"][:16384]  # LED supports long inputs
    best_summary, sims = ensemble.summarize(text)
    generated_summaries.append({
        "ID": example["ID"],
        "generated_summary": best_summary,
        "reference_summary": example["summary"],
        "similarities": [float(x) for x in sims]
    })

os.makedirs("outputs", exist_ok=True)
val_output_file = "outputs/ensemble_val_summaries.jsonl"
with open(val_output_file, "w", encoding="utf-8") as f:
    for item in generated_summaries:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print(f" Validation summaries saved to {val_output_file}")


# =========================================================
# 6 Evaluate Summaries
# =========================================================
rouge_scorer = Rouge()
bleu_scorer = BLEU()
bertscore_eval = evaluate.load("bertscore")

rouge1_scores, rouge2_scores, rougel_scores, bleu_scores, bert_scores = [], [], [], [], []

for item in generated_summaries:
    hyp = item["generated_summary"]
    ref = item["reference_summary"]

    try:
        r = rouge_scorer.get_scores(hyp, ref)[0]
        rouge1 = r["rouge-1"]["f"] * 100
        rouge2 = r["rouge-2"]["f"] * 100
        rougel = r["rouge-l"]["f"] * 100
    except Exception:
        rouge1, rouge2, rougel = 0, 0, 0

    try:
        bleu = bleu_scorer.sentence_score(hyp, [ref]).score
    except Exception:
        bleu = 0

    try:
        bert = bertscore_eval.compute(predictions=[hyp], references=[ref], lang="en")["f1"][0] * 100
    except Exception:
        bert = 0

    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)
    rougel_scores.append(rougel)
    bleu_scores.append(bleu)
    bert_scores.append(bert)

metrics = {
    "ROUGE-1 (%)": np.mean(rouge1_scores),
    "ROUGE-2 (%)": np.mean(rouge2_scores),
    "ROUGE-L (%)": np.mean(rougel_scores),
    "BLEU (%)": np.mean(bleu_scores),
    "BERTScore-F1 (%)": np.mean(bert_scores),
    "AVG_SCORE (%)": np.mean([np.mean(rouge2_scores), np.mean(rougel_scores), np.mean(bleu_scores)])
}

#  Convert NumPy floats to Python floats before saving
metrics = {k: float(v) if isinstance(v, (np.float32, np.float64, np.float_)) else v for k, v in metrics.items()}

metrics_path = "outputs/ensemble_val_metrics.json"
with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=4)

print("\n========== ðŸ§¾ Ensemble Validation Metrics (in %) ==========")
for k, v in metrics.items():
    print(f"{k}: {v:.2f}")
print("===========================================================")


In [None]:
# =========================================================
#  Legal Ensemble Summarizer â€” BART + Pegasus + LED
#  Generates summaries (400â€“500 words), reranks using fine-tuned BART embeddings,
#    saves outputs in JSONL format:
#    {"ID": "<datapoint-id>", "Summary": "<generated summary>"}
# =========================================================

import torch, re, os, json, jsonlines, shutil
from tqdm import tqdm
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# =========================================================
# 1 Device setup
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# =========================================================
# 2 Legal-specific preprocessing function
# =========================================================
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(
        r"(Petitioner\s*:-.*?Respondent\s*:-.*?Counsel for Respondent\s*:-.*?)(\1)+",
        r"\1",
        text,
        flags=re.DOTALL
    )
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# =========================================================
# 3 Ensemble Summarizer Class
# =========================================================
class LegalEnsembleSummarizer:
    def __init__(self, bart_path, pegasus_path, led_path):
        # Load BART model (for generation & embedding)
        self.bart_tokenizer = BartTokenizer.from_pretrained(bart_path)
        self.bart_model = BartForConditionalGeneration.from_pretrained(bart_path).to(device)
        self.bart_model.eval()  # generator & reranker

        # Load Pegasus model
        self.pegasus_tokenizer = PegasusTokenizer.from_pretrained(pegasus_path)
        self.pegasus_model = PegasusForConditionalGeneration.from_pretrained(pegasus_path).to(device)
        self.pegasus_model.eval()

        # Load LED model
        self.led_tokenizer = LEDTokenizer.from_pretrained(led_path)
        self.led_model = LEDForConditionalGeneration.from_pretrained(led_path).to(device)
        self.led_model.eval()

    # Generate summary from a model
    def generate_summary(self, model, tokenizer, text, max_words=500, min_words=400, model_type="bart"):
        max_tokens = int(max_words * 1.5)
        min_tokens = int(min_words * 1.5)
        if model_type == "led":
            # LED supports long documents
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=16384).to(device)
        else:
            inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
        summary_ids = model.generate(
            **inputs,
            max_length=max_tokens,
            min_length=min_tokens,
            num_beams=4,
            early_stopping=True
        )
        return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    # Get embeddings using BART encoder
    def get_embedding(self, text):
        inputs = self.bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
        with torch.no_grad():
            encoder_outputs = self.bart_model.model.encoder(**inputs)
        emb = encoder_outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return emb

    # Compute similarity of each summary with judgment
    def rerank(self, judgment, summaries):
        j_emb = self.get_embedding(judgment)
        sims = [cosine_similarity(j_emb, self.get_embedding(s))[0][0] for s in summaries]
        return sims

    # Generate summaries and select best one
    def summarize(self, judgment_text):
        bart_sum = self.generate_summary(self.bart_model, self.bart_tokenizer, judgment_text, model_type="bart")
        peg_sum = self.generate_summary(self.pegasus_model, self.pegasus_tokenizer, judgment_text, model_type="pegasus")
        led_sum = self.generate_summary(self.led_model, self.led_tokenizer, judgment_text, model_type="led")

        summaries = [bart_sum, peg_sum, led_sum]
        sims = self.rerank(judgment_text, summaries)
        best_idx = int(np.argmax(sims))
        best_summary = summaries[best_idx]
        return best_summary, sims

# =========================================================
# 4 Initialize paths & models
# =========================================================
bart_path = "bart_legal_summ_model_final"
pegasus_path = "pegasus_legal_summ_model_final"
led_path = "led_legal_summ_model_final"
test_path = "test_judg.jsonl"       # Input JSONL file: {"id": "<ID>", "judgment": "<text>"}
save_dir = "ensemble_legal_model"
output_file = "generated_summaries.jsonl"
similarity_report_file = "ensemble_test_similarity_report.jsonl"

ensemble = LegalEnsembleSummarizer(bart_path, pegasus_path, led_path)

# =========================================================
# 5 Load Test Data
# =========================================================
judgments = []
with jsonlines.open(test_path) as reader:
    for obj in reader:
        judgments.append({
            "ID": obj["id"],
            "text": clean_judgment_text(obj["judgment"])
        })

print(f"Loaded {len(judgments)} test cases")

# =========================================================
# 6 Generate Summaries & Compute Cosine Similarities
# =========================================================
similarity_report = []

with jsonlines.open(output_file, mode="w") as writer:
    for idx, j in enumerate(tqdm(judgments, desc="Generating ensemble summaries")):
        text = j["text"][:16384]  # LED can handle long inputs
        summary, sims = ensemble.summarize(text)
        best_sim = float(max(sims))
        chosen_model = ["BART", "Pegasus", "LED"][int(np.argmax(sims))]

        # Save summary
        writer.write({
            "ID": j["ID"],
            "Summary": summary
        })

        # Save similarity info
        similarity_report.append({
            "ID": j["ID"],
            "bart_similarity": float(sims[0]),
            "pegasus_similarity": float(sims[1]),
            "led_similarity": float(sims[2]),
            "chosen_model": chosen_model,
            "best_similarity": best_sim
        })

        if idx < 10:
            print(f"\nID: {j['ID']}")
            print(f"   BART sim: {sims[0]:.4f}")
            print(f"   Pegasus sim: {sims[1]:.4f}")
            print(f"   LED sim: {sims[2]:.4f}")
            print(f"   Chosen: {chosen_model} (Cosine={best_sim:.4f})")

# =========================================================
# 7 Save Similarity Report
# =========================================================
with jsonlines.open(similarity_report_file, mode="w") as writer:
    writer.write_all(similarity_report)

print("\n All summaries generated successfully!")
print(f"  Saved summaries â†’ {output_file}")
print(f"  Saved similarity report â†’ {similarity_report_file}")

# =========================================================
# 8 Save Full Ensemble Folder
# =========================================================
def save_full_ensemble(save_dir):
    os.makedirs(save_dir, exist_ok=True)
    shutil.copytree(bart_path, os.path.join(save_dir, "bart_model"), dirs_exist_ok=True)
    shutil.copytree(pegasus_path, os.path.join(save_dir, "pegasus_model"), dirs_exist_ok=True)
    shutil.copytree(led_path, os.path.join(save_dir, "led_model"), dirs_exist_ok=True)
    
    metadata = {
        "bart_model": "bart_model",
        "pegasus_model": "pegasus_model",
        "led_model": "led_model",
        "reranker_model": "BART (fine-tuned)",
        "description": "Ensemble of fine-tuned BART + legal Pegasus + LED reranked using BART embeddings"
    }
    with open(os.path.join(save_dir, "ensemble_metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)
    print(f" Full ensemble saved at: {save_dir}")

save_full_ensemble(save_dir)
