# Complete Guide to Fine-tuning LLMs for Legal Applications

## 🎯 Learning Objectives

By the end of this notebook, you will understand:

1. **What is LLM Fine-tuning?** - Core concepts and terminology
2. **Legal Data Characteristics** - Understanding domain-specific challenges
3. **Data Preparation** - Formatting data for instruction-following models
4. **Model Selection** - Choosing the right base model for legal tasks
5. **Training Process** - Step-by-step fine-tuning with practical code
6. **Evaluation** - Measuring model performance objectively
7. **Deployment** - Using your fine-tuned model for real applications

## 🚀 What You'll Build

We'll fine-tune Google's FLAN-T5 model on Indian Constitutional law Q&A data to create a specialized legal assistant that can:
- Answer questions about Indian constitutional provisions
- Understand legal terminology and context
- Provide accurate, relevant responses to legal queries

Let's start this exciting journey into legal AI! 🧑‍⚖️🤖

## 1. Understanding LLM Fine-tuning

### What is Fine-tuning?

**Fine-tuning** is like giving a smart student (pre-trained model) additional specialized training in a specific subject (legal domain).

```
Pre-trained Model → Legal Data Training → Specialized Legal Model
     (General)           (Domain)              (Expert)
```

### Why Fine-tune for Legal Applications?

1. **Specialized Vocabulary**: Legal text contains unique terminology (jurisprudence, habeas corpus, etc.)
2. **Complex Reasoning**: Legal questions often require understanding precedents and context
3. **Precision Requirements**: Legal applications demand high accuracy
4. **Domain Knowledge**: Understanding constitutional principles, legal procedures

### Types of Fine-tuning

- **Full Fine-tuning**: Update all model parameters (computationally expensive)
- **Parameter-Efficient**: Update only a subset of parameters (LoRA, AdaLoRA)
- **Instruction Tuning**: Teaching models to follow specific instruction formats

For this tutorial, we'll use **instruction tuning** which is perfect for Q&A tasks.

In [None]:
# Let's start by importing all necessary libraries
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import re
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set up plotting style
plt.style.use('default')
sns.set_palette("husl")

print("📚 Libraries imported successfully!")
print("🎯 Ready to explore legal LLM fine-tuning!")

## 2. Exploring Our Legal Dataset

Before we can fine-tune a model, we need to understand our data. Let's explore the Indian Constitutional Q&A dataset.

### Key Questions to Answer:
- How many Q&A pairs do we have?
- What's the typical length of questions and answers?
- What legal terms appear most frequently?
- Are there patterns in how questions are structured?

In [None]:
# Load the constitutional Q&A dataset
print("📖 Loading Indian Constitutional Q&A Dataset...")

with open('constitution_qa.json', 'r', encoding='utf-8') as f:
    legal_data = json.load(f)

print(f"✅ Dataset loaded successfully!")
print(f"📊 Total Q&A pairs: {len(legal_data)}")

# Display first few examples to understand the structure
print("\n🔍 Sample Q&A pairs:")
print("=" * 80)

for i, item in enumerate(legal_data[:3]):
    print(f"\n📝 Example {i+1}:")
    print(f"❓ Question: {item['question']}")
    print(f"💡 Answer: {item['answer']}")
    print("-" * 60)

In [None]:
# Convert to DataFrame for easier analysis
df = pd.DataFrame(legal_data)

print("📊 Dataset Structure Analysis")
print("=" * 40)
print(f"Dataset shape: {df.shape}")
print(f"Columns: {list(df.columns)}")

# Analyze text characteristics
df['question_length_chars'] = df['question'].str.len()
df['answer_length_chars'] = df['answer'].str.len()
df['question_words'] = df['question'].str.split().str.len()
df['answer_words'] = df['answer'].str.split().str.len()

print("\n📏 Text Length Statistics:")
stats_df = df[['question_length_chars', 'answer_length_chars', 'question_words', 'answer_words']].describe()
print(stats_df)

