# Recursive Diagnosis Extraction from ICD-10-CM Tabular XML

This notebook extracts all diagnoses recursively from the ICD-10-CM tabular XML file, preserving the parent-child hierarchy.

In [None]:
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import List, Dict, Any
import json
import csv
import re
from collections import Counter, defaultdict
import pandas as pd

In [None]:
# Configuration
DATA = Path("../data/icd10cm-table-and-index-2026")
xml_path = DATA / "icd10cm-tabular-2026.xml"

assert xml_path.exists(), f"File not found: {xml_path}"

In [None]:
def clean_text(text: str) -> str:
    """Clean whitespace from text."""
    return re.sub(r"\s+", " ", text or "").strip()


def extract_notes(diag_elem: ET.Element, note_type: str) -> List[str]:
    """Extract notes of a specific type from a diagnosis element."""
    notes = []
    for note_elem in diag_elem.findall(f".//{note_type}"):
        for note in note_elem.findall(".//note"):
            if note.text:
                notes.append(clean_text(note.text))
    return notes


def extract_inclusion_terms(diag_elem: ET.Element) -> List[str]:
    """Extract inclusion terms from a diagnosis element."""
    terms = []
    for inclusion in diag_elem.findall(".//inclusionTerm"):
        for note in inclusion.findall(".//note"):
            if note.text:
                terms.append(clean_text(note.text))
    return terms

In [None]:
def extract_diagnosis_recursive(diag_elem: ET.Element, parent_code: str = None,
                                level: int = 0) -> List[Dict[str, Any]]:
    """
    Recursively extract a diagnosis and all its children.

    Args:
        diag_elem: The <diag> XML element
        parent_code: The code of the parent diagnosis
        level: Current depth level in the hierarchy

    Returns:
        List of diagnosis dictionaries with all metadata
    """
    diagnoses = []

    # Extract current diagnosis information
    code = clean_text(diag_elem.findtext(".//name"))
    desc = clean_text(diag_elem.findtext(".//desc"))

    if not code:
        # If no code, skip but still process children
        for child_diag in diag_elem.findall("./diag"):
            diagnoses.extend(extract_diagnosis_recursive(child_diag, parent_code, level))
        return diagnoses

    # Build the diagnosis record
    diagnosis = {
        "code": code,
        "description": desc,
        "parent_code": parent_code,
        "level": level,
        "has_children": False,  # Will update if children found
        "inclusion_terms": extract_inclusion_terms(diag_elem),
        "includes": extract_notes(diag_elem, "includes"),
        "excludes1": extract_notes(diag_elem, "excludes1"),
        "excludes2": extract_notes(diag_elem, "excludes2"),
        "code_first": extract_notes(diag_elem, "codeFirst"),
        "use_additional_code": extract_notes(diag_elem, "useAdditionalCode"),
        "code_also": extract_notes(diag_elem, "codeAlso"),
    }

    # Find direct child <diag> elements (not nested deeper)
    child_diags = diag_elem.findall("./diag")

    if child_diags:
        diagnosis["has_children"] = True
        diagnosis["num_children"] = len(child_diags)
        diagnosis["is_billable"] = False  # Parent codes are typically not billable
    else:
        diagnosis["num_children"] = 0
        diagnosis["is_billable"] = True  # Leaf codes are typically billable

    diagnoses.append(diagnosis)

    # Recursively process children
    for child_diag in child_diags:
        child_diagnoses = extract_diagnosis_recursive(child_diag, code, level + 1)
        diagnoses.extend(child_diagnoses)

    return diagnoses

In [None]:
def extract_all_diagnoses(xml_path: Path) -> List[Dict[str, Any]]:
    """
    Extract all diagnoses from the ICD-10-CM tabular XML file.

    Args:
        xml_path: Path to the icd10cm-tabular XML file

    Returns:
        List of all diagnoses with metadata
    """
    tree = ET.parse(xml_path)
    root = tree.getroot()

    all_diagnoses = []

    # Iterate through chapters and sections
    for chapter in root.findall(".//chapter"):
        chapter_name = clean_text(chapter.findtext(".//name"))
        chapter_desc = clean_text(chapter.findtext(".//desc"))

        # Process all sections in the chapter
        for section in chapter.findall(".//section"):
            section_id = section.get("id", "")
            section_desc = clean_text(section.findtext(".//desc"))

            # Process all top-level diagnoses in the section
            for diag in section.findall("./diag"):
                diagnoses = extract_diagnosis_recursive(diag, parent_code=None, level=0)

                # Add chapter and section context to each diagnosis
                for diagnosis in diagnoses:
                    diagnosis["chapter"] = chapter_name
                    diagnosis["chapter_desc"] = chapter_desc
                    diagnosis["section_id"] = section_id
                    diagnosis["section_desc"] = section_desc

                all_diagnoses.extend(diagnoses)

    return all_diagnoses

