In [None]:
#!/usr/bin/env python3
"""
rag_legal_summarizer.py

Retriever-Augmented Generation pipeline for legal summarization.

Outputs:
 - generated_summaries_rag.jsonl  (ID, Summary)
 - cosine_scores_rag.jsonl        (ID, Cosine_Similarity)

Configurable parameters at top of file.
"""

import os, re, json, jsonlines, math, time
from tqdm import tqdm
import numpy as np
import torch
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import nltk
nltk.download("punkt")

# ---------------------------
# CONFIG
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Retriever model (recommend a legal SBERT if available)
RETRIEVER_MODEL = "law-ai/InLegalBERT"  # replace with legal SBERT for better perf

# Generator checkpoints (use your fine-tuned models / local paths)
BART_MODEL = "bart_legal_summ_model_final"
PEGASUS_MODEL = "pegasus_legal_summ_model_final"
LED_MODEL = "led_legal_summ_model_final"

# Files
TEST_PATH = "test_judg.jsonl"   # input JSONL with {"id": "...", "judgment": "..."}
OUT_SUMMARIES = "generated_summaries_rag.jsonl"
OUT_COSINE = "cosine_scores_rag.jsonl"

# Chunking & retrieval params
CHUNK_SIZE_TOKENS = 800       # chunk token length when splitting (approx)
CHUNK_OVERLAP_TOKENS = 200
TOP_K = 8                     # number of chunks to retrieve per judgment (initial)
RETRIEVE_BY = "sentence"      # "sentence" or "token" chunking

# Generation: enforce ~400-500 words
MIN_WORDS = 400
MAX_WORDS = 500
TOK_PER_WORD = 1.5            # rough token estimate per word => tokens = words * TOK_PER_WORD
GEN_MIN_TOKENS = int(MIN_WORDS * TOK_PER_WORD)   # e.g., 400 * 1.5 = 600
GEN_MAX_TOKENS = int(MAX_WORDS * TOK_PER_WORD)   # e.g., 500 * 1.5 = 750

# Retrieval-to-generation flow
MAX_RETRIEVE_ROUNDS = 3       # if initial generation is too short, you can expand retrieval rounds

# Safety / diversity thresholds
SENTENCE_SIM_DIVERSITY = 0.85

# ---------------------------
# UTILITIES
# ---------------------------
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# simple splitter by sentences to produce retrieval units
def chunk_into_sentences(text):
    sents = sent_tokenize(text)
    return [s.strip() for s in sents if len(s.strip())>20]

# token-based chunking using tokenizer.encode/decode for better boundaries
def chunk_by_tokens(text, tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return [tokenizer.decode(tokens, clean_up_tokenization_spaces=True)]
    chunks = []
    step = max_tokens - overlap
    for i in range(0, len(tokens), step):
        chunk_tokens = tokens[i:i+max_tokens]
        if not chunk_tokens:
            break
        chunks.append(tokenizer.decode(chunk_tokens, clean_up_tokenization_spaces=True))
        if i + max_tokens >= len(tokens):
            break
    return chunks

def words_count(text):
    return len(re.findall(r"\w+", text))

def truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS):
    words = summary.split()
    if len(words) <= max_words and len(words) >= min_words:
        return summary
    sents = sent_tokenize(summary)
    out = []
    w = 0
    for s in sents:
        sw = len(s.split())
        if w + sw > max_words:
            break
        out.append(s)
        w += sw
    # if result shorter than min_words, append first sentences until min reached (from original)
    if w < min_words:
        # append more sentences (allow some exceeding max)
        for s in sents[len(out):]:
            out.append(s)
            w += len(s.split())
            if w >= min_words:
                break
    return " ".join(out)

# ---------------------------
# LOAD MODELS
# ---------------------------
print("Loading retriever and generators (may take a while)...")
retriever = SentenceTransformer(RETRIEVER_MODEL, device=device)

bart_tokenizer = BartTokenizer.from_pretrained(BART_MODEL)
bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL).to(device).eval()

pegasus_tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL).to(device).eval()

led_tokenizer = LEDTokenizer.from_pretrained(LED_MODEL)
led_model = LEDForConditionalGeneration.from_pretrained(LED_MODEL).to(device).eval()

# we use BART encoder mean-pooled embeddings for cosine evaluation
def bart_encode_mean(text, max_length=1024):
    inputs = bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length, padding=True).to(device)
    with torch.no_grad():
        enc = bart_model.model.encoder(**inputs)
    return enc.last_hidden_state.mean(dim=1).cpu().numpy()

