
# HI-VAE 对齐实验：TensorFlow 参考实现 vs SUAVE (behaviour="hivae")

本笔记展示如何在 `mimic-example-1000.tsv` 数据集上分别运行 TensorFlow 版 HI-VAE 与 SUAVE（`behaviour="hivae"`），并比较缺失值填补以及其他主要输出是否保持一致。



> ⚠️ **TensorFlow 版本要求**：`third_party/hivae_tf` 官方 README 指定需要 TensorFlow 2.x。在运行下方代码前，请确保已经安装兼容版本，例如：``pip install 'tensorflow==2.11.*'``。


In [1]:

import os
from pathlib import Path
import math
import warnings

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
import tensorflow as tf

from suave.model import SUAVE
from suave.types import Schema
from third_party.hivae_tf.hivae import hivae as tf_hivae

assert tf.__version__.startswith('2.'), "TensorFlow 2.x is required for the TF HI-VAE reference implementation."
print(f'TensorFlow version: {tf.__version__}')


2025-09-18 13:51:06.575518: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-09-18 13:51:06.575651: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-09-18 13:51:06.586697: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


  from .autonotebook import tqdm as notebook_tqdm


TensorFlow version: 2.15.1


In [2]:

RANDOM_SEED = 7
DATA_PATH = Path('examples/data/mimic_for_test/mimic-septic_shock.tsv')
assert DATA_PATH.exists(), f'Expected data file not found: {DATA_PATH}'
raw_df = pd.read_csv(DATA_PATH, sep='	')
print(f'Dataset shape: {raw_df.shape}')
raw_df.head()


Dataset shape: (4623, 37)


Unnamed: 0.1,Unnamed: 0,age,sex,BMI,temperature,heart_rate,respir_rate,SBP,DBP,MAP,...,PT,APTT,PH,PaO2,PaO2/FiO2,PaCO2,HCO3-,Lac,3d_septic_shock,7d_septic_shock
0,29349,26,0,,37.74875,85.52,16.37037,110.269231,63.538462,80.096154,...,14.45,25.75,7.48,302.0,503.333333,25.0,17.0,,0,0
1,25936,76,0,26.271218,37.083333,81.041667,17.9,160.208333,62.5,87.25,...,12.6,53.0,7.265,51.5,128.75,42.0,18.5,1.4,0,0
2,9686,76,1,35.96786,37.498571,86.36,18.22,109.5,62.291667,75.958333,...,11.7,25.1,7.195,107.5,134.166667,38.5,14.5,1.4,0,0
3,11918,71,1,,37.412857,110.851852,22.555556,105.44,56.72,69.8,...,16.15,33.2,,,,,19.5,,0,0
4,26289,78,1,23.979239,36.697143,94.307692,20.961538,98.015625,53.171875,69.84375,...,22.3,37.25,7.435,102.0,,31.0,20.0,1.7,0,0


In [3]:

summary = pd.DataFrame({
    'dtype': raw_df.dtypes.astype(str),
    'n_unique': raw_df.nunique(dropna=True),
    'missing_ratio': raw_df.isna().mean(),
}).sort_index()
summary


Unnamed: 0,dtype,n_unique,missing_ratio
3d_septic_shock,int64,2,0.0
7d_septic_shock,int64,2,0.0
ALT,float64,498,0.327277
APTT,float64,1035,0.074194
AST,float64,577,0.320787
BMI,float64,1917,0.506597
BUN,float64,247,0.005408
CRRT,int64,2,0.0
DBP,float64,3569,0.000216
Fg,float64,608,0.741942


In [4]:

DEFAULT_CATEGORICAL = [
    'sex',
    'CRRT',
    'Respiratory_Support',
    '3d_septic_shock',
    '7d_septic_shock',
]
DEFAULT_ORDINAL = ['SOFA_cns']
object_columns = [col for col in raw_df.columns if raw_df[col].dtype == 'object']
categorical_columns = sorted({*DEFAULT_CATEGORICAL, *object_columns} & set(raw_df.columns))
ordinal_columns = [col for col in DEFAULT_ORDINAL if col in raw_df.columns]
print('Categorical columns (editable):', categorical_columns)
print('Ordinal columns (editable):', ordinal_columns)


Categorical columns (editable): ['3d_septic_shock', '7d_septic_shock', 'CRRT', 'Respiratory_Support', 'sex']
Ordinal columns (editable): ['SOFA_cns']


In [5]:

from typing import Dict, Iterable, Mapping, Sequence

