# 🧭 Overview
This notebook processes XML-based lung function data to identify airway obstruction using morphological features and machine learning. 

**Goals:**
- Load and parse XML-format lung function test reports
- Extract relevant spirometry parameters
- Identify patterns consistent with airway obstruction
- Prepare data for ECG-linked ML classification

# 📁 Setup and Imports

### 📦 What does `from lxml import etree` do?

This line imports the `etree` module from the `lxml` library, which is a powerful and fast XML parser.

✅ **Why use `lxml.etree` instead of Python’s built-in `xml.etree.ElementTree`?**

- Much **faster** and more **robust**, especially with large or complex XML files.
- Supports **XPath** expressions, allowing flexible and efficient queries.
- Handles **namespaces**, **comments**, and **malformed XML** more gracefully.
- Often used in scientific and biomedical data processing where XML schemas are complex.

💡 In this project, we use `etree` to:
- Parse XML-based lung function test reports
- Navigate the tree-like structure to extract values like FEV1, FVC, etc.
- Convert structured XML data into tabular format for ML processing

In [None]:
from pathlib import Path
from lxml import etree
import pandas as pd
import json
import matplotlib.pyplot as plt
from datetime import datetime

# 📦 Load the Data

In [None]:
# Set up base folders
base_path = Path.cwd()
input_folder = base_path / "input"
output_folder = base_path / "output"

xml_files = list(input_folder.glob("*.xml"))
print(f"Found {len(xml_files)} .xml file(s) for extraction:")
for f in xml_files:
    print(" -", f.name)
    print(f)

# 📄 XML Parsing
### 📄 Understanding the XML File Structure

We assume you have downloaded the XML files directly from the lung function lab computer.  
These originate from **COSMED** hardware systems, using the **OMNIA** software suite for pulmonary diagnostics.  
The XML schema is proprietary to COSMED's OMNIA software. It is not a public or standardized format and may change in future software updates.  
Our current files follow a consistent structure designed to encode both subject demographics and full pulmonary test results.

Each file contains three main sections:

- **Subject**: Includes the patient's ID, name, date of birth (formatted as `YYYYMMDD000000` within the `ExtendedInfo` attribute), gender, and ethnicity. This is essential for linking test results to individuals. Data is therefore kept within the hospital's ecosystem and handled in accordance with hospital policies and the Ethics approval for this study.
  
- **Visit**: Encapsulates clinical metadata about the testing encounter such as visit date, smoking history (e.g., pack-years, cigarettes/day), comorbidities (e.g., diabetes), anthropometric measurements (height/weight), and the referring physician.

- **Test**: Nested inside each Visit, this section may contain multiple test records (e.g., Spirometry, MVV). Each Test node includes:
  - `TestType`: e.g., "Spirometry"
  - `AdditionalData`: comprising:
    - **Parameter** nodes for scalar outputs like FEV1, FVC, PEF, etc.
    - **Graph** nodes (e.g., `Graph1`) with digitized waveform data like flow-volume loops in `<Point X="..." Y="..."/>` format, along with metadata like sampling interval and axis labels.

Importantly, **multiple test attempts may be present within the same visit**—for example, repeated spirometry blows. In this notebook, we extract and save **only the best effort**, as defined by COSMED’s OMNIA software, which is consistently presented as the **first** or **only preserved** entry in the XML.

This rich XML structure allows us to extract both discrete measurements and waveform data for downstream analysis such as obstruction classification and waveform morphology learning.

### ⚙️ The XML Extraction and Processing Pipeline

This section performs the full end-to-end extraction of structured and graphical data from each XML file. The goal is to convert each patient test file into both a JSON format (for flexible structured access) and a flattened CSV row (for easy analysis and model training).

Here's what happens step-by-step:

1. **Graph Data Extraction**  
   `extract_graph()` processes flow-volume loop data inside each `Graph` element. It reads:
   - X and Y axis labels (e.g., `"V (L)"`, `"F (L/s)"`)
   - Sampling interval and point count
   - A list of 2D coordinates from `<Point X="..." Y="..."/>`, stored as a list of [x, y] pairs.

