# COVID-19 Detection: Named Entity Recognition Pipeline

This notebook demonstrates the first stage of our COVID-19 detection pipeline: Named Entity Recognition (NER) for extracting relevant medical entities from unstructured text.

## Pipeline Overview

Our COVID-19 detection pipeline consists of two major stages:

1. **Named Entity Recognition (NER)**: Extract medical entities from text
2. **Classification**: Determine if the text indicates COVID-19 or another condition

This notebook focuses on the NER stage.

## 1. Setup and Imports

In [None]:
import os
import re
import json
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict, Counter

# Import our custom modules
import sys
sys.path.append('..')
from src.data_collection import generate_synthetic_clinical_note
from src.ner_extraction import extract_entities_from_text, format_entities_for_bert

# Create output directory
os.makedirs('../output', exist_ok=True)

## 2. Generate Sample Data

For demonstration purposes, we'll generate a set of synthetic clinical notes, some indicating COVID-19 and others indicating other conditions.

In [None]:
# If there are import errors, use this simplified function instead
def generate_synthetic_note(has_covid=True):
    """Generate a synthetic clinical note."""
    if has_covid:
        subjective = "Patient is a 45-year-old male who presents with fever, dry cough, and fatigue for the past 3 days. Patient also reports loss of taste and smell since yesterday."
        objective = "Vitals: Temp 38.5°C, HR 95, BP 128/82, RR 18, O2 Sat 94% on room air. Physical exam reveals mild respiratory distress. Lungs with scattered rhonchi bilaterally. No rales or wheezes."
        assessment = "Assessment: Clinical presentation consistent with COVID-19 infection."
        plan = "Plan: COVID-19 PCR test ordered. Patient advised to self-isolate pending results. Symptomatic treatment with acetaminophen for fever. Follow up in 2-3 days."
    else:
        subjective = "Patient is a 34-year-old female with sore throat, nasal congestion, and productive cough with green sputum for 5 days. No fever or shortness of breath reported."
        objective = "Vitals: Temp 37.2°C, HR 72, BP 118/76, RR 16, O2 Sat 98% on room air. Physical exam shows erythematous pharynx with tonsillar exudate. No respiratory distress."
        assessment = "Assessment: Acute bacterial pharyngitis, likely streptococcal in origin."
        plan = "Plan: Rapid strep test performed and positive. Prescribed amoxicillin 500mg TID for 10 days. Symptomatic treatment with acetaminophen and warm salt water gargles."
    
    note = f"{subjective}\n\n{objective}\n\n{assessment} {plan}"
    return note

In [None]:
# Generate 10 synthetic notes (5 COVID-19, 5 non-COVID)
try:
    # Try to use our module function
    covid_notes = [generate_synthetic_clinical_note(has_covid=True) for _ in range(5)]
    non_covid_notes = [generate_synthetic_clinical_note(has_covid=False) for _ in range(5)]
except Exception as e:
    print(f"Using simplified function due to error: {e}")
    # Fall back to simplified function
    covid_notes = [generate_synthetic_note(has_covid=True) for _ in range(5)]
    non_covid_notes = [generate_synthetic_note(has_covid=False) for _ in range(5)]

# Combine into a dataset
notes = covid_notes + non_covid_notes
labels = [1] * 5 + [0] * 5  # 1 for COVID, 0 for non-COVID

# Create a dataframe
df = pd.DataFrame({
    'note_id': [f'note_{i}' for i in range(10)],
    'text': notes,
    'has_covid': labels
})

# Display a sample note from each class
print("=== COVID-19 CLINICAL NOTE ===\n")
print(df[df['has_covid'] == 1]['text'].iloc[0])
print("\n\n=== NON-COVID CLINICAL NOTE ===\n")
print(df[df['has_covid'] == 0]['text'].iloc[0])

## 3. Rule-Based NER Extraction

We'll start with a simple rule-based approach to extract medical entities from the notes.

