In [None]:
# Load both base model and LoRA adapter for comparison
from os.path import join

import torch
import yaml
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load configuration from config.yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

MODEL_NAME = config["base_model_name"]
adapter_dir = join(config["adapter_dir_prefix"], MODEL_NAME)
print(f"Using model: {MODEL_NAME}")

# Load the base model for inference (without LoRA)
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
    local_files_only=True,
)

# Load the base model again and add LoRA adapter
lora_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
    local_files_only=True,
)

# Load the LoRA adapter
lora_model = PeftModel.from_pretrained(
    lora_model,
    adapter_dir,
    local_files_only=True,
)

# Load tokenizer for inference
inference_tokenizer = AutoTokenizer.from_pretrained(
    adapter_dir,
    local_files_only=True,
)
if inference_tokenizer.pad_token is None:
    inference_tokenizer.pad_token = inference_tokenizer.eos_token

print("Both base model and LoRA adapter loaded successfully for comparison!")

Both base model and LoRA adapter loaded successfully for comparison!


In [2]:
def generate_medical_response_single(
    model, question, max_length=512, temperature=0.7, do_sample=True
):
    """
    Generate a medical reasoning response for a given question using a specific model.

    Args:
        model: The model to use for generation
        question (str): The medical question to answer
        max_length (int): Maximum total length (input + output)
        temperature (float): Temperature for sampling (higher = more creative)
        do_sample (bool): Whether to use sampling or greedy decoding

    Returns:
        str: The generated response with reasoning
    """
    # Format the input as a conversation
    messages = [{"role": "user", "content": question}]

    # Apply chat template
    input_text = inference_tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Tokenize the input
    inputs = inference_tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length // 2,  # Leave room for generation
    ).to(model.device)

    # Calculate max_new_tokens based on input length and desired max_length
    input_length = inputs.input_ids.shape[1]
    max_new_tokens = max_length - input_length
    max_new_tokens = max(max_new_tokens, 50)  # Ensure at least 50 new tokens

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,  # Use max_new_tokens instead of max_length
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=inference_tokenizer.pad_token_id,
            eos_token_id=inference_tokenizer.eos_token_id,
            repetition_penalty=1.1,
            top_p=0.9,
        )

    # Decode the response
    response = inference_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the generated part (remove the input prompt)
    if input_text in response:
        response = response.replace(input_text, "").strip()

    return response


def generate_medical_response_comparison(question, max_length=512, temperature=0.7, do_sample=True):
    """
    Generate medical reasoning responses from both base model and LoRA adapter for comparison.

    Args:
        question (str): The medical question to answer
        max_length (int): Maximum total length (input + output)
        temperature (float): Temperature for sampling (higher = more creative)
        do_sample (bool): Whether to use sampling or greedy decoding

    Returns:
        tuple: (base_response, lora_response) - responses from base model and LoRA adapter
    """
    print("Generating response from base model...")
    base_response = generate_medical_response_single(
        base_model, question, max_length, temperature, do_sample
    )

    print("Generating response from LoRA adapter...")
    lora_response = generate_medical_response_single(
        lora_model, question, max_length, temperature, do_sample
    )

    return base_response, lora_response


def generate_medical_response(question, max_length=512, temperature=0.7, do_sample=True):
    """
    Generate a medical reasoning response using the LoRA adapter (backward compatibility).
    """
    return generate_medical_response_single(
        lora_model, question, max_length, temperature, do_sample
    )


print("Medical response generation functions defined!")

Medical response generation functions defined!


In [3]:
# Interactive inference function with comparison
def interactive_medical_chat():
    """
    Interactive chat function for medical questions showing both base model and LoRA responses.
    Type 'quit' or 'exit' to stop.
    """
    print("=== Medical Reasoning Assistant - Model Comparison ===")
    print(
        "Ask me medical questions! I'll show responses from both the base model and LoRA adapter."
    )
    print("Type 'quit' or 'exit' to stop.\n")

    while True:
        try:
            # Get user input
            question = input("You: ").strip()

            # Check for exit commands
            if question.lower() in ["quit", "exit", "q"]:
                print("Goodbye!")
                break

            if not question:
                print("Please enter a question.")
                continue

            print("\nProcessing...")

            # Generate responses from both models
            base_response, lora_response = generate_medical_response_comparison(question)

            print(f"\n{'='*60}")
            print("ðŸ”µ BASE MODEL RESPONSE:")
            print("-" * 30)
            print(f"{base_response}")
            print(f"\n{'='*60}")
            print("ðŸŸ  LORA ADAPTER RESPONSE:")
            print("-" * 30)
            print(f"{lora_response}")
            print(f"\n{'='*60}")

        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except Exception as e:
            print(f"Error: {e}")
            print("Please try again.")