2. **Subject and Visit Metadata**  
   `parse_xml_to_dict()` extracts:
   - Patient information like ID, name, DOB, gender, ethnicity
   - Visit-level metadata: smoking status, diabetes, height/weight, physician, technician, etc.
   - An empty `Tests` list, to be populated in the next step.

3. **Test Parsing and Selection**  
   `add_tests_from_xml()` finds all tests in the visit and:
   - Captures the test type and ID
   - Appends all scalar parameters (like FEV1, FVC) into a list
   - Extracts waveform data from any `Graph` elements using `extract_graph()`

   ⚠️ Note: If multiple tests of the same type are present (e.g., multiple spirometry blows), **all are stored**, but **only the best loop is saved in the CSV**, as described below.

4. **Flattening for CSV Output**  
   `flatten_dict_to_row()` converts the structured dictionary into a flat row for DataFrame use:
   - All core metadata and scalar test results are unfolded into named columns
   - The **best flow-volume loop** is identified by matching `X = "V (L)"` and `Y = "F (L/s)"` in `Graph1`, and saved under a `FlowVolumeLoop` column

5. **Main Processing Loop**  
   For each `.xml` file:
   - It parses the XML using `lxml.etree`
   - Builds the structured data dictionary
   - Saves a full `.json` file with all nested metadata and waveform points
   - Flattens the data into a single-row `.csv` file for analysis
   - Appends the output paths to `saved_json_paths` for summary display

This modular design allows flexible reuse of the structured `.json` format while supporting tabular workflows with the `.csv` outputs.

In [None]:
def extract_graph(graph_elem):
    return {
        "X": graph_elem.attrib.get("X"),
        "Y": graph_elem.attrib.get("Y"),
        "Count": int(graph_elem.attrib.get("Count", "0")),
        "SamplingInterval": float(graph_elem.attrib.get("SamplingInterval", "0")),
        "Points": [
            [float(p.attrib["X"]), float(p.attrib["Y"])]
            for p in graph_elem.iter("Point")
            if "X" in p.attrib and "Y" in p.attrib
        ]
    }
def parse_xml_to_dict(root):
    subject = root.find("Subject")
    visit = subject.find("Visit")
    return {
        "Subject": {
            "SubjectID": subject.findtext("ID"),
            "FirstName": subject.findtext("FirstName"),
            "LastName": subject.findtext("LastName"),
            "DOB": subject.find("DayOfBirth").attrib.get("ExtendedInfo")[:8],
            "Gender": subject.find("GenderID").attrib.get("ExtendedInfo"),
            "Ethnicity": subject.find("ethnicID").attrib.get("ExtendedInfo"),
        },
        "Visit": {
            "RecordID": visit.findtext("RecordID"),
            "CreatedOn": visit.find("CreatedOn").attrib.get("ExtendedInfo")[:8],
            "Smoker": visit.findtext("Smoker"),
            "CigarettesPerDay": visit.findtext("CigDie"),
            "SmokeYears": visit.findtext("SmokeYears"),
            "SmokeWhat": visit.findtext("SmokeWhat"),
            "NonSmokeYears": visit.findtext("NonSmokeYears"),
            "Height_cm": visit.findtext("Height"),
            "Weight_kg": visit.findtext("Weight"),
            "Technician": visit.findtext("Technician"),
            "Physician": visit.findtext("Physician"),
            "ReferringPhysician": visit.findtext("ReferringPhysician"),
            "VisitReason": visit.findtext("VisitReason"),
            "Diabetes": visit.findtext("Diabetes"),
            "Tests": []
        }
    }

def add_tests_from_xml(data_dict, visit_elem):
    for test in visit_elem.findall("Test"):
        test_type_elem = test.find("TestType")
        test_type = test_type_elem.attrib.get("ExtendedInfo") if test_type_elem is not None else None
        test_id = test_type_elem.text if test_type_elem is not None else None

        test_record = {
            "TestType": test_type,
            "TestID": test_id,
            "Parameters": [],
            "Graphs": {}
        }

        additional_data = test.find("AdditionalData")
        if additional_data is not None:
            for param in additional_data.iter("Parameter"):
                test_record["Parameters"].append(dict(param.attrib))
            for graph in additional_data:
                if graph.tag.startswith("Graph"):
                    test_record["Graphs"][graph.tag] = extract_graph(graph)

        data_dict["Visit"]["Tests"].append(test_record)

