# Simple Multi-GPU Training (No notebook_launcher)

This is the **simplest** approach for multi-GPU training in Jupyter notebooks.

## Key: Just DON'T use device_map!

When you don't specify `device_map`, the Trainer will automatically:
1. Detect all available GPUs
2. Use DataParallel or DistributedDataParallel
3. Distribute training across all GPUs

That's it! No notebook_launcher needed.

In [1]:
# Cell 1: Install packages
!pip install -q transformers datasets peft accelerate trl wandb
print("✅ Packages installed!")

✅ Packages installed!


In [2]:
# Cell 2: Imports and GPU check
import os
import torch
import wandb
import math
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig

print("="*80)
print("SYSTEM CHECK")
print("="*80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")
    for i in range(num_gpus):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
print("="*80)

2025-10-12 23:20:01.399705: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760311201.422059     278 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760311201.429261     278 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


SYSTEM CHECK
PyTorch version: 2.6.0+cu124
CUDA available: True
Number of GPUs: 2
  GPU 0: Tesla T4
  GPU 1: Tesla T4


In [None]:
# Cell 3: Configuration
import os

MODEL_NAME = "rzeraat/qwen-2-0-5b-law-lora"
DATASET_NAME = "rzeraat/law"
OUTPUT_DIR = "./pactoria-v1-simple"

# API Keys - Load from environment variables for security
WANDB_API_KEY = os.getenv('WANDB_API_KEY')
HUGGINGFACE_API_KEY = os.getenv('HUGGINGFACE_TOKEN')  # Uses HUGGINGFACE_TOKEN from environment

# W&B Configuration
WANDB_PROJECT = "uk-legal-training"
WANDB_ENABLED = True if WANDB_API_KEY else False

# LoRA config
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# Training config
BATCH_SIZE = 5
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 1e-3
NUM_EPOCHS = 3
MAX_SEQ_LENGTH = 700

# W&B run name
WANDB_RUN_NAME = f"qwen-0.5b-law-r{LORA_R}-seq{MAX_SEQ_LENGTH}-bs{BATCH_SIZE}x{GRADIENT_ACCUMULATION_STEPS}-ep{NUM_EPOCHS}"

print(f"✅ Configuration loaded")
print(f"   Model: {MODEL_NAME}")
print(f"   Dataset: {DATASET_NAME}")
print(f"   Output: {OUTPUT_DIR}")
if WANDB_ENABLED:
    print(f"   W&B: {WANDB_PROJECT}/{WANDB_RUN_NAME}")
else:
    print(f"   ⚠️  W&B disabled (WANDB_API_KEY not found in environment)")

In [4]:
# Cell 4: Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("✅ Tokenizer loaded")

Loading tokenizer...
✅ Tokenizer loaded


In [5]:
# Cell 5: Load model - NO DEVICE_MAP!
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    trust_remote_code=True,
    attn_implementation="sdpa",
)

model.gradient_checkpointing_enable()

print(f"✅ Model loaded (params: {model.num_parameters():,})")
print(f"   Device: {next(model.parameters()).device}")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading model...
✅ Model loaded (params: 494,032,768)
   Device: cpu


In [6]:
# Cell 6: Apply LoRA
print("Applying LoRA...")
lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

Applying LoRA...
trainable params: 8,798,208 || all params: 502,830,976 || trainable%: 1.7497