In [None]:
# Extract all diagnoses
print(f"Extracting diagnoses from {xml_path}...")
diagnoses = extract_all_diagnoses(xml_path)
print(f"Extracted {len(diagnoses)} diagnoses")

In [None]:
# Display sample of the hierarchy
print("\nSample diagnosis tree:\n")
for i, diag in enumerate(diagnoses[:50]):
    indent = "  " * diag["level"]
    billable = "✓" if diag["is_billable"] else "○"
    children_info = f" ({diag['num_children']} children)" if diag["has_children"] else ""
    print(f"{indent}{billable} {diag['code']}: {diag['description']}{children_info}")

In [None]:
# Statistics
print("\n" + "="*60)
print("STATISTICS")
print("="*60)
print(f"Total diagnoses: {len(diagnoses)}")
print(f"Billable codes: {sum(1 for d in diagnoses if d['is_billable'])}")
print(f"Parent codes: {sum(1 for d in diagnoses if d['has_children'])}")
print(f"Max depth: {max(d['level'] for d in diagnoses)}")

# Level distribution
level_counts = Counter(d['level'] for d in diagnoses)
print("\nDiagnoses by level:")
for level in sorted(level_counts.keys()):
    print(f"  Level {level}: {level_counts[level]}")

In [None]:
# Top parent codes by number of children
print("\nTop parent codes by number of children:\n")
parents = [d for d in diagnoses if d['has_children']]
parents.sort(key=lambda x: x['num_children'], reverse=True)
for parent in parents[:15]:
    print(f"{parent['code']:8} {parent['description']:60} ({parent['num_children']} children)")

In [None]:
# Convert to DataFrame for easier analysis
df = pd.DataFrame(diagnoses)
print(f"\nDataFrame shape: {df.shape}")
df.head(20)

In [None]:
# Example: Find all children of a specific code
def get_children(parent_code: str, diagnoses: List[Dict]) -> List[Dict]:
    """Get all direct children of a diagnosis code."""
    return [d for d in diagnoses if d['parent_code'] == parent_code]

def get_all_descendants(parent_code: str, diagnoses: List[Dict]) -> List[Dict]:
    """Get all descendants (children, grandchildren, etc.) of a diagnosis code."""
    descendants = []
    children = get_children(parent_code, diagnoses)
    descendants.extend(children)
    for child in children:
        descendants.extend(get_all_descendants(child['code'], diagnoses))
    return descendants

# Example usage
example_code = "A01.0"
children = get_all_descendants(example_code, diagnoses)
print(f"\nAll descendants of {example_code}:")
for child in children:
    indent = "  " * (child['level'] - 1)
    print(f"{indent}{child['code']}: {child['description']}")

In [None]:
# Save to JSON
output_json = DATA / "diagnoses_recursive.json"
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(diagnoses, f, indent=2, ensure_ascii=False)
print(f"Saved all diagnoses to {output_json}")

In [None]:
# Save to CSV
csv_path = DATA / "diagnoses_recursive.csv"
df_export = df[['code', 'description', 'parent_code', 'level', 'is_billable', 
                'num_children', 'chapter', 'chapter_desc', 'section_id', 'section_desc']]
df_export.to_csv(csv_path, index=False)
print(f"Saved CSV to {csv_path}")

In [None]:
# Build a parent-child relationship graph
print("\nBuilding parent-child relationship graph...")
parent_child_edges = []
for diag in diagnoses:
    if diag['parent_code']:
        parent_child_edges.append({
            'parent': diag['parent_code'],
            'child': diag['code'],
            'child_desc': diag['description'],
            'level': diag['level']
        })

edges_df = pd.DataFrame(parent_child_edges)
print(f"Total parent-child relationships: {len(edges_df)}")
edges_df.head(20)

In [None]:
# Save parent-child edges
edges_csv = DATA / "diagnosis_hierarchy_edges.csv"
edges_df.to_csv(edges_csv, index=False)
print(f"Saved hierarchy edges to {edges_csv}")