# Check for any data quality issues
print(f"\n🔍 Data Quality Check:")
print(f"Questions with missing data: {df['question'].isna().sum()}")
print(f"Answers with missing data: {df['answer'].isna().sum()}")
print(f"Empty questions: {(df['question'].str.strip() == '').sum()}")
print(f"Empty answers: {(df['answer'].str.strip() == '').sum()}")

In [None]:
# Create comprehensive visualizations
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('📊 Legal Dataset Analysis Dashboard', fontsize=16, fontweight='bold')

# 1. Question length distribution
axes[0, 0].hist(df['question_words'], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[0, 0].set_title('📝 Question Length Distribution', fontweight='bold')
axes[0, 0].set_xlabel('Number of words')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].axvline(df['question_words'].mean(), color='red', linestyle='--', 
                   label=f'Mean: {df["question_words"].mean():.1f}')
axes[0, 0].legend()

# 2. Answer length distribution
axes[0, 1].hist(df['answer_words'], bins=30, alpha=0.7, color='lightgreen', edgecolor='black')
axes[0, 1].set_title('💡 Answer Length Distribution', fontweight='bold')
axes[0, 1].set_xlabel('Number of words')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].axvline(df['answer_words'].mean(), color='red', linestyle='--', 
                   label=f'Mean: {df["answer_words"].mean():.1f}')
axes[0, 1].legend()

# 3. Question vs Answer length relationship
scatter = axes[0, 2].scatter(df['question_words'], df['answer_words'], alpha=0.6, c='purple')
axes[0, 2].set_title('🔗 Question vs Answer Length', fontweight='bold')
axes[0, 2].set_xlabel('Question length (words)')
axes[0, 2].set_ylabel('Answer length (words)')