In [None]:
# Cell 7: Load and format dataset with train/val/test splits - ENHANCED WITH SAMPLE TYPES
def format_sample(sample):
    """
    Format a legal sample for training with support for multiple sample types.
    
    Sample Types:
    - case_analysis: IRAC methodology (Issue → Rule → Application → Conclusion)
    - educational: Structured teaching (Definition → Legal Basis → Key Elements → Examples)
    - client_interaction: Practical advice (Understanding → Legal Position → Options → Recommendation)
    - statutory_interpretation: Legislative analysis (Statutory Text → Purpose → Interpretation → Case Law)
    """
    
    # Get sample type (default to case_analysis for backward compatibility)
    sample_type = sample.get('sample_type', 'case_analysis')
    
    # Build instruction (same for all types)
    instruction = f"""### Instruction:
{sample['question']}

### Response:"""
    
    # Start response with metadata
    response = f"""
### Sample Type: {sample_type}
### Topic: {sample.get('topic', 'Corporate Law')}
### Difficulty: {sample.get('difficulty', 'intermediate')}
"""
    
    # Add reasoning section (same for all types - chain of thought)
    if 'reasoning' in sample and sample['reasoning']:
        response += f"""
### Reasoning:
{sample['reasoning']}
"""
    
    # Add answer section with type-specific formatting hints
    # These hints help the model learn different answer structures
    type_hints = {
        'case_analysis': '(IRAC: Issue → Rule → Application → Conclusion)',
        'educational': '(Teaching: Definition → Legal Basis → Key Elements → Examples)',
        'client_interaction': '(Client Advice: Understanding → Position → Options → Recommendation)',
        'statutory_interpretation': '(Statutory: Text → Purpose → Interpretation → Case Law)'
    }
    
    hint = type_hints.get(sample_type, '')
    
    response += f"""
### Answer {hint}:
{sample['answer']}"""
    
    # Add case citations if available
    if 'case_citation' in sample and sample['case_citation']:
        response += f"""

### Case Citation:
{sample['case_citation']}"""
    
    return {"text": instruction + response}

print("Loading and formatting dataset with sample type support...")
dataset = load_dataset(DATASET_NAME)
formatted_dataset = dataset.map(format_sample)

# Split: 80% train, 10% validation, 10% test
train_val_split = formatted_dataset['train'].train_test_split(test_size=0.2, seed=42)
val_test_split = train_val_split['test'].train_test_split(test_size=0.5, seed=42)

train_dataset = train_val_split['train']
val_dataset = val_test_split['train']
test_dataset = val_test_split['test']

print(f"✅ Dataset split: Train={len(train_dataset)} | Val={len(val_dataset)} | Test={len(test_dataset)}")
print(f"📊 Enhanced with sample type awareness for multi-format training")

In [8]:
# Cell 8: Training arguments with validation and W&B

# Initialize Weights & Biases
if WANDB_ENABLED:
    try:
        wandb.login(key=WANDB_API_KEY, relogin=True)
        wandb.init(
            project=WANDB_PROJECT,
            name=WANDB_RUN_NAME,
            mode="online",
            config={
                "model": MODEL_NAME,
                "dataset": DATASET_NAME,
                "lora_r": LORA_R,
                "lora_alpha": LORA_ALPHA,
                "lora_dropout": LORA_DROPOUT,
                "batch_size": BATCH_SIZE,
                "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
                "learning_rate": LEARNING_RATE,
                "num_epochs": NUM_EPOCHS,
                "max_seq_length": MAX_SEQ_LENGTH,
                "effective_batch_size": BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS,
                "num_gpus": torch.cuda.device_count(),
            }
        )
        wandb.config.update({
            "gpu_ids": list(range(torch.cuda.device_count())),
            "gpu_names": [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
        })
        print(f"✅ W&B initialized: {WANDB_PROJECT}/{WANDB_RUN_NAME}")
        print(f"📊 Track at: https://wandb.ai/{wandb.run.entity}/{WANDB_PROJECT}/runs/{wandb.run.id}")
    except Exception as e:
        print(f"⚠️  W&B initialization failed: {e}")
        WANDB_ENABLED = False

training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    fp16=True,
    logging_steps=5,
    save_strategy="steps",
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    max_grad_norm=0.3,
    weight_decay=0.01,
    report_to="wandb" if WANDB_ENABLED else "none",
    run_name=WANDB_RUN_NAME if WANDB_ENABLED else None,
    
    # Evaluation settings (note: eval_strategy not evaluation_strategy)
    eval_strategy="steps",
    eval_steps=25,
    per_device_eval_batch_size=BATCH_SIZE,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    
    # SFT-specific
    max_length=MAX_SEQ_LENGTH,
    dataset_text_field="text",
    packing=True,
)

