# Comparing SUAVE (behaviour="hivae") with the reference HI-VAE

This notebook trains both implementations on the `mimic-example-1000.tsv` sample and compares their reconstructions/imputations to ensure the behaviours match.

> **TensorFlow dependency**: the reference implementation in `third_party/hivae_tf` expects TensorFlow 2.x (see the bundled README). The experiments below were validated with TensorFlow 2.11 and `tensorflow-probability` 0.19.

In [None]:
# Uncomment the following line if TensorFlow 2.11 and tensorflow-probability 0.19
# are not available in your environment.
# !pip install tensorflow==2.11.0 tensorflow-probability==0.19.0


In [None]:
import os
import sys
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import tensorflow as tf

REPO_ROOT = Path('..').resolve()
sys.path.insert(0, str(REPO_ROOT))
sys.path.insert(0, str(REPO_ROOT / 'third_party' / 'hivae_tf'))

from suave import SUAVE, Schema
from hivae import hivae as Hivae

os.environ.setdefault('PYTHONHASHSEED', '42')
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
tf.random.set_seed(42)


In [None]:
DATA_PATH = Path('data/mimic-example-1000/mimic-example-1000.tsv')
if not DATA_PATH.exists():
    raise FileNotFoundError(f'Dataset not found: {DATA_PATH}')

raw_df = pd.read_csv(DATA_PATH, sep='	')
print(f'Shape: {raw_df.shape[0]} rows x {raw_df.shape[1]} columns')
raw_df.head()


In [None]:
summary = pd.DataFrame({
    'dtype': raw_df.dtypes.astype(str),
    'n_unique': raw_df.nunique(dropna=True),
    'n_missing': raw_df.isna().sum(),
})
summary['missing_pct'] = summary['n_missing'] / len(raw_df) * 100.0
summary.sort_values('missing_pct', ascending=False)


The helper below proposes a schema by inspecting dtypes and value ranges. Review the output and adjust any column specifications manually if needed.

In [None]:
import pandas.api.types as pdt


def infer_column_spec(series: pd.Series) -> dict[str, object]:
    dtype = series.dtype
    non_na = series.dropna()
    if pdt.is_bool_dtype(dtype):
        return {'type': 'cat', 'n_classes': 2}
    if pdt.is_integer_dtype(dtype):
        unique_values = non_na.unique()
        unique_count = len(unique_values)
        if unique_count <= 2 and set(map(int, unique_values)) <= {0, 1}:
            return {'type': 'cat', 'n_classes': max(unique_count, 2)}
        return {'type': 'real'}
    if pdt.is_float_dtype(dtype):
        return {'type': 'real'}
    if pdt.is_numeric_dtype(dtype):
        return {'type': 'real'}
    if non_na.empty:
        return {'type': 'real'}
    return {'type': 'cat', 'n_classes': max(int(non_na.nunique()), 2)}

schema_dict = {column: infer_column_spec(raw_df[column]) for column in raw_df.columns}

schema_preview = pd.DataFrame(schema_dict).T
schema_preview


In [None]:
schema = Schema(schema_dict)
print(f'Total features: {len(list(schema.feature_names))}')

def schema_to_hivae_types(schema: Schema) -> list[tuple[str, str, int | None, int | None]]:
    types: list[tuple[str, str, int | None, int | None]] = []
    for column in schema.feature_names:
        spec = schema[column]
        if spec.type == 'cat':
            types.append((column, 'cat', spec.n_classes, spec.n_classes))
        elif spec.type == 'ordinal':
            types.append((column, 'ordinal', spec.n_classes, spec.n_classes))
        elif spec.type == 'real':
            types.append((column, 'real', 1, None))
        elif spec.type == 'pos':
            types.append((column, 'pos', 1, None))
        elif spec.type == 'count':
            types.append((column, 'count', 1, None))
        else:
            raise ValueError(f'Unsupported feature type: {spec.type}')
    return types

types_list = schema_to_hivae_types(schema)
print(f'HI-VAE type descriptors: {len(types_list)} columns')
types_list[:5]


In [None]:
train_df = raw_df.sample(frac=0.8, random_state=42)
test_df = raw_df.drop(train_df.index)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

train_missing_mask = train_df.notna().astype(int)
test_missing_mask = test_df.notna().astype(int)

train_df.head()


## Train the reference HI-VAE implementation

In [None]:
artifact_root = Path('artifacts') / 'mimic_hivae'
artifact_root.mkdir(parents=True, exist_ok=True)

network_config = {
    'model_name': 'model_HIVAE_inputDropout',
    'dim_z': 32,
    'dim_y': 32,
    'dim_s': 32,
    'batch_size': 64,
}

hivae_model = Hivae(
    types_list,
    network_config,
    results_path=str(artifact_root / 'results'),
    network_path=str(artifact_root / 'networks'),
    verbosity_level=1,
)

hivae_model.fit(train_df, epochs=50, true_missing_mask=train_missing_mask)
(
    hivae_test_data,
    hivae_reconstructed,
    hivae_decoded,
    hivae_latent_z,
    hivae_latent_s,
) = hivae_model.predict(test_df, true_missing_mask=test_missing_mask)

hivae_decoded_df = pd.DataFrame(hivae_decoded, columns=test_df.columns)
hivae_decoded_df.head()


## Train SUAVE with `behaviour="hivae"`

In [None]:
suave_model = SUAVE(
    schema=schema,
    behaviour='hivae',
    latent_dim=32,
    hidden_dims=(256, 128),
    dropout=0.1,
    learning_rate=1e-3,
    batch_size=64,
    kl_warmup_epochs=5,
    random_state=42,
)

