# bea-tools Example Notebook

This notebook demonstrates the main features of the `bea-tools` package:
1. **Utility Functions** - Clean output formatting helpers
2. **Pandas Extensions** - Enhanced Series methods
3. **TreeSampler** - Hierarchical stratified sampling for pandas DataFrames
4. **DicomComparer** - DICOM file comparison tool

---

## 1. Utility Functions

Helper functions for formatting and displaying output in a clean, readable way.

In [None]:
# Create sample data
sample_series = pd.Series(['CT', 'MRI', 'CT', 'XR', 'MRI', 'CT', 'PET', 'MRI', 'CT', 'XR'])

print(divider("Standard Value Counts", "=", line_width=60, align="center"))
counts = sample_series.bea.value_counts(output=True)
print(counts)
print()

print(divider("With Proportions", "=", line_width=60, align="center"))
counts_with_pct = sample_series.bea.value_counts(with_proportion=True, output=True)
for level, value in counts_with_pct.items():
    print(aligned(level, value, frame_width=30, line_width=60))
print()

print(divider("Custom Sort Order", "=", line_width=60, align="center"))
custom_sorted = sample_series.bea.value_counts(
    sort=['CT', 'MRI', 'XR', 'PET'], 
    with_proportion=True, 
    output=True
)
for level, value in custom_sorted.items():
    print(aligned(level, value, frame_width=30, line_width=60))

### Series.bea.value_counts() - Enhanced value counting

---

## 2. Pandas Series Extensions

Enhanced Series methods with custom formatting and output options.

In [None]:
# Center two items within a frame
print(aligned("Label:", "Value", frame_width=30, line_width=60))
print(aligned("Name:", "Beatrice", frame_width=30, line_width=60))
print(aligned("Status:", "Active", frame_width=30, line_width=60))
print()

# Multiple items distributed across frame
print(aligned("A", "B", "C", frame_width=50, line_width=60))
print(aligned("Left", "Middle", "Right", frame_width=50, line_width=60))

### aligned() - Format items within a frame

In [None]:
# Simple divider
print(divider(line_width=60))

# Divider with centered text
print(divider("Section Header", "=", line_width=60, align="center"))

# Divider with left-aligned text
print(divider("Results", "-", line_width=60, align="left"))

# Divider with right-aligned text
print(divider("End", "-", line_width=60, align="right"))

### divider() - Create formatted divider lines

In [None]:
import pandas as pd
import numpy as np

# For utility functions
from bea_tools.utility import divider, aligned

# For TreeSampler
from bea_tools import TreeSampler
from bea_tools._pandas.sampler import Feature

# For DicomComparer
import pydicom
from bea_tools._dicom.dicomp import DicomComparer

---

## 3. TreeSampler - Hierarchical Stratified Sampling

TreeSampler allows you to sample from a DataFrame while maintaining specific proportions across multiple categorical dimensions. It handles capacity constraints intelligently through hierarchical spillover.

### Generate Example Data

First, let's create a synthetic medical imaging dataset with patient demographics and exam information.

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Generate synthetic patient data
n_patients = 500
n_exams = 1200  # Some patients have multiple exams

# Create patient pool
patient_ids = [f"P{str(i).zfill(4)}" for i in range(n_patients)]
genders = np.random.choice(["M", "F"], size=n_patients, p=[0.48, 0.52])
ages = np.random.normal(55, 15, n_patients).clip(18, 90).astype(int)

# Create patient lookup
patient_data = {pid: {"gender": g, "age": a} for pid, g, a in zip(patient_ids, genders, ages)}

# Generate exams (some patients have multiple)
exam_patient_ids = np.random.choice(patient_ids, size=n_exams, replace=True)
exam_ids = [f"E{str(i).zfill(5)}" for i in range(n_exams)]
modalities = np.random.choice(["CT", "MRI", "XR"], size=n_exams, p=[0.4, 0.35, 0.25])
study_dates = pd.date_range("2020-01-01", periods=n_exams, freq="4H")