def flatten_dict_to_row(data_dict):
    subject = data_dict["Subject"]
    visit = data_dict["Visit"]
    row = {
        "SubjectID": subject.get("SubjectID"),
        "FirstName": subject.get("FirstName"),
        "LastName": subject.get("LastName"),
        "DOB": subject.get("DOB"),
        "Gender": subject.get("Gender"),
        "Ethnicity": subject.get("Ethnicity"),
        "VisitRecordID": visit.get("RecordID"),
        "VisitDate": visit.get("CreatedOn"),
        "Smoker": visit.get("Smoker"),
        "CigarettesPerDay": visit.get("CigarettesPerDay"),
        "SmokeYears": visit.get("SmokeYears"),
        "SmokeWhat": visit.get("SmokeWhat"),
        "NonSmokeYears": visit.get("NonSmokeYears"),
        "Height_cm": visit.get("Height_cm"),
        "Weight_kg": visit.get("Weight_kg"),
        "Technician": visit.get("Technician"),
        "Physician": visit.get("Physician"),
        "ReferringPhysician": visit.get("ReferringPhysician"),
        "VisitReason": visit.get("VisitReason"),
        "Diabetes": visit.get("Diabetes"),
    }

    for test in visit.get("Tests", []):
        test_prefix = test.get("TestType", "Unknown").replace(" ", "_")
        for param in test.get("Parameters", []):
            name = param.get("Name", "Unnamed").replace(" ", "_")
            for k, v in param.items():
                if k != "Name":
                    col = f"{test_prefix}_{name}_{k}".replace(" ", "_")
                    row[col] = v
        for gname, gdata in test.get("Graphs", {}).items():
            if gdata.get("X") == "V (L)" and gdata.get("Y") == "F (L/s)" and gdata.get("Points"):
                row["FlowVolumeLoop"] = gdata["Points"]
                break
    return row

# ---------- MAIN LOOP ----------
saved_json_paths = []
for xml_file in xml_files:
    print(f"📂 Processing: {xml_file.name}")
    
    # Parse and build XML tree
    tree = etree.parse(str(xml_file))
    root = tree.getroot()
    visit_elem = root.find(".//Visit")

    # Build structured dict from Subject + Visit
    data_dict = parse_xml_to_dict(root)

    # Add Tests + Graphs to the same dict (in-place)
    add_tests_from_xml(data_dict, visit_elem)

    # ✅ Save JSON per patient
    json_path = output_folder / f"{xml_file.stem.replace(' ', '_')}_extracted.json"
    with open(json_path, "w", encoding="utf-8") as f:
        json.dump(data_dict, f, indent=2)
    saved_json_paths.append(json_path)
    print(f" - ✅ JSON saved to: {json_path}")

    # ✅ Flatten and save CSV per patient
    row = flatten_dict_to_row(data_dict)
    df = pd.DataFrame([row])
    csv_path = output_folder / f"{xml_file.stem.replace(' ', '_')}_extracted.csv"
    df.to_csv(csv_path, index=False)
    print(f" - ✅ CSV saved to: {csv_path}")

print("\n📦 Summary: The following JSON files were saved in the list <saved_json_paths> :")
for path in saved_json_paths:
    print(f"  • {path.name}")

# 🔍 Review, Visualisation and Classification

This section reviews the extracted JSON files for each patient test, displaying subject and test details, plotting the flow-volume curve (if available), and classifying the presence and severity of airway obstruction.

The pipeline performs the following steps:

1. **Load Extracted JSON**  
   For each previously saved `.json` file, the script loads structured patient and visit data into memory.

2. **Display Visit and Subject Metadata**  
   Using predefined field maps (`SUBJECT_FIELDS_MAP`, `VISIT_FIELDS_MAP`), the script prints demographic details (e.g., gender, smoking history, height/weight) and visit details (e.g., technician, physician, reason for test).

3. **Group Parameters by Test Type**  
   All parameters from each test are grouped into clinically meaningful categories:
   - **Spirometry**: e.g., FEV1, FVC, FEV1/FVC, PEF
   - **Lung Volumes**: e.g., TLC, RV, VC
   - **Diffusion**: e.g., DLCO, VA, KCO
   - **Other**: any unclassified metrics

