# Testing cVAE as harmonization tool on real brain data

In [1]:
import os.path
%load_ext autoreload
%autoreload 2

In [1]:
from mecvae.lit import RealBrainMeasuresDataModule

data_module = RealBrainMeasuresDataModule('/Users/ssilvari/Downloads/fedcombat_synthetic_data/test/non_harmonized_data.csv', cat_cols=('Sex',), batch_size=6)

In [2]:
data_module.prepare_data()
data_module.setup()

Reading /Users/ssilvari/Downloads/fedcombat_synthetic_data/test/non_harmonized_data.csv...
Number of numerical variables: self.n_num_cols=4
Number of categorical variables: self.n_cat_cols=1
Number of batches: 10


In [3]:
data_module.X

Unnamed: 0,C(Sex)[Female],C(Sex)[Male],"standardize(Q(""Age""))","standardize(Q(""eTIV""))","standardize(Q(""y_0""))","standardize(Q(""y_1""))"
SUB_0_SITE_0,1.0,0.0,2.151577,0.929978,-0.690339,0.665926
SUB_1_SITE_0,1.0,0.0,1.186107,0.927382,-0.902737,-0.193873
SUB_2_SITE_0,0.0,1.0,1.257040,0.450961,1.148109,0.819459
SUB_3_SITE_0,1.0,0.0,0.862002,0.928355,-0.915231,-0.639821
SUB_4_SITE_0,1.0,0.0,1.404043,0.932197,-0.677242,-0.029836
...,...,...,...,...,...,...
SUB_491_SITE_9,1.0,0.0,1.467178,0.928355,-1.454019,-0.104461
SUB_492_SITE_9,1.0,0.0,1.195031,0.929719,-1.254568,-0.342068
SUB_493_SITE_9,1.0,0.0,1.641957,0.926115,-1.546160,0.114817
SUB_494_SITE_9,1.0,0.0,1.207868,0.927285,-1.342822,-0.310449


In [4]:
from mecvae.lit_models.cvae import LitFlexCVAE
import torch.nn as nn

# Import LitCVAE

model = LitFlexCVAE(data_dim=data_module.n_features,
                    conditioning_dim=data_module.n_batches,
                    # lr=1e-5, activation=nn.Tanh(),
                    hidden_dim=[256, 128],
                    z_dim=64,
                    optimizer='adam')

# Test model with batch
x, y = next(iter(data_module.train_dataloader()))
print(x.shape, y.shape)
x_hat, mu, log_var = model(x, y)

torch.Size([6, 6]) torch.Size([6, 10])


In [6]:
# Train model using pytorch lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Define logger
logger = TensorBoardLogger(save_dir='/Users/ssilvari/PycharmProjects/Fed-MECVAE/lightning_logs_synthetic', name='cVAE (Flex)')

# Save best model using a model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Validation/loss',
    dirpath=logger.log_dir,
    filename='cVAE (Flex)-{epoch:02d}-{Validation-loss:.2f}',
    save_top_k=1,
    mode='min',
)

# Define early stopping callback
callbacks = [EarlyStopping(monitor='Validation/loss', patience=10), checkpoint_callback]

In [7]:
# Define trainer
trainer = Trainer(max_epochs=2000, callbacks=callbacks, logger=logger, enable_progress_bar=False, accelerator='cpu',
                  gradient_clip_val=1.0, accumulate_grad_batches=4, precision=16)
chekpoint_model = checkpoint_callback.best_model_path if os.path.exists(checkpoint_callback.best_model_path) else None

# Train model
trainer.fit(model, data_module, ckpt_path=chekpoint_model)