# generation helper
def generate_with(model, tokenizer, input_text, min_tokens=GEN_MIN_TOKENS, max_tokens=GEN_MAX_TOKENS, model_type="bart"):
    if model_type == "led":
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=LED_MODEL and 16000 or 8192).to(device)
    else:
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).to(device)

    out_ids = model.generate(
        **inputs,
        num_beams=4,
        max_length=max_tokens,
        min_length=min_tokens,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

# ---------------------------
# RAG Workflow per judgment
# ---------------------------
def build_chunk_corpus(text, method="token"):
    """Return list of chunks (strings) for this judgment"""
    if method == "sentence":
        # chunk by grouping N sentences per chunk to reach approx chunk size
        sents = chunk_into_sentences(text)
        chunks = []
        N = 8  # start grouping ~8 sentences (tunable)
        for i in range(0, len(sents), N):
            chunks.append(" ".join(sents[i:i+N]))
        return chunks
    else:
        # token-based chunking using LED tokenizer for stable tokenization
        return chunk_by_tokens(text, led_tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS)

def retrieve_top_k(chunks, query_embedding, k=TOP_K):
    # compute embeddings for chunks and cosine similarity
    chunk_embs = retriever.encode(chunks, convert_to_tensor=True, show_progress_bar=False)
    scores = util.cos_sim(query_embedding, chunk_embs)[0]  # tensor
    topk = torch.topk(scores, min(k, len(chunks)))[1].cpu().numpy().tolist()
    # return (idx, chunk_text, score)
    return [(i, chunks[i], float(scores[i].cpu().item())) for i in topk]

def rag_summarize_judgment(judgment_text, rounds=1):
    """Main RAG summarization for single judgment. rounds: allow expanding retrieval if needed."""
    # Prepare chunks and their embeddings
    chunks = build_chunk_corpus(judgment_text, method=RETRIEVE_BY)
    if len(chunks) == 0:
        return "", 0.0

    # Use the whole judgment as query (or you can craft a query)
    query_emb = retriever.encode(judgment_text, convert_to_tensor=True, show_progress_bar=False)

    # initial retrieval
    current_k = TOP_K
    retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
    retrieved_text = " ".join([r[1] for r in retrieved])

    # create generator input (prepend instruction prompt)
    prompt = "Summarize the following legal judgment excerpts into a coherent abstractive summary (400-500 words):\n\n"
    gen_input = prompt + retrieved_text

    # Primary generation with BART
    summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

    # If length constraints not satisfied, try expand retrieval and/or use LED to meta-summarize
    words = words_count(summary)
    round_idx = 1
    while (words < MIN_WORDS or words > MAX_WORDS) and round_idx <= rounds:
        # expand retrieval: increase k
        current_k = min(len(chunks), current_k + TOP_K)
        retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
        retrieved_text = " ".join([r[1] for r in retrieved])
        gen_input = prompt + retrieved_text

        # try LED (better for longer contexts) to produce meta-summary
        try:
            summary = generate_with(led_model, led_tokenizer, gen_input, model_type="led")
        except Exception:
            # fallback: re-generate with BART but with new context
            summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

        words = words_count(summary)
        round_idx += 1

    # Postprocess: enforce 400-500 words using sentence trimming/extension
    final_summary = truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS)

    # compute cosine sim between judgment and final_summary using BART encoder
    j_emb = bart_encode_mean(judgment_text)
    s_emb = bart_encode_mean(final_summary)
    cos_sim = float(cosine_similarity(j_emb, s_emb)[0][0])

    return final_summary, cos_sim

# ---------------------------
# MAIN
# ---------------------------
def main():
    # load inputs
    inputs = []
    with jsonlines.open(TEST_PATH) as reader:
        for obj in reader:
            if "id" not in obj or "judgment" not in obj:
                raise ValueError("Input JSONL must contain 'id' and 'judgment' fields.")
            inputs.append({"id": obj["id"], "text": clean_judgment_text(obj["judgment"])})

    print(f"Loaded {len(inputs)} items")

    writer_sum = jsonlines.open(OUT_SUMMARIES, mode="w")
    writer_cos = jsonlines.open(OUT_COSINE, mode="w")

    for ex in tqdm(inputs, desc="RAG summarization"):
        jid = ex["id"]
        text = ex["text"]
        try:
            summary, cos_sim = rag_summarize_judgment(text, rounds=MAX_RETRIEVE_ROUNDS)
        except Exception as e:
            print(f"Error processing {jid}: {e}")
            summary, cos_sim = "", 0.0

        writer_sum.write({"ID": jid, "Summary": summary})
        writer_cos.write({"ID": jid, "Cosine_Similarity": round(cos_sim, 6)})

        # small debug print for first few
        # if needed, print sample preview
        # if idx < 3:
        #     print(jid, "->", summary[:200], "...")
    writer_sum.close()
    writer_cos.close()
    print("Done — outputs:")
    print(" ", OUT_SUMMARIES)
    print(" ", OUT_COSINE)

if __name__ == "__main__":
    main()


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\pavit\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Device: cuda
Loading retriever and generators (may take a while)...


No sentence-transformers model found with name law-ai/InLegalBERT. Creating a new one with mean pooling.


Loaded 400 items


RAG summarization:   0%|▏                                                            | 1/400 [01:05<7:17:18, 65.76s/it]

In [1]:
# Files
TRAIN_PATH = "train_judg.jsonl"   # input JSONL with {"id": "...", "judgment": "...", "summary": "..."}  <-- adjust field name if needed
TEST_PATH = "test_judg.jsonl"    # input JSONL with {"id": "...", "judgment": "..."}

OUT_SUMMARIES = "generated_summaries_rag.jsonl"
OUT_COSINE = "cosine_scores_rag.jsonl"

# Qualitative analysis
NUM_QUAL_SAMPLES = 10  # how many train samples to inspect


In [2]:
#!/usr/bin/env python3
"""
rag_legal_summarizer.py

Retriever-Augmented Generation pipeline for legal summarization.

This variant:
 - Loads models (retriever + BART/PEGASUS/LED generators)
 - Uses robust loader for training JSONL (tolerant to common field-name variants)
 - Reads the first training judgment + matching reference summary (if present)
 - Generates a RAG summary for that sample and prints:
     * cleaned judgment (truncated for safety)
     * reference/golden summary (if found)
     * generated RAG summary
 - Writes no outputs by default (keeps behaviour safe for debugging)

Adjust CONFIG paths at top before running.
"""