4. **Display Parameters in Table Format**  
   Parameters in each group are printed in a neat, aligned table showing:
   - Raw value, units, predicted value
   - Percent predicted (% predicted), Z-score
   - LLN and ULN (Lower/Upper Limits of Normal)

5. **Extract and Plot Flow-Volume Loop**  
   If a digitised flow-volume curve is available (i.e., X-axis is `"V (L)"` and Y-axis is `"F (L/s)"`), it is plotted using matplotlib and saved as a `.png` file.

6. **Classify Obstruction Using GOLD Criteria**  
   - If the **FEV1/FVC ratio is ≥70%**, the test is classified as **"No Obstruction"**.
   - If **<70%**, the GOLD stage is determined using % predicted FEV1:
     - **GOLD Stage 1**: ≥80%
     - **GOLD Stage 2**: 50–79%
     - **GOLD Stage 3**: 30–49%
     - **GOLD Stage 4**: <30%
   - A brief note with the GOLD classification thresholds is printed alongside the result.

This section allows manual inspection and verification of the extracted data, visual confirmation of waveform quality, and automatic rule-based classification of obstructive lung disease severity.

In [None]:
# Field mappings for display
SUBJECT_FIELDS_MAP = {
    "SubjectID": ("Subject", "SubjectID"),
    "FirstName": ("Subject", "FirstName"),
    "LastName": ("Subject", "LastName"),
    "DOB": ("Subject", "DOB"),
    "Gender": ("Subject", "Gender"),
    "Ethnicity": ("Subject", "Ethnicity"),
    "Smoker": ("Visit", "Smoker"),
    "CigarettesPerDay": ("Visit", "CigarettesPerDay"),
    "SmokeYears": ("Visit", "SmokeYears"),
    "SmokeWhat": ("Visit", "SmokeWhat"),
    "NonSmokeYears": ("Visit", "NonSmokeYears"),
    "Height_cm": ("Visit", "Height_cm"),
    "Weight_kg": ("Visit", "Weight_kg"),
    "Diabetes": ("Visit", "Diabetes"),
}

VISIT_FIELDS_MAP = {
    "RecordID": ("Visit", "RecordID"),
    "CreatedOn": ("Visit", "CreatedOn"),
    "HRMax": ("Visit", "HRMax"),
    "Technician": ("Visit", "Technician"),
    "Physician": ("Visit", "Physician"),
    "ReferringPhysician": ("Visit", "ReferringPhysician"),
    "VisitReason": ("Visit", "VisitReason"),
}

# Physiological test categories
SPIROMETRY_KEYS = {
    "FVC", "FEV1", "FEV1/FVC%", "FEF25-75%", "PEF", "FEV1/VCmax%",
    "MEF25%", "MEF50%", "MEF75%"
}
LUNG_VOLUME_KEYS = {
    "TLC(Pleth)", "RV(Pleth)", "FRC(Pleth)", "RV/TLC(Pleth)",
    "VC", "ERV", "IC"
}
DIFFUSION_KEYS = {
    "DLCO unadj", "DLCO corr", "VA", "KCO", "TLC(DLCO)", "DLCO PB"
}

# Column formatting helper
def format_param_row(p):
    return "{:<18} {:<7} {:<12} {:<7} {:<8} {:<9} {:<7} {:<7}".format(
        p.get("Name", ""),
        p.get("Value", ""),
        p.get("UM", ""),
        p.get("Predicted", ""),
        p.get("PercPred", ""),
        p.get("ZScore", ""),
        p.get("LLN", ""),
        p.get("ULN", "")
    )

PARAM_HEADER = "{:<18} {:<7} {:<12} {:<7} {:<8} {:<9} {:<7} {:<7}".format(
    "Parameter", "Value", "UM", "Pred", "%Pred", "ZScore", "LLN", "ULN"
)
    
def classify_obstruction(params):
    """Determine GOLD stage based on % predicted FEV1 (assumes FEV1/FVC < 70)."""
    fev1_perc_pred = float(next(p["PercPred"] for p in params if p["Name"] == "FEV1"))
    if fev1_perc_pred >= 80:
        return "GOLD Stage 1"
    elif fev1_perc_pred >= 50:
        return "GOLD Stage 2"
    elif fev1_perc_pred >= 30:
        return "GOLD Stage 3"
    else:
        return "GOLD Stage 4"
    
    