In [None]:
# Function for rule-based entity extraction
def extract_entities_with_rules(text):
    """Extract entities using simple rule-based patterns."""
    # COVID symptoms
    COVID_SYMPTOMS = [
        "fever", "cough", "shortness of breath", "difficulty breathing", 
        "fatigue", "muscle pain", "body ache", "headache", "loss of taste",
        "loss of smell", "sore throat", "congestion", "runny nose", "nausea",
        "vomiting", "diarrhea", "chills"
    ]
    
    # Time expressions
    TIME_EXPRESSIONS = [
        "days ago", "weeks ago", "yesterday", "last week", "since", "for the past",
        "hours", "days", "weeks", "months", "began", "started", "onset"
    ]
    
    # Severity indicators
    SEVERITY_INDICATORS = [
        "mild", "moderate", "severe", "slight", "significant", "extreme",
        "worsening", "improving", "persistent", "intermittent", "constant"
    ]
    
    entities = {
        "SYMPTOM": [],
        "TIME": [],
        "SEVERITY": []
    }
    
    # Extract symptoms
    for symptom in COVID_SYMPTOMS:
        pattern = re.compile(r'\b({})\b'.format(re.escape(symptom)), re.IGNORECASE)
        for match in pattern.finditer(text):
            entities["SYMPTOM"].append({
                "text": match.group(0),
                "start": match.start(),
                "end": match.end()
            })
    
    # Extract time expressions
    for time_expr in TIME_EXPRESSIONS:
        pattern = re.compile(r'([\w\s]+\s{})'.format(re.escape(time_expr)), re.IGNORECASE)
        for match in pattern.finditer(text):
            entities["TIME"].append({
                "text": match.group(0),
                "start": match.start(),
                "end": match.end()
            })
    
    # Extract severity indicators
    for severity in SEVERITY_INDICATORS:
        pattern = re.compile(r'\b({})\s+\w+'.format(re.escape(severity)), re.IGNORECASE)
        for match in pattern.finditer(text):
            entities["SEVERITY"].append({
                "text": match.group(0),
                "start": match.start(),
                "end": match.end()
            })
    
    return entities

In [None]:
# Extract entities from all notes
try:
    # Try to use our module function
    extracted_entities = [extract_entities_from_text(note, method="rule") for note in notes]
except Exception as e:
    print(f"Using simplified function due to error: {e}")
    # Fall back to simplified function
    extracted_entities = [extract_entities_with_rules(note) for note in notes]

# Add entities to dataframe
df['entities'] = extracted_entities

# Display extracted entities for a COVID-19 note
covid_note_idx = df[df['has_covid'] == 1].index[0]
covid_entities = df.loc[covid_note_idx, 'entities']

print("Entities extracted from COVID-19 note:")
for entity_type, entities in covid_entities.items():
    print(f"\n{entity_type}:")
    for entity in entities:
        print(f"  - {entity['text']}")

## 4. Entity Analysis

Now let's analyze the extracted entities to see if we can identify patterns that distinguish COVID-19 from other conditions.

In [None]:
# Count symptoms in COVID vs non-COVID notes
covid_symptoms = Counter()
non_covid_symptoms = Counter()

for i, row in df.iterrows():
    symptoms = [entity['text'].lower() for entity in row['entities']['SYMPTOM']]
    if row['has_covid'] == 1:
        covid_symptoms.update(symptoms)
    else:
        non_covid_symptoms.update(symptoms)

# Display symptom counts
print("COVID-19 symptoms:")
for symptom, count in covid_symptoms.most_common():
    print(f"  - {symptom}: {count}")

print("\nNon-COVID symptoms:")
for symptom, count in non_covid_symptoms.most_common():
    print(f"  - {symptom}: {count}")

In [None]:
# Visualize symptom distribution
plt.figure(figsize=(12, 6))