import os
import re
import json
import jsonlines
import math
import time
from tqdm import tqdm
import numpy as np
import torch
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import nltk

nltk.download("punkt")

# ---------------------------
# CONFIG
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Retriever model (legal SBERT recommended)
RETRIEVER_MODEL = "law-ai/InLegalBERT"  # change if needed

# Generator checkpoints (use your fine-tuned models / local paths)
BART_MODEL = "bart_legal_summ_model_final"
PEGASUS_MODEL = "pegasus_legal_summ_model_final"
LED_MODEL = "led_legal_summ_model_final"

# Files (update paths as required)
TRAIN_JUDG_PATH = "train_judg.jsonl"         # expects one JSON object per line
TRAIN_REF_SUMM_PATH = "train_ref_summ.jsonl"  # expects one JSON object per line

# Chunking & retrieval params
CHUNK_SIZE_TOKENS = 800       # chunk token length when splitting (approx)
CHUNK_OVERLAP_TOKENS = 200
TOP_K = 8                     # number of chunks to retrieve per judgment (initial)
RETRIEVE_BY = "sentence"      # "sentence" or "token" chunking

# Generation: enforce ~400-500 words
MIN_WORDS = 400
MAX_WORDS = 500
TOK_PER_WORD = 1.5            # rough token estimate per word => tokens = words * TOK_PER_WORD
GEN_MIN_TOKENS = int(MIN_WORDS * TOK_PER_WORD)   # e.g., 400 * 1.5 = 600
GEN_MAX_TOKENS = int(MAX_WORDS * TOK_PER_WORD)   # e.g., 500 * 1.5 = 750

# Retrieval-to-generation flow
MAX_RETRIEVE_ROUNDS = 3       # expand retrieval if initial generation is too short

# Safety / diversity thresholds (unused placeholder for now)
SENTENCE_SIM_DIVERSITY = 0.85

# ---------------------------
# UTILITIES
# ---------------------------
def clean_judgment_text(text: str) -> str:
    """Basic cleaning for legal judgments to remove page markers, case headers, extra spaces."""
    if not isinstance(text, str):
        return ""
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

def chunk_into_sentences(text: str):
    sents = sent_tokenize(text)
    return [s.strip() for s in sents if len(s.strip()) > 20]

def chunk_by_tokens(text: str, tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return [tokenizer.decode(tokens, clean_up_tokenization_spaces=True)]
    chunks = []
    step = max_tokens - overlap
    for i in range(0, len(tokens), step):
        chunk_tokens = tokens[i:i + max_tokens]
        if not chunk_tokens:
            break
        chunks.append(tokenizer.decode(chunk_tokens, clean_up_tokenization_spaces=True))
        if i + max_tokens >= len(tokens):
            break
    return chunks

def words_count(text: str) -> int:
    return len(re.findall(r"\w+", text))

def truncate_to_word_limit_by_sentences(summary: str, min_words=MIN_WORDS, max_words=MAX_WORDS) -> str:
    words = summary.split()
    if len(words) <= max_words and len(words) >= min_words:
        return summary
    sents = sent_tokenize(summary)
    out = []
    w = 0
    for s in sents:
        sw = len(s.split())
        if w + sw > max_words:
            break
        out.append(s)
        w += sw
    if w < min_words:
        for s in sents[len(out):]:
            out.append(s)
            w += len(s.split())
            if w >= min_words:
                break
    return " ".join(out)

# ---------------------------
# LOAD MODELS
# ---------------------------
print("Loading retriever and generators (may take a while)...")
retriever = SentenceTransformer(RETRIEVER_MODEL, device=device)

# Generators tokenizers & models
bart_tokenizer = BartTokenizer.from_pretrained(BART_MODEL)
bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL).to(device).eval()

pegasus_tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL).to(device).eval()

led_tokenizer = LEDTokenizer.from_pretrained(LED_MODEL)
led_model = LEDForConditionalGeneration.from_pretrained(LED_MODEL).to(device).eval()

# BART mean-pooled encoder embedding (for cosine scoring)
def bart_encode_mean(text: str, max_length=1024):
    inputs = bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length, padding=True).to(device)
    with torch.no_grad():
        enc = bart_model.model.encoder(**inputs)
    return enc.last_hidden_state.mean(dim=1).cpu().numpy()