print("✅ Training configuration created with validation")
print(f"   Evaluation every {training_args.eval_steps} steps")
print(f"   Logging every {training_args.logging_steps} steps")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrzeraat-tur[0m ([33mrzeraat-tur-elyoni[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ W&B initialized: uk-legal-training/qwen-0.5b-law-r16-seq700-bs5x1-ep2
📊 Track at: https://wandb.ai/rzeraat-tur-elyoni/uk-legal-training/runs/mjdrp2k2
✅ Training configuration created with validation
   Evaluation every 25 steps
   Logging every 5 steps


In [9]:
# Cell 9: Create Trainer with validation
print("Creating trainer...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("✅ Trainer created")
print(f"   Training on {torch.cuda.device_count()} GPU(s)")
print(f"   Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

Creating trainer...


Adding EOS to train dataset:   0%|          | 0/4112 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/4112 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/4112 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/514 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/514 [00:00<?, ? examples/s]

Packing eval dataset:   0%|          | 0/514 [00:00<?, ? examples/s]

✅ Trainer created
   Training on 2 GPU(s)
   Train: 4112 | Val: 514 | Test: 514


In [10]:
# Cell 10: Train!
print("="*80)
print("STARTING TRAINING")
print("="*80)
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Epochs: {NUM_EPOCHS}")
print(f"Batch: {BATCH_SIZE} | Grad Accum: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective Batch: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * torch.cuda.device_count()}")
print("="*80)

trainer.train()

# Show final metrics
train_loss = [x['loss'] for x in trainer.state.log_history if 'loss' in x]
eval_loss = [x['eval_loss'] for x in trainer.state.log_history if 'eval_loss' in x]

print("\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
if train_loss:
    print(f"Train Loss: {train_loss[0]:.4f} → {train_loss[-1]:.4f}")
if eval_loss:
    print(f"Val Loss: {eval_loss[-1]:.4f} | Perplexity: {math.exp(eval_loss[-1]):.2f}")
print("="*80)

# Finish W&B
if WANDB_ENABLED:
    wandb.finish()
    print("✅ W&B run finished")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


STARTING TRAINING
Train: 4112 | Val: 514 | Epochs: 2
Batch: 5 | Grad Accum: 1
Effective Batch: 10


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
25,3.4394,3.550741,1.849121,172692.0,0.605122
50,3.436,3.316838,1.700396,345204.0,0.623878
75,3.3293,3.223716,1.639677,518334.0,0.63157
100,3.1692,3.152757,1.623133,691613.0,0.639351
125,3.2535,3.095894,1.453412,864503.0,0.643813
150,3.1105,3.028745,1.52944,1036865.0,0.649877
175,3.0414,2.984836,1.506032,1209256.0,0.653402
200,2.8689,2.943163,1.481244,1382260.0,0.658197
225,2.8144,2.908906,1.435173,1555075.0,0.66071
250,2.9646,2.86607,1.455337,1727901.0,0.664671



TRAINING COMPLETED
Train Loss: 4.4536 → 2.3958
Val Loss: 2.6908 | Perplexity: 14.74


0,1
eval/entropy,█▆▆▆▄▅▄▄▄▄▁▁▂▁▁▁▁▁▁▁▁
eval/loss,█▆▅▅▄▄▃▃▃▂▃▂▂▂▂▁▁▁▁▁▁
eval/mean_token_accuracy,▁▃▃▄▄▅▅▆▆▆▆▇▇▇▇██████
eval/num_tokens,▁▁▂▂▂▃▃▃▄▄▅▅▅▆▆▆▇▇▇██
eval/runtime,▆▃▁▇▆▁▆▇▄▇█▄▅▄▆▄▄▆▄▆▇
eval/samples_per_second,▃▆█▂▃█▃▂▅▂▁▅▄▅▃▅▅▃▅▃▂
eval/steps_per_second,▃▆█▂▃█▃▂▅▃▁▆▄▅▃▆▆▃▅▃▃
train/entropy,█▇▅▅▄▅▄▅▄▄▄▄▄▄▃▃▃▃▃▂▁▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▂▂
train/epoch,▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█
train/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████

0,1
eval/entropy,1.20614
eval/loss,2.6908
eval/mean_token_accuracy,0.68567
eval/num_tokens,3627680.0
eval/runtime,58.6547
eval/samples_per_second,5.711
eval/steps_per_second,0.58
total_flos,8178347351769600.0
train/entropy,1.21562
train/epoch,2.0


✅ W&B run finished


In [12]:
# Cell 11: Save model
print("Saving model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"✅ Model saved to {OUTPUT_DIR}")

Saving model...
✅ Model saved to ./pactoria-v1-simple


In [13]:
from peft import PeftModel

print("Merging LoRA weights with base model...")

if torch.cuda.is_available():
    torch.cuda.empty_cache()

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto",
    trust_remote_code=True
)

merged_model = PeftModel.from_pretrained(base_model, OUTPUT_DIR)
merged_model = merged_model.merge_and_unload()

merged_output_dir = "./qwen-law-merged"
merged_model.save_pretrained(merged_output_dir)
tokenizer.save_pretrained(merged_output_dir)
print(f"✅ Merged model saved to {merged_output_dir}")

Merging LoRA weights with base model...
✅ Merged model saved to ./qwen-law-merged


In [None]:
from huggingface_hub import login

HUGGINGFACE_MODEL_NAME = "rzeraat/pactoria-v1"

print(f"Pushing to HuggingFace Hub: {HUGGINGFACE_MODEL_NAME}")

if HUGGINGFACE_API_KEY:
    login(token=HUGGINGFACE_API_KEY, add_to_git_credential=False)
    merged_model.push_to_hub(HUGGINGFACE_MODEL_NAME)
    tokenizer.push_to_hub(HUGGINGFACE_MODEL_NAME)
    print(f"✅ Model pushed: https://huggingface.co/{HUGGINGFACE_MODEL_NAME}")
else:
    print("❌ No HuggingFace API key configured")

In [16]:
# Cell 14: Install and Import Gradio
!pip install -q gradio
import gradio as gr
print("✅ Gradio installed and imported!")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.6/68.6 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m444.8/444.8 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m✅ Gradio installed and imported!


In [17]:
# Cell 15: Load Merged Model for Inference
print("Loading merged model for inference...")

# Load the merged model
inference_model = AutoModelForCausalLM.from_pretrained(
    merged_output_dir,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)

inference_tokenizer = AutoTokenizer.from_pretrained(
    merged_output_dir,
    trust_remote_code=True
)
inference_tokenizer.pad_token = inference_tokenizer.eos_token

print(f"✅ Inference model loaded from {merged_output_dir}")
print(f"   Device: {next(inference_model.parameters()).device}")

Loading merged model for inference...
✅ Inference model loaded from ./qwen-law-merged
   Device: cuda:0


In [None]:
# Cell 16: Create Gradio UI with Streaming Legal Q&A Interface - SAMPLE TYPE AWARE

from transformers import TextIteratorStreamer
from threading import Thread

def generate_legal_answer_stream(
    question,
    sample_type="case_analysis",
    temperature=0.7,
    max_new_tokens=512,
    top_p=0.9,
    repetition_penalty=1.1
):
    """Generate answer to legal question using fine-tuned model with streaming and sample type awareness"""
    
    # Format prompt matching the training format - include sample type hint
    type_hints = {
        'case_analysis': '(IRAC: Issue → Rule → Application → Conclusion)',
        'educational': '(Teaching: Definition → Legal Basis → Key Elements → Examples)',
        'client_interaction': '(Client Advice: Understanding → Position → Options → Recommendation)',
        'statutory_interpretation': '(Statutory: Text → Purpose → Interpretation → Case Law)'
    }
    
    hint = type_hints.get(sample_type, '')
    
    prompt = f"""### Instruction:
{question}

### Response:
### Sample Type: {sample_type}"""
    
    # Tokenize
    inputs = inference_tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_SEQ_LENGTH
    ).to(inference_model.device)
    
    # Create streamer
    streamer = TextIteratorStreamer(
        inference_tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )
    
    # Generation kwargs
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=inference_tokenizer.eos_token_id,
        eos_token_id=inference_tokenizer.eos_token_id,
        streamer=streamer,
    )
    
    # Run generation in separate thread
    thread = Thread(target=inference_model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Stream the output
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        # Extract only the response part (after "### Response:")
        if "### Response:" in partial_text:
            response = partial_text.split("### Response:")[1].strip()
        else:
            response = partial_text
        yield response
    
    thread.join()


# Sample legal questions for quick testing (with suggested sample types)
sample_questions = [
    ("What are the key duties of company directors under UK law?", "educational"),
    ("Explain the concept of consideration in contract law with a relevant case example.", "educational"),
    ("What constitutes unfair dismissal in employment law?", "client_interaction"),
    ("What is the difference between negligence and breach of statutory duty in tort law?", "educational"),
    ("How does the Corporate Manslaughter and Corporate Homicide Act 2007 apply to organizations?", "statutory_interpretation"),
]

# Create Gradio Interface with Streaming and Sample Type Selection
with gr.Blocks(theme=gr.themes.Soft(), title="UK Legal AI Assistant - Multi-Format") as demo:
    gr.Markdown("""
    # 🏛️ UK Legal AI Assistant (Multi-Format Training)
    ### Powered by Fine-tuned Qwen2-0.5B with 4 Answer Styles
    
    Choose your preferred answer format:
    - **📋 Case Analysis (IRAC)**: Problem-solving with Issue → Rule → Application → Conclusion
    - **📚 Educational**: Teaching format with Definition → Legal Basis → Key Elements → Examples
    - **💼 Client Interaction**: Practical advice with Understanding → Position → Options → Recommendation
    - **📜 Statutory Interpretation**: Legislative analysis with Text → Purpose → Interpretation → Case Law
    
    **Real-time streaming responses!** ✨
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            question_input = gr.Textbox(
                label="Legal Question",
                placeholder="Enter your UK law question here...",
                lines=3
            )
            
            # Sample Type Selector (NEW!)
            sample_type_selector = gr.Radio(
                choices=[
                    "case_analysis",
                    "educational", 
                    "client_interaction",
                    "statutory_interpretation"
                ],
                value="case_analysis",
                label="📊 Answer Format",
                info="Select the style of answer you want"
            )
            
            with gr.Accordion("⚙️ Generation Settings", open=False):
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature (creativity)",
                    info="Lower = more focused, Higher = more creative"
                )
                
                max_tokens = gr.Slider(
                    minimum=128,
                    maximum=1024,
                    value=512,
                    step=64,
                    label="Max Tokens",
                    info="Maximum length of generated response"
                )
                
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top P (nucleus sampling)",
                    info="Controls diversity of output"
                )
                
                repetition_penalty = gr.Slider(
                    minimum=1.0,
                    maximum=2.0,
                    value=1.1,
                    step=0.1,
                    label="Repetition Penalty",
                    info="Penalize repeated tokens"
                )
            
            generate_btn = gr.Button("🔍 Generate Answer (Streaming)", variant="primary", size="lg")
            
            gr.Markdown("### 📝 Sample Questions")
            sample_btns = []
            for i, (sq, stype) in enumerate(sample_questions):
                btn = gr.Button(f"{sq} [{stype}]", size="sm")
                sample_btns.append((btn, sq, stype))
        
        with gr.Column(scale=3):
            answer_output = gr.Textbox(
                label="AI Response (Streaming)",
                lines=20,
                show_copy_button=True
            )
            
            gr.Markdown("""
            ---
            **Model Info:**
            - Base: Qwen/Qwen2-0.5B
            - Fine-tuned on: rzeraat/law dataset (multi-format)
            - LoRA Rank: 32
            - Training Epochs: 3
            - Sample Types: 4 (case_analysis, educational, client_interaction, statutory_interpretation)
            
            **💡 Multi-format training** - Model adapts answer style based on your selection!
            """)
    
    # Connect the generate button with streaming + sample type
    generate_btn.click(
        fn=generate_legal_answer_stream,
        inputs=[question_input, sample_type_selector, temperature, max_tokens, top_p, repetition_penalty],
        outputs=answer_output
    )
    
    # Connect sample question buttons (with auto sample type selection)
    for btn, question, stype in sample_btns:
        btn.click(
            fn=lambda q=question, st=stype: (q, st),
            outputs=[question_input, sample_type_selector]
        )
    
    # Examples section at bottom
    gr.Examples(
        examples=[
            ["What are the requirements for a valid contract under English law?", "educational"],
            ["A client's company is facing insolvency. What should they do?", "client_interaction"],
            ["Explain piercing the corporate veil with case examples.", "case_analysis"],
            ["How does Section 172 of the Companies Act 2006 define directors' duties?", "statutory_interpretation"],
            ["What is the difference between wrongful and unfair dismissal?", "educational"],
        ],
        inputs=[question_input, sample_type_selector],
        label="💡 Example Questions with Sample Types"
    )

print("✅ Gradio UI with streaming and sample type selection created!")
print("🚀 Launching interface...")

# Launch the interface
demo.launch(
    share=True,  # Creates public link
    debug=True,
    show_error=True
)