
# Fine-tune FLAN-T5-small for Indian Legal QA / Summarization

This notebook walks you through a **minimal, reliable** fine-tuning pipeline on Indian legal text using a **small instruction-tuned model** (`google/flan-t5-small`).  
It supports **two dataset modes**:
1. **CSV** (e.g., Kaggle QA pairs) – columns like `question`, `answer`, and optionally `context` or `text`  
2. **Hugging Face Datasets** – e.g., `viber1/indian-law-dataset` or similar

> 🔧 You can adapt this to any of your datasets by mapping columns in the **Column Mapping** cell.

**What you'll do:**
- Install libs
- Load dataset (CSV or HF)
- Clean & split
- Tokenize to instruction format
- Train with `Trainer`
- Evaluate (ROUGE)
- Save & run inference

**Hardware**: Works on CPU, but GPU recommended (your RTX 4060 is perfect).

## 0. Setup & Installs

In [None]:

# If running in a fresh environment, uncomment and run:
# !pip install -U pip
# !pip install "torch>=2.2" --index-url https://download.pytorch.org/whl/cu121  # for CUDA 12.x systems
# !pip install transformers datasets accelerate evaluate sentencepiece rouge-score
# #Optional: for faster training on NVIDIA GPUs (8-bit/4-bit quantization not needed for flan-t5-small)
# !pip install bitsandbytes
import os, sys, math, random
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Trainer, TrainingArguments
import evaluate
import numpy as np
import torch

print('Torch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('CUDA device:', torch.cuda.get_device_name(0))


ModuleNotFoundError: No module named 'torch'

## 1. Configuration

In [None]:

# ==== Configuration ====
# Choose data source: "csv" or "huggingface"
DATA_SOURCE = "csv"  # "csv" | "huggingface"

# If CSV: set your local CSV path here (Kaggle or custom)
CSV_PATH = "/path/to/your/indian_legal_qa.csv"   # e.g., ./datasets/kaggle/qa.csv

# If HuggingFace: set dataset name and splits/fields accordingly
HF_DATASET_NAME = "viber1/indian-law-dataset"    # example; change to your dataset
HF_SPLIT_TRAIN = "train"                          # update if dataset provides different splits
HF_SPLIT_VALID = "validation"                      # fallback/auto-split used if not present

# Task style: "qa" (question -> answer) or "summarization" (document -> summary)
TASK_STYLE = "qa"  # "qa" | "summarization"

# Model choice: use a small instruction-tuned seq2seq model
MODEL_NAME = "google/flan-t5-small"

# Sequence lengths
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128

# Training hyperparameters
EPOCHS = 3
BATCH_SIZE = 8
LEARNING_RATE = 5e-5
WARMUP_RATIO = 0.03
WEIGHT_DECAY = 0.01
LOGGING_STEPS = 50

# Output directory
OUTPUT_DIR = "./models/flan_t5_small_legal_qa"
os.makedirs(OUTPUT_DIR, exist_ok=True)


## 2. Load Dataset

In [None]:

def load_data_from_csv(csv_path: str) -> Dataset:
    df = pd.read_csv(csv_path)
    # Basic cleanup: drop fully empty rows
    df = df.dropna(how='all')
    # Add an 'id' if not present
    if 'id' not in df.columns:
        df.insert(0, 'id', range(1, len(df) + 1))
    return Dataset.from_pandas(df, preserve_index=False)

def try_load_hf_dataset(name: str, split_train: str, split_valid: str) -> DatasetDict:
    try:
        dsd = load_dataset(name)
        # If the dataset already has train/validation/test, we use them directly
        if isinstance(dsd, DatasetDict):
            if "train" in dsd and "validation" in dsd:
                return dsd
            if "train" in dsd and "test" in dsd:
                # create a validation split from test
                dsd = DatasetDict({
                    "train": dsd["train"],
                    "validation": dsd["test"]
                })
                return dsd
        # Or fallback to splits if provided
        train = load_dataset(name, split=split_train)
        valid = load_dataset(name, split=split_valid)
        return DatasetDict({"train": train, "validation": valid})
    except Exception as e:
        print("HF dataset load failed:", e)
        raise

if DATA_SOURCE == "csv":
    ds_all = load_data_from_csv(CSV_PATH)
    print(ds_all)
    # Split 90/10 if no explicit validation set provided
    ds_all = ds_all.train_test_split(test_size=0.1, seed=42)
    ds = DatasetDict({"train": ds_all["train"], "validation": ds_all["test"]})
else:
    ds = try_load_hf_dataset(HF_DATASET_NAME, HF_SPLIT_TRAIN, HF_SPLIT_VALID)
    print(ds)


## 3. Column Mapping (adapt to your dataset)

In [None]:

# We try to infer common column names for question/answer/context/summary/text.
# You can override manually below if needed.

candidate_cols = {c.lower(): c for c in ds["train"].column_names}
print("Available columns:", list(candidate_cols.values()))

def pick_col(possible_names):
    for name in possible_names:
        if name in candidate_cols:
            return candidate_cols[name]
    return None

# Default guesses (override if incorrect)
QUESTION_COL = pick_col(["question", "query", "prompt", "ques"])
ANSWER_COL   = pick_col(["answer", "response", "target", "label", "output"])
CONTEXT_COL  = pick_col(["context", "passage", "text", "document", "body", "content"])
SUMMARY_COL  = pick_col(["summary", "abstract", "target", "label"])

print("Guessed mapping -> QUESTION:", QUESTION_COL, "| ANSWER:", ANSWER_COL, "| CONTEXT:", CONTEXT_COL, "| SUMMARY:", SUMMARY_COL)

# Manual override examples (uncomment and set if needed):
# QUESTION_COL = "your_question_column"
# ANSWER_COL   = "your_answer_column"
# CONTEXT_COL  = "your_context_or_text_column"
# SUMMARY_COL  = "your_summary_column"


## 4. Basic Cleaning & Preview

In [None]:

def strip_or_none(x):
    if x is None:
        return None
    if isinstance(x, str):
        s = x.strip()
        return s if s else None
    return x

def clean_example(example):
    ex = dict(example)
    for k, v in ex.items():
        if isinstance(v, str):
            ex[k] = v.replace("\u00a0", " ").replace("\t", " ").replace("\r", " ").strip()
    return ex

ds = ds.map(clean_example)

print("Sample records from train split:")
for i in range(min(3, len(ds["train"]))):
    row = ds["train"][i]
    print({k: row.get(k) for k in [QUESTION_COL, ANSWER_COL, CONTEXT_COL, SUMMARY_COL] if k is not None})


## 5. Build Instruction-style Inputs & Targets

In [None]:

# We convert data into instruction format for FLAN-T5:
#   - QA:        input = f"question: {Q}\ncontext: {C}"  -> target = A
#   - Summarize: input = f"summarize: {TEXT}"            -> target = SUMMARY
#
# If context is missing for QA, we omit it.

def build_qa_text(example):
    q = example.get(QUESTION_COL) if QUESTION_COL else None
    a = example.get(ANSWER_COL)   if ANSWER_COL   else None
    c = example.get(CONTEXT_COL)  if CONTEXT_COL  else None

    q = strip_or_none(q)
    a = strip_or_none(a)
    c = strip_or_none(c)

    if q is None or a is None:
        return {"input_text": None, "target_text": None}

    if c:
        inp = f"question: {q}\ncontext: {c}"
    else:
        inp = f"question: {q}"
    return {"input_text": inp, "target_text": a}

def build_sum_text(example):
    txt = example.get(CONTEXT_COL) or example.get("text") or example.get("document")
    tgt = example.get(SUMMARY_COL)
    txt = strip_or_none(txt)
    tgt = strip_or_none(tgt)
    if txt is None or tgt is None:
        return {"input_text": None, "target_text": None}
    return {"input_text": f"summarize: {txt}", "target_text": tgt}

builder = build_qa_text if TASK_STYLE == "qa" else build_sum_text

ds_fmt = ds.map(builder, remove_columns=ds["train"].column_names)
# Drop rows where either input or target is None
ds_fmt = ds_fmt.filter(lambda x: x["input_text"] is not None and x["target_text"] is not None)

print(ds_fmt)
print("Example formatted record:", ds_fmt["train"][0])


## 6. Tokenize

In [None]:

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input_text"],
        max_length=MAX_INPUT_LENGTH,
        truncation=True
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["target_text"],
            max_length=MAX_TARGET_LENGTH,
            truncation=True
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized = ds_fmt.map(tokenize_function, batched=True, remove_columns=ds_fmt["train"].column_names)
print(tokenized)


## 7. Model & Data Collator

In [None]:

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)