# Load best model
model = LitFlexCVAE.load_from_checkpoint(checkpoint_callback.best_model_path)

  rank_zero_warn(
  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name      | Type       | Params
-----------------------------------------
0 | encoder   | Sequential | 37.2 K
1 | fc_mean   | Linear     | 8.3 K 
2 | fc_logvar | Linear     | 8.3 K 
3 | decoder   | Sequential | 44.2 K
-----------------------------------------
97.9 K    Trainable params
0         Non-trainable params
97.9 K    Total params
0.392     Total estimated model params size (MB)


Reading /Users/ssilvari/Downloads/fedcombat_synthetic_data/test/non_harmonized_data.csv...
Number of numerical variables: self.n_num_cols=4
Number of categorical variables: self.n_cat_cols=1
Number of batches: 10


  rank_zero_warn(
  rank_zero_warn(


Epoch 00089: reducing learning rate of group 0 to 1.0000e-04.
Epoch 00109: reducing learning rate of group 0 to 1.0000e-05.
Epoch 00117: reducing learning rate of group 0 to 1.0000e-06.
Epoch 00125: reducing learning rate of group 0 to 1.0000e-07.


In [8]:
checkpoint_callback.best_model_path

'/Users/ssilvari/PycharmProjects/Fed-MECVAE/lightning_logs_synthetic/cVAE (Flex)/version_5/cVAE (Flex)-epoch=118-Validation-loss=0.00.ckpt'

In [5]:
# Predict the whole dataset removing the batch effect by sampling y from a categorical distribution
import torch

# Compute probabilities of y
y_probs = data_module.y.value_counts(normalize=True).sort_index().values
y_probs = torch.tensor(y_probs, dtype=torch.float32)

# Extract tensors from data_module
x = torch.tensor(data_module.X.values, dtype=torch.float32)

# Create a list of predictions (we'll sample 100 times)
x_hats = []
for _ in range(30):
    with torch.no_grad():
        # Sample y
        y = torch.multinomial(y_probs, len(data_module.y), replacement=True)
        # One hot encode
        y = torch.nn.functional.one_hot(y, num_classes=len(y_probs)) #* 0
        # print(x.shape, y.shape)

        # Predict x
        x_hat, mu, log_var = model(x, y)

        # Append to list
        x_hats.append(x_hat.detach())

# Compute x_hat mean from samples
x_hat_mean = torch.stack(x_hats).mean(dim=0)

In [6]:
# Create a dataframe with the predictions
import os
import re
import pandas as pd

x_hat_df = pd.DataFrame(x_hat_mean.numpy(), columns=data_module.X.columns, index=data_module.df.index)

# Remove covariate columns. Those containing 'Age', 'Sex', 'DX'
x_hat_df = x_hat_df.loc[:, ~x_hat_df.columns.str.contains('Age|Sex|DX')]

# Extract the phenotype name using a regex knwong that the column names are wrapped by something like this: standardize(Q("lh_inferiorparietal_thickness")) where lh_inferiorparietal_thickness is the phenotype name
x_hat_df.columns = [re.search(r'Q\("(.*)"\)', col).group(1) for col in x_hat_df.columns]

# Back transform from the standardization
x_hat_df_destd = x_hat_df * data_module.df[x_hat_df.columns].std() + data_module.df[x_hat_df.columns].mean()
print(x_hat_df_destd.head())

# Join the rest of the columns present in data_module.df that are not in x_hat_df_destd
x_hat_df_destd = x_hat_df_destd.join(data_module.df.loc[:, ~data_module.df.columns.isin(x_hat_df_destd.columns)])
print(x_hat_df_destd)

# save as csv
root_dir = os.path.dirname(data_module.csv_file)
harmonized_csv = root_dir + os.sep + 'harmonized_cVAE.csv'
x_hat_df_destd.to_csv(harmonized_csv)
print(f'Harmonized data saved to {harmonized_csv}')

                     eTIV       y_0        y_1
SUB_0_SITE_0  1816.601460  6.861273  10.393136
SUB_1_SITE_0  1808.650993  7.075057  10.412584
SUB_2_SITE_0  1820.727812  6.619480  10.283878
SUB_3_SITE_0  1806.689113  6.952393  10.096791
SUB_4_SITE_0  1817.098473  6.637516  10.073777
                       eTIV       y_0        y_1        Age     Sex    site
SUB_0_SITE_0    1816.601460  6.861273  10.393136  91.594612  Female  Site 0
SUB_1_SITE_0    1808.650993  7.075057  10.412584  67.314133  Female  Site 0
SUB_2_SITE_0    1820.727812  6.619480  10.283878  69.098038    Male  Site 0
SUB_3_SITE_0    1806.689113  6.952393  10.096791  59.163277  Female  Site 0
SUB_4_SITE_0    1817.098473  6.637516  10.073777  72.794983  Female  Site 0
...                     ...       ...        ...        ...     ...     ...
SUB_491_SITE_9  1798.276896  6.639803  10.423545  74.382774  Female  Site 9
SUB_492_SITE_9  1790.322302  7.088334  10.085405  67.538574  Female  Site 9
SUB_493_SITE_9  1797.781230  6.988

In [7]:
# Add to `methods_params.csv` in the same folder
methods_params_file = root_dir + os.sep + 'methods_params.csv'
methods_params = pd.read_csv(methods_params_file)

# Create new entry
cvae_mparams = pd.DataFrame([{
    'Method': 'cVAE',
    'data_file_path': harmonized_csv,
    'classification_results_path': root_dir + os.sep + 'benchmark cVAE.csv'
}])

methods_params = pd.concat([methods_params, cvae_mparams], ignore_index=True).drop_duplicates()

# save as csv ignore index
methods_params.to_csv(methods_params_file, index=False)

In [8]:
# Print done with date time
from datetime import datetime
print(f'Done at {datetime.now()}')

Done at 2023-07-07 11:13:50.440668
