In [None]:
# =========================================================
# Legal Ensemble Summarizer â€” BART + LED + InLegalBERT
#  Generates summaries (400â€“500 words) for validation set,
#    reranks using InLegalBERT, and calculates evaluation scores
# =========================================================

import torch, re, os, json, jsonlines, shutil
from tqdm import tqdm
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration,
    AutoTokenizer, AutoModel
)
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import numpy as np
from rouge import Rouge
from sacrebleu.metrics import BLEU
import evaluate

# =========================================================
# 1 Setup
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# =========================================================
# 2 Preprocessing
# =========================================================
def clean_text(text):
    text = re.sub(r'Page\s*\d+\s*of\s*\d+', ' ', text)
    text = re.sub(r'(Case\s*No\.?|Crl\.A\.No\.?|Appeal\s*No\.?)\s*[\w/-]+', ' ', text)
    text = re.sub(r'\(\d{4}\)\s*\d+\s*[A-Z]+\s*\d+', ' ', text)
    text = re.sub(r'AIR\s*\d{4}\s*[A-Z]+\s*\d+', ' ', text)
    text = re.sub(r'Dated\s*[:\-]?\s*\d{1,2}[-./]\d{1,2}[-./]\d{2,4}', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# =========================================================
# 3 Ensemble Summarizer Class
# =========================================================
class LegalEnsembleSummarizer:
    def __init__(self, bart_path, led_path, reranker_name="law-ai/InLegalBERT"):
        # Load BART model
        self.bart_tokenizer = BartTokenizer.from_pretrained(bart_path)
        self.bart_model = BartForConditionalGeneration.from_pretrained(bart_path).to(device)

        # Load LED model
        self.led_tokenizer = LEDTokenizer.from_pretrained(led_path)
        self.led_model = LEDForConditionalGeneration.from_pretrained(led_path).to(device)

        # Load InLegalBERT for reranking
        self.rerank_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
        self.rerank_model = AutoModel.from_pretrained(reranker_name).to(device)

    def generate_summary(self, model, tokenizer, text, max_words=500, min_words=400):
        # Adjust input length depending on model type
        if "bart" in model.config.model_type.lower():
            max_input_len = 1024
        elif "led" in model.config.model_type.lower():
            max_input_len = 4096
        else:
            max_input_len = 1024  # default safe limit

        max_tokens = int(max_words * 1.5)
        min_tokens = int(min_words * 1.5)

        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_len).to(device)

        # LED requires global attention on first token
        if "led" in model.config.model_type.lower():
            global_attention_mask = torch.zeros_like(inputs["input_ids"])
            global_attention_mask[:, 0] = 1
            summary_ids = model.generate(
                **inputs,
                global_attention_mask=global_attention_mask,
                max_length=max_tokens,
                min_length=min_tokens,
                num_beams=4,
                early_stopping=True
            )
        else:
            summary_ids = model.generate(
                **inputs,
                max_length=max_tokens,
                min_length=min_tokens,
                num_beams=4,
                early_stopping=True
            )

        return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def get_embedding(self, text):
        inputs = self.rerank_tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = self.rerank_model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).cpu().numpy()

    def rerank(self, judgment, summaries):
        j_emb = self.get_embedding(judgment)
        sims = [cosine_similarity(j_emb, self.get_embedding(s))[0][0] for s in summaries]
        return sims

    def summarize(self, judgment_text):
        bart_sum = self.generate_summary(self.bart_model, self.bart_tokenizer, judgment_text)
        led_sum = self.generate_summary(self.led_model, self.led_tokenizer, judgment_text)
        sims = self.rerank(judgment_text, [bart_sum, led_sum])
        best_idx = int(np.argmax(sims))
        return [bart_sum, led_sum][best_idx], sims

# =========================================================
# 4 Load Training Dataset and Split 90/10
# =========================================================
train_judg_file = "train_judg.jsonl"    # {"ID": "...", "Judgment": "..."}
train_summ_file = "train_ref_summ.jsonl"  # {"ID": "...", "Summary": "..."}

train_judgs = []
with jsonlines.open(train_judg_file) as f_j, jsonlines.open(train_summ_file) as f_s:
    for j_obj, s_obj in zip(f_j, f_s):
        train_judgs.append({
            "ID": j_obj["ID"],
            "text": clean_text(j_obj["Judgment"]),
            "summary": s_obj["Summary"]
        })