# Build the DataFrame
df = pd.DataFrame({
    "empi_anon": exam_patient_ids,
    "exam_id": exam_ids,
    "gender": [patient_data[pid]["gender"] for pid in exam_patient_ids],
    "age": [patient_data[pid]["age"] for pid in exam_patient_ids],
    "modality": modalities,
    "studydate_anon": study_dates,
})

# Add age groups
df["age_group"] = pd.cut(
    df["age"], 
    bins=[0, 40, 60, 100], 
    labels=["Young (<40)", "Middle (40-60)", "Senior (60+)"]
)

print(f"Generated {len(df):,} exams from {df['empi_anon'].nunique():,} unique patients")
df.head(10)

In [None]:
# View the distribution of our data
print("Gender distribution:")
print(df["gender"].value_counts(normalize=True).round(3))
print("\nModality distribution:")
print(df["modality"].value_counts(normalize=True).round(3))
print("\nAge group distribution:")
print(df["age_group"].value_counts(normalize=True).round(3))

### Basic Sampling: Single Feature

Sample 100 patients with a 50/50 gender split.

In [None]:
# Define a single stratification feature
gender_feature = Feature(
    name="gender",
    match_type="equals",
    levels=["M", "F"],
    weights=[0.5, 0.5]  # Target 50/50 split
)

# Create and run the sampler
sampler = TreeSampler(
    n=100,
    features=[gender_feature],
    seed=42,
    count_col="empi_anon",
    single_per_patient=True  # One exam per patient
)

sample_basic = sampler.sample_data(df.copy())

print(f"Sampled {len(sample_basic)} patients")
print("\nGender distribution in sample:")
print(sample_basic["gender"].value_counts())

### Multi-Feature Stratification

Sample with stratification across both gender AND modality.

In [None]:
# Define multiple stratification features
features_multi = [
    Feature(
        name="gender",
        match_type="equals",
        levels=["M", "F"],
        weights=[0.5, 0.5]
    ),
    Feature(
        name="modality",
        match_type="equals",
        levels=["CT", "MRI", "XR"],
        weights=[0.4, 0.4, 0.2]  # Target: 40% CT, 40% MRI, 20% XR
    )
]

sampler_multi = TreeSampler(
    n=200,
    features=features_multi,
    seed=42,
    count_col="empi_anon",
    single_per_patient=True
)

sample_multi = sampler_multi.sample_data(df.copy())

print(f"Sampled {len(sample_multi)} patients")
print("\nGender distribution:")
print(sample_multi["gender"].value_counts(normalize=True).round(3))
print("\nModality distribution:")
print(sample_multi["modality"].value_counts(normalize=True).round(3))
print("\nCross-tabulation:")
print(pd.crosstab(sample_multi["gender"], sample_multi["modality"], margins=True))

### Using "Between" Match Type for Age Ranges

Sample with age brackets using the `between` match type.

In [None]:
# Define age brackets using "between" match type
age_feature = Feature(
    name="age",
    match_type="between",
    levels=[(0, 40), (40, 60), (60, 100)],  # Tuples define (min, max] ranges
    weights=[0.33, 0.34, 0.33],
    labels=["Young", "Middle", "Senior"],
    label_col="sampled_age_group"  # Add a label column to the output
)

sampler_age = TreeSampler(
    n=150,
    features=[age_feature],
    seed=42,
    count_col="empi_anon",
    single_per_patient=True
)

sample_age = sampler_age.sample_data(df.copy())

print(f"Sampled {len(sample_age)} patients")
print("\nAge group distribution:")
print(sample_age["sampled_age_group"].value_counts())
print("\nAge statistics per group:")
print(sample_age.groupby("sampled_age_group")["age"].describe()[["min", "max", "mean"]])

### Strict Mode: Prevent Spillover

Use `strict=True` to prevent a stratum from absorbing overflow from other strata.

In [None]:
# Create a small subset to demonstrate spillover behavior
df_small = df[df["modality"].isin(["CT", "XR"])].head(100).copy()

