TVAE Runner (SDV-based)

This script trains a TVAE model on single-cell (tabular) data using the SDV library. It mirrors the CTGAN workflow and supports the same three study scenarios used in the paper.

PBMC3K:

Trains TVAE on real train samples with 100 epochs.

Generates synthetic samples equal to TEST sample size.

Saves: pbmc3k_TVAE_100epc.pkl

HCA-BM10K (5-fold CV)

Integrated Pancreatic dataset (5-fold CV)

For EACH FOLD:

Fit TVAE on TRAIN ONLY.

Generate synthetic samples to reach the Q3 (75th percentile) count of the corresponding cell-type distribution
(per-class if --label-col is provided; labels are assigned accordingly).

Augment TRAIN with synthetic data. VALIDATION/TEST are NEVER touched.

Saves per-fold files under: {output}/folds/fold_{i}/

dictionary includes each fold objects: train_gen, y_train_gen

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
pip install sdv --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/185.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m185.6/185.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m125.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.8/73.8 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.1/198.1 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m89.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import sdv
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [None]:
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.single_table import CopulaGANSynthesizer

from sdv.metadata import Metadata
from sdv.sampling import Condition as cd

PBMC3K

In [None]:
import pickle
with open(f"data/pbmc3k_train.pkl", "rb") as f:
    X_train = pickle.load(f)
with open(f"data/pbmc3k_test.pkl", "rb") as f:
    X_test = pickle.load(f)
with open(f"data/pbmc3k_y_train.pkl", "rb") as f:
    y_train = pickle.load(f)
with open(f"data/pbmc3k_y_test.pkl", "rb") as f:
    y_test = pickle.load(f)

5CV Pancrea

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"data/5CV/fold_skf_3000_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)


In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 2:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 3:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 4:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 5:
X_train shape: (11338, 3000)
y_train shape: (11338,)


In [None]:
unique_values, counts = np.unique(all_folds[0]['y_train'], return_counts=True)
display(dict(zip(unique_values, counts)),np.max(counts))

{'PSC': np.int64(41),
 'acinar': np.int64(1089),
 'activated_stellate': np.int64(231),
 'alpha': np.int64(3876),
 'beta': np.int64(2975),
 'delta': np.int64(740),
 'ductal': np.int64(1353),
 'endothelial': np.int64(229),
 'epsilon': np.int64(16),
 'gamma': np.int64(350),
 'macrophage': np.int64(48),
 'mast': np.int64(23),
 'mesenchymal': np.int64(69),
 'pp': np.int64(150),
 'quiescent_stellate': np.int64(135),
 'schwann': np.int64(12)}

np.int64(3876)

HCA

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"HCA/5CV/fold_skf_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)



TVAE

In [None]:
from sdv.single_table import TVAESynthesizer

def tVAE_method(X_min,num_synthetic_samples, metadata):
  tvae = TVAESynthesizer(metadata,epochs=100,verbose=True)
  tvae.fit(X_min)
  # Create synthetic data
  synthetic_data = tvae.sample(num_synthetic_samples)
  return synthetic_data

PBMC3K without class info

In [None]:
metadata = Metadata()
metadata = metadata.detect_from_dataframe(X_train)
numeric_cols = list(X_train.columns)
for col in numeric_cols:
  metadata.update_column(
  column_name=col,
  sdtype='numerical'
  )
synthetic_samples = tVAE_method(X_train, int(X_test.shape[0]),metadata )

Loss: -8162.612: 100%|██████████| 100/100 [16:06<00:00,  9.66s/it]


In [None]:
synthetic_samples.shape

(264, 2000)

In [None]:
with open(g"results" + os.sep + f'pbmc3k_TVAE_100epc.pkl', 'wb') as f:
          pickle.dump(synthetic_samples, f)

5CV

In [None]:

