In [None]:
# =========================================================
# T5 Legal Summarization Full Training Script
# =========================================================

import re
import json
import torch
import numpy as np
from tqdm import tqdm
from datasets import Dataset, load_dataset
from transformers import (
    T5Tokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
import evaluate
import os

# =========================================================
# 1 Device setup
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
if device == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

# =========================================================
# 2 Preprocessing function
# =========================================================
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(
        r"(Petitioner\s*:-.*?Respondent\s*:-.*?Counsel for Respondent\s*:-.*?)(\1)+",
        r"\1",
        text,
        flags=re.DOTALL
    )
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# =========================================================
# 3 Load and prepare data
# =========================================================
with open("train_judg.jsonl", "r", encoding="utf-8") as f_text, \
     open("train_ref_summ.jsonl", "r", encoding="utf-8") as f_summ:
    train_texts = [json.loads(line) for line in f_text]
    train_summaries = [json.loads(line) for line in f_summ]

train_data = []
for t, s in zip(train_texts, train_summaries):
    train_data.append({
        "ID": t["ID"],
        "text": clean_judgment_text(t["Judgment"]),
        "summary": clean_judgment_text(s["Summary"])
    })

train_dataset = Dataset.from_list(train_data)

# Validation dataset
val_dataset = load_dataset("json", data_files={"val": "val_judg.jsonl"})["val"]
val_dataset = val_dataset.map(lambda x: {"text": clean_judgment_text(x["Judgment"])})
if "Summary" in val_dataset.column_names:
    val_dataset = val_dataset.map(lambda x: {"summary": clean_judgment_text(x["Summary"])})

print("Train size:", len(train_dataset))
print("Validation size:", len(val_dataset))

# =========================================================
# 4 Model and tokenizer (T5)
# =========================================================
model_name = "t5-base"  # can use t5-large or legal-domain variant
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

# =========================================================
# 5 Tokenization function
# =========================================================
max_input_len = 1024
max_output_len = 512

def preprocess_function(examples):
    model_inputs = tokenizer(
        examples["text"],
        max_length=max_input_len,
        truncation=True,
        padding="max_length"
    )
    labels = tokenizer(
        examples["summary"],
        max_length=max_output_len,
        truncation=True,
        padding="max_length"
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)

# =========================================================
# 6 Data collator
# =========================================================
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# =========================================================
# 7 Evaluation metrics
# =========================================================
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    bleu_result = bleu.compute(predictions=decoded_preds, references=decoded_labels)
    bert_result = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")

    return {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bleu": bleu_result["bleu"],
        "bertscore_f1": np.mean(bert_result["f1"]),
    }

# =========================================================
# 8 Training arguments
# =========================================================
training_args = Seq2SeqTrainingArguments(
    output_dir="./t5_legal_summ_model",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_dir="./logs",
    logging_steps=50
)

# =========================================================
# 9 Trainer setup
# =========================================================
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=train_dataset.select(range(min(100, len(train_dataset)))),  # small subset for quick eval
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# =========================================================
# 10 Train model
# =========================================================
print("Starting fine-tuning...")
train_result = trainer.train()
print("Training completed.")

with open("training_metrics.json", "w") as f:
    json.dump(train_result.metrics, f, indent=4)
print("Training metrics saved to training_metrics.json")

# =========================================================
# 11 Save fine-tuned model
# =========================================================
model_dir = "./t5_legal_summ_model_final"
trainer.save_model(model_dir)
tokenizer.save_pretrained(model_dir)
print("Model saved to", model_dir)

# =========================================================
# 12 Generate summaries for validation set
# =========================================================
print("Generating summaries for validation set...")
generated_summaries = []

for i in tqdm(range(len(val_dataset))):
    inputs = tokenizer(
        val_dataset[i]["text"],
        return_tensors="pt",
        max_length=max_input_len,
        truncation=True
    ).to(device)

    outputs = model.generate(
        **inputs,
        max_length=max_output_len,
        min_length=200,
        num_beams=8,
        length_penalty=1.0,
        no_repeat_ngram_size=4,
        early_stopping=True
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)
    generated_summaries.append({
        "ID": val_dataset[i]["ID"],
        "generated_summary": summary
    })

os.makedirs("outputs1", exist_ok=True)
out_file = "outputs1/val_generated_summaries.jsonl"
with open(out_file, "w", encoding="utf-8") as f:
    for item in generated_summaries:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print("Generated summaries saved to", out_file)

# =========================================================
# 13 Evaluate on validation set
# =========================================================
if "summary" in val_dataset.column_names:
    preds = [item["generated_summary"] for item in generated_summaries]
    refs = [val_dataset[i]["summary"] for i in range(len(val_dataset))]

    rouge_result = rouge.compute(predictions=preds, references=refs)
    bleu_result = bleu.compute(predictions=preds, references=refs)
    bert_result = bertscore.compute(predictions=preds, references=refs, lang="en")

    eval_results = {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bleu": bleu_result["bleu"],
        "bertscore_f1": np.mean(bert_result["f1"]),
    }

    with open("outputs/validation_metrics.json", "w") as f:
        json.dump(eval_results, f, indent=4)

    print("Validation Metrics:")
    print(json.dumps(eval_results, indent=4))
    print("Validation metrics saved to outputs/validation_metrics.json")


In [None]:
# =========================================================
#  Legal Judgment Summarization â€” Inference with Fine-Tuned T5
# =========================================================

import torch
import json
from tqdm import tqdm
from datasets import Dataset
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
from sklearn.metrics.pairwise import cosine_similarity
import re

# =========================================================
# 1 Device setup
# =========================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# =========================================================
# 2 Load Fine-Tuned T5 Model and Tokenizer
# =========================================================
model_name = "./t5_legal_summ_model_final"   # Path to your fine-tuned T5 model
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

# =========================================================
# 3 Preprocessing function
# =========================================================
def clean_judgment_text(text):
    text = re.sub(r"\[Page No\.\s*\d+\]", " ", text)
    text = re.sub(r"Case\s*:-.*?\n", " ", text)
    text = re.sub(
        r"(Petitioner\s*:-.*?Respondent\s*:-.*?Counsel for Respondent\s*:-.*?)(\1)+",
        r"\1",
        text,
        flags=re.DOTALL
    )
    text = re.sub(r"\(\d+\)", "", text)
    text = re.sub(r"\n+", " ", text)
    text = re.sub(r"\s{2,}", " ", text)
    text = text.replace(" ,", ",").replace(" .", ".").strip()
    return text

# =========================================================
# 4 Load Test Dataset
# =========================================================
test_file = "test_judg.jsonl"  # Must contain {"id": "...", "judgment": "..."} per line
with open(test_file, "r", encoding="utf-8") as f:
    test_data = [json.loads(line) for line in f]

print(f" Loaded {len(test_data)} test samples")

# =========================================================
# 5 Generate Summaries + Compute Cosine Similarities
# =========================================================
generated_summaries = []
cosine_scores = []

max_input_len = 1024
max_output_len = 400

for example in tqdm(test_data, desc="Generating summaries"):
    case_id = example["id"]
    judgment_text = clean_judgment_text(example["judgment"])

    # Tokenize input
    inputs = tokenizer(
        judgment_text,
        max_length=max_input_len,
        truncation=True,
        return_tensors="pt"
    ).to(device)

    # Generate summary
    summary_ids = model.generate(
        **inputs,
        max_length=max_output_len,
        min_length=50,
        num_beams=5,
        length_penalty=1.0,
        no_repeat_ngram_size=3,
        early_stopping=True
    )

    generated_summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    # Compute cosine similarity between input and summary embeddings
    with torch.no_grad():
        input_emb = model.encoder(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        summary_inputs = tokenizer(
            generated_summary,
            max_length=512,
            truncation=True,
            return_tensors="pt"
        ).to(device)
        summary_emb = model.encoder(**summary_inputs).last_hidden_state.mean(dim=1).cpu().numpy()
        similarity = cosine_similarity(input_emb, summary_emb)[0][0]

    # Store results
    generated_summaries.append({"ID": case_id, "Summary": generated_summary})
    cosine_scores.append({"ID": case_id, "Cosine_Similarity": float(similarity)})

# =========================================================
# 6 Save Outputs
# =========================================================
with open("generated_summaries.jsonl", "w", encoding="utf-8") as f:
    for item in generated_summaries:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

with open("cosine_similarity_scores.jsonl", "w", encoding="utf-8") as f:
    for item in cosine_scores:
        f.write(json.dumps(item, ensure_ascii=False) + "\n")

print("\n Done!")
print("Summaries saved in 'generated_summaries.jsonl'")
print("Cosine similarity scores saved in 'cosine_similarity_scores.jsonl'")