# Without strict mode - XR can absorb overflow from CT if needed
feature_normal = Feature(
    name="modality",
    match_type="equals",
    levels=["CT", "XR"],
    weights=[0.8, 0.2]  # Request 80% CT, but may not have enough
)

# With strict mode - XR will NOT absorb overflow
feature_strict = Feature(
    name="modality",
    match_type="equals",
    levels=["CT", "XR"],
    weights=[0.8, 0.2],
    strict=True  # Prevents spillover
)

print("Data available:")
print(df_small["modality"].value_counts())

# Normal mode
sampler_normal = TreeSampler(n=50, features=[feature_normal], seed=42, 
                              count_col="empi_anon", single_per_patient=True)
sample_normal = sampler_normal.sample_data(df_small.copy())
print("\nNormal mode (spillover allowed):")
print(sample_normal["modality"].value_counts())

# Strict mode
sampler_strict = TreeSampler(n=50, features=[feature_strict], seed=42,
                              count_col="empi_anon", single_per_patient=True)
sample_strict = sampler_strict.sample_data(df_small.copy())
print("\nStrict mode (no spillover):")
print(sample_strict["modality"].value_counts())

### Conditional Weights

Define weights that vary based on parent feature values - useful when you want different distributions within different groups.

In [None]:
# Define conditional weights: modality distribution depends on gender
# For Males: 60% CT, 30% MRI, 10% XR
# For Females: 30% CT, 50% MRI, 20% XR

features_conditional = [
    Feature(
        name="gender",
        match_type="equals",
        levels=["M", "F"],
        weights=[0.5, 0.5]
    ),
    Feature(
        name="modality",
        match_type="equals",
        levels=["CT", "MRI", "XR"],
        conditional_weights=[{
            "feature": "gender",
            "weights": {
                "M": [0.6, 0.3, 0.1],   # Males: 60% CT, 30% MRI, 10% XR
                "F": [0.3, 0.5, 0.2]    # Females: 30% CT, 50% MRI, 20% XR
            }
        }]
    )
]

sampler_cond = TreeSampler(
    n=200,
    features=features_conditional,
    seed=42,
    count_col="empi_anon",
    single_per_patient=True
)

sample_cond = sampler_cond.sample_data(df.copy())

print("Conditional sampling results:")
print("\nModality distribution by gender:")
print(pd.crosstab(sample_cond["gender"], sample_cond["modality"], normalize="index").round(3))

In [None]:
# With sorting - selects earliest exam per patient
sampler_with_sort = TreeSampler(
    n=50,
    features=[gender_feature],
    seed=42,
    count_col="empi_anon",
    sort_col="studydate_anon",  # Default - sorts by study date
    single_per_patient=True
)

sample_with_sort = sampler_with_sort.sample_data(df.copy())

# Without sorting - arbitrary selection (faster)
sampler_no_sort = TreeSampler(
    n=50,
    features=[gender_feature],
    seed=42,
    count_col="empi_anon",
    sort_col=None,  # Disables sorting
    single_per_patient=True
)

sample_no_sort = sampler_no_sort.sample_data(df.copy())

print("With sorting (earliest exam per patient):")
print(f"Date range: {sample_with_sort['studydate_anon'].min()} to {sample_with_sort['studydate_anon'].max()}")
print(f"\nWithout sorting (arbitrary selection):")
print(f"Date range: {sample_no_sort['studydate_anon'].min()} to {sample_no_sort['studydate_anon'].max()}")

### Optional Sorting for Single-Per-Patient

Control whether to use a sort column when selecting one row per patient. Sorting is useful when you want to select the earliest/latest exam, while disabling sorting is faster.

In [None]:
# Balanced sampling - ensures equal samples from each modality
# regardless of population distribution
feature_balanced = Feature(
    name="modality",
    match_type="equals",
    levels=["CT", "MRI", "XR"],
    balanced=True  # Distributes samples equally across all 3 levels
)

sampler_balanced = TreeSampler(
    n=150,
    features=[feature_balanced],
    seed=42,
    count_col="empi_anon",
    single_per_patient=True
)

sample_balanced = sampler_balanced.sample_data(df.copy())

