# Inference with TRL LoRA adapter model

> NOTE
> ----
> Run the `trl_medical_reasoning_training.ipynb` first to fine-tune the model.

> Important
> ---------
> This notebook is for educational purposes only.

## Model Loading and Setup

This section loads both the base model and the LoRA-adapted model for comparison:

1. **Configuration Loading**: Reads model settings from `config.yaml`
2. **Base Model Loading**: Loads the original pre-trained model without any fine-tuning
3. **LoRA Model Loading**: Loads the same base model and applies the LoRA adapter weights
4. **Tokenizer Setup**: Configures the tokenizer for text processing

The key difference is that we load the base model twice:
- `base_model`: The original model for baseline comparisons
- `lora_model`: The same model enhanced with LoRA adapter for medical reasoning

Both models use the same tokenizer and are loaded with `bfloat16` precision for efficient GPU memory usage.

In [1]:
# Load both base model and LoRA adapter for comparison
from os.path import join

import torch
import yaml
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load configuration from config.yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

MODEL_NAME = config["base_model_name"]
print(f"Using model: {MODEL_NAME}")

adapter_dir = join(config["adapter_dir_prefix"], MODEL_NAME)
print(f"LoRA adapter directory: {adapter_dir}")

max_output_length = int(config["max_output_length"])

# Load the base model for inference (without LoRA)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
    local_files_only=True,
)

# Load the base model again and add LoRA adapter
lora_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
    local_files_only=True,
)

# Load the LoRA adapter
lora_model = PeftModel.from_pretrained(
    lora_model,
    adapter_dir,
    local_files_only=True,
)

# Load tokenizer for inference
inference_tokenizer = AutoTokenizer.from_pretrained(
    adapter_dir,
    local_files_only=True,
)
if inference_tokenizer.pad_token is None:
    inference_tokenizer.pad_token = inference_tokenizer.eos_token

print("Both base model and LoRA adapter loaded successfully for comparison!")

Using model: microsoft/Phi-4-mini-instruct
LoRA adapter directory: lora_adapter/microsoft/Phi-4-mini-instruct


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device because they were offloaded to the cpu.


Both base model and LoRA adapter loaded successfully for comparison!


## Response Generation Functions

This section defines the core inference functions for medical reasoning:

### Key Functions:

1. **`generate_medical_response_single()`**: 
   - Generates responses from a single model (base or LoRA)
   - Handles chat templating, tokenization, and response generation
   - Uses configurable parameters for temperature and sampling

2. **`generate_medical_response_comparison()`**: 
   - Generates responses from both base and LoRA models
   - Returns both responses for side-by-side comparison
   - Shows the difference in medical reasoning capabilities

3. **`generate_medical_response()`**: 
   - Convenience function that uses only the LoRA model
   - Maintains backward compatibility with existing code

The comparison approach allows you to see how the LoRA fine-tuning improves medical reasoning compared to the base model.

In [2]:
def generate_medical_response_single(
    model, question, max_new_tokens=1024, temperature=0.7, do_sample=True
):
    """
    Generate a medical reasoning response for a given question using a specific model.

    Args:
        model: The model to use for generation
        question (str): The medical question to answer
        max_new_tokens (int): Maximum number of new tokens to generate
        temperature (float): Temperature for sampling (higher = more creative)
        do_sample (bool): Whether to use sampling or greedy decoding

    Returns:
        str: The generated response with reasoning
    """
    # Format the input as a conversation with system prompt for CoT
    # Note: SmolLM3-3B requires "/think" at the end of system prompt to enable extended thinking
    messages = [
        {
            "role": "system",
            "content": "You are a medical AI assistant. "
            "When answering medical questions, use /think "
            "to show your reasoning process before providing "
            " your final answer. Structure your response as: "
            "/think [your detailed reasoning] [final answer]./think",
        },
        {"role": "user", "content": question},
    ]

    # Apply chat template (following official SmolLM3-3B example)
    input_text = inference_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Tokenize the input (following official example format)
    model_inputs = inference_tokenizer([input_text], return_tensors="pt").to(model.device)

    # Generate response
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=inference_tokenizer.pad_token_id,
            eos_token_id=inference_tokenizer.eos_token_id,
            repetition_penalty=1.1,
            top_p=0.9,
        )

    # Extract only the generated part (following official example)
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :]
    response = inference_tokenizer.decode(output_ids, skip_special_tokens=True)

    return response


def generate_medical_response_comparison(
    question, max_new_tokens=1024, temperature=0.7, do_sample=True
):
    """
    Generate medical reasoning responses from both base model and LoRA adapter for comparison.

    Args:
        question (str): The medical question to answer
        max_new_tokens (int): Maximum number of new tokens to generate
        temperature (float): Temperature for sampling (higher = more creative)
        do_sample (bool): Whether to use sampling or greedy decoding

    Returns:
        tuple: (base_response, lora_response) - responses from base model and LoRA adapter
    """
    print(f"Question: {question}")
    print("Generating response from base model...")
    base_response = generate_medical_response_single(
        base_model, question, max_new_tokens, temperature, do_sample
    )

    print("Generating response from LoRA adapter...")
    lora_response = generate_medical_response_single(
        lora_model, question, max_new_tokens, temperature, do_sample
    )

    return base_response, lora_response