## 8. Metrics

In [None]:

rouge = evaluate.load("rouge")

def postprocess_text(preds, labels):
    preds = [p.strip() for p in preds]
    labels = [l.strip() for l in labels]
    return preds, labels

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 2) for k, v in result.items()}
    return result


## 9. Train

In [None]:

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    logging_steps=LOGGING_STEPS,
    evaluation_strategy="steps",
    eval_steps=LOGGING_STEPS,
    save_strategy="steps",
    save_steps=LOGGING_STEPS,
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    report_to="none",
    bf16=torch.cuda.is_available(),   # if Ampere+ GPU
    fp16=not torch.cuda.is_available() and False,  # leave False if no GPU
    gradient_accumulation_steps=1,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Training complete! Model saved to", OUTPUT_DIR)


## 10. Quick Evaluation / Generation

In [None]:

def generate_answer(prompt, max_new_tokens=128, num_beams=4):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Pick a sample from validation set
sample = ds_fmt["validation"][0]
print("INPUT:\n", sample["input_text"][:500])
print("\nREFERENCE:\n", sample["target_text"][:500])
print("\nPREDICTION:\n", generate_answer(sample["input_text"]))


## 11. Inference Helper

In [None]:

# Reload from disk (optional: for a clean session)
# from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# model = AutoModelForSeq2SeqLM.from_pretrained(OUTPUT_DIR)
# tokenizer = AutoTokenizer.from_pretrained(OUTPUT_DIR)

def ask_legal(question: str, context: str = None) -> str:
    if TASK_STYLE == "qa":
        if context:
            prompt = f"question: {question}\ncontext: {context}"
        else:
            prompt = f"question: {question}"
    else:
        prompt = f"summarize: {question if context is None else context}"
    return generate_answer(prompt)

print(ask_legal("What is the punishment for theft under the IPC?", ""))



## 12. Tips & Next Steps

- **Column mapping:** If the guessed `QUESTION_COL`, `ANSWER_COL`, etc. are wrong, set them manually in the mapping cell.
- **Long documents:** Increase `MAX_INPUT_LENGTH` (up to 1024 for small T5s), or chunk long contexts.
- **Better models:** Upgrade to `google/flan-t5-base` or instruction-tune a 7B model with **QLoRA** (LoRA + 4-bit) once this pipeline works.
- **Evaluation:** Use a held-out test set. Add metrics like BLEU, exact match, or task-specific scoring.
- **Licensing:** Verify dataset licenses before training for redistribution/commercial use.
- **Reproducibility:** Save `seed`, environment, and commit your config to git.
