In [None]:

import os, re, json, jsonlines, math, time
from tqdm import tqdm
import numpy as np
import torch
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import nltk
nltk.download("punkt")

# ---------------------------
# CONFIG
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Retriever model (recommend a legal SBERT if available)
RETRIEVER_MODEL = "law-ai/InLegalBERT"  # replace with legal SBERT for better perf

# Generator checkpoints (use your fine-tuned models / local paths)
BART_MODEL = "bart_legal_summ_model_final"
PEGASUS_MODEL = "pegasus_legal_summ_model_final"
LED_MODEL = "led_legal_summ_model_final"

# Files
TEST_PATH = "test_judg.jsonl"   # input JSONL with {"id": "...", "judgment": "..."}
OUT_SUMMARIES = "generated_summaries_rag.jsonl"
OUT_COSINE = "cosine_scores_rag.jsonl"

# Chunking & retrieval params
CHUNK_SIZE_TOKENS = 800       # chunk token length when splitting (approx)
CHUNK_OVERLAP_TOKENS = 200
TOP_K = 8                     # number of chunks to retrieve per judgment (initial)
RETRIEVE_BY = "sentence"      # "sentence" or "token" chunking

# Generation: enforce ~400-500 words
MIN_WORDS = 400
MAX_WORDS = 500
TOK_PER_WORD = 1.5            # rough token estimate per word => tokens = words * TOK_PER_WORD
GEN_MIN_TOKENS = int(MIN_WORDS * TOK_PER_WORD)   # e.g., 400 * 1.5 = 600
GEN_MAX_TOKENS = int(MAX_WORDS * TOK_PER_WORD)   # e.g., 500 * 1.5 = 750

# Retrieval-to-generation flow
MAX_RETRIEVE_ROUNDS = 3       # if initial generation is too short, you can expand retrieval rounds

# Safety / diversity thresholds
SENTENCE_SIM_DIVERSITY = 0.85

# ---------------------------
# UTILITIES
# ---------------------------
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# simple splitter by sentences to produce retrieval units
def chunk_into_sentences(text):
    sents = sent_tokenize(text)
    return [s.strip() for s in sents if len(s.strip())>20]

# token-based chunking using tokenizer.encode/decode for better boundaries
def chunk_by_tokens(text, tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return [tokenizer.decode(tokens, clean_up_tokenization_spaces=True)]
    chunks = []
    step = max_tokens - overlap
    for i in range(0, len(tokens), step):
        chunk_tokens = tokens[i:i+max_tokens]
        if not chunk_tokens:
            break
        chunks.append(tokenizer.decode(chunk_tokens, clean_up_tokenization_spaces=True))
        if i + max_tokens >= len(tokens):
            break
    return chunks

def words_count(text):
    return len(re.findall(r"\w+", text))

def truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS):
    words = summary.split()
    if len(words) <= max_words and len(words) >= min_words:
        return summary
    sents = sent_tokenize(summary)
    out = []
    w = 0
    for s in sents:
        sw = len(s.split())
        if w + sw > max_words:
            break
        out.append(s)
        w += sw
    # if result shorter than min_words, append first sentences until min reached (from original)
    if w < min_words:
        # append more sentences (allow some exceeding max)
        for s in sents[len(out):]:
            out.append(s)
            w += len(s.split())
            if w >= min_words:
                break
    return " ".join(out)

# ---------------------------
# LOAD MODELS
# ---------------------------
print("Loading retriever and generators (may take a while)...")
retriever = SentenceTransformer(RETRIEVER_MODEL, device=device)

bart_tokenizer = BartTokenizer.from_pretrained(BART_MODEL)
bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL).to(device).eval()

pegasus_tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL).to(device).eval()

led_tokenizer = LEDTokenizer.from_pretrained(LED_MODEL)
led_model = LEDForConditionalGeneration.from_pretrained(LED_MODEL).to(device).eval()

# we use BART encoder mean-pooled embeddings for cosine evaluation
def bart_encode_mean(text, max_length=1024):
    inputs = bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length, padding=True).to(device)
    with torch.no_grad():
        enc = bart_model.model.encoder(**inputs)
    return enc.last_hidden_state.mean(dim=1).cpu().numpy()