# Combine symptoms and get counts
all_symptoms = set(covid_symptoms.keys()) | set(non_covid_symptoms.keys())
symptom_names = list(all_symptoms)
covid_counts = [covid_symptoms.get(s, 0) for s in symptom_names]
non_covid_counts = [non_covid_symptoms.get(s, 0) for s in symptom_names]

# Sort by total count
symptom_order = sorted(range(len(symptom_names)), 
                       key=lambda i: covid_counts[i] + non_covid_counts[i], 
                       reverse=True)
symptom_names = [symptom_names[i] for i in symptom_order]
covid_counts = [covid_counts[i] for i in symptom_order]
non_covid_counts = [non_covid_counts[i] for i in symptom_order]

# Create plot
x = range(len(symptom_names))
plt.bar([i - 0.2 for i in x], covid_counts, width=0.4, label='COVID-19', color='red', alpha=0.7)
plt.bar([i + 0.2 for i in x], non_covid_counts, width=0.4, label='Non-COVID', color='blue', alpha=0.7)

plt.xticks(x, symptom_names, rotation=45, ha='right')
plt.xlabel('Symptom')
plt.ylabel('Count')
plt.title('Symptom Distribution: COVID-19 vs Non-COVID')
plt.legend()
plt.tight_layout()
plt.savefig('../output/symptom_distribution.png')
plt.show()

## 5. Prepare Structured Data for Transformer Model

The next step in our pipeline is to format the extracted entities for input to our transformer model.

In [None]:
# Format entities for transformer model
def format_for_transformer(text, entities):
    """Format extracted entities for transformer model input."""
    # Flatten all entities into a single list
    all_entities = []
    for entity_type, entity_list in entities.items():
        for entity in entity_list:
            all_entities.append({
                "text": entity["text"],
                "type": entity_type,
                "start": entity["start"],
                "end": entity["end"]
            })
    
    # Sort entities by start position
    all_entities.sort(key=lambda x: x["start"])
    
    # Create a list of entity mentions with their types
    entity_mentions = [f"{e['text']} [{e['type']}]" for e in all_entities]
    
    # Create the formatted input for transformer
    formatted_input = {
        "original_text": text,
        "entity_count": len(all_entities),
        "entities": all_entities,
        "formatted_text": " ".join(entity_mentions)
    }
    
    return formatted_input

In [None]:
# Format all notes for transformer input
try:
    # Try to use our module function
    transformer_inputs = [format_entities_for_bert(note, entities) 
                         for note, entities in zip(notes, extracted_entities)]
except Exception as e:
    print(f"Using simplified function due to error: {e}")
    # Fall back to simplified function
    transformer_inputs = [format_for_transformer(note, entities) 
                         for note, entities in zip(notes, extracted_entities)]

# Add formatted inputs to dataframe
df['transformer_input'] = transformer_inputs

# Display a formatted example
print("Formatted input for transformer model:")
print(df.loc[covid_note_idx, 'transformer_input']['formatted_text'])

## 6. Summary and Next Steps

In this notebook, we've demonstrated the Named Entity Recognition (NER) stage of our COVID-19 detection pipeline. We've:

1. Generated synthetic clinical notes (COVID-19 and non-COVID)
2. Extracted medical entities (symptoms, time expressions, severity) using rule-based NER
3. Analyzed the distributions of these entities across COVID and non-COVID notes
4. Formatted the extracted entities for input to a transformer model

The next steps in our pipeline are to:

1. Train a transformer model to classify notes as COVID-19 or non-COVID based on the extracted entities
2. Evaluate the model's performance on held-out test data
3. Develop a visualization of how the model makes its predictions

These steps will be covered in the next notebook: "05_transformer_classification.ipynb".

In [None]:
# Save the processed data for use in the next notebook
df.to_csv('../output/ner_processed_notes.csv', index=False)

# Save the full results (including complex objects) as JSON
results = df.to_dict(orient='records')
with open('../output/ner_results.json', 'w') as f:
    json.dump(results, f, indent=2)

print("Data saved for next stage of pipeline.")