In [None]:
from huggingface_hub import login
login()

# Dataset Preparation

In [3]:
from datasets import load_dataset

dataset = load_dataset("vishnun0027/Indian-Law")
dataset = dataset["train"]
dataset

Dataset({
    features: ['Instruction', 'Response'],
    num_rows: 25607
})

In [4]:
print("Instruction")
print(dataset[0]["Instruction"])

Instruction
What is the difference between a petition and a plaint in Indian law?


In [5]:
print("Response")
print(dataset[0]["Response"])

Response
A petition is a formal request submitted to a courttribunalor authority to seek a specific remedy or relief. It is commonly used for various purposessuch as filing a writ petition in the High Court or submitting a petition for divorce. On the other handa plaint is a formal written statement of a plaintiff's claim in a civil lawsuit. The key difference is that a petition is more versatile and can be used for various legal matterswhile a plaint is specific to civil cases.


## Remove empty rows

In [6]:
dataset = dataset.filter(lambda x: x["Instruction"] is not None and x["Response"] is not None)
dataset

Dataset({
    features: ['Instruction', 'Response'],
    num_rows: 25600
})

### Dataset cleaning

In [7]:
import re
pattern = r"[###]" # identify the rows which have special characters except essentials
filtered_df = dataset.filter(lambda x: re.match(pattern,x["Instruction"]) or re.match(pattern,x["Response"]) )
filtered_df

Dataset({
    features: ['Instruction', 'Response'],
    num_rows: 1000
})

In [8]:
print(filtered_df[1]["Instruction"])
cleaned_text = re.sub(r'^### (Instruction|Response):.*\n?', '', filtered_df[1]["Instruction"], flags=re.MULTILINE)
print("Cleaned Text")
print(cleaned_text.strip())

### Instruction:
Draft a hypothetical legal petition based on the provided case.

### Response:

Cleaned Text
Draft a hypothetical legal petition based on the provided case.


In [9]:
def clean_df(row):
    if re.match(pattern,row["Instruction"]) or re.match(pattern,row["Response"]):
        cleaned_text = re.sub(r'^### (Instruction|Response):.*\n?', '', row["Instruction"], flags=re.MULTILINE)
        row["Instruction"] = cleaned_text.strip()
    return row
dataset = dataset.map(clean_df)

# Model Setup

In [6]:
# Import necessary libraries
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct",device_map="auto")
# Set our name for the finetune to be saved &/ uploaded to
finetune_name = "SmolLM2-FT-legal-india"
finetune_tags = ["smol", "leagal-india","indian law"]

## Generate with the base model

In [7]:
prompt = "Can a Vakalatnama be revoked or withdrawn in India?"
# Format with template
messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)

# Generate response
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=100)
print("Before training:")
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Before training:
system
You are a helpful AI assistant named SmolLM, trained by Hugging Face
user
Can a Vakalatnama be revoked or withdrawn in India?
assistant
Yes, a Vakalatnama can be revoked or withdrawn in India. The Indian Constitution provides for revocation of a Vakalatnama under the provisions of Section 12 of the Indian Constitution. However, the process for revoking a Vakalatnama can be complex and may involve a court hearing, a hearing by a special authority, or a court martial.

The Indian Constitution also provides for the right to a speedy trial


## Dataset Preparation

In [12]:
def apply_chat_template(example):
    messages = [
        {"role": "user", "content": example['Instruction']},
        {"role": "assistant", "content": example['Response']}
    ]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    return {"prompt": prompt}

In [32]:
# Apply the chat templatefunction to the dataset
chat_df = dataset.map(apply_chat_template)
chat_df = chat_df.train_test_split(0.05)

Map:   0%|          | 0/1280 [00:00<?, ? examples/s]

## Tokenize dataset

In [10]:
def tokenize_function(example):
    tokens = tokenizer(example['prompt'], padding="max_length", truncation=True, max_length=128)
    tokens['labels'] = [
        -100 if token == tokenizer.pad_token_id else token for token in tokens['input_ids']
    ]
    return tokens

In [11]:
# Apply tokenize_function to each row
tokenized_dataset = chat_df.map(tokenize_function)
tokenized_dataset = tokenized_dataset.remove_columns(['Instruction', 'Response', 'prompt'])

Map:   0%|          | 0/24320 [00:00<?, ? examples/s]

Map:   0%|          | 0/1280 [00:00<?, ? examples/s]

## Configuring the SFTTrainer

In [12]:
import os
os.environ['WANDB_MODE'] = 'disabled'

In [13]:
# Configure the SFTTrainer
sft_config = SFTConfig(
    output_dir="./sft_output",
    max_steps=1000,  # Adjust based on dataset size and desired training duration
    per_device_train_batch_size=16,  # Set according to your GPU memory capacity
    learning_rate=5e-5,  # Common starting point for fine-tuning
    logging_steps=100,  # Frequency of logging training metrics
    save_steps=200,  # Frequency of saving model checkpoints
    eval_steps=200,  # Frequency of evaluation
    use_mps_device=(
        True if device == "mps" else False
    ),
    hub_model_id=finetune_name,
)

# Initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)


## Training the Model

In [14]:
# Train the model
trainer.train()

# Save the model
trainer.save_model(f"./{finetune_name}")



Step,Training Loss
100,1.4868
200,1.2821
300,1.2314
400,1.1932
500,1.1628
600,1.1184
700,1.1099
800,1.0867
900,1.0698
1000,1.0861


In [15]:
trainer.push_to_hub(tags=finetune_tags)

