# ü¶ô Llama 3 Fine-Tuning with LoRA for Legal Reasoning

This notebook fine-tunes Llama 3 using LoRA/QLoRA for Indian legal tasks:
- Legal question answering
- Case summarization
- Argument drafting
- Legal language simplification

**Based on:** Aalap project methodology

In [None]:
# Setup
import sys
sys.path.insert(0, '..')

import torch
import json
from pathlib import Path

from src.training import LlamaLoRATrainer, LoRAConfig
from src.utils import set_seed, setup_logging

# Configuration
set_seed(42)
setup_logging(log_level="INFO")

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    print("Using Apple MPS")

## 1. Configure LoRA Parameters

In [None]:
# LoRA Configuration
config = LoRAConfig(
    # LoRA hyperparameters
    r=16,                    # Rank (8, 16, 32, 64)
    lora_alpha=32,           # Scaling factor
    lora_dropout=0.05,       # Dropout
    
    # Target all linear modules for better reasoning
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    
    # Quantization (QLoRA)
    use_4bit=True,           # Enable 4-bit for memory efficiency
    
    # Training
    learning_rate=2e-4,
    batch_size=4,
    gradient_accumulation_steps=4,  # Effective batch size = 16
    num_epochs=3,
    max_seq_length=2048,
)

print("LoRA Configuration:")
print(f"  Rank (r): {config.r}")
print(f"  Alpha: {config.lora_alpha}")
print(f"  4-bit quantization: {config.use_4bit}")
print(f"  Target modules: {len(config.target_modules)} layers")

## 2. Prepare Instruction Dataset (Aalap Format)

In [None]:
# Sample instruction data (Aalap format)
SAMPLE_INSTRUCTIONS = [
    {
        "instruction": "Summarize the key legal issues in this case.",
        "input": """The petitioner has filed a writ petition under Article 226 challenging 
        the order of termination from service. The petitioner was employed as a Grade B 
        officer and was terminated without a show cause notice. The petitioner contends 
        that the termination violates the principles of natural justice.""",
        "output": """This case involves a challenge to service termination under Article 226. 
        The key issues are: (1) Whether the termination without show cause notice violates 
        principles of natural justice, and (2) Whether the petitioner has a right to be heard 
        before adverse action."""
    },
    {
        "instruction": "Draft an argument for the petitioner based on Section 302 IPC.",
        "input": """The accused was charged under Section 302 IPC for murder. The prosecution 
        claims the accused shot the deceased during a property dispute. The defense argues 
        that the accused acted in self-defense when the deceased attacked him with a knife.""",
        "output": """Argument for the Petitioner:

1. The prosecution has failed to establish mens rea (guilty intention) beyond reasonable doubt.

2. The accused's actions fall within the ambit of private defense under Section 96-106 IPC, as:
   - The deceased was the aggressor
   - The threat to life was imminent
   - The force used was proportionate

3. In Darshan Singh v. State of Punjab (2010) 2 SCC 333, this Hon'ble Court held that the 
right of private defense extends to causing death when there is reasonable apprehension.

4. The charge should be reduced to Section 304 Part I (culpable homicide not amounting to murder)."""
    },
    {
        "instruction": "Simplify this legal text for a common person.",
        "input": """The res judicata principle operates as a bar to subsequent litigation 
        where the matter in issue has been directly and substantially in issue in a former 
        suit between the same parties and has been heard and finally decided.""",
        "output": """In simple terms: Once a court has made a final decision on a dispute 
        between two parties, those same parties cannot file another case about the same issue 
        in any court. This is to prevent endless litigation and ensure finality."""
    },
    {
        "instruction": "What are the legal grounds for bail under Section 437 CrPC?",
        "input": "",
        "output": """Under Section 437 CrPC, bail may be granted in non-bailable offences when:

1. There are reasonable grounds to believe the accused is not guilty
2. The accused is a woman, child under 16, or sick/infirm person
3. The investigation is complete and no custodial interrogation is needed
4. Maximum punishment is less than 7 years imprisonment

However, bail cannot be granted if:
- The accused is charged with an offence punishable with death or life imprisonment
- The accused has previous convictions for similar offences
- There is risk of tampering with evidence or influencing witnesses"""
    }
]

