GC Runner (SDV-based)

This script trains a GC model on single-cell (tabular) data using the SDV library. It mirrors the CTGAN workflow and supports the same three study scenarios used in the paper.

PBMC3K:

Trains TVAE on real train samples.

Generates synthetic samples equal to TEST sample size.

Saves: GC_pbmc3k.pkl

HCA-BM10K (5-fold CV)

Integrated Pancreatic dataset (5-fold CV)

For EACH FOLD:

Fit TVAE on TRAIN ONLY.

Generate synthetic samples to reach the Q3 (75th percentile) count of the corresponding cell-type distribution
(per-class if --label-col is provided; labels are assigned accordingly).

Augment TRAIN with synthetic data. VALIDATION/TEST are NEVER touched.

Saves per-fold files under: {output}/folds/fold_{i}/

dictionary includes each fold objects: train_gen, y_train_gen

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
pip install sdv --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/185.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m185.6/185.6 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/140.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.1/140.1 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/14.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m7.0/14.0 MB[0m [31m212.1 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m14.0/14.0 MB[0m [31m235.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m133.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━

In [None]:
import sdv
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [None]:
import sdv
print(sdv.__version__)

1.25.0


In [None]:
from sdv.single_table import GaussianCopulaSynthesizer
from sdv.single_table import CTGANSynthesizer
from sdv.single_table import TVAESynthesizer
from sdv.single_table import CopulaGANSynthesizer

from sdv.metadata import Metadata

PBMC3K

In [None]:
import pickle
with open(f"data/pbmc3k_train.pkl", "rb") as f:
    X_train = pickle.load(f)
with open(f"data/pbmc3k_test.pkl", "rb") as f:
    X_test = pickle.load(f)
with open(f"data/pbmc3k_y_train.pkl", "rb") as f:
    y_train = pickle.load(f)
with open(f"data/pbmc3k_y_test.pkl", "rb") as f:
    y_test = pickle.load(f)

5CV Pancreas

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"data/5CV/fold_skf_3000_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)



In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 2:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 3:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 4:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 5:
X_train shape: (11338, 3000)
y_train shape: (11338,)


In [None]:
unique_values, counts = np.unique(all_folds[1]['y_train'], return_counts=True)
display(dict(zip(unique_values, counts)),np.max(counts))

{'PSC': np.int64(44),
 'acinar': np.int64(1084),
 'activated_stellate': np.int64(224),
 'alpha': np.int64(3911),
 'beta': np.int64(2954),
 'delta': np.int64(744),
 'ductal': np.int64(1392),
 'endothelial': np.int64(228),
 'epsilon': np.int64(17),
 'gamma': np.int64(334),
 'macrophage': np.int64(43),
 'mast': np.int64(17),
 'mesenchymal': np.int64(61),
 'pp': np.int64(141),
 'quiescent_stellate': np.int64(134),
 'schwann': np.int64(9)}

np.int64(3911)

In [None]:
Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
print("Q1",Q1,"\nQ2",Q2,"\nQ3", Q3)


HCA

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"HCA/5CV/fold_skf_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)


In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 2:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 3:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 4:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 5:
X_train shape: (8000, 3000)
y_train shape: (8000,)


Gaussian Copula

In [None]:
def generate_GC(metadata,num_synthetic_samples, X_min):

    gc = GaussianCopulaSynthesizer(metadata)
    gc.fit(X_min)
    gc_synthetic = gc.sample(num_synthetic_samples)

    return gc_synthetic

without class

In [None]:
metadata = Metadata()
metadata = metadata.detect_from_dataframe(X_train)
numeric_cols = list(X_train.columns)
for col in numeric_cols:
  metadata.update_column(
  column_name=col,
  sdtype='numerical'
  )
synthetic_samples = generate_GC(metadata, int(X_test.shape[0]), X_train)



In [None]:
synthetic_samples.shape

(264, 2000)

In [None]:
with open("results" + os.sep + f'GC_pbmc3k.pkl', 'wb') as f:
          pickle.dump(synthetic_samples, f)

CV with class

In [None]:

def GC_CV():
  gen_dict = []
  for k, fold in enumerate(all_folds, start=1):
      X_train = fold['X_train']
      X_val = fold['X_val']
      y_train = fold['y_train']
      y_val = fold['y_val']

      metadata = Metadata()
      metadata = metadata.detect_from_dataframe(X_train)
      numeric_cols = list(X_train.columns)
      for col in numeric_cols:
          metadata.update_column(
              column_name=col,
              sdtype='numerical'
          )
      # Get unique values and their counts
      unique_values, counts = np.unique(y_train, return_counts=True)
      classlabel_counts = dict(zip(unique_values, counts))
      Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
      max_count = Q3 #np.max(counts)

      i=1
      for label, count in classlabel_counts.items():
          #print(x.shape, y.shape)
          print("label, count, max_count",label,count,max_count)
          X_minority = X_train[y_train == label]
          if count < max_count:
              #print("\n")
              #print(f"Value {label} appears {count} times.")
              num_synthetic_samples = max_count - count
              synthetic_samples = generate_GC(metadata, num_synthetic_samples, X_minority)

              X_minority = np.array(X_minority)
              synthetic_samples = np.array(synthetic_samples)
              if i==1:
                X_train_gen = X_minority
                X_train_gen = np.vstack([X_train_gen, synthetic_samples])

                y_train_gen = np.full(max_count, label)

                y_train_indexes = np.full(count, 1)
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
              else:
                tmp = np.vstack([X_minority, synthetic_samples])
                X_train_gen = np.vstack([X_train_gen, tmp])
                y_train_gen = np.concatenate([y_train_gen, np.full(max_count, label)])

                y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
          else:
              X_train_gen = np.vstack([X_train_gen, X_minority])
              y_train_gen = np.concatenate([y_train_gen, np.full(count, label)])
              y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
          i=i+1

      syn = {
          'X_train_gen': X_train_gen,
          'y_train_gen': y_train_gen
      }
      gen_dict.append(syn)

      with open("HCA/5CV" + os.sep + f'GC_skf_fold'+str(k)+'.pkl', 'wb') as f:
          pickle.dump(syn, f)

  return gen_dict


In [None]:
gen_dict = GC_CV()

label, count, max_count 1 476 670




label, count, max_count 2 399 670
label, count, max_count 3 918 670
label, count, max_count 4 1205 670
label, count, max_count 5 670 670
label, count, max_count 6 275 670
label, count, max_count 7 1658 670
label, count, max_count 8 470 670
label, count, max_count 9 892 670
label, count, max_count 10 202 670
label, count, max_count 11 377 670
label, count, max_count 12 91 670
label, count, max_count 13 107 670
label, count, max_count 14 80 670
label, count, max_count 15 121 670
label, count, max_count 16 59 670
label, count, max_count 1 476 671




label, count, max_count 2 399 671
label, count, max_count 3 918 671
label, count, max_count 4 1205 671
label, count, max_count 5 671 671
label, count, max_count 6 274 671
label, count, max_count 7 1658 671
label, count, max_count 8 470 671
label, count, max_count 9 892 671
label, count, max_count 10 202 671
label, count, max_count 11 377 671
label, count, max_count 12 91 671
label, count, max_count 13 107 671
label, count, max_count 14 80 671
label, count, max_count 15 121 671
label, count, max_count 16 59 671
label, count, max_count 1 476 671




label, count, max_count 2 399 671
label, count, max_count 3 919 671
label, count, max_count 4 1204 671
label, count, max_count 5 671 671
label, count, max_count 6 274 671
label, count, max_count 7 1659 671
label, count, max_count 8 469 671
label, count, max_count 9 892 671
label, count, max_count 10 203 671
label, count, max_count 11 376 671
label, count, max_count 12 92 671
label, count, max_count 13 106 671
label, count, max_count 14 80 671
label, count, max_count 15 122 671
label, count, max_count 16 58 671
label, count, max_count 1 476 670




label, count, max_count 2 399 670
label, count, max_count 3 919 670
label, count, max_count 4 1205 670
label, count, max_count 5 670 670
label, count, max_count 6 274 670
label, count, max_count 7 1659 670
label, count, max_count 8 469 670
label, count, max_count 9 892 670
label, count, max_count 10 203 670
label, count, max_count 11 377 670
label, count, max_count 12 91 670
label, count, max_count 13 106 670
label, count, max_count 14 80 670
label, count, max_count 15 122 670
label, count, max_count 16 58 670
label, count, max_count 1 476 670




label, count, max_count 2 400 670
label, count, max_count 3 918 670
label, count, max_count 4 1205 670
label, count, max_count 5 670 670
label, count, max_count 6 275 670
label, count, max_count 7 1658 670
label, count, max_count 8 470 670
label, count, max_count 9 892 670
label, count, max_count 10 202 670
label, count, max_count 11 377 670
label, count, max_count 12 91 670
label, count, max_count 13 106 670
label, count, max_count 14 80 670
label, count, max_count 15 122 670
label, count, max_count 16 58 670