print("Interactive medical chat function with comparison defined!")

Interactive medical chat function with comparison defined!


In [4]:
# Test the inference with sample questions - showing both models
test_questions = [
    "What are the common symptoms of diabetes?",
    "How should I treat a patient with hypertension?",
    "What diagnostic tests should be ordered for chest pain?",
    "Explain the mechanism of action of ACE inhibitors.",
    "What are the contraindications for aspirin therapy?",
]

print("=== Testing Medical Reasoning Model - Base vs LoRA Comparison ===\n")

for i, question in enumerate(test_questions, 1):
    print(f"{'='*80}")
    print(f"TEST QUESTION {i}: {question}")
    print(f"{'='*80}")

    try:
        base_response, lora_response = generate_medical_response_comparison(
            question, max_length=400, temperature=0.7
        )

        print("ðŸ”µ BASE MODEL RESPONSE:")
        print("-" * 40)
        print(f"{base_response}")
        print()

        print("ðŸŸ  LORA ADAPTER RESPONSE:")
        print("-" * 40)
        print(f"{lora_response}")
        print()

    except Exception as e:
        print(f"Error generating response: {e}")

    print(f"{'='*80}\n")

=== Testing Medical Reasoning Model - Base vs LoRA Comparison ===

TEST QUESTION 1: What are the common symptoms of diabetes?
Generating response from base model...
Generating response from LoRA adapter...
ðŸ”µ BASE MODEL RESPONSE:
----------------------------------------
user
What are the common symptoms of diabetes?
assistant
Diabetes is a group of metabolic disorders characterized by high blood sugar levels, which can lead to various complications if left unmanaged. The common symptoms of diabetes vary from person to person and may include:

1. **Fatigue**: Feeling tired or sluggish, even after getting enough rest.
2. **Increased thirst and urination**: Increased urine production, accompanied by dry mouth and increased hunger.
3. **Blurred vision**: Blurring of vision due to low blood sugar levels.
4. **Slow healing of cuts and wounds**: Slow healing of cuts and wound infections.
5. **Dizziness and lightheadedness**: Feeling dizzy or lightheaded when standing up quickly.
6. **Nausea

In [None]:
# Start interactive chat (uncomment to use)
# interactive_medical_chat()

# Alternative: Direct question answering with comparison
print("\n=== Direct Question Answering with Comparison ===")
print("You can now compare responses from both models:")
print(
    "Example: "
    "base_resp, lora_resp = "
    "generate_medical_response_comparison('What causes high blood pressure?')"
)
print("Then: print('Base:', base_resp)")
print("      print('LoRA:', lora_resp)")
print("\nOr for single LoRA response (backward compatibility):")
print("response = generate_medical_response('What causes high blood pressure?')")
print("\nRun interactive_medical_chat() for a continuous conversation with comparison.")


=== Direct Question Answering ===
You can also ask questions directly:
Example: response = generate_medical_response('What causes high blood pressure?')
Then: print(response)

Or run interactive_medical_chat() for a continuous conversation.


In [None]:
# Quick demo of base vs LoRA comparison
demo_question = "What are the main symptoms of Type 2 diabetes?"

print("=== DEMO: Base Model vs LoRA Adapter Comparison ===")
print(f"Question: {demo_question}\n")

base_response, lora_response = generate_medical_response_comparison(
    demo_question, max_length=300, temperature=0.7
)

print("ðŸ”µ BASE MODEL:")
print("-" * 20)
print(base_response)
print()

print("ðŸŸ  LORA ADAPTER:")
print("-" * 20)
print(lora_response)
print("\nNotice the differences in medical reasoning depth and structure!")

In [5]:
interactive_medical_chat()

=== Medical Reasoning Assistant ===
Ask me medical questions! Type 'quit' or 'exit' to stop.


Thinking...

Assistant: user
what are the 9 signs and symptoms of major depression?
assistant
Major depression is a serious mental health condition that can cause feelings of sadness, hopelessness, and loss of interest in activities. While there is no single test or symptom to diagnose major depression, certain signs and symptoms may indicate the presence of this condition. Here are the nine common signs and symptoms of major depression:

**1. Persistent Depressive Disorder (PD) Symptoms:**
	* Feeling sad, empty, or hopeless most of the time
	* Loss of interest in activities that were once enjoyable
	* Changes in appetite or sleep patterns
	* Fatigue or lethargy
	* Difficulty concentrating or making decisions
	* Irritability or mood swings
	* Feelings of guilt, worthlessness, or helplessness
	* Withdrawal from social interactions or relationships

**2. Recurring Feelings of Sadness or Grief:*