def TVAE_CV():
  gen_dict = []
  #for k, fold in enumerate(all_folds, start=1):
  for k, fold in enumerate(all_folds, start=1):
      X_train = fold['X_train']
      X_val = fold['X_val']
      y_train = fold['y_train']
      y_val = fold['y_val']

      metadata = Metadata()
      metadata = metadata.detect_from_dataframe(X_train)
      numeric_cols = list(X_train.columns)
      for col in numeric_cols:
          metadata.update_column(
              column_name=col,
              sdtype='numerical'
          )
      # Get unique values and their counts
      unique_values, counts = np.unique(y_train, return_counts=True)
      classlabel_counts = dict(zip(unique_values, counts))
      Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
      max_count = Q3 #np.max(counts)

      i=1
      for label, count in classlabel_counts.items():
          #print(x.shape, y.shape)
          #print("label, count, max_count",label,count,max_count)
          X_minority = X_train[y_train == label]
          if count < max_count:
              #print("\n")
              print("label, count, max_count",label,count,max_count)
              num_synthetic_samples = max_count - count
              synthetic_samples = tVAE_method(X_minority, num_synthetic_samples,metadata )

              X_minority = np.array(X_minority)
              synthetic_samples = np.array(synthetic_samples)
              if i==1:
                X_train_gen = X_minority
                X_train_gen = np.vstack([X_train_gen, synthetic_samples])

                y_train_gen = np.full(max_count, label)

                y_train_indexes = np.full(count, 1)
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
              else:
                tmp = np.vstack([X_minority, synthetic_samples])
                X_train_gen = np.vstack([X_train_gen, tmp])
                y_train_gen = np.concatenate([y_train_gen, np.full(max_count, label)])

                y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
          else:
              X_train_gen = np.vstack([X_train_gen, X_minority])
              y_train_gen = np.concatenate([y_train_gen, np.full(count, label)])
              y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
          i=i+1

      syn = {
          'X_train_gen': X_train_gen,
          'y_train_gen': y_train_gen
      }
      gen_dict.append(syn)

      with open(gdrivePath + os.sep + "Revision" + os.sep + "HCA/5CV" + os.sep + f'TVAE_skf_fold'+str(k)+'.pkl', 'wb') as f:
          pickle.dump(syn, f)
      print(f"TVAE results for fold {k} saved.")

  return gen_dict



In [None]:
gen_dict = TVAE_CV()

label, count, max_count 1 476 670


Loss: -7264.120: 100%|██████████| 100/100 [04:50<00:00,  2.90s/it]


label, count, max_count 2 399 670


Loss: -14066.626: 100%|██████████| 100/100 [04:41<00:00,  2.82s/it]


label, count, max_count 6 275 670


Loss: -6507.839: 100%|██████████| 100/100 [04:44<00:00,  2.85s/it]


label, count, max_count 8 470 670


Loss: -11453.591: 100%|██████████| 100/100 [04:45<00:00,  2.85s/it]


label, count, max_count 10 202 670


Loss: -13321.816: 100%|██████████| 100/100 [04:40<00:00,  2.81s/it]


label, count, max_count 11 377 670


Loss: -12701.664: 100%|██████████| 100/100 [04:41<00:00,  2.82s/it]


label, count, max_count 12 91 670


Loss: -10518.201: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 13 107 670


Loss: -7749.180: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 14 80 670


Loss: -10400.285: 100%|██████████| 100/100 [04:40<00:00,  2.81s/it]


label, count, max_count 15 121 670


Loss: -17046.598: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 16 59 670


Loss: -11962.193: 100%|██████████| 100/100 [04:40<00:00,  2.80s/it]


TVAE results for fold 1 saved.
label, count, max_count 1 476 671


Loss: -7272.464: 100%|██████████| 100/100 [04:45<00:00,  2.86s/it]


label, count, max_count 2 399 671


Loss: -14158.539: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 6 274 671


Loss: -6500.896: 100%|██████████| 100/100 [04:43<00:00,  2.84s/it]


label, count, max_count 8 470 671


Loss: -11307.372: 100%|██████████| 100/100 [04:43<00:00,  2.83s/it]


label, count, max_count 10 202 671


Loss: -13486.954: 100%|██████████| 100/100 [04:42<00:00,  2.83s/it]


label, count, max_count 11 377 671


