In [None]:
!pip install -U datasets transformers accelerate bitsandbytes peft

In [None]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
df = load_dataset("SnehaDeshmukh/IndianBailJudgments-1200")
ds = df

In [None]:

import pandas as pd
df = pd.DataFrame(df)

In [None]:
df

In [None]:
def format_for_finetune(example):
    return {
        "text": f"""### Instruction:
Summarize the following Indian bail judgment.

### Fact:
{example['facts']}

### Judgment:
{example['judgment_reason']}

### Summary:
{example['summary']}"""
    }

In [None]:
ds

In [None]:
formatted_dataset = ds['train'].map(format_for_finetune)

In [None]:
formatted_dataset.to_json("indian_bail_judgments.jsonl", orient="records", lines=True)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
import torch

In [None]:
model_name = "NousResearch/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config=BitsAndBytes(
                                                 load_in_4bit=True,
                                                 bnb_4bit_compute_dtype=torch.float16,
                                                 bnb_4bit_quant_type="nf4",
                                                 bnb_4bit_use_double_quant=True,
                                                 device_map="Auto"
                                             ))

In [None]:
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["query_key_value"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model,lora_config)

In [None]:
dataset = load_dataset("json", data_files="indian_bail_judgments.jsonl")["train"]

def formatting(example):
    return f"""
    ### Fact:
    {example['facts']}

    ### Judgment:
    {example['judgment']}

    ### Summary:
    {example['summary']}
    """

dataset = dataset.map(lambda x: {"text": formatting(x)})

In [None]:
training_args = TrainingArguments(
    output_dir="./llama-legal-lora",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=2,
    learning_rate=2e-4,
    fp16=True,
    save_total_limit=2,
    save_steps=100,
    warmup_steps=10,
    weight_decay=0.01,
    lr_scheduler_type="cosine"
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    args=training_args,
    packing=True  # multiple examples per sequence
)

trainer.train()

In [None]:
model.eval()

prompt = """
### Fact:
The accused was found in possession of illegal narcotics and attempted to flee arrest.

### Judgment:
"""

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=150)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

In [None]:
model.save_pretrained("llama2-lawyer-lora")
tokenizer.save_pretrained("llama2-lawyer-lora")