print(f"Original population distribution:")
print(df.groupby("empi_anon")["modality"].first().value_counts(normalize=True).round(3))
print(f"\nBalanced sample distribution (target: 33.3% each):")
print(sample_balanced["modality"].value_counts(normalize=True).round(3))
print(f"\nActual counts:")
print(sample_balanced["modality"].value_counts())

### Balanced Sampling

Use `balanced=True` to ensure equal representation across all levels, regardless of the underlying population distribution.

---

## 4. DicomComparer - DICOM File Comparison

DicomComparer allows you to compare two DICOM files at the attribute level, identifying shared attributes, exclusive attributes, and value conflicts.

### Basic Usage

Load two DICOM files and compare them.

In [None]:
# Load your DICOM files
# Replace these paths with your actual DICOM file paths
DICOM_PATH_1 = "path/to/your/first_dicom.dcm"
DICOM_PATH_2 = "path/to/your/second_dicom.dcm"

# Uncomment and run when you have actual DICOM files:
# dcm1 = pydicom.dcmread(DICOM_PATH_1)
# dcm2 = pydicom.dcmread(DICOM_PATH_2)

# # Create the comparer
# comparer = DicomComparer(dcm1, dcm2)

# # Run the comparison
# comparison = comparer.compare()

# # Print the visual summary
# comparison.summary()

### Using Labeled Dictionary Input

You can also pass DICOM files as a labeled dictionary for clearer output.

In [None]:
# Alternative: use labeled dictionary for clearer identification
# comparer = DicomComparer(dcms={
#     "Original CT Scan": dcm1,
#     "Anonymized CT Scan": dcm2
# })
# comparison = comparer.compare()
# comparison.summary()

### Accessing Comparison Details

After running a comparison, you can access detailed information about matches and conflicts.

In [None]:
# Access comparison details programmatically
# 
# # Number of shared attributes
# print(f"Shared attributes: {comparison.intersection.n}")
# 
# # Number of attributes only in file 1
# print(f"Exclusive to file 1: {comparison.exclusive_to_1.n}")
# 
# # Number of attributes only in file 2
# print(f"Exclusive to file 2: {comparison.exclusive_to_2.n}")
# 
# # Number of matching vs conflicting values
# print(f"Matching values: {comparison.intersection.comparison.n_matches}")
# print(f"Conflicting values: {comparison.intersection.comparison.n_conflicts}")

### Inspecting Conflicts

View the specific attributes that have conflicting values between the two files.

In [None]:
# Iterate through conflicts to see what differs
# 
# print("Conflicting attributes:")
# print("-" * 60)
# for conflict in comparison.intersection.comparison.conflicts:
#     print(f"\nAttribute: {conflict.attr}")
#     print(f"  File 1 value: {conflict.value1.repval}")
#     print(f"  File 2 value: {conflict.value2.repval}")

### Inspecting Exclusive Attributes

View attributes that only exist in one of the files.

In [None]:
# View attributes exclusive to each file
# 
# print("Attributes only in File 1:")
# for attr in list(comparison.exclusive_to_1.attributes)[:10]:  # First 10
#     print(f"  {attr}")
# 
# print("\nAttributes only in File 2:")
# for attr in list(comparison.exclusive_to_2.attributes)[:10]:  # First 10
#     print(f"  {attr}")

### Note on Pixel Data Comparison

The DicomComparer compares values using content hashes, not just string representations. This means it can detect differences in pixel data (`PixelData` attribute) even when the display representation looks identical. This is useful for:

- Detecting image modifications or corruptions
- Verifying anonymization didn't alter image data
- Comparing images from different processing pipelines

In [None]:
# Example: Check if pixel data matches between files
# 
# pixel_conflicts = [
#     c for c in comparison.intersection.comparison.conflicts 
#     if "Pixel" in c.attr.name
# ]
# 
# if pixel_conflicts:
#     print("WARNING: Pixel data differs between files!")
#     for pc in pixel_conflicts:
#         print(f"  {pc.attr.name}: hashes differ")
# else:
#     print("Pixel data matches between files")