### 1. Setup and Environment:
- Deploy Llama3.2-1B using ollama or vllm.
- Install necessary packages (transformers, datasets, scikit-learn, sentence-transformers).

In [4]:
import os
import json
import random
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors
import numpy as np
import requests

# Configuration
NUM_FEW_SHOT = 3
NUM_ENSEMBLE = 3

In [5]:
# Initialize models
# LLM_MODEL = 'meta-llama/Llama-3.2-1B'
# tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
# model = AutoModelForCausalLM.from_pretrained(LLM_MODEL)

EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
embedder = SentenceTransformer(EMBEDDING_MODEL)



In [6]:
OLLAMA_API_URL = "http://localhost:11434/api/generate"
model_name = "llama3.2:1b"
system_prompt = "You are a medical expert. Given a medical question with multiple answer choices, carefully evaluate the options and select the most accurate and evidence-based answer. You must output one choice as the answer. \n\n"

def llm_request(message):
    payload = {
        "model": model_name,
        "prompt": system_prompt + message,
        "stream": False  # Set to True for streaming responses
    }

    # Send the request to the local LLM
    response = requests.post(OLLAMA_API_URL, json=payload)
    result = response.json()['response']
    return result

### 2. Data Preparation (MedQA Benchmark):
- Load the MedQA dataset (train and test splits).
- Sample 2,000 examples from the training split for CoT generation.
- Convert test split into a standardized multiple-choice format.

In [7]:
from med_qa import MedQAExamplarSampler

In [8]:
# Initialize sampler
sampler = MedQAExamplarSampler()

# Load dataset using MedQADataset
print("Loading MedQA dataset...")
examples = sampler.load_dataset()

# Sample examples
print("\nSampling examples...")
train_set, test_set = sampler.sample_examples(examples, train_num=2000, test_num=250)
print(train_set[0].keys())

Loading MedQA dataset...
../data/med_qa/data_clean/questions/US/train.jsonl
../data/med_qa/data_clean/questions/US/test.jsonl

Sampling examples...
dict_keys(['question', 'answer', 'options', 'meta_info', 'answer_idx'])


### 3. Implementation of MedPrompt

(1) Dynamic Few-shot Learning Implementation:
- Embed the test question and the CoT examples using the embedding model.
- Use kNN to retrieve the most semantically similar examples from the generated CoT.

In [9]:
def get_few_shot_examples(dataset, test_question):
    test_embedding = embedder.encode([test_question])
    distances, indices = knn.kneighbors(test_embedding)
    few_shot_examples = ''
    for i in indices[0]:
        example = dataset[int(i)]
        few_shot_examples += (f"Question: {example['question']}\n"
                    f"Choices: {example['options']}\n"
                    f"Answer: {example['answer_idx']}. {example['answer']}\n\n")
    return few_shot_examples

In [10]:
# Generate embeddings for training examples
train_texts = [ex['question'] for ex in train_set]
train_embeddings = embedder.encode(train_texts, convert_to_tensor=True).cpu().numpy()

# Build kNN index
knn = NearestNeighbors(n_neighbors=NUM_FEW_SHOT, metric='cosine')
knn.fit(train_embeddings)

(2) Self-Generated Chain of Thought (CoT):
- Generate CoT examples using the LLM.
- Verify CoT correctness by comparing the model-generated answer with the ground truth.
- Filter out incorrect reasoning steps.

In [11]:
def generate_chain_of_thought(example):
    prompt = (
            f"Solve this medical question step by step:\n\n"
            f"1. Understanding the question:\n"
            f"   - The question asks about {example['question']}\n\n"
            f"2. Analyzing each option in the choices\n"
            f"3. Medical reasoning:\n"
            f"   - Key clinical point 1...\n"
            f"   - Key clinical point 2...\n"
            f"   - Relevant medical knowledge...\n\n"
            f"Output the correct answer."
        )
    return llm_request(prompt)


(3) Choice Shuffling and Ensemble:
- Shuffle the order of multiple-choice answers and perform inference multiple times.
- Aggregate results using majority voting for final predictions.

In [12]:
def choice_shuffling_and_voting(test_question, choices, few_shot_examples=None):
    predictions = []
    for _ in range(NUM_ENSEMBLE):
        shuffled_choices = random.sample(choices, len(choices))
        if few_shot_examples is not None:
            prompt = (f"Question: {test_question}\n"
                    f"Choices: {shuffled_choices}\n\n"
                    f"Examples: {few_shot_examples}\n\n"
                    f"Answer:")
        else:
            prompt = (f"Question: {test_question}\n"
                    f"Choices: {shuffled_choices}\n"
                    f"Answer:")
            
        answer = llm_request(prompt)
        predictions.append(answer)
    return max(set(predictions), key=predictions.count)

### 4. Evaluation:
- Evaluate on the test split of MedQA using accuracy as the primary metric.
- Compare performance with and without MedPrompt components using ablation studies.

In [21]:
correct = 0

for example in tqdm(test_set):
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    few_shot_examples = get_few_shot_examples(train_set, question)
    question_cot = generate_chain_of_thought(example)
    prediction = choice_shuffling_and_voting(question_cot, choices, few_shot_examples)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

 22%|██▏       | 56/250 [26:02<1:35:32, 29.55s/it]

### 5. Ablation Study

Component Contribution: 
- Dynamic Few-shot Learning
- Self-Generated CoT
- Choice Shuffling Ensemble

In [None]:
# Baseline
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    prompt = (f"Question: {question}\n"
                f"Choices: {choices}\n"
                f"Answer:")
    prediction = llm_request(prompt)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + Few-shot
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    few_shot_examples = get_few_shot_examples(train_set, question)
    prompt = (f"Question: {question}\n"
                f"Choices: {choices}\n"
                f"Examples: {few_shot_examples}\n"
                f"Answer:")
    prediction = llm_request(prompt)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + CoT
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    question_cot = generate_chain_of_thought(example)
    prompt = (f"Question: {question_cot}\n"
                f"Choices: {choices}\n"
                f"Answer:")
    prediction = llm_request(prompt)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + Shuffle Ensemble
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    prediction = choice_shuffling_and_voting(question, choices)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + CoT + Shuffle Ensemble
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    question_cot = generate_chain_of_thought(example)
    prediction = choice_shuffling_and_voting(question_cot, choices)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + Few-shot + Shuffle Ensemble
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    few_shot_examples = get_few_shot_examples(train_set, question)
    prediction = choice_shuffling_and_voting(question, choices, few_shot_examples)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")

In [None]:
# Baseline + CoT + Few-shot
correct = 0

for example in test_set:
    question = example['question']
    choices = example['options']
    correct_answer = example['answer']
    few_shot_examples = get_few_shot_examples(train_set, question)
    question_cot = generate_chain_of_thought(example)
    prompt = (f"Question: {question_cot}\n"
                f"Choices: {choices}\n"
                f"Examples: {few_shot_examples}\n"
                f"Answer:")
    prediction = llm_request(prompt)
    if correct_answer in prediction:
        correct += 1
accuracy = correct / len(test_set)
print(f"Test Accuracy: {accuracy:.2%}")