# generation helper
def generate_with(model, tokenizer, input_text: str, min_tokens=GEN_MIN_TOKENS, max_tokens=GEN_MAX_TOKENS, model_type="bart"):
    if model_type == "led":
        # LED supports long sequences; set a large max_length for tokenization
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=16000).to(device)
    else:
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).to(device)

    out_ids = model.generate(
        **inputs,
        num_beams=4,
        max_length=max_tokens,
        min_length=min_tokens,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

# ---------------------------
# RAG WORKFLOW PER JUDGMENT
# ---------------------------
def build_chunk_corpus(text: str, method="token"):
    """Return list of chunks (strings) for this judgment."""
    if method == "sentence":
        sents = chunk_into_sentences(text)
        chunks = []
        N = 8
        for i in range(0, len(sents), N):
            chunks.append(" ".join(sents[i:i + N]))
        return chunks
    else:
        return chunk_by_tokens(text, led_tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS)

def retrieve_top_k(chunks, query_embedding, k=TOP_K):
    chunk_embs = retriever.encode(chunks, convert_to_tensor=True, show_progress_bar=False)
    scores = util.cos_sim(query_embedding, chunk_embs)[0]
    topk = torch.topk(scores, min(k, len(chunks)))[1].cpu().numpy().tolist()
    return [(i, chunks[i], float(scores[i].cpu().item())) for i in topk]

def rag_summarize_judgment(judgment_text: str, rounds=1):
    chunks = build_chunk_corpus(judgment_text, method=RETRIEVE_BY)
    if len(chunks) == 0:
        return "", 0.0

    query_emb = retriever.encode(judgment_text, convert_to_tensor=True, show_progress_bar=False)

    current_k = TOP_K
    retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
    retrieved_text = " ".join([r[1] for r in retrieved])

    prompt = "Summarize the following legal judgment excerpts into a coherent abstractive summary (400-500 words):\n\n"
    gen_input = prompt + retrieved_text

    # Primary generation with BART
    summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

    # Iterative retrieval-expansion if summary too short/long
    words = words_count(summary)
    round_idx = 1
    while (words < MIN_WORDS or words > MAX_WORDS) and round_idx <= rounds:
        current_k = min(len(chunks), current_k + TOP_K)
        retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
        retrieved_text = " ".join([r[1] for r in retrieved])
        gen_input = prompt + retrieved_text

        try:
            summary = generate_with(led_model, led_tokenizer, gen_input, model_type="led")
        except Exception:
            summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

        words = words_count(summary)
        round_idx += 1

    final_summary = truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS)

    # Cosine similarity between judgment and final_summary (using BART encoder mean-pooled)
    j_emb = bart_encode_mean(judgment_text)
    s_emb = bart_encode_mean(final_summary)
    cos_sim = float(cosine_similarity(j_emb, s_emb)[0][0])

    return final_summary, cos_sim

# ---------------------------
# Robust JSONL loader utilities for common key variants
# ---------------------------
COMMON_ID_KEYS = ["id", "ID", "Id", "doc_id", "jid", "case_id"]
COMMON_JUDG_KEYS = ["judgment", "Judgment", "Judgement", "text", "case", "JudgmentText", "judgement"]
COMMON_SUMM_KEYS = ["summary", "Summary", "ref_summary", "gold", "gold_summary", "ref", "reference", "abstract"]

def get_field_any(record: dict, candidates):
    for k in candidates:
        if k in record:
            return record[k]
    # case-insensitive fallback
    low_map = {rk.lower(): rv for rk, rv in record.items() if isinstance(rk, str)}
    for c in candidates:
        if c.lower() in low_map:
            return low_map[c.lower()]
    return None

def find_first_training_record(path):
    """Return tuple (raw_obj, id_value, judgment_text) or (None, None, None)."""
    # Print first few raw lines to help debugging format
    print(f"[INFO] Showing up to first 3 raw lines of {path} for debugging:")
    try:
        with open(path, "r", encoding="utf-8") as fh:
            for _ in range(3):
                line = fh.readline()
                if not line:
                    break
                print(" RAW:", line.strip())
    except Exception as e:
        print(f"[WARN] Could not read raw preview of {path}: {e}")

    with jsonlines.open(path) as reader:
        for obj in reader:
            id_val = get_field_any(obj, COMMON_ID_KEYS)
            jud_val = get_field_any(obj, COMMON_JUDG_KEYS)
            if jud_val and isinstance(jud_val, str) and len(jud_val.strip()) > 20:
                return obj, id_val, jud_val
    return None, None, None

def find_ref_summary_for_id(path, match_id):
    """Search reference summaries file for the matching id using common keys."""
    with jsonlines.open(path) as reader:
        for obj in reader:
            id_val = get_field_any(obj, COMMON_ID_KEYS)
            if id_val is None:
                # try any string value that equals match_id
                for v in obj.values():
                    if isinstance(v, str) and str(v).strip() == str(match_id).strip():
                        summ = get_field_any(obj, COMMON_SUMM_KEYS)
                        if summ:
                            return summ
                continue
            if str(id_val).strip() == str(match_id).strip():
                summ = get_field_any(obj, COMMON_SUMM_KEYS)
                return summ
    return None

# ---------------------------
# MAIN for qualitative example (first training doc)
# ---------------------------
def main():
    if not os.path.exists(TRAIN_JUDG_PATH):
        print(f"[ERROR] Training judgments file not found: {TRAIN_JUDG_PATH}")
        return
    if not os.path.exists(TRAIN_REF_SUMM_PATH):
        print(f"[ERROR] Training reference summaries file not found: {TRAIN_REF_SUMM_PATH}")
        return

    raw_obj, train_id, raw_judgment = find_first_training_record(TRAIN_JUDG_PATH)
    if raw_obj is None:
        print("[ERROR] No valid training judgment found in", TRAIN_JUDG_PATH)
        return

    print(f"[INFO] Detected training record id: {train_id}")
    cleaned_judgment = clean_judgment_text(raw_judgment)

    ref_summary = None
    if train_id is not None:
        ref_summary = find_ref_summary_for_id(TRAIN_REF_SUMM_PATH, train_id)
    if ref_summary is None:
        print(f"[WARN] No matching reference summary found in {TRAIN_REF_SUMM_PATH} for ID={train_id}")
        ref_summary = "[NO GOLD / REF SUMMARY FOUND]"

    # Generate RAG summary
    gen_summary, cos_sim = rag_summarize_judgment(cleaned_judgment, rounds=MAX_RETRIEVE_ROUNDS)

    # Print results (truncate judgement printed length for safety)
    print("\n" + "=" * 120)
    print(f"FIRST TRAINING DOCUMENT  |  ID: {train_id}")
    print("=" * 120 + "\n")

    print(">> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):\n")
    print(cleaned_judgment[:5000] + ("\n...[truncated]" if len(cleaned_judgment) > 5000 else ""))

    print("\n" + "-" * 120)
    print(">> GOLDEN / REFERENCE SUMMARY:\n")
    print(ref_summary)

    print("\n" + "-" * 120)
    print(">> GENERATED RAG SUMMARY:\n")
    print(gen_summary if gen_summary.strip() else "[EMPTY SUMMARY]")

    print(f"\nCosine similarity (judgment vs. generated summary): {cos_sim:.4f}")
    print("\n" + "=" * 120 + "\n")

