In [None]:
in percentage

In [None]:
# =========================================================
#  LEGAL JUDGMENT SUMMARIZATION â€” LED FINE-TUNING
# Using internal 90/10 split for evaluation
# =========================================================

import re
import json
import torch
import numpy as np
from tqdm import tqdm
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    LEDForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
from rouge import Rouge
from sacrebleu.metrics import BLEU
import os

# =========================================================
# 1 Device setup
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

# =========================================================
# 2 Preprocessing function
# =========================================================
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(
        r"(Petitioner\s*:-.*?Respondent\s*:-.*?Counsel for Respondent\s*:-.*?)(\1)+",
        r"\1",
        text,
        flags=re.DOTALL
    )
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# =========================================================
# 3 Load all training data
# =========================================================
with open("train_judg.jsonl", "r", encoding="utf-8") as f_text, \
     open("train_ref_summ.jsonl", "r", encoding="utf-8") as f_summ:
    train_texts = [json.loads(line) for line in f_text]
    train_summaries = [json.loads(line) for line in f_summ]

all_data = []
for t, s in zip(train_texts, train_summaries):
    all_data.append({
        "ID": t["ID"],
        "text": clean_judgment_text(t["Judgment"]),
        "summary": clean_judgment_text(s["Summary"])
    })

# =========================================================
# 4 Split into 90% train, 10% validation
# =========================================================
train_list, val_list = train_test_split(all_data, test_size=0.1, random_state=42)
train_dataset = Dataset.from_list(train_list)
val_dataset = Dataset.from_list(val_list)

print(f"Train size: {len(train_dataset)} | Validation size: {len(val_dataset)}")

# =========================================================
# 5 Model and tokenizer (LED)
# =========================================================
model_name = "nsi319/legal-led-base-16384"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name).to(device)

# Important: LED requires global attention on at least one token
def add_global_attention_mask(input_ids):
    attention_mask = torch.ones_like(input_ids)
    global_attention_mask = torch.zeros_like(input_ids)
    global_attention_mask[:, 0] = 1  # First token has global attention
    return attention_mask, global_attention_mask

# =========================================================
# 6 Tokenization
# =========================================================
max_input_len = 4096   # LED supports up to 16384 tokens
max_output_len = 1024  # enough for ~500 words

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["text"],
        max_length=max_input_len,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        examples["summary"],
        max_length=max_output_len,
        truncation=True,
        padding="max_length"
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=["text", "summary", "ID"])
val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=["text", "summary", "ID"])

# =========================================================
# 7 Data collator
# =========================================================
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# =========================================================
# 8 Evaluation metrics
# =========================================================
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    return {
        "rouge1": rouge_result["rouge1"] * 100,
        "rouge2": rouge_result["rouge2"] * 100,
        "rougeL": rouge_result["rougeL"] * 100,
    }

# =========================================================
# 9 Training arguments
# =========================================================
training_args = Seq2SeqTrainingArguments(
    output_dir="./led_legal_summ_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_dir="./logs",
    logging_steps=50,
    remove_unused_columns=False
)

# =========================================================
# 10 Trainer
# =========================================================
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# =========================================================
# 11 Train
# =========================================================
print("Starting fine-tuning...")
train_result = trainer.train()
print("Training completed.")

with open("training_metrics.json", "w") as f:
    json.dump(train_result.metrics, f, indent=4)
print("Training metrics saved.")

# =========================================================
# 12 Save fine-tuned model
# =========================================================
model_dir = "./led_legal_summ_model_final"
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)
print("Model saved to", model_dir)


# =========================================================
# Generate summaries
# =========================================================
print("Generating summaries for validation set (400-500 words)...")
generated_summaries = []

max_input_len = 1024
max_output_len = 1024  # allows ~500 words

for example in tqdm(val_list):
    inputs = tokenizer(
        example["text"],
        return_tensors="pt",
        max_length=max_input_len,
        truncation=True,
        padding=True
    ).to(device)

    outputs = model.generate(
        **inputs,
        max_length=500,   # ~max 500 words
        min_length=400,   # ~min 400 words
        num_beams=8,
        length_penalty=1.0,
        no_repeat_ngram_size=4,
        early_stopping=True
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generated_summaries.append({
        "ID": example["ID"],
        "generated_summary": summary,
        "reference_summary": example["summary"]
    })

# Save generated summaries
os.makedirs("outputs", exist_ok=True)
out_file = "outputs/val_generated_summaries.jsonl"
with open(out_file, "w", encoding="utf-8") as f:
    for item in generated_summaries:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")
print("Generated summaries saved to", out_file)

# =========================================================
# Evaluate using ROUGE, BLEU, BERTScore
# =========================================================
print("Evaluating summaries...")
rouge_scorer = Rouge()
bleu_scorer = BLEU()
bertscore_eval = evaluate.load("bertscore")

rouge1_scores, rouge2_scores, rougel_scores, bleu_scores, bert_scores = [], [], [], [], []

for idx, item in enumerate(generated_summaries):
    hyp = item["generated_summary"]
    ref = item["reference_summary"]

    try:
        r = rouge_scorer.get_scores(hyps=hyp, refs=ref)[0]
        rouge1 = r["rouge-1"]["f"] * 100
        rouge2 = r["rouge-2"]["f"] * 100
        rougel = r["rouge-l"]["f"] * 100
    except:
        rouge1, rouge2, rougel = 0,0,0

    try:
        bleu = bleu_scorer.sentence_score(hypothesis=hyp, references=[ref]).score
    except:
        bleu = 0

    try:
        bert = bertscore_eval.compute(predictions=[hyp], references=[ref], lang="en")["f1"][0]*100
    except:
        bert = 0

    rouge1_scores.append(rouge1)
    rouge2_scores.append(rouge2)
    rougel_scores.append(rougel)
    bleu_scores.append(bleu)
    bert_scores.append(bert)

metrics = {
    "ROUGE-1 (%)": np.mean(rouge1_scores),
    "ROUGE-2 (%)": np.mean(rouge2_scores),
    "ROUGE-L (%)": np.mean(rougel_scores),
    "BLEU (%)": np.mean(bleu_scores),
    "BERTScore-F1 (%)": np.mean(bert_scores),
    "AVG_SCORE (%)": np.mean([np.mean(rouge2_scores), np.mean(rougel_scores), np.mean(bleu_scores)])
}

with open("outputs/validation_metrics.json", "w") as f:
    json.dump(metrics, f, indent=4)

print("\n========== Validation Metrics (in %) ==========")
for k, v in metrics.items():
    print(f"{k}: {v:.2f}")
print("===============================================")
print("Validation metrics saved to outputs/validation_metrics.json")