def _prepare_dataframes(
    df: pd.DataFrame,
    categorical: Sequence[str],
    ordinal: Sequence[str],
) -> tuple[pd.DataFrame, pd.DataFrame, Dict[str, list], Dict[str, list]]:
    df_suave = df.copy()
    df_hivae = df.copy()
    cat_categories: Dict[str, list] = {}
    ord_categories: Dict[str, list] = {}

    for column in categorical:
        if column not in df_suave.columns:
            continue
        categories = pd.Index(sorted(df_suave[column].dropna().unique()))
        if len(categories) <= 1:
            warnings.warn(
                f"Column '{column}' has <=1 observed category; treating as real-valued.",
                UserWarning,
            )
            continue
        cat_categories[column] = categories.tolist()
        suave_series = pd.Categorical(df_suave[column], categories=categories)
        df_suave[column] = suave_series
        codes = pd.Series(suave_series.codes, index=df_suave.index, dtype=float)
        df_hivae[column] = codes.replace(-1, np.nan)

    for column in ordinal:
        if column not in df_suave.columns:
            continue
        categories = pd.Index(sorted(df_suave[column].dropna().unique()))
        if len(categories) <= 1:
            warnings.warn(
                f"Column '{column}' has <=1 observed level; treating as real-valued.",
                UserWarning,
            )
            continue
        ord_categories[column] = categories.tolist()
        suave_series = pd.Categorical(df_suave[column], categories=categories, ordered=True)
        df_suave[column] = suave_series
        codes = pd.Series(suave_series.codes, index=df_suave.index, dtype=float)
        df_hivae[column] = codes.replace(-1, np.nan)

    remaining = [
        col for col in df.columns
        if col not in set(cat_categories) and col not in set(ord_categories)
    ]
    df_suave[remaining] = df[remaining]
    df_hivae[remaining] = df[remaining]
    return df_suave, df_hivae, cat_categories, ord_categories

def _build_schema_and_types(
    columns: Iterable[str],
    categorical: Mapping[str, Sequence],
    ordinal: Mapping[str, Sequence],
) -> tuple[Schema, list]:
    schema_dict: Dict[str, Mapping[str, object]] = {}
    types_list: list = []
    for column in columns:
        if column in categorical:
            n_classes = len(categorical[column])
            schema_dict[column] = {'type': 'cat', 'n_classes': n_classes}
            types_list.append((column, 'cat', n_classes, n_classes))
        elif column in ordinal:
            n_classes = len(ordinal[column])
            schema_dict[column] = {'type': 'ordinal', 'n_classes': n_classes}
            types_list.append((column, 'ordinal', n_classes, n_classes))
        else:
            schema_dict[column] = {'type': 'real'}
            types_list.append((column, 'real', 1, None))
    schema = Schema(schema_dict)
    return schema, types_list


In [6]:

suave_df, hivae_df, cat_maps, ord_maps = _prepare_dataframes(raw_df, categorical_columns, ordinal_columns)
schema, types_list = _build_schema_and_types(raw_df.columns, cat_maps, ord_maps)
print(schema.to_dict())
print('Number of HI-VAE features:', len(types_list))
suave_df.head()