if __name__ == "__main__":
    main()





[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\pavit\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Device: cuda
Loading retriever and generators (may take a while)...


No sentence-transformers model found with name law-ai/InLegalBERT. Creating a new one with mean pooling.


[INFO] Showing up to first 3 raw lines of train_judg.jsonl for debugging:
 RAW: {"ID": "id_10", "Judgment": "Case :- WRIT - C No. - 11383 of 2023\nPetitioner :- Syed Hamidul Bari\nRespondent :- State Of U.P. Thru. Addl. Chief/Prin. Secy. Housing And\nUrban Planning Deptt. Lko. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim, Amrit Khare\nCounsel for Respondent :- C.S.C., Ratnesh Chandra\nCase :- WRIT - C No. - 11360 of 2023\nPetitioner :- Mohd. Naushad\nRespondent :- State Of U.P. Thru. Addl. Chief Secy./Prin. Secy. Housing\nAnd Urban Planning Deptt. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim,Amrit Khare\nCounsel for Respondent :- C.S.C.,Ratnesh Chandra\nCase :- WRIT - C No. - 11362 of 2023\nPetitioner :- Mohammad Abrar\nRespondent :- State Of U.P. Thru. Addl. Chief/Prin. Secy. Housing Urban\nPlanning Deptt. Lko. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim,Amrit Khare\nCounsel for Respondent :- C.S.C.,Ratnesh Chandra\nCase :- WRIT - C No. - 11368 of 2023\nPet

In [3]:
!pip install tf-keras

Collecting tf-keras
  Downloading tf_keras-2.20.1-py3-none-any.whl.metadata (1.8 kB)
Downloading tf_keras-2.20.1-py3-none-any.whl (1.7 MB)
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
   ---------------------------------------- 1.7/1.7 MB 11.3 MB/s eta 0:00:00
Installing collected packages: tf-keras
Successfully installed tf-keras-2.20.1


In [3]:
#!/usr/bin/env python3
"""
rag_legal_summarizer.py

Retriever-Augmented Generation pipeline for legal summarization.

This variant:
 - Loads models (retriever + BART/PEGASUS/LED generators)
 - Uses robust loader for training JSONL (tolerant to common field-name variants)
 - Reads the first N training judgments + matching reference summaries (if present)
 - Generates a RAG summary for each sample and prints:
     * cleaned judgment (truncated for safety)
     * reference/golden summary (if found)
     * generated RAG summary
 - Intended for quick qualitative inspection of a few training examples.

Adjust CONFIG paths / N at top before running.
"""

import os
import re
import json
import jsonlines
import math
import time
from tqdm import tqdm
import numpy as np
import torch
from transformers import (
    BartTokenizer, BartForConditionalGeneration,
    PegasusTokenizer, PegasusForConditionalGeneration,
    LEDTokenizer, LEDForConditionalGeneration
)
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics.pairwise import cosine_similarity
from nltk import sent_tokenize
import nltk

nltk.download("punkt")

# ---------------------------
# CONFIG
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# Retriever model (legal SBERT recommended)
RETRIEVER_MODEL = "law-ai/InLegalBERT"  # change if needed

# Generator checkpoints (use your fine-tuned models / local paths)
BART_MODEL = "bart_legal_summ_model_final"
PEGASUS_MODEL = "pegasus_legal_summ_model_final"
LED_MODEL = "led_legal_summ_model_final"

# Files (update paths as required)
TRAIN_JUDG_PATH = "train_judg.jsonl"         # expects one JSON object per line
TRAIN_REF_SUMM_PATH = "train_ref_summ.jsonl"  # expects one JSON object per line

# How many training examples to print for qualitative inspection
NUM_EXAMPLES_TO_PRINT = 5

# Chunking & retrieval params
CHUNK_SIZE_TOKENS = 800       # chunk token length when splitting (approx)
CHUNK_OVERLAP_TOKENS = 200
TOP_K = 8                     # number of chunks to retrieve per judgment (initial)
RETRIEVE_BY = "sentence"      # "sentence" or "token" chunking

# Generation: enforce ~400-500 words
MIN_WORDS = 400
MAX_WORDS = 500
TOK_PER_WORD = 1.5            # rough token estimate per word => tokens = words * TOK_PER_WORD
GEN_MIN_TOKENS = int(MIN_WORDS * TOK_PER_WORD)   # e.g., 400 * 1.5 = 600
GEN_MAX_TOKENS = int(MAX_WORDS * TOK_PER_WORD)   # e.g., 500 * 1.5 = 750

# Retrieval-to-generation flow
MAX_RETRIEVE_ROUNDS = 3       # expand retrieval if initial generation is too short

# ---------------------------
# UTILITIES
# ---------------------------
def clean_judgment_text(text: str) -> str:
    """Basic cleaning for legal judgments to remove page markers, case headers, extra spaces."""
    if not isinstance(text, str):
        return ""
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

def chunk_into_sentences(text: str):
    sents = sent_tokenize(text)
    return [s.strip() for s in sents if len(s.strip()) > 20]

def chunk_by_tokens(text: str, tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    if len(tokens) <= max_tokens:
        return [tokenizer.decode(tokens, clean_up_tokenization_spaces=True)]
    chunks = []
    step = max_tokens - overlap
    for i in range(0, len(tokens), step):
        chunk_tokens = tokens[i:i + max_tokens]
        if not chunk_tokens:
            break
        chunks.append(tokenizer.decode(chunk_tokens, clean_up_tokenization_spaces=True))
        if i + max_tokens >= len(tokens):
            break
    return chunks

def words_count(text: str) -> int:
    return len(re.findall(r"\w+", text))

def truncate_to_word_limit_by_sentences(summary: str, min_words=MIN_WORDS, max_words=MAX_WORDS) -> str:
    words = summary.split()
    if len(words) <= max_words and len(words) >= min_words:
        return summary
    sents = sent_tokenize(summary)
    out = []
    w = 0
    for s in sents:
        sw = len(s.split())
        if w + sw > max_words:
            break
        out.append(s)
        w += sw
    if w < min_words:
        for s in sents[len(out):]:
            out.append(s)
            w += len(s.split())
            if w >= min_words:
                break
    return " ".join(out)

# ---------------------------
# LOAD MODELS
# ---------------------------
print("Loading retriever and generators (may take a while)...")
retriever = SentenceTransformer(RETRIEVER_MODEL, device=device)

# Generators tokenizers & models
bart_tokenizer = BartTokenizer.from_pretrained(BART_MODEL)
bart_model = BartForConditionalGeneration.from_pretrained(BART_MODEL).to(device).eval()

pegasus_tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL)
pegasus_model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL).to(device).eval()

led_tokenizer = LEDTokenizer.from_pretrained(LED_MODEL)
led_model = LEDForConditionalGeneration.from_pretrained(LED_MODEL).to(device).eval()

# BART mean-pooled encoder embedding (for cosine scoring)
def bart_encode_mean(text: str, max_length=1024):
    inputs = bart_tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length, padding=True).to(device)
    with torch.no_grad():
        enc = bart_model.model.encoder(**inputs)
    return enc.last_hidden_state.mean(dim=1).cpu().numpy()