Loss: -12676.441: 100%|██████████| 100/100 [04:43<00:00,  2.83s/it]


label, count, max_count 12 91 671


Loss: -10677.445: 100%|██████████| 100/100 [04:40<00:00,  2.80s/it]


label, count, max_count 13 107 671


Loss: -7885.081: 100%|██████████| 100/100 [04:39<00:00,  2.79s/it]


label, count, max_count 14 80 671


Loss: -10707.492: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 15 121 671


Loss: -17321.734: 100%|██████████| 100/100 [04:41<00:00,  2.82s/it]


label, count, max_count 16 59 671


Loss: -11960.564: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


TVAE results for fold 2 saved.
label, count, max_count 1 476 671


Loss: -7236.382: 100%|██████████| 100/100 [04:44<00:00,  2.84s/it]


label, count, max_count 2 399 671


Loss: -14032.726: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 6 274 671


Loss: -6853.242: 100%|██████████| 100/100 [04:44<00:00,  2.84s/it]


label, count, max_count 8 469 671


Loss: -11236.532: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 10 203 671


Loss: -13479.576: 100%|██████████| 100/100 [04:40<00:00,  2.81s/it]


label, count, max_count 11 376 671


Loss: -12631.296: 100%|██████████| 100/100 [04:42<00:00,  2.83s/it]


label, count, max_count 12 92 671


Loss: -10510.754: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 13 106 671


Loss: -7591.417: 100%|██████████| 100/100 [04:39<00:00,  2.79s/it]


label, count, max_count 14 80 671


Loss: -10602.875: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 15 122 671


Loss: -17439.414: 100%|██████████| 100/100 [04:39<00:00,  2.79s/it]


label, count, max_count 16 58 671


Loss: -11955.771: 100%|██████████| 100/100 [04:39<00:00,  2.79s/it]


TVAE results for fold 3 saved.
label, count, max_count 1 476 670


Loss: -7235.811: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 2 399 670


Loss: -14075.387: 100%|██████████| 100/100 [04:41<00:00,  2.81s/it]


label, count, max_count 6 274 670


Loss: -6655.973: 100%|██████████| 100/100 [04:42<00:00,  2.83s/it]


label, count, max_count 8 469 670


Loss: -11323.518: 100%|██████████| 100/100 [04:41<00:00,  2.81s/it]


label, count, max_count 10 203 670


Loss: -13335.234: 100%|██████████| 100/100 [04:41<00:00,  2.82s/it]


label, count, max_count 11 377 670


Loss: -12600.188: 100%|██████████| 100/100 [04:42<00:00,  2.83s/it]


label, count, max_count 12 91 670


Loss: -10653.310: 100%|██████████| 100/100 [04:38<00:00,  2.78s/it]


label, count, max_count 13 106 670


Loss: -7659.338: 100%|██████████| 100/100 [04:38<00:00,  2.79s/it]


label, count, max_count 14 80 670


Loss: -11013.092: 100%|██████████| 100/100 [04:39<00:00,  2.80s/it]


label, count, max_count 15 122 670


Loss: -17307.301: 100%|██████████| 100/100 [04:38<00:00,  2.79s/it]


label, count, max_count 16 58 670


Loss: -11825.111: 100%|██████████| 100/100 [04:38<00:00,  2.79s/it]


TVAE results for fold 4 saved.
label, count, max_count 1 476 670


Loss: -7234.114: 100%|██████████| 100/100 [04:42<00:00,  2.82s/it]


label, count, max_count 2 400 670


Loss: -13939.025: 100%|██████████| 100/100 [04:43<00:00,  2.84s/it]


label, count, max_count 6 275 670


Loss: -6568.258: 100%|██████████| 100/100 [04:52<00:00,  2.93s/it]


label, count, max_count 8 470 670


Loss: -11370.841: 100%|██████████| 100/100 [04:44<00:00,  2.84s/it]


label, count, max_count 10 202 670


Loss: -13477.799: 100%|██████████| 100/100 [04:44<00:00,  2.84s/it]


label, count, max_count 11 377 670