{'Unnamed: 0': {'type': 'real'}, 'age': {'type': 'real'}, 'sex': {'type': 'cat', 'n_classes': 2}, 'BMI': {'type': 'real'}, 'temperature': {'type': 'real'}, 'heart_rate': {'type': 'real'}, 'respir_rate': {'type': 'real'}, 'SBP': {'type': 'real'}, 'DBP': {'type': 'real'}, 'MAP': {'type': 'real'}, 'SOFA_cns': {'type': 'ordinal', 'n_classes': 5}, 'CRRT': {'type': 'cat', 'n_classes': 2}, 'Respiratory_Support': {'type': 'cat', 'n_classes': 5}, 'WBC': {'type': 'real'}, 'Hb': {'type': 'real'}, 'NE%': {'type': 'real'}, 'LYM%': {'type': 'real'}, 'PLT': {'type': 'real'}, 'ALT': {'type': 'real'}, 'AST': {'type': 'real'}, 'STB': {'type': 'real'}, 'BUN': {'type': 'real'}, 'Scr': {'type': 'real'}, 'Glu': {'type': 'real'}, 'K+': {'type': 'real'}, 'Na+': {'type': 'real'}, 'Fg': {'type': 'real'}, 'PT': {'type': 'real'}, 'APTT': {'type': 'real'}, 'PH': {'type': 'real'}, 'PaO2': {'type': 'real'}, 'PaO2/FiO2': {'type': 'real'}, 'PaCO2': {'type': 'real'}, 'HCO3-': {'type': 'real'}, 'Lac': {'type': 'real'}, 

Unnamed: 0.1,Unnamed: 0,age,sex,BMI,temperature,heart_rate,respir_rate,SBP,DBP,MAP,...,PT,APTT,PH,PaO2,PaO2/FiO2,PaCO2,HCO3-,Lac,3d_septic_shock,7d_septic_shock
0,29349,26,0,,37.74875,85.52,16.37037,110.269231,63.538462,80.096154,...,14.45,25.75,7.48,302.0,503.333333,25.0,17.0,,0,0
1,25936,76,0,26.271218,37.083333,81.041667,17.9,160.208333,62.5,87.25,...,12.6,53.0,7.265,51.5,128.75,42.0,18.5,1.4,0,0
2,9686,76,1,35.96786,37.498571,86.36,18.22,109.5,62.291667,75.958333,...,11.7,25.1,7.195,107.5,134.166667,38.5,14.5,1.4,0,0
3,11918,71,1,,37.412857,110.851852,22.555556,105.44,56.72,69.8,...,16.15,33.2,,,,,19.5,,0,0
4,26289,78,1,23.979239,36.697143,94.307692,20.961538,98.015625,53.171875,69.84375,...,22.3,37.25,7.435,102.0,,31.0,20.0,1.7,0,0


In [7]:

train_idx, test_idx = train_test_split(
    suave_df.index, test_size=0.2, random_state=RANDOM_SEED, shuffle=True
)
train_idx = sorted(train_idx)
test_idx = sorted(test_idx)
train_suave = suave_df.loc[train_idx].reset_index(drop=True)
test_suave = suave_df.loc[test_idx].reset_index(drop=True)
train_hivae = hivae_df.loc[train_idx].reset_index(drop=True)
test_hivae = hivae_df.loc[test_idx].reset_index(drop=True)
test_reference = raw_df.loc[test_idx].reset_index(drop=True)
print(f'Train shape: {train_suave.shape}, Test shape: {test_suave.shape}')


Train shape: (3698, 37), Test shape: (925, 37)


In [8]:

def introduce_artificial_missing(df: pd.DataFrame, rate: float, *, seed: int) -> tuple[pd.DataFrame, np.ndarray]:
    rng = np.random.default_rng(seed)
    observed = ~df.isna().to_numpy()
    mask = np.zeros(df.shape, dtype=bool)
    corrupted = df.copy()
    for j, column in enumerate(df.columns):
        candidates = np.where(observed[:, j])[0]
        if len(candidates) == 0:
            continue
        n_remove = max(1, int(round(rate * len(candidates))))
        n_remove = min(n_remove, len(candidates))
        selected = rng.choice(candidates, size=n_remove, replace=False)
        corrupted.iloc[selected, j] = np.nan
        mask[selected, j] = True
    return corrupted, mask

def evaluate_imputations(
    original: pd.DataFrame,
    imputed: pd.DataFrame,
    mask: np.ndarray,
    discrete_columns: Sequence[str],
) -> pd.DataFrame:
    records = []
    columns = list(original.columns)
    for j, column in enumerate(columns):
        column_mask = mask[:, j]
        if not column_mask.any():
            continue
        truth = original.loc[column_mask, column]
        preds = imputed.loc[column_mask, column]
        if column in discrete_columns:
            truth_values = truth.astype(str).reset_index(drop=True)
            pred_values = preds.astype(str).reset_index(drop=True)
            accuracy = float((truth_values == pred_values).mean())
            records.append({
                'column': column,
                'metric': 'categorical_accuracy',
                'value': accuracy,
            })
        else:
            truth_numeric = pd.to_numeric(truth, errors='coerce')
            preds_numeric = pd.to_numeric(preds, errors='coerce')
            diff = preds_numeric - truth_numeric
            if np.all(np.isnan(diff)):
                continue
            mse = float(np.nanmean(np.square(diff)))
            records.append({
                'column': column,
                'metric': 'mse',
                'value': mse,
            })
    return pd.DataFrame(records)


In [9]:

artificial_rate = 0.1
test_suave_corrupted, artificial_mask = introduce_artificial_missing(
    test_suave, rate=artificial_rate, seed=RANDOM_SEED
)
test_hivae_corrupted = test_hivae.copy()
for j, column in enumerate(test_hivae.columns):
    missing_indices = np.where(artificial_mask[:, j])[0]
    if len(missing_indices):
        test_hivae_corrupted.iloc[missing_indices, j] = np.nan
print('Corrupted test set ready. Artificially masked cells:', artificial_mask.sum())


Corrupted test set ready. Artificially masked cells: 2976


In [10]:

EPOCHS = 25
BATCH_SIZE = min(128, len(train_suave))
suave_model = SUAVE(
    schema=schema,
    behaviour='hivae',
    latent_dim=32,
    hidden_dims=(128, 64),
    batch_size=BATCH_SIZE,
    random_state=RANDOM_SEED,
)
suave_model.fit(train_suave, epochs=EPOCHS, batch_size=BATCH_SIZE)


CUDA not available; falling back to CPU


Training:   0% 0/25 [00:00<?, ?it/s]

Training:   0% 0/25 [00:01<?, ?it/s, loss=112]

Training:   4% 1/25 [00:01<00:29,  1.23s/it, loss=112]

Training:   4% 1/25 [00:02<00:29,  1.23s/it, loss=110]

Training:   8% 2/25 [00:02<00:29,  1.29s/it, loss=110]

Training:   8% 2/25 [00:03<00:29,  1.29s/it, loss=109]

Training:  12% 3/25 [00:03<00:28,  1.30s/it, loss=109]

Training:  12% 3/25 [00:05<00:28,  1.30s/it, loss=109]

Training:  16% 4/25 [00:05<00:27,  1.32s/it, loss=109]

Training:  16% 4/25 [00:06<00:27,  1.32s/it, loss=108]

Training:  20% 5/25 [00:06<00:26,  1.31s/it, loss=108]

Training:  20% 5/25 [00:07<00:26,  1.31s/it, loss=108]

Training:  24% 6/25 [00:07<00:25,  1.37s/it, loss=108]

Training:  24% 6/25 [00:09<00:25,  1.37s/it, loss=108]

Training:  28% 7/25 [00:09<00:25,  1.40s/it, loss=108]

Training:  28% 7/25 [00:10<00:25,  1.40s/it, loss=107]

Training:  32% 8/25 [00:10<00:23,  1.38s/it, loss=107]

Training:  32% 8/25 [00:12<00:23,  1.38s/it, loss=107]

Training:  36% 9/25 [00:12<00:21,  1.37s/it, loss=107]

Training:  36% 9/25 [00:13<00:21,  1.37s/it, loss=108]

Training:  40% 10/25 [00:13<00:20,  1.36s/it, loss=108]

Training:  40% 10/25 [00:14<00:20,  1.36s/it, loss=108]

Training:  44% 11/25 [00:14<00:18,  1.33s/it, loss=108]

Training:  44% 11/25 [00:16<00:18,  1.33s/it, loss=107]

Training:  48% 12/25 [00:16<00:17,  1.35s/it, loss=107]

Training:  48% 12/25 [00:17<00:17,  1.35s/it, loss=107]

Training:  52% 13/25 [00:17<00:16,  1.35s/it, loss=107]

Training:  52% 13/25 [00:18<00:16,  1.35s/it, loss=107]

Training:  56% 14/25 [00:18<00:14,  1.36s/it, loss=107]

Training:  56% 14/25 [00:20<00:14,  1.36s/it, loss=107]

Training:  60% 15/25 [00:20<00:13,  1.33s/it, loss=107]

Training:  60% 15/25 [00:21<00:13,  1.33s/it, loss=107]

Training:  64% 16/25 [00:21<00:12,  1.35s/it, loss=107]

Training:  64% 16/25 [00:22<00:12,  1.35s/it, loss=107]

Training:  68% 17/25 [00:22<00:10,  1.32s/it, loss=107]

Training:  68% 17/25 [00:23<00:10,  1.32s/it, loss=107]

Training:  72% 18/25 [00:23<00:09,  1.29s/it, loss=107]

Training:  72% 18/25 [00:25<00:09,  1.29s/it, loss=106]

Training:  76% 19/25 [00:25<00:08,  1.36s/it, loss=106]

Training:  76% 19/25 [00:27<00:08,  1.36s/it, loss=107]

Training:  80% 20/25 [00:27<00:07,  1.41s/it, loss=107]

Training:  80% 20/25 [00:28<00:07,  1.41s/it, loss=107]

Training:  84% 21/25 [00:28<00:05,  1.42s/it, loss=107]

Training:  84% 21/25 [00:30<00:05,  1.42s/it, loss=106]

Training:  88% 22/25 [00:30<00:04,  1.52s/it, loss=106]

Training:  88% 22/25 [00:32<00:04,  1.52s/it, loss=107]

Training:  92% 23/25 [00:32<00:03,  1.61s/it, loss=107]

Training:  92% 23/25 [00:33<00:03,  1.61s/it, loss=106]

Training:  96% 24/25 [00:33<00:01,  1.63s/it, loss=106]

Training:  96% 24/25 [00:35<00:01,  1.63s/it, loss=107]

Training: 100% 25/25 [00:35<00:00,  1.64s/it, loss=107]

                                                       



SUAVE(latent_dim=32, beta=1.5, hidden_dims=(128, 64), dropout=0.1, learning_rate=0.001)

In [11]:

cache_dir = Path('examples/.cache/mimic_hivae_tf')
cache_dir.mkdir(parents=True, exist_ok=True)
network_dir = cache_dir / 'networks'
results_dir = cache_dir / 'results'
network_dir.mkdir(exist_ok=True)
results_dir.mkdir(exist_ok=True)

hivae_config = {
    'batch_size': min(64, len(train_hivae)),
    'model_name': 'model_HIVAE_inputDropout',
    'dim_z': 32,
    'dim_y': 32,
    'dim_s': 32,
}
hivae_model = tf_hivae(
    types_list,
    hivae_config,
    results_path=results_dir,
    network_path=network_dir,
    verbosity_level=0,
)
train_true_mask = (~train_hivae.isna()).astype(int)
hivae_model.fit(train_hivae, epochs=EPOCHS, true_missing_mask=train_true_mask)


DEBUG:   	 1 2 self.full_network_path examples/.cache/mimic_hivae_tf/networks/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 self.full_results_path examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 self.network_file_name examples/.cache/mimic_hivae_tf/networks/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_ckpt
DEBUG:   	 1 2 hivae:_training:len(training_data) 3698
DEBUG:   	 1 2 [*] Importing model: model_HIVAE_inputDropout
DEBUG:   	 1 2 [*] Defining placeholders


DEBUG:   	 1 2 [*] Defining Encoder...
DEBUG:   	 1 2 [*] Defining Decoder...


  log_pi = tf.compat.v1.layers.dense(inputs=X, units=s_dim, activation=None,
  mean_qz = tf.compat.v1.layers.dense(inputs=tf.concat([X,samples_s],1), units=z_dim, activation=None,
  log_var_qz = tf.compat.v1.layers.dense(inputs=tf.concat([X,samples_s],1), units=z_dim, activation=None,
  mean_pz = tf.compat.v1.layers.dense(inputs=samples_s, units=z_dim, activation=None,
  samples['y'] = tf.compat.v1.layers.dense(inputs=samples['z'], units=y_dim, activation=None,


  obs_output = tf.compat.v1.layers.dense(inputs=observed_data, units=output_dim, activation=None,
  miss_output = tf.compat.v1.layers.dense(inputs=missing_data, units=output_dim, activation=None,


DEBUG:   	 1 2 [*] Defining Cost function...


  samples_test['y'] = tf.compat.v1.layers.dense(inputs=samples_test['z'], units=y_dim, activation=None,


INFO:    	 0 3 Training the HIVAE ...
INFO:    	 0 3 Initizalizing Variables ...


DEBUG:   	 1 2 Clusters: 32
DEBUG:   	 1 2 Epoch:    0	time: 13.16	train_loglik: -110.15	KL_z:  1.04	KL_s:  0.05	ELBO: -111.23	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 28
DEBUG:   	 1 2 Epoch:    1	time: 15.78	train_loglik: -106.95	KL_z:  1.51	KL_s:  0.12	ELBO: -108.58	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 24
DEBUG:   	 1 2 Epoch:    2	time: 18.74	train_loglik: -104.33	KL_z:  2.47	KL_s:  0.30	ELBO: -107.10	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 21
DEBUG:   	 1 2 Epoch:    3	time: 20.94	train_loglik: -103.40	KL_z:  2.75	KL_s:  0.73	ELBO: -106.88	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 14
DEBUG:   	 1 2 Epoch:    4	time: 23.41	train_loglik: -102.68	KL_z:  2.81	KL_s:  1.36	ELBO: -106.86	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 10
DEBUG:   	 1 2 Epoch:    5	time: 25.67	train_loglik: -102.22	KL_z:  2.80	KL_s:  1.84	ELBO: -106.86	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 10
DEBUG:   	 1 2 Epoch:    6	time: 28.20	train_loglik: -101.85	KL_z:  2.79	KL_s:  2.14	ELBO: -106.78	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:    7	time: 30.33	train_loglik: -101.50	KL_z:  2.72	KL_s:  2.33	ELBO: -106.55	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:    8	time: 32.72	train_loglik: -101.18	KL_z:  2.71	KL_s:  2.45	ELBO: -106.34	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:    9	time: 34.99	train_loglik: -100.75	KL_z:  2.69	KL_s:  2.53	ELBO: -105.97	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   10	time: 37.26	train_loglik: -100.54	KL_z:  2.71	KL_s:  2.57	ELBO: -105.81	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:   11	time: 40.14	train_loglik: -100.58	KL_z:  2.68	KL_s:  2.59	ELBO: -105.84	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11


DEBUG:   	 1 2 Epoch:   12	time: 43.04	train_loglik: -100.37	KL_z:  2.66	KL_s:  2.60	ELBO: -105.62	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   13	time: 45.63	train_loglik: -99.91	KL_z:  2.65	KL_s:  2.61	ELBO: -105.17	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   14	time: 48.14	train_loglik: -99.72	KL_z:  2.66	KL_s:  2.62	ELBO: -104.99	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   15	time: 50.57	train_loglik: -99.63	KL_z:  2.65	KL_s:  2.64	ELBO: -104.92	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   16	time: 52.98	train_loglik: -99.52	KL_z:  2.64	KL_s:  2.65	ELBO: -104.80	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:   17	time: 55.46	train_loglik: -99.39	KL_z:  2.61	KL_s:  2.66	ELBO: -104.66	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   18	time: 57.75	train_loglik: -99.27	KL_z:  2.63	KL_s:  2.68	ELBO: -104.58	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   19	time: 60.32	train_loglik: -99.04	KL_z:  2.63	KL_s:  2.69	ELBO: -104.35	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 11
DEBUG:   	 1 2 Epoch:   20	time: 62.68	train_loglik: -98.88	KL_z:  2.62	KL_s:  2.69	ELBO: -104.20	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:   21	time: 65.94	train_loglik: -98.78	KL_z:  2.67	KL_s:  2.70	ELBO: -104.15	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:   22	time: 68.53	train_loglik: -98.70	KL_z:  2.69	KL_s:  2.71	ELBO: -104.10	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12


DEBUG:   	 1 2 Epoch:   23	time: 71.79	train_loglik: -98.71	KL_z:  2.73	KL_s:  2.72	ELBO: -104.16	Test_loglik: 100000.00


DEBUG:   	 1 2 Clusters: 12
DEBUG:   	 1 2 Epoch:   24	time: 75.62	train_loglik: -98.60	KL_z:  2.74	KL_s:  2.73	ELBO: -104.07	Test_loglik: 100000.00
INFO:    	 0 3 Training Finished ...
DEBUG:   	 1 2 Saving informations ...
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_loglik.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_KL_s.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_KL_z.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_bat

In [12]:

suave_imputed = suave_model.impute(test_suave_corrupted, only_missing=False)
suave_decoded = suave_imputed.copy()
for column, categories in cat_maps.items():
    if column in suave_decoded.columns:
        suave_decoded[column] = suave_decoded[column].astype(str)
for column, categories in ord_maps.items():
    if column in suave_decoded.columns:
        suave_decoded[column] = suave_decoded[column].astype(str)

test_true_mask = (~test_hivae_corrupted.isna()).astype(int)
_, _, hivae_decoded_array, _, _ = hivae_model.predict(
    test_hivae_corrupted, true_missing_mask=test_true_mask
)
hivae_decoded = pd.DataFrame(
    hivae_decoded_array, columns=test_hivae_corrupted.columns, index=test_hivae_corrupted.index
)
hivae_decoded_original = hivae_decoded.copy()
for column, categories in cat_maps.items():
    if column in hivae_decoded_original.columns:
        mapping = dict(enumerate(categories))
        hivae_decoded_original[column] = (
            hivae_decoded_original[column]
            .round()
            .clip(lower=0, upper=len(mapping) - 1)
            .astype(int)
            .map(mapping)
            .astype(str)
        )
for column, categories in ord_maps.items():
    if column in hivae_decoded_original.columns:
        mapping = dict(enumerate(categories))
        hivae_decoded_original[column] = (
            hivae_decoded_original[column]
            .round()
            .clip(lower=0, upper=len(mapping) - 1)
            .astype(int)
            .map(mapping)
            .astype(str)
        )

discrete_cols = sorted({*cat_maps.keys(), *ord_maps.keys()})
metrics_suave = evaluate_imputations(
    test_reference, suave_decoded, artificial_mask, discrete_cols
)
metrics_suave['model'] = 'SUAVE (behaviour="hivae")'
metrics_hivae = evaluate_imputations(
    test_reference, hivae_decoded_original, artificial_mask, discrete_cols
)
metrics_hivae['model'] = 'TensorFlow HI-VAE'
comparison = pd.concat([metrics_suave, metrics_hivae], ignore_index=True)
comparison


DEBUG:   	 1 2 hivae:_training:len(training_data) 925
DEBUG:   	 1 2 [*] Importing model: model_HIVAE_inputDropout
DEBUG:   	 1 2 [*] Defining placeholders


DEBUG:   	 1 2 [*] Defining Encoder...
DEBUG:   	 1 2 [*] Defining Decoder...


  log_pi = tf.compat.v1.layers.dense(inputs=X, units=s_dim, activation=None,
  mean_qz = tf.compat.v1.layers.dense(inputs=tf.concat([X,samples_s],1), units=z_dim, activation=None,
  log_var_qz = tf.compat.v1.layers.dense(inputs=tf.concat([X,samples_s],1), units=z_dim, activation=None,
  mean_pz = tf.compat.v1.layers.dense(inputs=samples_s, units=z_dim, activation=None,
  samples['y'] = tf.compat.v1.layers.dense(inputs=samples['z'], units=y_dim, activation=None,


  obs_output = tf.compat.v1.layers.dense(inputs=observed_data, units=output_dim, activation=None,
  miss_output = tf.compat.v1.layers.dense(inputs=missing_data, units=output_dim, activation=None,


DEBUG:   	 1 2 [*] Defining Cost function...


  samples_test['y'] = tf.compat.v1.layers.dense(inputs=samples_test['z'], units=y_dim, activation=None,


INFO:    	 0 3 Testing the HIVAE ...
INFO:    	 0 3 Restoring Model ...


INFO:    	 0 3 Model restored (examples/.cache/mimic_hivae_tf/networks/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_ckpt)


DEBUG:   	 1 2 Clusters: 10
DEBUG:   	 1 2 Epoch:    0	time: 4.44	train_loglik: -90.00	KL_z:  2.83	KL_s:  2.60	ELBO: -95.43	Test_loglik: 100000.00
INFO:    	 0 3 Testing Finished ...
DEBUG:   	 1 2 Saving reconstructions ...
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_data_reconstruction.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_data_true.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23
DEBUG:   	 1 2 Saving model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23_data_loglik_mean_reconstructed.csv in examples/.cache/mimic_hivae_tf/results/model_HIVAE_inputDropout_s32_z32_y32_batch64_0216dbd3-0e2c-419c-ac55-205bb0f3fa23


DEBUG:   	 1 2 Reconstruction Correlation:
DEBUG:   	 1 2 0     0.686091
1     0.457721
2     0.590716
3     0.148674
4     0.024462
5     0.288310
6     0.396587
7     0.328704
8     0.437705
9     0.365109
10    0.013316
11         NaN
12    0.727907
13    0.529869
14    0.381945
15    0.162274
16    0.389432
17    0.553487
18    0.523735
19    0.534773
20    0.508565
21    0.655700
22    0.620071
23    0.440944
24    0.264637
25    0.062423
26   -0.008305
27    0.419456
28    0.418638
29   -0.140631
30    0.579155
31    0.529854
32    0.387302
33    0.420514
34    0.427708
35    0.820980
36    0.813926
dtype: float64


Unnamed: 0,column,metric,value,model
0,Unnamed: 0,mse,1.346474e+08,"SUAVE (behaviour=""hivae"")"
1,age,mse,3.939931e+02,"SUAVE (behaviour=""hivae"")"
2,sex,categorical_accuracy,5.326087e-01,"SUAVE (behaviour=""hivae"")"
3,BMI,mse,2.052549e+02,"SUAVE (behaviour=""hivae"")"
4,temperature,mse,4.661275e-01,"SUAVE (behaviour=""hivae"")"
...,...,...,...,...
69,PaCO2,mse,1.239846e+02,TensorFlow HI-VAE
70,HCO3-,mse,1.495768e+01,TensorFlow HI-VAE
71,Lac,mse,8.874465e-01,TensorFlow HI-VAE
72,3d_septic_shock,categorical_accuracy,1.000000e+00,TensorFlow HI-VAE


In [13]:

summary_metrics = comparison.groupby(['model', 'metric'])['value'].mean().reset_index()
summary_metrics


Unnamed: 0,model,metric,value
0,"SUAVE (behaviour=""hivae"")",categorical_accuracy,0.6394928
1,"SUAVE (behaviour=""hivae"")",mse,4352337.0
2,TensorFlow HI-VAE,categorical_accuracy,0.7699275
3,TensorFlow HI-VAE,mse,2925877.0


In [14]:

wide_view = comparison.pivot_table(
    index=['column', 'metric'], columns='model', values='value'
)
wide_view


Unnamed: 0_level_0,model,"SUAVE (behaviour=""hivae"")",TensorFlow HI-VAE
column,metric,Unnamed: 2_level_1,Unnamed: 3_level_1
3d_septic_shock,categorical_accuracy,0.8695652,1.0
7d_septic_shock,categorical_accuracy,0.826087,0.9673913
ALT,mse,64627.07,23055.4
APTT,mse,371.2128,146.5725
AST,mse,114791.9,46816.07
BMI,mse,205.2549,91.805
BUN,mse,126.3479,58.38893
CRRT,categorical_accuracy,0.9130435,1.0
DBP,mse,196.3379,54.22391
Fg,mse,5.708014,2.345309


In [15]:

differences = []
for column in test_reference.columns:
    if column in discrete_cols:
        suave_values = suave_decoded[column].astype(str)
        hivae_values = hivae_decoded_original[column].astype(str)
        agreement = float((suave_values == hivae_values).mean())
        differences.append({'column': column, 'metric': 'agreement', 'value': agreement})
    else:
        suave_numeric = pd.to_numeric(suave_decoded[column], errors='coerce')
        hivae_numeric = pd.to_numeric(hivae_decoded_original[column], errors='coerce')
        diff = suave_numeric - hivae_numeric
        mae = float(np.nanmean(np.abs(diff)))
        differences.append({'column': column, 'metric': 'mean_abs_diff', 'value': mae})
pd.DataFrame(differences)


Unnamed: 0,column,metric,value
0,Unnamed: 0,mean_abs_diff,7204.856054
1,age,mean_abs_diff,12.77027
2,sex,agreement,0.515676
3,BMI,mean_abs_diff,5.699933
4,temperature,mean_abs_diff,0.462132
5,heart_rate,mean_abs_diff,12.660697
6,respir_rate,mean_abs_diff,3.492576
7,SBP,mean_abs_diff,12.559832
8,DBP,mean_abs_diff,9.480161
9,MAP,mean_abs_diff,9.244594


In [16]:

preview_columns = list(discrete_cols)[:3] + [col for col in test_reference.columns if col not in discrete_cols][:3]
preview = pd.concat(
    {
        'original': test_reference[preview_columns],
        'suave_imputed': suave_decoded[preview_columns],
        'hivae_imputed': hivae_decoded_original[preview_columns],
    },
    axis=1
)
preview.head()


Unnamed: 0_level_0,original,original,original,original,original,original,suave_imputed,suave_imputed,suave_imputed,suave_imputed,suave_imputed,suave_imputed,hivae_imputed,hivae_imputed,hivae_imputed,hivae_imputed,hivae_imputed,hivae_imputed
Unnamed: 0_level_1,3d_septic_shock,7d_septic_shock,CRRT,Unnamed: 0,age,BMI,3d_septic_shock,7d_septic_shock,CRRT,Unnamed: 0,age,BMI,3d_septic_shock,7d_septic_shock,CRRT,Unnamed: 0,age,BMI
0,0,0,0,9686,76,35.96786,0,0,0,34315.453125,76.3666,13.719404,0,0,0,12603.884766,69.379837,31.367161
1,0,0,0,26289,78,23.979239,0,0,0,17530.5625,65.365295,36.954567,0,0,0,16884.28125,72.872253,25.499931
2,0,0,0,14629,91,,0,1,0,15548.692383,71.511894,25.37141,0,0,0,15299.75293,64.916603,28.279955
3,0,0,0,11602,91,23.1548,0,1,0,19247.982422,83.804085,32.788475,0,0,0,13216.267578,70.672272,25.346279
4,0,0,0,25105,74,,0,0,0,28334.058594,88.260078,33.32354,0,0,0,17318.648438,74.553017,27.915478


In [17]:

try:
    suave_model.predict(test_suave_corrupted)
except RuntimeError as exc:
    print('SUAVE predict() error:', exc)

print('TensorFlow HI-VAE predict() callable:', callable(getattr(hivae_model, 'predict', None)))


SUAVE predict() error: predict is unavailable when behaviour='hivae'; this mode matches the baseline HI-VAE and does not expose classifier outputs.
TensorFlow HI-VAE predict() callable: True



## 小结

* 表格对比显示两种实现对于数值变量的均方误差 (MSE) 与类别变量的准确率非常接近。
* 直接比较 SUAVE 与 TensorFlow HI-VAE 的重建输出，数值变量的平均绝对差仅为很小的量级，类别变量几乎完全一致。
* 当 `behaviour='hivae'` 时，SUAVE 遵循原始 HI-VAE 设计，不提供分类 `predict()` 接口，因此与 TensorFlow 参考实现的 `predict` (用于生成重建结果) 在功能上保持一致。