# Save sample data
sample_data_path = Path("../data/corpora/sample_instructions.jsonl")
sample_data_path.parent.mkdir(parents=True, exist_ok=True)

with open(sample_data_path, 'w', encoding='utf-8') as f:
    for item in SAMPLE_INSTRUCTIONS:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"Saved {len(SAMPLE_INSTRUCTIONS)} instruction samples to {sample_data_path}")

## 3. Initialize Trainer

In [None]:
# Initialize trainer
# Note: For Llama 3, you need access from Meta. Use a smaller model for testing.

MODEL_NAME = "meta-llama/Meta-Llama-3-8B"  # Requires access
# Alternative for testing: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

trainer = LlamaLoRATrainer(
    model_name=MODEL_NAME,
    config=config,
    output_dir="../models/legal_llama"
)

print(f"Trainer initialized for: {MODEL_NAME}")

## 4. Load Model (QLoRA)

In [None]:
# ‚ö†Ô∏è Uncomment to load (requires ~8GB GPU memory for 8B model with 4-bit)
# trainer.load_model()

print("Model loading code ready!")
print("\nMemory requirements (approximate):")
print("  - Llama 3 8B (4-bit): ~6-8 GB VRAM")
print("  - Llama 3 70B (4-bit): ~40 GB VRAM")

## 5. Prepare Dataset

In [None]:
# Prepare dataset from JSONL file
# dataset = trainer.prepare_dataset(str(sample_data_path))
# print(f"Dataset prepared with {len(dataset)} samples")

print("Dataset preparation code ready!")

## 6. Train with LoRA

In [None]:
# ‚ö†Ô∏è Uncomment to train
# metrics = trainer.train(dataset)
# print(f"Training complete! Final metrics: {metrics}")

print("Training code ready!")
print("\nExpected training time:")
print("  - 1000 samples, 3 epochs, A100: ~30 minutes")
print("  - 1000 samples, 3 epochs, RTX 3090: ~2 hours")

## 7. Save LoRA Adapter

In [None]:
# ‚ö†Ô∏è Uncomment after training
# trainer.save_model()

# Optional: Merge LoRA with base model for deployment
# trainer.merge_and_save("../models/legal_llama_merged")

print("Save code ready!")

## 8. Inference Example

In [None]:
# Example inference (after training)
def test_legal_reasoning():
    instruction = "What are the essential ingredients of Section 420 IPC (cheating)?"
    
    response = trainer.generate(
        instruction=instruction,
        input_text="",
        max_new_tokens=256,
        temperature=0.7
    )
    
    print(f"Instruction: {instruction}")
    print(f"\nResponse:\n{response}")

# ‚ö†Ô∏è Uncomment after training
# test_legal_reasoning()

print("Inference code ready!")

## 9. Using Groq for Fast Inference

For production deployment, use Groq's blazing-fast LPU inference.

In [None]:
import os

# Set your Groq API key
# os.environ["GROQ_API_KEY"] = "your-api-key-here"

def query_groq_llama(prompt: str):
    """Query Llama 3 via Groq API."""
    try:
        from groq import Groq
        
        client = Groq(api_key=os.getenv("GROQ_API_KEY"))
        
        response = client.chat.completions.create(
            model="llama3-70b-8192",
            messages=[
                {"role": "system", "content": "You are an AI legal assistant specialized in Indian law."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.3,
            max_tokens=1024
        )
        
        return response.choices[0].message.content
    except Exception as e:
        return f"Error: {e}"

# Test Groq inference
# result = query_groq_llama("Explain the concept of 'bail' in Indian criminal law.")
# print(result)

print("Groq inference code ready!")
print("Set GROQ_API_KEY environment variable to use.")

## Next Steps

1. **Get Aalap Dataset**: Download from [HuggingFace](https://huggingface.co/datasets/opennyaiorg/aalap_instruction_dataset)
2. **Fine-tune**: Run training with full dataset
3. **Evaluate**: Test on legal reasoning benchmarks
4. **Deploy**: Use vLLM or Groq for serving