train_list, val_list = train_test_split(train_judgs, test_size=0.1, random_state=42)
print(f" Dataset split: {len(train_list)} training, {len(val_list)} validation")

# =========================================================
# 5 Initialize Ensemble
# =========================================================
bart_path = "bart_legal_summ_model_final"
led_path = "led_legal_summ_model_final"

ensemble = LegalEnsembleSummarizer(bart_path, led_path)

# =========================================================
# 6 Generate Summaries for Validation Set
# =========================================================
generated_summaries = []

print("Generating summaries for validation set...")
for example in tqdm(val_list):
    text = example["text"][:8000]  # limit to 8000 characters for safety
    summary, sims = ensemble.summarize(text)
    generated_summaries.append({
        "ID": example["ID"],
        "generated_summary": summary,
        "reference_summary": example["summary"],
        "bart_similarity": float(sims[0]),
        "led_similarity": float(sims[1]),
        "chosen_model": "BART" if sims[0] > sims[1] else "LED"
    })

os.makedirs("outputs", exist_ok=True)
val_output_file = "outputs/val_generated_summaries.jsonl"
with jsonlines.open(val_output_file, mode="w") as writer:
    writer.write_all(generated_summaries)

print(f" Validation summaries saved to {val_output_file}")

# =========================================================
# 7 Evaluation Metrics: ROUGE, BLEU, BERTScore
# =========================================================
rouge_scorer = Rouge()
bleu_scorer = BLEU()
bertscore_eval = evaluate.load("bertscore")

rouge1_scores, rouge2_scores, rougel_scores, bleu_scores, bert_scores = [], [], [], [], []

for item in generated_summaries:
    hyp = item["generated_summary"]
    ref = item["reference_summary"]

    # ROUGE
    try:
        r = rouge_scorer.get_scores(hyp, ref)[0]
        rouge1 = r["rouge-1"]["f"] * 100
        rouge2 = r["rouge-2"]["f"] * 100
        rougel = r["rouge-l"]["f"] * 100
    except:
        rouge1, rouge2, rougel = 0, 0, 0

    # BLEU
    try:
        bleu = bleu_scorer.sentence_score(hyp, [ref]).score
    except:
        bleu = 0

    # BERTScore
    try:
        bert = bertscore_eval.compute(predictions=[hyp], references=[ref], lang="en")["f1"][0] * 100
    except:
        bert = 0

    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)
    rougel_scores.append(rougel)
    bleu_scores.append(bleu)
    bert_scores.append(bert)

metrics = {
    "ROUGE-1 (%)": float(np.mean(rouge1_scores)),
    "ROUGE-2 (%)": float(np.mean(rouge2_scores)),
    "ROUGE-L (%)": float(np.mean(rougel_scores)),
    "BLEU (%)": float(np.mean(bleu_scores)),
    "BERTScore-F1 (%)": float(np.mean(bert_scores)),
    "AVG_SCORE (%)": float(np.mean([np.mean(rouge2_scores), np.mean(rougel_scores), np.mean(bleu_scores)]))
}

metrics_path = "outputs/val_metrics.json"
with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump(metrics, f, indent=4)

print("\n========== ðŸ§¾ Validation Metrics ==========")
for k, v in metrics.items():
    print(f"{k}: {v:.2f}")
print("==========================================")

# =========================================================
# 8 Save the Full Ensemble Folder
# =========================================================
save_dir = "ensemble_legal_model"

def save_full_ensemble(save_dir):
    os.makedirs(save_dir, exist_ok=True)
    shutil.copytree(bart_path, os.path.join(save_dir, "bart_model"), dirs_exist_ok=True)
    shutil.copytree(led_path, os.path.join(save_dir, "led_model"), dirs_exist_ok=True)

    metadata = {
        "bart_model": "bart_model",
        "led_model": "led_model",
        "reranker_model": "law-ai/InLegalBERT",
        "description": "Ensemble of fine-tuned BART + LED reranked with InLegalBERT for legal summarization"
    }
    with open(os.path.join(save_dir, "ensemble_metadata.json"), "w") as f:
        json.dump(metadata, f, indent=2)

    print(f" Full ensemble saved at: {save_dir}")

save_full_ensemble(save_dir)