model.safetensors:   0%|          | 0.00/538M [00:00<?, ?B/s]

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.62k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/saicharan1010/SmolLM2-FT-legal-india/commit/a3ab5513364d700fda05e4dc18659a2a72ffbaae', commit_message='End of training', commit_description='', oid='a3ab5513364d700fda05e4dc18659a2a72ffbaae', pr_url=None, repo_url=RepoUrl('https://huggingface.co/saicharan1010/SmolLM2-FT-legal-india', endpoint='https://huggingface.co', repo_type='model', repo_id='saicharan1010/SmolLM2-FT-legal-india'), pr_revision=None, pr_num=None)

## Testing the model

In [19]:
prompt = "Can a Vakalatnama be revoked or withdrawn in India?"

messages = [{"role": "user", "content": prompt}]
formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False)

inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

system
You are a helpful AI assistant named SmolLM, trained by Hugging Face
user
Can a Vakalatnama be revoked or withdrawn in India?
assistant
Yes, a Vakalatnama can be revoked or withdrawn in India. According to the context provided, "The revocation of a Vakalatnama shall be in the nature of a law made by the Legislature of the State in which the Vakalatnama is to be found." This means that the revocation or withdrawal of a Vakalatnama is a legal process that can be performed by the State Legislature. However, it is important to note that this is not a legal process that can be performed by the Supreme Court or any other authority in India. The context does not specify any specific procedure for revoking or withdrawing a Vakalatnama. Therefore, it is not possible to determine whether a Vakalatnama can be revoked or withdrawn in India based on this context alone.


## Evaluation of model

In [26]:
import logging
import time
import torch
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm


dataset = dataset["test"]


device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained("./SmolLM2-FT-legal-india", padding_side="left", truncation=True)
tokenizer.pad_token = tokenizer.eos_token  # Ensure padding token is set

# Update pipeline calls
model_1 = pipeline(
    "text-generation",
    model="./SmolLM2-FT-legal-india",
    tokenizer=tokenizer,
    device=device,
    torch_dtype=torch.float16,
    truncation=True  # Explicitly enable truncation
)

model_2 = pipeline(
    "text-generation",
    model="HuggingFaceTB/SmolLM2-135M-Instruct",
    tokenizer=tokenizer,
    device=device,
    torch_dtype=torch.float16,
    truncation=True  # Explicitly enable truncation
)


# Function to generate responses in batches
def generate_responses(model, dataset, input_column="Instruction", batch_size=16, max_length=200):
    responses = []
    total_time = 0
    print(f"Generating responses for {len(dataset)} samples in batches of {batch_size}...")

    try:
        start_time = time.time()
        for output in tqdm(model(KeyDataset(dataset, input_column), batch_size=batch_size, max_length=max_length, do_sample=True), total=len(dataset)):
            responses.append(output[0]["generated_text"])

        total_time = time.time() - start_time
        avg_time = total_time / len(dataset)
        print(f"Total time: {total_time:.3f} sec | Avg response time: {avg_time:.3f} sec")

    except Exception as e:
        print(f"Error during batch processing: {e}")

    return responses

# Generate responses efficiently
print("Generating responses for Model 1...")
responses_1 = generate_responses(model_1, dataset)

print("Generating responses for Model 2...")
responses_2 = generate_responses(model_2, dataset)


2025-03-01 01:03:31,104 - INFO - Loading dataset...
2025-03-01 01:03:31,105 - INFO - Dataset size: 1280 samples.
2025-03-01 01:03:31,193 - INFO - Loading models...
Device set to use cuda:0
Device set to use cuda:0
2025-03-01 01:03:32,414 - INFO - Models loaded successfully!
2025-03-01 01:03:32,415 - INFO - Generating responses for Model 1...
2025-03-01 01:03:32,415 - INFO - Generating responses for 1280 samples in batches of 16...


  0%|          | 0/1280 [00:00<?, ?it/s]

2025-03-01 01:05:48,591 - INFO - Total time: 136.175 sec | Avg response time: 0.106 sec
2025-03-01 01:05:48,591 - INFO - Generating responses for Model 2...
2025-03-01 01:05:48,591 - INFO - Generating responses for 1280 samples in batches of 16...


  0%|          | 0/1280 [00:00<?, ?it/s]

2025-03-01 01:07:55,627 - INFO - Total time: 127.035 sec | Avg response time: 0.099 sec
2025-03-01 01:07:55,628 - INFO - Response generation completed successfully!


In [34]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge

rouge = Rouge()
smooth_func = SmoothingFunction().method1  # Smoothing function to avoid zero BLEU scores

def evaluate_responses(predictions, dataset, ground_truth_column="Response"):
    bleu_scores, rouge_scores = [], []
    
    for pred, gt in zip(predictions, dataset[ground_truth_column]):
        pred_tokens, gt_tokens = pred.split(), gt.split()

        # Compute BLEU score with smoothing
        if pred_tokens:
            bleu_scores.append(sentence_bleu([gt_tokens], pred_tokens, smoothing_function=smooth_func))
        else:
            bleu_scores.append(0.0)  # If prediction is empty, BLEU score is 0

        # Compute ROUGE score
        rouge_scores.append(rouge.get_scores(pred, gt)[0])

    return {
        "BLEU": sum(bleu_scores) / len(bleu_scores),
        "ROUGE": sum([score["rouge-l"]["f"] for score in rouge_scores]) / len(rouge_scores),
    }

# Example usage
metrics_1 = evaluate_responses(responses_1, dataset)
metrics_2 = evaluate_responses(responses_2, dataset)

print("Model 1 (fine tunned):", metrics_1)
print("Model 2:", metrics_2)


Model 1 (fine tunned): {'BLEU': 0.12600565632766575, 'ROUGE': 0.3042493416159143}
Model 2: {'BLEU': 0.1206421911321742, 'ROUGE': 0.34455162376518916}