# Add correlation coefficient
correlation = df['question_words'].corr(df['answer_words'])
axes[0, 2].text(0.05, 0.95, f'Correlation: {correlation:.3f}', 
                transform=axes[0, 2].transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow"))

# 4. Box plot comparison
data_for_box = [df['question_words'], df['answer_words']]
bp = axes[1, 0].boxplot(data_for_box, labels=['Questions', 'Answers'], patch_artist=True)
bp['boxes'][0].set_facecolor('skyblue')
bp['boxes'][1].set_facecolor('lightgreen')
axes[1, 0].set_title('📦 Length Distribution Comparison', fontweight='bold')
axes[1, 0].set_ylabel('Number of words')

# 5. Cumulative distribution
sorted_q = np.sort(df['question_words'])
sorted_a = np.sort(df['answer_words'])
y_q = np.arange(1, len(sorted_q) + 1) / len(sorted_q)
y_a = np.arange(1, len(sorted_a) + 1) / len(sorted_a)

axes[1, 1].plot(sorted_q, y_q, label='Questions', linewidth=2)
axes[1, 1].plot(sorted_a, y_a, label='Answers', linewidth=2)
axes[1, 1].set_title('📈 Cumulative Distribution', fontweight='bold')
axes[1, 1].set_xlabel('Number of words')
axes[1, 1].set_ylabel('Cumulative probability')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# 6. Data summary pie chart
lengths = {
    'Short Questions (<10 words)': (df['question_words'] < 10).sum(),
    'Medium Questions (10-25 words)': ((df['question_words'] >= 10) & (df['question_words'] <= 25)).sum(),
    'Long Questions (>25 words)': (df['question_words'] > 25).sum()
}

axes[1, 2].pie(lengths.values(), labels=lengths.keys(), autopct='%1.1f%%', startangle=90)
axes[1, 2].set_title('📊 Question Length Categories', fontweight='bold')

plt.tight_layout()
plt.show()

print(f"\n🎯 Key Insights:")
print(f"• Average question length: {df['question_words'].mean():.1f} words")
print(f"• Average answer length: {df['answer_words'].mean():.1f} words")
print(f"• Most questions are between {df['question_words'].quantile(0.25):.0f} and {df['question_words'].quantile(0.75):.0f} words")
print(f"• Correlation between question and answer length: {correlation:.3f}")

In [None]:
# Analyze legal terminology and patterns
print("⚖️ Legal Terminology Analysis")
print("=" * 40)

# Combine all text for analysis
all_questions = ' '.join(df['question'].tolist())
all_answers = ' '.join(df['answer'].tolist())

# Define legal terms pattern - more comprehensive
legal_terms_pattern = r'\b(?:Parliament|Constitution|State|Union|Bill|Act|Article|Schedule|Amendment|Court|Justice|Law|Right|Duty|Citizen|Government|Territory|President|Governor|Minister|Legislature|Judiciary|Executive|Council|Assembly|Election|Democracy|Republic|Federal|Fundamental|Directive|Emergency|Ordinance|Writ|Petition|Appeal|Jurisdiction|Sovereignty|Secularism|Socialism)\b'

# Extract legal terms
legal_terms_questions = re.findall(legal_terms_pattern, all_questions, re.IGNORECASE)
legal_terms_answers = re.findall(legal_terms_pattern, all_answers, re.IGNORECASE)

print("🔍 Most Common Legal Terms in Questions:")
question_counter = Counter([term.title() for term in legal_terms_questions])
for i, (term, count) in enumerate(question_counter.most_common(15), 1):
    print(f"  {i:2d}. {term:<15} : {count:4d} occurrences")

print("\n💡 Most Common Legal Terms in Answers:")
answer_counter = Counter([term.title() for term in legal_terms_answers])
for i, (term, count) in enumerate(answer_counter.most_common(15), 1):
    print(f"  {i:2d}. {term:<15} : {count:4d} occurrences")

# Analyze question patterns
print("\n❓ Question Pattern Analysis:")
question_starters = []
for question in df['question']:
    words = question.split()
    if len(words) >= 3:
        starter = ' '.join(words[:3]).lower()
        question_starters.append(starter)

pattern_counter = Counter(question_starters)
print("\n🔄 Most Common Question Patterns (first 3 words):")
for i, (pattern, count) in enumerate(pattern_counter.most_common(12), 1):
    print(f"  {i:2d}. '{pattern.title():<20}' : {count:3d} times")

# Visualize top legal terms
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Top terms in questions
top_q_terms = dict(question_counter.most_common(10))
ax1.barh(list(top_q_terms.keys()), list(top_q_terms.values()), color='skyblue')
ax1.set_title('🔍 Top Legal Terms in Questions', fontweight='bold')
ax1.set_xlabel('Frequency')

# Top terms in answers
top_a_terms = dict(answer_counter.most_common(10))
ax2.barh(list(top_a_terms.keys()), list(top_a_terms.values()), color='lightgreen')
ax2.set_title('💡 Top Legal Terms in Answers', fontweight='bold')
ax2.set_xlabel('Frequency')

plt.tight_layout()
plt.show()

## 3. Understanding the Fine-tuning Process

### Step-by-Step Breakdown

```mermaid
graph TD
    A[Raw Legal Data] --> B[Data Preprocessing]
    B --> C[Instruction Formatting]
    C --> D[Train/Val/Test Split]
    D --> E[Tokenization]
    E --> F[Model Loading]
    F --> G[Training Loop]
    G --> H[Validation]
    H --> I[Model Evaluation]
    I --> J[Deployment]
```

### Key Concepts:

1. **Instruction Formatting**: Converting Q&A pairs into instruction-following format
2. **Tokenization**: Converting text to numbers the model understands
3. **Training Loop**: Iteratively updating model weights based on legal data
4. **Validation**: Monitoring performance to prevent overfitting
5. **Evaluation**: Measuring how well the model performs on unseen data

In [None]:
# Step 1: Format data for instruction-following
print("🔧 Data Formatting for Instruction-Following")
print("=" * 50)

def format_legal_instruction(question, answer, task_type="constitutional_qa"):
    """
    Format a legal Q&A pair into instruction-following format for T5
    
    Args:
        question (str): The legal question
        answer (str): The corresponding answer
        task_type (str): Type of legal task
    
    Returns:
        dict: Formatted instruction with input and target text
    """
    # Create an instruction that clearly defines the task
    if task_type == "constitutional_qa":
        instruction = f"Answer the following question about Indian constitutional law: {question}"
    else:
        instruction = f"Answer this legal question: {question}"
    
    return {
        'input_text': instruction,
        'target_text': answer,
        'task_type': task_type,
        'original_question': question,
        'original_answer': answer
    }

# Apply formatting to our dataset
print("🔄 Formatting legal Q&A pairs...")
formatted_legal_data = []

for i, item in enumerate(legal_data):
    formatted_item = format_legal_instruction(
        item['question'], 
        item['answer'], 
        "constitutional_qa"
    )
    formatted_legal_data.append(formatted_item)

print(f"✅ Formatted {len(formatted_legal_data)} legal Q&A pairs")

# Show examples of formatted data
print("\n📝 Examples of Formatted Instructions:")
print("=" * 60)

for i, item in enumerate(formatted_legal_data[:3]):
    print(f"\n🔍 Example {i+1}:")
    print(f"📥 INPUT: {item['input_text']}")
    print(f"📤 TARGET: {item['target_text']}")
    print("-" * 60)

# Analyze input/target lengths after formatting
input_lengths = [len(item['input_text'].split()) for item in formatted_legal_data]
target_lengths = [len(item['target_text'].split()) for item in formatted_legal_data]

print(f"\n📊 Formatted Data Statistics:")
print(f"Average input length: {np.mean(input_lengths):.1f} words")
print(f"Average target length: {np.mean(target_lengths):.1f} words")
print(f"Max input length: {max(input_lengths)} words")
print(f"Max target length: {max(target_lengths)} words")

In [None]:
# Step 2: Split data into train/validation/test sets
print("📊 Data Splitting Strategy")
print("=" * 40)

# Set random seed for reproducibility
np.random.seed(42)

# Split ratios
train_ratio = 0.70  # 70% for training
val_ratio = 0.15    # 15% for validation  
test_ratio = 0.15   # 15% for final testing

print(f"📈 Split Strategy:")
print(f"  • Training:   {train_ratio*100:.0f}% ({int(len(formatted_legal_data) * train_ratio):,} examples)")
print(f"  • Validation: {val_ratio*100:.0f}% ({int(len(formatted_legal_data) * val_ratio):,} examples)")
print(f"  • Testing:    {test_ratio*100:.0f}% ({int(len(formatted_legal_data) * test_ratio):,} examples)")

# First split: separate test set
train_val_data, test_data = train_test_split(
    formatted_legal_data, 
    test_size=test_ratio, 
    random_state=42,
    shuffle=True
)

# Second split: separate validation from training
train_data, val_data = train_test_split(
    train_val_data, 
    test_size=val_ratio/(train_ratio + val_ratio),  # Adjust ratio for remaining data
    random_state=42,
    shuffle=True
)

print(f"\n✅ Data Split Complete:")
print(f"  • Training set:   {len(train_data):,} examples ({len(train_data)/len(formatted_legal_data)*100:.1f}%)")
print(f"  • Validation set: {len(val_data):,} examples ({len(val_data)/len(formatted_legal_data)*100:.1f}%)")
print(f"  • Test set:       {len(test_data):,} examples ({len(test_data)/len(formatted_legal_data)*100:.1f}%)")

# Visualize the split
plt.figure(figsize=(10, 6))
splits = ['Training', 'Validation', 'Test']
sizes = [len(train_data), len(val_data), len(test_data)]
colors = ['skyblue', 'lightgreen', 'salmon']

plt.pie(sizes, labels=splits, autopct='%1.1f%%', colors=colors, startangle=90)
plt.title('📊 Dataset Split Distribution', fontsize=14, fontweight='bold')
plt.axis('equal')
plt.show()

print(f"\n🎯 Why This Split Strategy?")
print(f"""
🏋️ Training Set (70%): Large enough to learn legal patterns and terminology
📏 Validation Set (15%): Monitor overfitting and tune hyperparameters
🧪 Test Set (15%): Unbiased evaluation of final model performance

This ensures our legal LLM generalizes well to new constitutional questions!
""")

In [None]:
# Step 3: Save prepared data for training
print("💾 Saving Prepared Legal Dataset")
print("=" * 40)

import pickle
import json

# Create data splits dictionary
legal_data_splits = {
    'train': train_data,
    'validation': val_data,
    'test': test_data,
    'metadata': {
        'total_examples': len(formatted_legal_data),
        'train_size': len(train_data),
        'val_size': len(val_data),
        'test_size': len(test_data),
        'task_type': 'constitutional_qa',
        'format_version': '1.0',
        'creation_date': '2025-08-16'
    }
}

# Save as pickle for easy loading during training
with open('legal_data_splits.pkl', 'wb') as f:
    pickle.dump(legal_data_splits, f)

# Also save as JSON for portability
with open('legal_data_splits.json', 'w', encoding='utf-8') as f:
    json.dump(legal_data_splits, f, indent=2, ensure_ascii=False)

print("✅ Data splits saved successfully!")
print("📁 Files created:")
print("  • legal_data_splits.pkl (Python pickle format)")
print("  • legal_data_splits.json (JSON format)")

# Display sample from each split
print(f"\n🔍 Sample from each split:")
print("=" * 50)

splits = [('Training', train_data), ('Validation', val_data), ('Test', test_data)]
for split_name, split_data in splits:
    sample = split_data[0]
    print(f"\n📚 {split_name} Sample:")
    print(f"❓ Input: {sample['input_text'][:100]}...")
    print(f"💡 Target: {sample['target_text'][:80]}...")

## 4. Model Selection and Architecture

### Why FLAN-T5 for Legal Applications?

**FLAN-T5** (Fine-tuned LAnguage Net - Text-to-Text Transfer Transformer) is perfect for legal Q&A because:

1. **Instruction-Tuned**: Pre-trained to follow instructions and answer questions
2. **Text-to-Text**: Can handle various legal tasks (Q&A, summarization, classification)
3. **Efficient**: Smaller models (like T5-small) work well for specialized domains
4. **Proven**: Strong performance on reasoning tasks similar to legal analysis

### Model Sizes Available:
- **T5-small**: 60M parameters (good for learning/prototyping)
- **T5-base**: 220M parameters (balanced performance/efficiency)
- **T5-large**: 770M parameters (higher quality, more resources needed)

For this tutorial, we'll use **FLAN-T5-small** as it's:
- Fast to train ⚡
- Runs on modest hardware 💻
- Good for learning concepts 📚
- Adequate for specialized legal domain 🏛️

In [None]:
# Install required libraries for fine-tuning
print("📦 Installing Required Libraries for LLM Fine-tuning")
print("=" * 55)

# Note: Run these installations in your environment
required_packages = [
    "transformers>=4.30.0",
    "datasets>=2.12.0", 
    "torch>=2.0.0",
    "accelerate>=0.20.0",
    "evaluate>=0.4.0",
    "rouge-score>=0.1.2",
    "sentencepiece>=0.1.99"
]

print("📋 Required packages for legal LLM fine-tuning:")
for i, package in enumerate(required_packages, 1):
    print(f"  {i}. {package}")

print(f"\n💡 Installation command:")
print(f"pip install {' '.join(required_packages)}")

print(f"\n🚀 Once installed, we can proceed with:")
print(f"  ✅ Loading pre-trained FLAN-T5 model")
print(f"  ✅ Setting up tokenizer for legal text")
print(f"  ✅ Configuring training parameters")
print(f"  ✅ Starting the fine-tuning process")

# Check if transformers is available
try:
    import transformers
    print(f"\n✅ Transformers library available: v{transformers.__version__}")
except ImportError:
    print(f"\n❌ Transformers library not found. Please install it first.")

try:
    import torch
    print(f"✅ PyTorch available: v{torch.__version__}")
    print(f"✅ CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"✅ GPU device: {torch.cuda.get_device_name(0)}")
except ImportError:
    print(f"❌ PyTorch not found. Please install it first.")

## 5. Complete Fine-tuning Implementation

The next section contains the complete code for fine-tuning FLAN-T5 on legal data. This includes:

1. **Model & Tokenizer Loading** 🤖
2. **Data Tokenization** 🔤  
3. **Training Configuration** ⚙️
4. **Training Loop** 🔄
5. **Evaluation** 📊
6. **Model Saving** 💾

### Key Training Parameters:

- **Learning Rate**: `5e-5` (conservative for legal precision)
- **Batch Size**: `8` (adjust based on your GPU memory)
- **Epochs**: `3` (prevent overfitting on specialized data)
- **Max Length**: `512` tokens for input, `128` for output
- **Evaluation**: ROUGE scores for text quality assessment

Ready to train your legal AI assistant? 🚀

In [None]:
# Complete Legal LLM Fine-tuning Implementation
print("🚀 Legal LLM Fine-tuning - Complete Implementation")
print("=" * 55)

# Uncomment and run this cell after installing required packages
"""
# Import required libraries for fine-tuning
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq,
    Trainer, 
    TrainingArguments,
    EarlyStoppingCallback
)
from datasets import Dataset, DatasetDict
import evaluate
import torch
import os
from datetime import datetime

# Configuration for legal LLM fine-tuning
class LegalLLMConfig:
    # Model configuration
    model_name = "google/flan-t5-small"
    
    # Data configuration
    max_input_length = 512
    max_target_length = 128
    
    # Training configuration
    output_dir = "./models/legal_constitutional_qa"
    num_epochs = 3
    batch_size = 8
    learning_rate = 5e-5
    warmup_ratio = 0.1
    weight_decay = 0.01
    
    # Evaluation configuration
    eval_strategy = "steps"
    eval_steps = 500
    logging_steps = 100
    save_steps = 500
    
    # Hardware configuration
    fp16 = torch.cuda.is_available()  # Use mixed precision if GPU available
    dataloader_num_workers = 4

config = LegalLLMConfig()

print(f"🎯 Training Configuration:")
print(f"  • Model: {config.model_name}")
print(f"  • Max input length: {config.max_input_length} tokens")
print(f"  • Max target length: {config.max_target_length} tokens")
print(f"  • Batch size: {config.batch_size}")
print(f"  • Learning rate: {config.learning_rate}")
print(f"  • Number of epochs: {config.num_epochs}")
print(f"  • Output directory: {config.output_dir}")

# Create output directory
os.makedirs(config.output_dir, exist_ok=True)
print(f"✅ Output directory created: {config.output_dir}")
"""

print("\n📝 Note: Uncomment the code above after installing required packages!")
print("This implementation provides a complete pipeline for fine-tuning on legal data.")

In [None]:
# Tokenization Function for Legal Data
print("🔤 Tokenization Strategy for Legal Text")
print("=" * 45)

# This shows how we'll tokenize the legal data
tokenization_example = """
def tokenize_legal_data(examples, tokenizer, config):
    '''
    Tokenize legal Q&A data for instruction-following format
    
    Args:
        examples: Batch of legal Q&A examples
        tokenizer: FLAN-T5 tokenizer
        config: Training configuration
    
    Returns:
        Tokenized inputs and labels
    '''
    # Tokenize inputs (questions with instruction format)
    model_inputs = tokenizer(
        examples['input_text'],
        max_length=config.max_input_length,
        truncation=True,
        padding=True,
        return_tensors="pt"
    )
    
    # Tokenize targets (answers)
    labels = tokenizer(
        examples['target_text'],
        max_length=config.max_target_length,
        truncation=True,
        padding=True,
        return_tensors="pt"
    )
    
    # Replace padding token id's of the labels by -100 
    # so they are ignored in the loss computation
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] 
        for label in labels["input_ids"]
    ]
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
"""

print("🔍 Key Tokenization Concepts:")
print("=" * 35)
print("1. 📥 Input Tokenization:")
print("   • Converts legal questions to token IDs")
print("   • Handles special tokens (instruction format)")
print("   • Truncates/pads to fixed length")

print("\n2. 📤 Target Tokenization:")
print("   • Converts legal answers to token IDs") 
print("   • Padding tokens replaced with -100 (ignored in loss)")
print("   • Ensures consistent sequence lengths")

print("\n3. 🎯 Why This Matters for Legal AI:")
print("   • Legal text often contains long, complex sentences")
print("   • Proper tokenization preserves legal terminology")
print("   • Instruction format teaches the model legal Q&A structure")

# Show example of what tokenization looks like
sample_question = "What is India according to the Union and its Territory?"
sample_answer = "India, that is Bharat, shall be a Union of States."
sample_instruction = f"Answer the following question about Indian constitutional law: {sample_question}"

print(f"\n📝 Tokenization Example:")
print(f"🔤 Raw Input: '{sample_instruction}'")
print(f"🔤 Raw Target: '{sample_answer}'")
print(f"⚡ After tokenization: Numbers that the model understands!")

In [None]:
# Training Loop and Evaluation Metrics
print("🏋️ Training Process and Evaluation")
print("=" * 40)

training_process = """
Training a Legal LLM involves several key steps:

1. 🔄 FORWARD PASS
   • Model processes legal question
   • Generates answer prediction
   • Compares with correct legal answer

2. 📊 LOSS CALCULATION  
   • CrossEntropy loss measures prediction errors
   • Higher loss = model is more confused
   • Goal: Minimize loss on legal Q&A pairs

3. ⬅️ BACKWARD PASS
   • Calculate gradients (how to improve)
   • Update model weights
   • Model gets better at legal reasoning

4. 📈 VALIDATION
   • Test on unseen legal questions
   • Monitor for overfitting
   • Early stopping if performance degrades

5. 💾 CHECKPOINTING
   • Save best model weights
   • Resume training if interrupted
   • Keep track of training progress
"""

print(training_process)

evaluation_metrics = """
📊 EVALUATION METRICS FOR LEGAL LLM

1. 🎯 ROUGE Scores (Text Quality)
   • ROUGE-1: Word overlap between prediction and reference
   • ROUGE-2: Bigram overlap (phrase matching)
   • ROUGE-L: Longest common subsequence
   
2. ⚖️ Legal-Specific Metrics
   • Constitutional Accuracy: Correct constitutional references
   • Legal Term Precision: Proper use of legal terminology
   • Factual Consistency: Accurate legal facts and precedents

3. 🔍 Qualitative Assessment
   • Human evaluation by legal experts
   • Logical reasoning quality
   • Appropriate legal citations
"""

print(evaluation_metrics)

# Sample evaluation code structure
print("💻 Evaluation Implementation Preview:")
print("=" * 40)
eval_code = '''
def evaluate_legal_model(model, eval_dataset, tokenizer):
    """Evaluate fine-tuned legal model"""
    
    # Load evaluation metrics
    rouge = evaluate.load("rouge")
    
    predictions = []
    references = []
    
    for example in eval_dataset:
        # Generate prediction
        inputs = tokenizer(example['input_text'], return_tensors="pt")
        outputs = model.generate(**inputs, max_length=128)
        prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        predictions.append(prediction)
        references.append(example['target_text'])
    
    # Calculate ROUGE scores
    rouge_scores = rouge.compute(
        predictions=predictions, 
        references=references
    )
    
    return rouge_scores, predictions, references
'''

print(eval_code)

## 6. Next Steps: Running Your Legal LLM Training

### 🚀 Ready to Train? Here's Your Action Plan:

#### Step 1: Environment Setup
```bash
# Install required packages
pip install transformers datasets torch accelerate evaluate rouge-score sentencepiece

# Verify GPU availability (optional but recommended)
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}')"
```

#### Step 2: Run the Training
1. Uncomment the training code in the cells above
2. Execute each cell in sequence
3. Monitor training progress in the output logs
4. Training will take approximately 1-3 hours on GPU

#### Step 3: Test Your Model
```python
# Load your trained model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("./models/legal_constitutional_qa")
model = AutoModelForSeq2SeqLM.from_pretrained("./models/legal_constitutional_qa")

# Test with a legal question
question = "What are the fundamental rights guaranteed by the Indian Constitution?"
input_text = f"Answer the following question about Indian constitutional law: {question}"

inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=128, do_sample=True, temperature=0.7)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Question: {question}")
print(f"AI Answer: {answer}")
```

### 🎯 Expected Outcomes:
- ✅ A fine-tuned model specialized in Indian constitutional law
- ✅ Ability to answer complex legal questions accurately  
- ✅ Understanding of legal terminology and concepts
- ✅ Ready-to-deploy legal AI assistant

### 🔄 Iterative Improvement:
1. **Collect Feedback**: Test with legal experts
2. **Data Augmentation**: Add more diverse legal Q&A pairs
3. **Hyperparameter Tuning**: Optimize learning rate, batch size
4. **Advanced Techniques**: Try LoRA, QLoRA for efficiency

## 7. Understanding What Happens During Training

### 🧠 The Learning Process

During fine-tuning, your legal LLM goes through several stages:

#### Phase 1: Initial Adaptation (Epoch 1)
- **What happens**: Model adjusts from general knowledge to legal domain
- **Loss behavior**: High initially, drops rapidly  
- **Learning focus**: Basic legal terminology and question patterns

#### Phase 2: Legal Specialization (Epoch 2)
- **What happens**: Model learns constitutional law specifics
- **Loss behavior**: Steady decrease, more gradual
- **Learning focus**: Complex legal relationships and reasoning

#### Phase 3: Fine-tuning (Epoch 3)
- **What happens**: Model polishes responses, reduces errors
- **Loss behavior**: Small improvements, convergence
- **Learning focus**: Precise legal language and edge cases

### 📊 Monitoring Your Training

```
Epoch 1/3: [████████████████] 100% 
├── Train Loss: 2.341 → 1.523
├── Eval Loss: 1.876
├── ROUGE-1: 0.342
└── Learning Rate: 5e-05

Epoch 2/3: [████████████████] 100%
├── Train Loss: 1.523 → 1.187  
├── Eval Loss: 1.634
├── ROUGE-1: 0.427
└── Learning Rate: 3.5e-05

Epoch 3/3: [████████████████] 100%
├── Train Loss: 1.187 → 1.089
├── Eval Loss: 1.598
├── ROUGE-1: 0.456
└── Training Complete! 🎉
```

### 🎯 Success Indicators:
- ✅ **Decreasing Loss**: Model is learning legal patterns
- ✅ **Stable Validation**: No overfitting on legal data
- ✅ **Improving ROUGE**: Better text quality and relevance
- ✅ **Legal Coherence**: Answers make sense legally

Congratulations! You now understand the complete process of fine-tuning LLMs for legal applications! 🎓⚖️🤖