suave_model.fit(train_df, epochs=50, batch_size=64)
suave_reconstruction = suave_model.impute(test_df, only_missing=False)
suave_reconstruction.head()


## Align reconstructions for comparison

In [None]:
def convert_to_numeric(
    suave_frame: pd.DataFrame,
    hivae_frame: pd.DataFrame,
    schema: Schema,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    suave_numeric: dict[str, pd.Series] = {}
    hivae_numeric: dict[str, pd.Series] = {}
    for column in schema.feature_names:
        spec = schema[column]
        suave_series = suave_frame[column]
        hivae_series = hivae_frame[column]
        if spec.type in {'cat', 'ordinal'}:
            suave_codes = pd.Categorical(suave_series).codes
            suave_numeric[column] = pd.Series(suave_codes, index=suave_series.index)
            hivae_numeric[column] = pd.to_numeric(hivae_series, errors='coerce')
        else:
            suave_numeric[column] = pd.to_numeric(suave_series, errors='coerce')
            hivae_numeric[column] = pd.to_numeric(hivae_series, errors='coerce')
    suave_df = pd.DataFrame(suave_numeric)
    hivae_df = pd.DataFrame(hivae_numeric)
    return suave_df, hivae_df

suave_numeric, hivae_numeric = convert_to_numeric(suave_reconstruction, hivae_decoded_df, schema)
suave_numeric.head()


## Quantitative comparison

In [None]:
numeric_metrics: list[dict[str, object]] = []
for column in schema.feature_names:
    spec = schema[column]
    if spec.type in {'cat', 'ordinal'}:
        continue
    suave_vals = suave_numeric[column]
    hivae_vals = hivae_numeric[column]
    diff = suave_vals - hivae_vals
    missing_mask = test_missing_mask[column] == 0
    mae_all = float(np.nanmean(np.abs(diff)))
    rmse_all = float(np.sqrt(np.nanmean(diff**2)))
    max_abs_all = float(np.nanmax(np.abs(diff))) if not np.isnan(diff).all() else np.nan
    if missing_mask.any():
        diff_missing = diff[missing_mask]
        mae_missing = float(np.nanmean(np.abs(diff_missing)))
        rmse_missing = float(np.sqrt(np.nanmean(diff_missing**2)))
    else:
        mae_missing = np.nan
        rmse_missing = np.nan
    numeric_metrics.append({
        'column': column,
        'type': spec.type,
        'mae_all': mae_all,
        'rmse_all': rmse_all,
        'max_abs_all': max_abs_all,
        'mae_missing': mae_missing,
        'rmse_missing': rmse_missing,
        'n_missing': int(missing_mask.sum()),
    })

numeric_metrics_df = pd.DataFrame(numeric_metrics).sort_values('mae_all', ascending=False)
numeric_metrics_df.head()


In [None]:
categorical_metrics: list[dict[str, object]] = []
for column in schema.feature_names:
    spec = schema[column]
    if spec.type not in {'cat', 'ordinal'}:
        continue
    suave_codes = pd.Categorical(suave_reconstruction[column]).codes
    hivae_codes = pd.to_numeric(hivae_decoded_df[column], errors='coerce')
    valid_mask = (~pd.isna(suave_codes)) & (~pd.isna(hivae_codes))
    total_valid = int(valid_mask.sum())
    overall_match = float((suave_codes[valid_mask] == hivae_codes[valid_mask]).mean()) if total_valid else np.nan
    missing_mask = (test_missing_mask[column] == 0)
    valid_missing = valid_mask & missing_mask
    valid_missing_total = int(valid_missing.sum())
    missing_match = float((suave_codes[valid_missing] == hivae_codes[valid_missing]).mean()) if valid_missing_total else np.nan
    categorical_metrics.append({
        'column': column,
        'type': spec.type,
        'match_rate_all': overall_match,
        'match_rate_missing': missing_match,
        'n_valid': total_valid,
        'n_missing': int(missing_mask.sum()),
        'n_valid_missing': valid_missing_total,
    })

categorical_metrics_df = pd.DataFrame(categorical_metrics).sort_values('match_rate_all', ascending=True)
categorical_metrics_df.head()


In [None]:
overall_numeric = {
    'mean_mae_all': float(np.nanmean(numeric_metrics_df['mae_all'])) if not numeric_metrics_df.empty else np.nan,
    'mean_rmse_all': float(np.nanmean(numeric_metrics_df['rmse_all'])) if not numeric_metrics_df.empty else np.nan,
    'mean_mae_missing': float(np.nanmean(numeric_metrics_df['mae_missing'])) if not numeric_metrics_df.empty else np.nan,
    'mean_rmse_missing': float(np.nanmean(numeric_metrics_df['rmse_missing'])) if not numeric_metrics_df.empty else np.nan,
}

overall_categorical = {
    'mean_match_all': float(np.nanmean(categorical_metrics_df['match_rate_all'])) if not categorical_metrics_df.empty else np.nan,
    'mean_match_missing': float(np.nanmean(categorical_metrics_df['match_rate_missing'])) if not categorical_metrics_df.empty else np.nan,
}

overall_numeric, overall_categorical


The summary above highlights how closely SUAVE (configured with `behaviour="hivae"`) tracks the original HI-VAE implementation on the same train/test split. Inspect the per-column metrics to spot any variables where the reconstructions diverge meaningfully.