In [None]:
import sys
sys.path.append('..')

from src.data.preprocess import format_medical_prompt

In [None]:
# Test sample
sample = {
    'question': 'What are the symptoms of type 2 diabetes?',
    'answer': 'Common symptoms include increased thirst, frequent urination, increased hunger, fatigue, and blurred vision.'
}

In [None]:
# Format 1: Chat template
prompt_v1 = format_medical_prompt(sample)
print("=" * 50)
print("Prompt Version 1: Chat Template")
print("=" * 50)
print(prompt_v1)

In [None]:
# Format 2: Instruction-following
def format_instruction_prompt(sample):
    return f"""### Instruction:
You are a medical AI assistant. Answer the following medical question accurately and concisely.

### Question:
{sample['question']}

### Answer:
{sample['answer']}"""

prompt_v2 = format_instruction_prompt(sample)
print("\n" + "=" * 50)
print("Prompt Version 2: Instruction Format")
print("=" * 50)
print(prompt_v2)

In [None]:
# Format 3: Few-shot with context
def format_fewshot_prompt(sample, context_examples=None):
    prompt = "Below are examples of medical question-answer pairs:\n\n"
    
    if context_examples:
        for i, ex in enumerate(context_examples, 1):
            prompt += f"Example {i}:\nQ: {ex['question']}\nA: {ex['answer']}\n\n"
    
    prompt += f"Now answer this question:\nQ: {sample['question']}\nA: {sample['answer']}"
    return prompt

# Test with context
context = [
    {'question': 'What is hypertension?', 'answer': 'High blood pressure affecting blood vessels.'}
]
prompt_v3 = format_fewshot_prompt(sample, context)
print("\n" + "=" * 50)
print("Prompt Version 3: Few-Shot Format")
print("=" * 50)
print(prompt_v3)

In [None]:
# Compare token counts
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

prompts = [prompt_v1, prompt_v2, prompt_v3]
names = ["Chat", "Instruction", "Few-Shot"]

print("\n" + "=" * 50)
print("Token Count Comparison")
print("=" * 50)

for name, prompt in zip(names, prompts):
    tokens = tokenizer.encode(prompt)
    print(f"{name}: {len(tokens)} tokens")