# generation helper
def generate_with(model, tokenizer, input_text: str, min_tokens=GEN_MIN_TOKENS, max_tokens=GEN_MAX_TOKENS, model_type="bart"):
    if model_type == "led":
        # LED supports long sequences; set a large max_length for tokenization
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=16000).to(device)
    else:
        inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).to(device)

    out_ids = model.generate(
        **inputs,
        num_beams=4,
        max_length=max_tokens,
        min_length=min_tokens,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )
    return tokenizer.decode(out_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

# ---------------------------
# RAG WORKFLOW PER JUDGMENT
# ---------------------------
def build_chunk_corpus(text: str, method="token"):
    """Return list of chunks (strings) for this judgment."""
    if method == "sentence":
        sents = chunk_into_sentences(text)
        chunks = []
        N = 8
        for i in range(0, len(sents), N):
            chunks.append(" ".join(sents[i:i + N]))
        return chunks
    else:
        return chunk_by_tokens(text, led_tokenizer, max_tokens=CHUNK_SIZE_TOKENS, overlap=CHUNK_OVERLAP_TOKENS)

def retrieve_top_k(chunks, query_embedding, k=TOP_K):
    chunk_embs = retriever.encode(chunks, convert_to_tensor=True, show_progress_bar=False)
    scores = util.cos_sim(query_embedding, chunk_embs)[0]
    topk = torch.topk(scores, min(k, len(chunks)))[1].cpu().numpy().tolist()
    return [(i, chunks[i], float(scores[i].cpu().item())) for i in topk]

def rag_summarize_judgment(judgment_text: str, rounds=1):
    chunks = build_chunk_corpus(judgment_text, method=RETRIEVE_BY)
    if len(chunks) == 0:
        return "", 0.0

    query_emb = retriever.encode(judgment_text, convert_to_tensor=True, show_progress_bar=False)

    current_k = TOP_K
    retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
    retrieved_text = " ".join([r[1] for r in retrieved])

    prompt = "Summarize the following legal judgment excerpts into a coherent abstractive summary (400-500 words):\n\n"
    gen_input = prompt + retrieved_text

    # Primary generation with BART
    summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

    # Iterative retrieval-expansion if summary too short/long
    words = words_count(summary)
    round_idx = 1
    while (words < MIN_WORDS or words > MAX_WORDS) and round_idx <= rounds:
        current_k = min(len(chunks), current_k + TOP_K)
        retrieved = retrieve_top_k(chunks, query_emb, k=current_k)
        retrieved_text = " ".join([r[1] for r in retrieved])
        gen_input = prompt + retrieved_text

        try:
            summary = generate_with(led_model, led_tokenizer, gen_input, model_type="led")
        except Exception:
            summary = generate_with(bart_model, bart_tokenizer, gen_input, model_type="bart")

        words = words_count(summary)
        round_idx += 1

    final_summary = truncate_to_word_limit_by_sentences(summary, min_words=MIN_WORDS, max_words=MAX_WORDS)

    # Cosine similarity between judgment and final_summary (using BART encoder mean-pooled)
    j_emb = bart_encode_mean(judgment_text)
    s_emb = bart_encode_mean(final_summary)
    cos_sim = float(cosine_similarity(j_emb, s_emb)[0][0])

    return final_summary, cos_sim

# ---------------------------
# Robust JSONL loader utilities for common key variants
# ---------------------------
COMMON_ID_KEYS = ["id", "ID", "Id", "doc_id", "jid", "case_id"]
COMMON_JUDG_KEYS = ["judgment", "Judgment", "Judgement", "text", "case", "JudgmentText", "judgement"]
COMMON_SUMM_KEYS = ["summary", "Summary", "ref_summary", "gold", "gold_summary", "ref", "reference", "abstract"]

def get_field_any(record: dict, candidates):
    for k in candidates:
        if k in record:
            return record[k]
    # case-insensitive fallback
    low_map = {rk.lower(): rv for rk, rv in record.items() if isinstance(rk, str)}
    for c in candidates:
        if c.lower() in low_map:
            return low_map[c.lower()]
    return None

def find_first_n_training_records(path, n=1):
    """Return list of tuples (raw_obj, id_value, judgment_text) up to n records."""
    results = []
    # optional raw preview for debugging
    print(f"[INFO] Showing up to first 3 raw lines of {path} for debugging:")
    try:
        with open(path, "r", encoding="utf-8") as fh:
            for _ in range(3):
                line = fh.readline()
                if not line:
                    break
                print(" RAW:", line.strip())
    except Exception as e:
        print(f"[WARN] Could not read raw preview of {path}: {e}")

    with jsonlines.open(path) as reader:
        for obj in reader:
            id_val = get_field_any(obj, COMMON_ID_KEYS)
            jud_val = get_field_any(obj, COMMON_JUDG_KEYS)
            if jud_val and isinstance(jud_val, str) and len(jud_val.strip()) > 20:
                results.append((obj, id_val, jud_val))
                if len(results) >= n:
                    break
    return results

def find_ref_summary_for_id(path, match_id):
    """Search reference summaries file for the matching id using common keys."""
    with jsonlines.open(path) as reader:
        for obj in reader:
            id_val = get_field_any(obj, COMMON_ID_KEYS)
            if id_val is None:
                # try any string value that equals match_id
                for v in obj.values():
                    if isinstance(v, str) and str(v).strip() == str(match_id).strip():
                        summ = get_field_any(obj, COMMON_SUMM_KEYS)
                        if summ:
                            return summ
                continue
            if str(id_val).strip() == str(match_id).strip():
                summ = get_field_any(obj, COMMON_SUMM_KEYS)
                return summ
    return None

# ---------------------------
# MAIN: qualitative examples for first NUM_EXAMPLES_TO_PRINT training docs
# ---------------------------
def main():
    if not os.path.exists(TRAIN_JUDG_PATH):
        print(f"[ERROR] Training judgments file not found: {TRAIN_JUDG_PATH}")
        return
    if not os.path.exists(TRAIN_REF_SUMM_PATH):
        print(f"[ERROR] Training reference summaries file not found: {TRAIN_REF_SUMM_PATH}")
        return

    examples = find_first_n_training_records(TRAIN_JUDG_PATH, n=NUM_EXAMPLES_TO_PRINT)
    if not examples:
        print("[ERROR] No valid training judgments found in", TRAIN_JUDG_PATH)
        return

    for idx, (raw_obj, train_id, raw_judgment) in enumerate(examples, start=1):
        print("\n" + "=" * 120)
        print(f"SAMPLE {idx} / {len(examples)}  |  RAW ID detected: {train_id}")
        print("=" * 120 + "\n")

        cleaned_judgment = clean_judgment_text(raw_judgment)
        # try find reference summary (may be missing)
        ref_summary = None
        if train_id is not None:
            ref_summary = find_ref_summary_for_id(TRAIN_REF_SUMM_PATH, train_id)
        if ref_summary is None:
            ref_summary = "[NO GOLD / REF SUMMARY FOUND]"

        # Generate RAG summary
        gen_summary, cos_sim = rag_summarize_judgment(cleaned_judgment, rounds=MAX_RETRIEVE_ROUNDS)

        # Print results (truncate judgment printed length for safety)
        print(">> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):\n")
        print(cleaned_judgment[:5000] + ("\n...[truncated]" if len(cleaned_judgment) > 5000 else ""))

        print("\n" + "-" * 120)
        print(">> GOLDEN / REFERENCE SUMMARY:\n")
        print(ref_summary)

        print("\n" + "-" * 120)
        print(">> GENERATED RAG SUMMARY:\n")
        print(gen_summary if gen_summary.strip() else "[EMPTY SUMMARY]")

        print(f"\nCosine similarity (judgment vs. generated summary): {cos_sim:.4f}")
        print("\n" + "=" * 120 + "\n")

if __name__ == "__main__":
    main()


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\pavit\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Device: cuda
Loading retriever and generators (may take a while)...


No sentence-transformers model found with name law-ai/InLegalBERT. Creating a new one with mean pooling.


[INFO] Showing up to first 3 raw lines of train_judg.jsonl for debugging:
 RAW: {"ID": "id_10", "Judgment": "Case :- WRIT - C No. - 11383 of 2023\nPetitioner :- Syed Hamidul Bari\nRespondent :- State Of U.P. Thru. Addl. Chief/Prin. Secy. Housing And\nUrban Planning Deptt. Lko. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim, Amrit Khare\nCounsel for Respondent :- C.S.C., Ratnesh Chandra\nCase :- WRIT - C No. - 11360 of 2023\nPetitioner :- Mohd. Naushad\nRespondent :- State Of U.P. Thru. Addl. Chief Secy./Prin. Secy. Housing\nAnd Urban Planning Deptt. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim,Amrit Khare\nCounsel for Respondent :- C.S.C.,Ratnesh Chandra\nCase :- WRIT - C No. - 11362 of 2023\nPetitioner :- Mohammad Abrar\nRespondent :- State Of U.P. Thru. Addl. Chief/Prin. Secy. Housing Urban\nPlanning Deptt. Lko. And 4 Others\nCounsel for Petitioner :- Kazim Ibrahim,Amrit Khare\nCounsel for Respondent :- C.S.C.,Ratnesh Chandra\nCase :- WRIT - C No. - 11368 of 2023\nPet

Input ids are automatically padded from 2256 to 3072 to be a multiple of `config.attention_window`: 1024


>> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):

Versus Appearance: and Date : 22/09/2023 1.The present case is an eye opener. The convict- Chandanji @ Gato Chhanaji Thakor has filed the present application seeking regular bail through jail. Such application was filed by him on 05.08.2023, which is forwarded to the Registry of this Court vide communication dated 11.08.2023 written by the Deputy Superintendent of Ahmedabad Central Jail. 2.When the matter was listed yesterday, learned advocate Mr.Soni appearing for the applicant-convict has invited attention of this Court to the order dated 29.09.2020 passed in Criminal Misc. Application (for suspension of sentence) No.1 of 2020 in the captioned appeal and has submitted that this Court, after passing a comprehensive order, had already released the applicant on regular bail by suspending his sentence under the provision of Section 389 of the Code of Criminal Procedure, 1973 (for short "the 3.The matter was ordered to be listed to

Input ids are automatically padded from 1819 to 2048 to be a multiple of `config.attention_window`: 1024


>> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):

Non-Reportable Criminal Appeal No._________ of 2024 (@Special Leave Petition (Crl.) No. 10499 OF 2023) The State of Jharkhand … Appellant Versus Sandeep Kumar … Respondent 1.Leave granted. 2.By order dated 06.07.2022 passed in ABA No. 3483 of 2022, the High Court of Jharkhand at Ranchi granted pre-arrest bail to the respondent herein in relation to Dhanwar PS Case No. 296 of 2021, registered for offences under Sections 419, 466, 221, 205, 109 and 120-B/ 34 IPC. Aggrieved thereby, the State of Jharkhand filed the present appeal. 3.The respondent was the Officer-in-Charge of Dhanwar Police Station at the relevant time and was the Investigating Officer in Dhanwar PS Case No. 276 of 2021 registered against one Ranjeet Kumar Saw, son of Lakhan Saw, under Sections 420, 475, 201, 109 and 34 IPC along with Sections 65 and 68 of the Copyright Act, 1957. The said case was registered upon the complaint made by one Sanjay Kumar Sharma on be

Input ids are automatically padded from 2641 to 3072 to be a multiple of `config.attention_window`: 1024


>> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):