def generate_medical_response(question, max_new_tokens=1024, temperature=0.7, do_sample=True):
    """
    Generate a medical reasoning response using the LoRA adapter (backward compatibility).
    """
    return generate_medical_response_single(
        lora_model, question, max_new_tokens, temperature, do_sample
    )


print("Medical response generation functions defined!")

Medical response generation functions defined!


## Interactive Medical Chat Interface

This section provides an interactive chat interface for medical questions:

### Features:
- **Real-time Comparison**: Shows responses from both base and LoRA models simultaneously
- **Interactive Input**: Type medical questions and get immediate responses
- **Visual Formatting**: Clearly distinguishes between base model (ðŸ”µ) and LoRA adapter (ðŸŸ ) responses
- **Error Handling**: Gracefully handles errors and allows you to continue chatting
- **Easy Exit**: Type 'quit', 'exit', or 'q' to stop the chat session

This is perfect for exploring how the LoRA fine-tuning affects medical reasoning quality in real-time.

In [3]:
# Interactive inference function with comparison
def interactive_medical_chat():
    """
    Interactive chat function for medical questions showing both base model and LoRA responses.
    Type 'quit' or 'exit' to stop.
    """
    print("=== Medical Reasoning Assistant - Model Comparison ===")
    print(
        "Ask me medical questions! I'll show responses from both the base model and LoRA adapter."
    )
    print("Type 'quit' or 'exit' to stop.\n")

    while True:
        try:
            # Get user input
            question = input("You: ").strip()

            # Check for exit commands
            if question.lower() in ["quit", "exit", "q"]:
                print("Goodbye!")
                break

            if not question:
                print("Please enter a question.")
                continue

            print("\nProcessing...")

            # Generate responses from both models
            base_response, lora_response = generate_medical_response_comparison(
                question, max_new_tokens=max_output_length
            )

            print(f"Question: {question}")

            print(f"\n{'='*60}")
            print("ðŸ”µ BASE MODEL RESPONSE:")
            print("-" * 30)
            print(f"{base_response}")
            print(f"\n{'='*60}")
            print("ðŸŸ  LORA ADAPTER RESPONSE:")
            print("-" * 30)
            print(f"{lora_response}")
            print(f"\n{'='*60}")

        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"Error: {e}")
            print("Please try again.")


print("Interactive medical chat function with comparison defined!")

Interactive medical chat function with comparison defined!


## Automated Testing with Sample Questions

> NOTE
> In the default notebook we only train for 1 epoch.
> This won't be sufficient to produce a fine-tuned model
> that is substantially better than the base.

This section runs a comprehensive test suite to evaluate both models:

### Test Questions Include:
- **Symptom Recognition**: "What are the common symptoms of diabetes?"
- **Treatment Planning**: "How should I treat a patient with hypertension?"
- **Diagnostic Procedures**: "What diagnostic tests should be ordered for chest pain?"
- **Pharmacology**: "Explain the mechanism of action of ACE inhibitors."
- **Clinical Guidelines**: "What are the contraindications for aspirin therapy?"

### What to Look For:
- **Depth of Reasoning**: How detailed and structured are the responses?
- **Medical Accuracy**: Are the medical facts correct and up-to-date?
- **Clinical Relevance**: Do the responses address practical clinical scenarios?
- **Consistency**: Are similar questions answered with consistent quality?

The side-by-side comparison helps you quantify the improvement from LoRA fine-tuning.

In [4]:
# Test the inference with sample questions - showing both models
test_questions = [
    "What are the common symptoms of diabetes?",
    "How should I treat a patient with hypertension?",
    "What diagnostic tests should be ordered for chest pain?",
    "Explain the mechanism of action of ACE inhibitors.",
    "What are the contraindications for aspirin therapy?",
]

print("=== Testing Medical Reasoning Model - Base vs LoRA Comparison ===\n")

for i, question in enumerate(test_questions, 1):
    print(f"{'='*80}")
    print(f"TEST QUESTION {i}: {question}")
    print(f"{'='*80}")

    try:
        base_response, lora_response = generate_medical_response_comparison(
            question, max_new_tokens=max_output_length, temperature=0.7
        )

        print("ðŸ”µ BASE MODEL RESPONSE:")
        print("-" * 40)
        print(f"{base_response}")
        print()

        print("ðŸŸ  LORA ADAPTER RESPONSE:")
        print("-" * 40)
        print(f"{lora_response}")
        print()

    except Exception as e:
        print(f"Error generating response: {e}")

    print(f"{'='*80}\n")

=== Testing Medical Reasoning Model - Base vs LoRA Comparison ===