# generation helper
def generate_with(model, tokenizer, input_text, min_tokens=GEN_MIN_TOKENS, max_tokens=GEN_MAX_TOKENS, model_type="bart"):
    if model_type == "led":
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=LED_MODEL and 16000 or 8192).to(device)
    else:
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).to(device)

    out_ids = model.generate(
        **inputs,
        num_beams=4,
        max_length=max_tokens,
        min_length=min_tokens,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

# ---------------------------
# RAG Workflow per judgment
# ---------------------------
def build_chunk_corpus(text, method="token"):
    """Return list of chunks (strings) for this judgment"""
    if method == "sentence":
        # chunk by grouping N sentences per chunk to reach approx chunk size
        sents = chunk_into_sentences(text)
        chunks = []
        N = 8  # start grouping ~8 sentences (tunable)
        for i in range(0, len(sents), N):
            chunks.append(" ".join(sents[i:i+N]))
        return chunks
    else:
        # token-based chunking using LED tokenizer for stable tokenization
        return chunk_by_tokens(text, led_tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS)

def retrieve_top_k(chunks, query_embedding, k=TOP_K):
    # compute embeddings for chunks and cosine similarity
    chunk_embs = retriever.encode(chunks, convert_to_tensor=True, show_progress_bar=False)
    scores = util.cos_sim(query_embedding, chunk_embs)[0]  # tensor
    topk = torch.topk(scores, min(k, len(chunks)))[1].cpu().numpy().tolist()
    # return (idx, chunk_text, score)
    return [(i, chunks[i], float(scores[i].cpu().item())) for i in topk]

def rag_summarize_judgment(judgment_text, rounds=1):
    """Main RAG summarization for single judgment. rounds: allow expanding retrieval if needed."""
    # Prepare chunks and their embeddings
    chunks = build_chunk_corpus(judgment_text, method=RETRIEVE_BY)
    if len(chunks) == 0:
        return "", 0.0

    # Use the whole judgment as query (or you can craft a query)
    query_emb = retriever.encode(judgment_text, convert_to_tensor=True, show_progress_bar=False)

    # initial retrieval
    current_k = TOP_K
    retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
    retrieved_text = " ".join([r[1] for r in retrieved])

    # create generator input (prepend instruction prompt)
    prompt = "Summarize the following legal judgment excerpts into a coherent abstractive summary (400-500 words):\n\n"
    gen_input = prompt + retrieved_text

    # Primary generation with BART
    summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

    # If length constraints not satisfied, try expand retrieval and/or use LED to meta-summarize
    words = words_count(summary)
    round_idx = 1
    while (words < MIN_WORDS or words > MAX_WORDS) and round_idx <= rounds:
        # expand retrieval: increase k
        current_k = min(len(chunks), current_k + TOP_K)
        retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
        retrieved_text = " ".join([r[1] for r in retrieved])
        gen_input = prompt + retrieved_text

        # try LED (better for longer contexts) to produce meta-summary
        try:
            summary = generate_with(led_model, led_tokenizer, gen_input, model_type="led")
        except Exception:
            # fallback: re-generate with BART but with new context
            summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

        words = words_count(summary)
        round_idx += 1

    # Postprocess: enforce 400-500 words using sentence trimming/extension
    final_summary = truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS)

    # compute cosine sim between judgment and final_summary using BART encoder
    j_emb = bart_encode_mean(judgment_text)
    s_emb = bart_encode_mean(final_summary)
    cos_sim = float(cosine_similarity(j_emb, s_emb)[0][0])

    return final_summary, cos_sim

# ---------------------------
# MAIN
# ---------------------------
def main():
    # load inputs
    inputs = []
    with jsonlines.open(TEST_PATH) as reader:
        for obj in reader:
            if "id" not in obj or "judgment" not in obj:
                raise ValueError("Input JSONL must contain 'id' and 'judgment' fields.")
            inputs.append({"id": obj["id"], "text": clean_judgment_text(obj["judgment"])})

    print(f"Loaded {len(inputs)} items")

    writer_sum = jsonlines.open(OUT_SUMMARIES, mode="w")
    writer_cos = jsonlines.open(OUT_COSINE, mode="w")

    for ex in tqdm(inputs, desc="RAG summarization"):
        jid = ex["id"]
        text = ex["text"]
        try:
            summary, cos_sim = rag_summarize_judgment(text, rounds=MAX_RETRIEVE_ROUNDS)
        except Exception as e:
            print(f"Error processing {jid}: {e}")
            summary, cos_sim = "", 0.0

        writer_sum.write({"ID": jid, "Summary": summary})
        writer_cos.write({"ID": jid, "Cosine_Similarity": round(cos_sim, 6)})

        # small debug print for first few
        # if needed, print sample preview
        # if idx < 3:
        #     print(jid, "->", summary[:200], "...")
    writer_sum.close()
    writer_cos.close()
    print("Done â€” outputs:")
    print(" ", OUT_SUMMARIES)
    print(" ", OUT_COSINE)

if __name__ == "__main__":
    main()