---- Appellant Versus Umesh Sharma S/o Late Omprakash Sharma Aged About 37 Years R/o ---- Respondent For Appellant :Shri Barun Kumar Chakrabarty, Advocate For Respondent :None, though served Hon'ble Shri Justice Goutam Bhaduri Hon'ble Shri Justice Sanjay S. Agrawal Judgment on Board Per Goutam Bhaduri, J. Heard. 1.The present appeal is against the judgment and decree dated 30/10/2021 (ANNEXURE A/1) passed by the learned Family Court, Raigarh, District Raigarh, C.G. in Civil Suit No.31-A/2020 whereby the application filed by the wife seeking divorce on the ground of cruelty was dismissed. Being aggrieved by such judgment and decree, the instant appeal is by the wife/appellant. 2.The respondent was ex-parte before the family Court. Here before this Court too despite service of the notice, the respondent/husband has not made any representation. Smt. Payal Sharma W/o Umesh Sharma Aged About 33 Years R/o Kelo 3.The brief facts of th

Input ids are automatically padded from 2265 to 3072 to be a multiple of `config.attention_window`: 1024


>> ORIGINAL JUDGMENT (CLEANED, truncated to 5000 chars):

1Whether Reporters of Local Papers may be allowed to see the judgment ? 2To be referred to the Reporter or not ? 3Whether their Lordships wish to see the fair copy of the judgment ? 4Whether this case involves a substantial question of law as to the interpretation of the Constitution of India or any order made thereunder ? Versus Appearance: MOHINI H DAVE for the Respondent(s) No. 2 Date : 22/06/2023 1.This application is filed under Section 482 of the Code of Criminal Procedure, 1973 (`the Code’ for short) praying to quash the FIR registered as C.R.No.I- 110 of 2016 with GIDC Vatva Police Station, Ahmedabad city for the offences punishable under Sections 498(A), 323 and 114 of the Indian Penal Code and Section 4 of the Dowry Prohibition Act. 2.The brief facts leading to filing of this application are such that the son of the applicants married the respondent no.2 on 28.2.2000 and thereafter, after some time, the applicants star