TEST QUESTION 1: What are the common symptoms of diabetes?
Question: What are the common symptoms of diabetes?
Generating response from base model...
Generating response from LoRA adapter...
ðŸ”µ BASE MODEL RESPONSE:
----------------------------------------
/think To identify and understand the most frequent signs associated with Type 1 or Type 2 Diabetes Mellitus (DM), we must examine typical manifestations that patients may experience due to elevated blood glucose levels over time.

Type 1 DM often presents acutely in childhood but can also appear at any age; it's an autoimmune condition where insulin production is drastically reduced by pancreatic beta cells being destroyed.
In contrast, Type 2 DM generally develops gradually during adulthood and involves resistance against normal actions of bodily hormones on tissues combined with inadequate compensatory release from the pancreas.


Common symptoms for both types in

## Usage Examples and API Guide

This section demonstrates different ways to use the inference functions:

### Available Methods:

1. **Comparison Mode** (Recommended):
   ```python
   base_resp, lora_resp = generate_medical_response_comparison('Your question here')
   ```

2. **LoRA Only Mode**:
   ```python
   response = generate_medical_response('Your question here')
   ```

3. **Interactive Chat**:
   ```python
   interactive_medical_chat()
   ```

### Parameters:
- `max_new_tokens`: Controls maximum new tokens to generate (default: from config)
- `temperature`: Controls creativity (0.1 = conservative, 1.0 = creative)
- `do_sample`: Use sampling vs. greedy decoding

Choose the method that best fits your use case!

In [5]:
# Start interactive chat (uncomment to use)
# interactive_medical_chat()

# Alternative: Direct question answering with comparison
print("\n=== Direct Question Answering with Comparison ===")
print("You can now compare responses from both models:")
print(
    "Example: "
    "base_resp, lora_resp = "
    "generate_medical_response_comparison('What causes high blood pressure?')"
)
print("Then: print('Base:', base_resp)")
print("      print('LoRA:', lora_resp)")
print("\nOr for single LoRA response (backward compatibility):")
print("response = generate_medical_response('What causes high blood pressure?')")
print("\nRun interactive_medical_chat() for a continuous conversation with comparison.")


=== Direct Question Answering with Comparison ===
You can now compare responses from both models:
Example: base_resp, lora_resp = generate_medical_response_comparison('What causes high blood pressure?')
Then: print('Base:', base_resp)
      print('LoRA:', lora_resp)

Or for single LoRA response (backward compatibility):
response = generate_medical_response('What causes high blood pressure?')

Run interactive_medical_chat() for a continuous conversation with comparison.


In [6]:
# Quick demo of base vs LoRA comparison
demo_question = "What are the main symptoms of Type 2 diabetes?"

print("=== DEMO: Base Model vs LoRA Adapter Comparison ===")
print(f"Question: {demo_question}\n")

base_response, lora_response = generate_medical_response_comparison(
    demo_question, max_new_tokens=300, temperature=0.7
)

print("ðŸ”µ BASE MODEL:")
print("-" * 20)
print(base_response)
print()

print("ðŸŸ  LORA ADAPTER:")
print("-" * 20)
print(lora_response)
print("\nNotice the differences in medical reasoning depth and structure!")

=== DEMO: Base Model vs LoRA Adapter Comparison ===
Question: What are the main symptoms of Type 2 diabetes?

Question: What are the main symptoms of Type 2 diabetes?
Generating response from base model...
Generating response from LoRA adapter...
ðŸ”µ BASE MODEL:
--------------------
/think To identify and understand type 2 diabetes' primary signs or 'symptoms', we need first consider its pathophysiology - this is when cells in our body become resistant to insulin (a hormone that regulates blood sugar), leading ultimately to elevated glucose levels ('hyperglycemia'). Over time these high glucose concentrations can damage various parts of the human anatomy including nerves (neuropathy) eyes (retinopathy). Common initial indications often include increased thirst/hydration issues, frequent urination/polyuria due to kidneys trying harder to get rid of excess sugars; fatigue/stress from inefficient cellular energy production processes caused by hyperglycemia affecting muscles/nervous syste

In [7]:
interactive_medical_chat()

=== Medical Reasoning Assistant - Model Comparison ===
Ask me medical questions! I'll show responses from both the base model and LoRA adapter.
Type 'quit' or 'exit' to stop.


Processing...
Question: a 16 year old boy has a fracture in his arm and his vertebrae. he has no family history of brittle bones. give me a differential diagnosis
Generating response from base model...
Generating response from LoRA adapter...
Question: a 16 year old boy has a fracture in his arm and his vertebrae. he has no family history of brittle bones. give me a differential diagnosis

ðŸ”µ BASE MODEL RESPONSE:
------------------------------
/think To diagnose the patient's condition with multiple fractures (arm and spine) at an early age without any known genetic predisposition for weak or fragile bone disease, we need to consider various possibilities that could lead to such injuries:

1. High-energy trauma - The patient might have experienced high-impact incidents like car accidents, falls from significan