<a href="https://colab.research.google.com/github/seeboontay/LLM_symptom_extraction/blob/main/LLM_symptom.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install transformers pandas  # Install Hugging Face LLMs and data tools



In [None]:
from transformers import pipeline

# Load a text-generation pipeline
symptom_extractor = pipeline(
    "text-generation",
    model="distilgpt2"  # Small and fast for testing
)

In [None]:
examples = [
    {
        "note": "Patient reports headache and nausea. Denies fever.",
        "symptoms": ["headache", "nausea"]
    },
    {
        "note": "Complains of chest pain and shortness of breath.",
        "symptoms": ["chest pain", "shortness of breath"]
    }
]

In [None]:
def build_prompt(note):
    prompt = """Extract symptoms from clinical notes. Ignore negated symptoms.
Examples:
"""
    # Add examples to the prompt
    for example in examples:
        prompt += f"Note: {example['note']}\nSymptoms: {example['symptoms']}\n"

    # Add the new note to analyze
    prompt += f"\nNote: {note}\nSymptoms:"
    return prompt

#Test the prompt
test_note = "Fatigue and dizziness x3 days. No vomiting."
print(build_prompt(test_note))

In [None]:
# Generate a response
prompt = build_prompt(test_note)
response = symptom_extractor(
    prompt,
    max_length=100,
    temperature=0.1  # Low randomness for precise output
)

# Extract the answer
output = response[0]['generated_text'].split("Symptoms:")[-1].strip()
print("Extracted Symptoms:", output)

In [None]:
# Simple symptom list (like a mini-database)
symptom_database = ["headache", "nausea", "chest pain", "shortness of breath", "fatigue", "dizziness"]

def rag_augment(note):
    # Find symptoms in the note that match the database
    matched_symptoms = [symptom for symptom in symptom_database if symptom in note.lower()]
    return f"Possible symptoms: {matched_symptoms}\nNote: {note}"

# Test RAG
rag_note = rag_augment(test_note)
print(rag_augment(test_note))

In [None]:
def extract_symptoms(note):
    # Step 1: Augment with RAG
    rag_note = rag_augment(note)

    # Step 2: Build the prompt
    prompt = build_prompt(rag_note)

    # Step 3: Generate and parse output
    response = symptom_extractor(prompt, max_length=100, temperature=0.1)
    output = response[0]['generated_text'].split("Symptoms:")[-1].strip()

    return output

# Test the full pipeline
test_note_2 = "Sore throat and runny nose. Denies fever."
print(extract_symptoms(test_note_2))