# Foundation PLR: Data Access Tutorial

**Purpose**: Learn how to access, understand, and work with the shared research data.

---

> **ELI5 (Explain Like I'm 5) - What is this?**
>
> This study tested whether we can detect glaucoma from how the pupil reacts to light.
> We collected pupil measurements from 63 people, extracted features (numbers describing the pupil response),
> and trained computers (classifiers) to predict who has glaucoma.
>
> All the data is stored in **DuckDB** databases - think of them as super-efficient Excel files that can hold millions of rows.

---

## What Data is Available?

| Database File | What's Inside | Size | Rows |
|--------------|---------------|------|------|
| `foundation_plr_results.db` | Classifier predictions and performance metrics | ~4 MB | 20,349 predictions |
| `foundation_plr_distributions.db` | Bootstrap statistics and per-subject predictions | ~40 MB | 1.3M+ rows |

---

## 1. Setup - Install Required Libraries

You only need **duckdb** to access the data. Optional libraries for data science work:

In [None]:
# Run this cell to install required libraries (only needed once)
# Remove the # to uncomment and run

# !pip install duckdb pandas numpy matplotlib
# !pip install polars  # Optional: faster alternative to pandas

In [None]:
# Import libraries
import duckdb
import pandas as pd
import numpy as np
from pathlib import Path

# Try importing Polars (optional)
try:
    import polars as pl
    HAS_POLARS = True
    print("✓ Polars available")
except ImportError:
    HAS_POLARS = False
    print("Polars not installed (optional)")

print("✓ Setup complete!")

---

## 2. Connect to the Database

> **ELI5 - What is DuckDB?**
>
> DuckDB is like SQLite but optimized for data analysis. It's a single file that contains
> tables of data. You can query it using SQL (a language for asking questions about data)
> or convert data to Pandas/Polars for Python analysis.

In [None]:
# Define paths to data files
# Adjust these paths to match where you saved the files
DATA_DIR = Path("../outputs")  # Or wherever you downloaded the files

RESULTS_DB = DATA_DIR / "foundation_plr_results.db"
DISTRIBUTIONS_DB = DATA_DIR / "foundation_plr_distributions.db"

# Check if files exist
print("Checking data files:")
print(f"  Results DB: {'✓ Found' if RESULTS_DB.exists() else '✗ Not found'} ({RESULTS_DB})")
print(f"  Distributions DB: {'✓ Found' if DISTRIBUTIONS_DB.exists() else '✗ Not found'} ({DISTRIBUTIONS_DB})")

In [None]:
# Connect to the results database (read-only to prevent accidental changes)
con = duckdb.connect(str(RESULTS_DB), read_only=True)

# See what tables are available
tables = con.execute("SHOW TABLES").fetchall()
print("Tables in results database:")
for table in tables:
    print(f"  - {table[0]}")

---

## 3. Understanding the Data: Table Schemas

Let's look at exactly what each table contains. **Every column is explained below.**

### 3.1 The `predictions` Table

> **ELI5**: Each row is one prediction - "for this person, this classifier predicted they have/don't have glaucoma".
> The table stores both the prediction AND the actual answer (ground truth) so we can measure accuracy.

In [None]:
# Get schema of predictions table
print("PREDICTIONS TABLE SCHEMA")
print("=" * 70)
schema = con.execute("DESCRIBE predictions").fetchdf()
print(schema.to_string(index=False))

# Count rows
count = con.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
print(f"\nTotal rows: {count:,}")

#### Column Explanations for `predictions`:

| Column | Type | Description | Example Values |
|--------|------|-------------|----------------|
| `prediction_id` | INTEGER | Unique ID for each prediction | 1, 2, 3, ... |
| `subject_id` | VARCHAR | Anonymous patient identifier | "PLR1001", "PLR1002" |
| `eye` | VARCHAR | Which eye was measured | "OD" (right), "OS" (left) |
| `fold` | INTEGER | Cross-validation fold number | 0, 1, 2, 3, 4 |
| `bootstrap_iter` | INTEGER | Bootstrap iteration (0 = original) | 0, 1, 2, ... |
| `outlier_method` | VARCHAR | How outliers were detected | "ensemble-LOF-..." |
| `imputation_method` | VARCHAR | How missing data was filled | "SAITS" |
| `featurization` | VARCHAR | Feature extraction method | "simple1.0" |
| `classifier` | VARCHAR | ML algorithm used | "TabM", "XGBOOST", "LogisticRegression", "TabPFN" |
| `source_name` | VARCHAR | Full pipeline config string | "XGBOOST_eval-auc__..." |
| `y_true` | INTEGER | **Ground truth**: Does patient have glaucoma? | 0 (no), 1 (yes) |
| `y_pred` | INTEGER | **Prediction**: Classifier's binary decision | 0 (no), 1 (yes) |
| `y_prob` | FLOAT | **Probability**: Classifier's confidence | 0.0 to 1.0 |
| `mlflow_run_id` | VARCHAR | Experiment tracking ID | "abc123..." |

In [None]:
# See example data
print("\nExample rows from predictions table:")
print(con.execute("SELECT * FROM predictions LIMIT 3").fetchdf().to_string())

### 3.2 The `metrics_per_fold` Table

> **ELI5**: Performance metrics (like accuracy) calculated separately for each cross-validation fold.
> Cross-validation means we split the data 5 ways and test 5 times to get robust results.

In [None]:
print("METRICS_PER_FOLD TABLE SCHEMA")
print("=" * 70)
schema = con.execute("DESCRIBE metrics_per_fold").fetchdf()
print(schema.to_string(index=False))

count = con.execute("SELECT COUNT(*) FROM metrics_per_fold").fetchone()[0]
print(f"\nTotal rows: {count:,}")

#### Column Explanations for `metrics_per_fold`:

| Column | Type | Description | Ideal Value |
|--------|------|-------------|-------------|
| `metric_id` | INTEGER | Unique row ID | - |
| `classifier` | VARCHAR | ML algorithm | - |
| `fold` | INTEGER | CV fold (0-4) | - |
| `metric_name` | VARCHAR | Which metric | "auroc", "auprc", etc. |
| `metric_value` | FLOAT | The measured value | Depends on metric |
| `bootstrap_iter` | INTEGER | Bootstrap iteration | - |
| `source_name` | VARCHAR | Pipeline config | - |

**Common Metrics:**
- `auroc`: Area Under ROC Curve (0-1, higher=better, 0.5=random, 1.0=perfect)
- `auprc`: Area Under Precision-Recall Curve (higher=better)
- `brier`: Brier Score (0-1, lower=better, measures calibration)
- `accuracy`: Correct predictions / Total predictions

### 3.3 The `metrics_aggregate` Table

> **ELI5**: Summary statistics (mean, median, confidence intervals) calculated across all folds.

In [None]:
print("METRICS_AGGREGATE TABLE SCHEMA")
print("=" * 70)
schema = con.execute("DESCRIBE metrics_aggregate").fetchdf()
print(schema.to_string(index=False))

#### Column Explanations for `metrics_aggregate`:

| Column | Description |
|--------|-------------|
| `aggregate_id` | Unique row ID |
| `classifier` | ML algorithm name |
| `metric_name` | Metric being summarized |
| `mean` | Average across folds |
| `std` | Standard deviation |
| `median` | Middle value |
| `q25` | 25th percentile |
| `q75` | 75th percentile |
| `ci_lower` | 95% CI lower bound |
| `ci_upper` | 95% CI upper bound |
| `source_name` | Pipeline config |

### 3.4 The `mlflow_runs` Table (if present)

> **ELI5**: Metadata about each experiment run - when it ran, what settings were used, etc.

In [None]:
# Check if mlflow_runs exists
tables = [t[0] for t in con.execute("SHOW TABLES").fetchall()]
if 'mlflow_runs' in tables:
    print("MLFLOW_RUNS TABLE SCHEMA")
    print("=" * 70)
    schema = con.execute("DESCRIBE mlflow_runs").fetchdf()
    print(schema.to_string(index=False))
else:
    print("mlflow_runs table not present in this database")

---

## 4. Querying Data: SQL Basics

> **ELI5 - What is SQL?**
>
> SQL (Structured Query Language) is how you ask questions to a database.
> - `SELECT` = which columns you want
> - `FROM` = which table
> - `WHERE` = filter conditions
> - `GROUP BY` = aggregate by category
> - `ORDER BY` = sort results

### 4.1 Basic Queries

In [None]:
# Example 1: Get AUROC for each classifier
print("AUROC by Classifier:")
print("-" * 50)

query = """
SELECT 
    classifier,
    ROUND(mean, 4) as auroc_mean,
    ROUND(ci_lower, 4) as ci_lower,
    ROUND(ci_upper, 4) as ci_upper
FROM metrics_aggregate
WHERE metric_name = 'auroc'
ORDER BY mean DESC
"""

result = con.execute(query).fetchdf()
print(result.to_string(index=False))

In [None]:
# Example 2: Count predictions per classifier
print("\nPredictions per Classifier:")
print("-" * 50)

query = """
SELECT 
    classifier,
    COUNT(*) as n_predictions,
    COUNT(DISTINCT subject_id) as n_subjects,
    ROUND(AVG(y_true), 3) as glaucoma_prevalence
FROM predictions
GROUP BY classifier
ORDER BY classifier
"""

result = con.execute(query).fetchdf()
print(result.to_string(index=False))

In [None]:
# Example 3: Get predictions for a specific subject
print("\nPredictions for first subject:")
print("-" * 50)

# First, find a subject ID
first_subject = con.execute("SELECT DISTINCT subject_id FROM predictions LIMIT 1").fetchone()[0]
print(f"Subject: {first_subject}\n")

query = f"""
SELECT 
    classifier,
    eye,
    y_true as has_glaucoma,
    ROUND(y_prob, 3) as predicted_probability,
    y_pred as predicted_class
FROM predictions
WHERE subject_id = '{first_subject}'
  AND fold = 0  -- Just one fold to keep it simple
ORDER BY classifier
"""

result = con.execute(query).fetchdf()
print(result.to_string(index=False))

### 4.2 Run Statistics Directly in DuckDB (No Pandas Needed!)

> **ELI5**: DuckDB can do math and statistics directly - you don't need to load data into Python.

In [None]:
# Compute statistics directly in DuckDB
print("Statistics computed directly in DuckDB:")
print("=" * 70)

query = """
SELECT 
    classifier,
    -- Basic stats
    COUNT(*) as n,
    ROUND(AVG(y_prob), 4) as mean_prob,
    ROUND(STDDEV(y_prob), 4) as std_prob,
    
    -- Percentiles
    ROUND(PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY y_prob), 4) as q25,
    ROUND(PERCENTILE_CONT(0.50) WITHIN GROUP (ORDER BY y_prob), 4) as median,
    ROUND(PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY y_prob), 4) as q75,
    
    -- Min/Max
    ROUND(MIN(y_prob), 4) as min_prob,
    ROUND(MAX(y_prob), 4) as max_prob
FROM predictions
WHERE fold = 0  -- Use first fold for cleaner stats
GROUP BY classifier
ORDER BY classifier
"""

result = con.execute(query).fetchdf()
print(result.to_string(index=False))

In [None]:
# Compute confusion matrix metrics in DuckDB
print("\nConfusion Matrix Metrics (at threshold=0.5):")
print("=" * 70)

query = """
WITH confusion AS (
    SELECT 
        classifier,
        SUM(CASE WHEN y_true = 1 AND y_pred = 1 THEN 1 ELSE 0 END) as TP,
        SUM(CASE WHEN y_true = 0 AND y_pred = 1 THEN 1 ELSE 0 END) as FP,
        SUM(CASE WHEN y_true = 1 AND y_pred = 0 THEN 1 ELSE 0 END) as FN,
        SUM(CASE WHEN y_true = 0 AND y_pred = 0 THEN 1 ELSE 0 END) as TN
    FROM predictions
    WHERE fold = 0
    GROUP BY classifier
)
SELECT 
    classifier,
    TP, FP, FN, TN,
    ROUND(CAST(TP AS FLOAT) / (TP + FN), 3) as Sensitivity,
    ROUND(CAST(TN AS FLOAT) / (TN + FP), 3) as Specificity,
    ROUND(CAST(TP + TN AS FLOAT) / (TP + TN + FP + FN), 3) as Accuracy
FROM confusion
ORDER BY classifier
"""

result = con.execute(query).fetchdf()
print(result.to_string(index=False))

---

## 5. Converting to DataFrames

Sometimes you need data in Python for visualization or custom analysis.

### 5.1 Convert to Pandas DataFrame

In [None]:
# Method 1: Using .fetchdf() - returns Pandas DataFrame
query = "SELECT * FROM predictions WHERE classifier = 'TabM' AND fold = 0"
df_pandas = con.execute(query).fetchdf()

print(f"Type: {type(df_pandas)}")
print(f"Shape: {df_pandas.shape}")
print(f"\nFirst 3 rows:")
print(df_pandas.head(3))

In [None]:
# Method 2: Using pd.read_sql (if you prefer this syntax)
# This opens a separate connection
df_pandas2 = pd.read_sql(
    "SELECT classifier, AVG(y_prob) as mean_prob FROM predictions GROUP BY classifier",
    duckdb.connect(str(RESULTS_DB), read_only=True)
)
print(df_pandas2)

### 5.2 Convert to Polars DataFrame (Faster for Large Data)

In [None]:
if HAS_POLARS:
    # Method 1: Using .pl() method (DuckDB >= 0.8.0)
    try:
        df_polars = con.execute(
            "SELECT * FROM predictions WHERE classifier = 'TabM' AND fold = 0"
        ).pl()
        print(f"Type: {type(df_polars)}")
        print(f"Shape: {df_polars.shape}")
        print(f"\nFirst 3 rows:")
        print(df_polars.head(3))
    except AttributeError:
        # Fallback: Convert via Arrow
        arrow_table = con.execute(
            "SELECT * FROM predictions WHERE classifier = 'TabM' AND fold = 0"
        ).arrow()
        df_polars = pl.from_arrow(arrow_table)
        print(f"Type: {type(df_polars)}")
        print(f"Shape: {df_polars.shape}")
else:
    print("Polars not installed. Install with: pip install polars")

---

## 6. Converting to NumPy Arrays

> **ELI5**: NumPy arrays are the basic data structure for numerical computing in Python.
> Machine learning libraries like scikit-learn work with NumPy arrays.

In [None]:
# Extract a single column as NumPy array
y_probs = con.execute(
    "SELECT y_prob FROM predictions WHERE classifier = 'TabM' AND fold = 0"
).fetchnumpy()['y_prob']

print(f"Type: {type(y_probs)}")
print(f"Shape: {y_probs.shape}")
print(f"Dtype: {y_probs.dtype}")
print(f"First 5 values: {y_probs[:5]}")

In [None]:
# Extract multiple columns as NumPy arrays
result = con.execute("""
    SELECT y_true, y_prob, y_pred 
    FROM predictions 
    WHERE classifier = 'TabM' AND fold = 0
""").fetchnumpy()

y_true = result['y_true']
y_prob = result['y_prob']
y_pred = result['y_pred']

print(f"y_true shape: {y_true.shape}, dtype: {y_true.dtype}")
print(f"y_prob shape: {y_prob.shape}, dtype: {y_prob.dtype}")
print(f"y_pred shape: {y_pred.shape}, dtype: {y_pred.dtype}")

In [None]:
# Now you can use these arrays with scikit-learn!
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix

print("Metrics computed from NumPy arrays:")
print(f"  AUROC: {roc_auc_score(y_true, y_prob):.4f}")
print(f"  Accuracy: {accuracy_score(y_true, y_pred):.4f}")
print(f"\nConfusion Matrix:")
print(confusion_matrix(y_true, y_pred))

---

## 7. Working with the Distributions Database

> **ELI5 - What are bootstrap distributions?**
>
> Bootstrap is a statistical technique where we resample the data many times (e.g., 1000 times)
> and calculate a metric each time. This gives us a distribution of values, which we use to
> estimate uncertainty (confidence intervals) instead of just a single number.

In [None]:
# Connect to distributions database
if DISTRIBUTIONS_DB.exists():
    con_dist = duckdb.connect(str(DISTRIBUTIONS_DB), read_only=True)
    
    print("Tables in distributions database:")
    for table in con_dist.execute("SHOW TABLES").fetchall():
        count = con_dist.execute(f"SELECT COUNT(*) FROM {table[0]}").fetchone()[0]
        print(f"  - {table[0]}: {count:,} rows")
else:
    print("Distributions database not found")

In [None]:
# Explore bootstrap_distributions table
if DISTRIBUTIONS_DB.exists():
    print("\nBOOTSTRAP_DISTRIBUTIONS TABLE SCHEMA")
    print("=" * 70)
    schema = con_dist.execute("DESCRIBE bootstrap_distributions").fetchdf()
    print(schema.to_string(index=False))
    
    print("\nExample rows:")
    print(con_dist.execute("SELECT * FROM bootstrap_distributions LIMIT 3").fetchdf().to_string())

#### Column Explanations for `bootstrap_distributions`:

| Column | Description |
|--------|-------------|
| `dist_id` | Unique row ID |
| `classifier` | ML algorithm |
| `metric_name` | Metric being bootstrapped |
| `fold` | Cross-validation fold |
| `bootstrap_iter` | Bootstrap iteration (0 to N-1) |
| `metric_value` | Metric value for this bootstrap sample |
| `source_name` | Pipeline configuration |

In [None]:
# Explore subject_predictions table
if DISTRIBUTIONS_DB.exists():
    print("\nSUBJECT_PREDICTIONS TABLE SCHEMA")
    print("=" * 70)
    schema = con_dist.execute("DESCRIBE subject_predictions").fetchdf()
    print(schema.to_string(index=False))
    
    print("\nExample rows:")
    print(con_dist.execute("SELECT * FROM subject_predictions LIMIT 3").fetchdf().to_string())

#### Column Explanations for `subject_predictions`:

| Column | Description |
|--------|-------------|
| `pred_id` | Unique row ID |
| `source_name` | Full pipeline config |
| `classifier` | ML algorithm |
| `split` | Data split (train/test) |
| `subject_code` | Anonymous subject ID |
| `y_true` | Ground truth (0=healthy, 1=glaucoma) |
| `y_pred_proba` | Predicted probability |
| `y_pred` | Binary prediction |
| `confidence` | Prediction confidence (if available) |
| `entropy_of_expected` | Uncertainty measure |
| `expected_entropy` | Uncertainty measure |
| `mutual_information` | Uncertainty measure |

In [None]:
# Example: Get bootstrap CI for AUROC
if DISTRIBUTIONS_DB.exists():
    print("\nBootstrap 95% CI for AUROC:")
    print("=" * 70)
    
    query = """
    SELECT 
        classifier,
        COUNT(*) as n_bootstraps,
        ROUND(AVG(metric_value), 4) as mean_auroc,
        ROUND(PERCENTILE_CONT(0.025) WITHIN GROUP (ORDER BY metric_value), 4) as ci_lower,
        ROUND(PERCENTILE_CONT(0.975) WITHIN GROUP (ORDER BY metric_value), 4) as ci_upper
    FROM bootstrap_distributions
    WHERE metric_name = 'auroc'
    GROUP BY classifier
    ORDER BY mean_auroc DESC
    """
    
    result = con_dist.execute(query).fetchdf()
    print(result.to_string(index=False))

---

## 8. Training Your Own Classifier

> **ELI5**: If you want to test a different classifier or settings, you can load the features
> and train your own model.

In [None]:
# For this example, we'll use the predictions to demonstrate
# In practice, you'd use the features database

# Get unique subjects and their outcomes for one classifier
query = """
SELECT DISTINCT 
    subject_id,
    eye,
    y_true,
    y_prob as original_prob
FROM predictions
WHERE classifier = 'TabM' AND fold = 0
"""

df = con.execute(query).fetchdf()
print(f"Subjects: {len(df)}")
print(f"Class distribution: {df['y_true'].value_counts().to_dict()}")
print(f"\nFirst 5 rows:")
print(df.head())

In [None]:
# Train a simple classifier (using prediction probabilities as features for demo)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

# Using the original probability as a single feature (just for demonstration)
X_demo = df['original_prob'].values.reshape(-1, 1)
y_demo = df['y_true'].values

# Train logistic regression
clf = LogisticRegression(random_state=42)
scores = cross_val_score(clf, X_demo, y_demo, cv=5, scoring='roc_auc')

print(f"Cross-validation AUROC scores: {scores}")
print(f"Mean AUROC: {scores.mean():.4f} (+/- {scores.std()*2:.4f})")

---

## 9. Saving and Exporting Data

DuckDB can export data to many formats.

In [None]:
# Export to CSV
con.execute("""
    COPY (
        SELECT classifier, mean, ci_lower, ci_upper
        FROM metrics_aggregate
        WHERE metric_name = 'auroc'
    ) TO 'auroc_results.csv' (HEADER, DELIMITER ',')
""")
print("✓ Exported to auroc_results.csv")

# Read it back to verify
print("\nContents:")
print(open('auroc_results.csv').read())

In [None]:
# Export to Parquet (efficient columnar format)
con.execute("""
    COPY (
        SELECT * FROM predictions WHERE classifier = 'TabM'
    ) TO 'tabm_predictions.parquet' (FORMAT PARQUET)
""")
print("✓ Exported to tabm_predictions.parquet")

# Check file size
import os
size_mb = os.path.getsize('tabm_predictions.parquet') / (1024 * 1024)
print(f"File size: {size_mb:.2f} MB")

In [None]:
# Clean up exported files
import os
for f in ['auroc_results.csv', 'tabm_predictions.parquet']:
    if os.path.exists(f):
        os.remove(f)
        print(f"Cleaned up: {f}")

---

## 10. Quick Reference

### Common SQL Queries

```sql
-- Get all metrics for a classifier
SELECT * FROM metrics_aggregate WHERE classifier = 'TabM';

-- Get predictions for a specific subject
SELECT * FROM predictions WHERE subject_id = 'PLR1001';

-- Count unique subjects
SELECT COUNT(DISTINCT subject_id) FROM predictions;

-- Get class balance
SELECT y_true, COUNT(*) FROM predictions GROUP BY y_true;

-- Filter by multiple conditions
SELECT * FROM predictions 
WHERE classifier = 'TabM' 
  AND fold = 0 
  AND y_prob > 0.5;
```

### Data Conversion Cheatsheet

| Want | Code |
|------|------|
| Pandas DataFrame | `con.execute(query).fetchdf()` |
| Polars DataFrame | `con.execute(query).pl()` |
| NumPy arrays | `con.execute(query).fetchnumpy()` |
| Python list | `con.execute(query).fetchall()` |
| Arrow Table | `con.execute(query).arrow()` |

### Database Files

| File | Contents | Use Case |
|------|----------|----------|
| `foundation_plr_results.db` | Predictions, metrics | Main analysis |
| `foundation_plr_distributions.db` | Bootstrap samples | Uncertainty analysis |

In [None]:
# Close connections when done
con.close()
if DISTRIBUTIONS_DB.exists():
    con_dist.close()
print("✓ Connections closed")

---

## Need Help?

- **DuckDB Documentation**: https://duckdb.org/docs/
- **SQL Tutorial**: https://www.w3schools.com/sql/
- **Pandas Documentation**: https://pandas.pydata.org/docs/
- **Questions about this data**: Contact the study authors