GOLD_NOTE = """
Note:
Global Initiative for Chronic Obstructive Lung Disease (GOLD) characterises obstruction as follows:
No Obstruction if FEV1/FVC ratio of ⩾70%
GOLD staging applies if FEV1/FVC <70%, based on ppFEV1:
- GOLD Stage 1: ⩾80%
- GOLD Stage 2: 50–79%
- GOLD Stage 3: 30–49%
- GOLD Stage 4: <30%
"""


# ========================
# MAIN REVIEW LOOP
# ========================

for json_path in saved_json_paths:
    print(f"\n📝 Reviewing: {json_path.name}")

    # Load JSON content
    with open(json_path, "r") as f:
        data = json.load(f)

    lines = []

    # --- VISIT INFORMATION ---
    lines.append("=== VISIT INFORMATION ===")
    for label, (section, key) in VISIT_FIELDS_MAP.items():
        value = data.get(section, {}).get(key)
        if value is not None:
            lines.append(f"{label}: {value}")
    lines.append("")

    # --- SUBJECT INFORMATION ---
    lines.append("=== SUBJECT INFORMATION ===")
    for label, (section, key) in SUBJECT_FIELDS_MAP.items():
        value = data.get(section, {}).get(key)
        if value is not None:
            lines.append(f"{label}: {value}")
    lines.append("")

    # --- TEST PARAMETERS ---
    grouped_params = {
        "Spirometry": [],
        "Lung Volumes": [],
        "Diffusion": [],
        "Other": []
    }

    visit = data.get("Visit", {})
    flow_vol_curve = None

    for test in visit.get("Tests", []):
        for param in test.get("Parameters", []):
            name = param.get("Name", "")
            if name in SPIROMETRY_KEYS:
                grouped_params["Spirometry"].append(param)
                if name == "FEV1/FVC%":
                    fev1_fvc = float(param.get("Value", 100))
                if name == "FEV1":
                    fev1_perc_pred = float(param.get("PercPred", 100))
            elif name in LUNG_VOLUME_KEYS:
                grouped_params["Lung Volumes"].append(param)
            elif name in DIFFUSION_KEYS:
                grouped_params["Diffusion"].append(param)
            else:
                grouped_params["Other"].append(param)

        # Extract flow-volume curve if present
        for gname, gdata in test.get("Graphs", {}).items():
            if (
                gdata.get("X") == "V (L)" and
                gdata.get("Y") == "F (L/s)" and
                gdata.get("Points")
            ):
                flow_vol_curve = gdata["Points"]
                break

    # --- DISPLAY PARAMETER TABLES ---
    for section, params in grouped_params.items():
        if params:
            lines.append(f"--- {section.upper()} ---")
            lines.append(PARAM_HEADER)
            for p in params:
                lines.append(format_param_row(p))
            lines.append("")

    # Print formatted report
    print("\n".join(lines))

       # --- FLOW-VOLUME PLOT ---
    if flow_vol_curve:
        vol, flow = zip(*flow_vol_curve)
        plt.figure(figsize=(4, 3))
        plt.plot(vol, flow)
        plt.xlabel("Volume (L)")
        plt.ylabel("Flow (L/s)")
        plt.title(f"Flow-Volume Curve: {json_path.stem}")
        plt.grid(True)
        plt.tight_layout()

        # Save the plot
        plot_filename = f"{json_path.stem}_flowvol.png"
        plot_path = output_folder / plot_filename
        plt.savefig(plot_path)
        print(f"📊 Flow-volume curve saved to: {plot_path}")

        plt.show()
    else:
        print("No flow-volume curve found with X='V (L)' and Y='F (L/s)'.")
    
    # Define obstruction for this test
    if fev1_fvc > 70:
        obstruction_flag = "No Obstruction"
    else:
        obstruction_flag = classify_obstruction(grouped_params["Spirometry"])
            

    print(f"The obstruction_flag for this test has been set to: {obstruction_flag}")
    print(GOLD_NOTE)