MAF-FB Runner (Flow-based, Masked Affine Autoregressive Transform – Baseline)

This script trains a MAF-FB flow-based model for single-cell data.

Default hyperparameters

epochs: 100

batch size: 128

hidden features: 1024

learning rate: 1e-6

PBMC3K

Trains MAF-FB on real train samples (100 epc, bs=128, hidden=1024, lr=1e-6).

Generates synthetic samples equal to the TEST sample size.

Saves: pbmc3k_MAF-FB.pkl

PBMC68K

Uses the processed PBMC68K data downloaded from the ACTIVA repository (unchanged).

Trains MAF-FB on train split with the same hyperparameters.

Generates synthetic samples equal to the TEST sample size.

Saves: pbmc68k_MAF-FB.pkl

HCA-BM10K (5-fold CV)

Integrated Pancreatic Dataset (5-fold CV)

For EACH FOLD:

Fit MAF-FB on TRAIN ONLY.

Generate synthetic samples per class to reach the Q3 (75th percentile) of the corresponding cell-type distribution in the training fold (if --label-col provided; labels assigned accordingly).

Augment TRAIN with synthetic data. VALIDATION/TEST are NEVER touched.

Saves per-fold files under: {output}/folds/fold_{i}/

dictionary includes for each fold: train_gen, y_train_gen

Notes

Model is flow-based with Masked Affine Autoregressive Transform (MAF) as the transform module.

All synthesis strictly uses train-only information; no leakage from validation/test.

PBMC68K’s processed input must be provided from the ACTIVA repo as used in the paper.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install nflows --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/45.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.8/45.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m124.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m94.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m55.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.1 MB/s[0m eta [3

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pylab import savefig
from scipy.io import arff
import ntpath
import glob
import os
import math
from sklearn import preprocessing
# !pip install liac-arff
#import arff
import argparse

import torch
from torch import nn, optim
from nflows.flows import Flow
from nflows.distributions import StandardNormal
from nflows.transforms import CompositeTransform, MaskedAffineAutoregressiveTransform

from sklearn import manifold
import string


5CV

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"data/5CV/fold_skf_3000_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)


In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 2:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 3:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 4:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 5:
X_train shape: (11338, 3000)
y_train shape: (11338,)


In [None]:
unique_values, counts = np.unique(y_train, return_counts=True)
display(dict(zip(unique_values, counts)),np.max(counts))

{np.str_('PSC'): np.int64(42),
 np.str_('acinar'): np.int64(1090),
 np.str_('activated_stellate'): np.int64(227),
 np.str_('alpha'): np.int64(3897),
 np.str_('beta'): np.int64(2950),
 np.str_('delta'): np.int64(759),
 np.str_('ductal'): np.int64(1360),
 np.str_('endothelial'): np.int64(231),
 np.str_('epsilon'): np.int64(17),
 np.str_('gamma'): np.int64(339),
 np.str_('macrophage'): np.int64(44),
 np.str_('mast'): np.int64(20),
 np.str_('mesenchymal'): np.int64(64),
 np.str_('pp'): np.int64(148),
 np.str_('quiescent_stellate'): np.int64(138),
 np.str_('schwann'): np.int64(11)}

np.int64(3897)

PBMC3K

In [None]:
import pickle
with open(gdrivePath+os.sep+f"Revision/data/pbmc3k_train.pkl", "rb") as f:
    X_train = pickle.load(f)
with open(gdrivePath+os.sep+f"Revision/data/pbmc3k_test.pkl", "rb") as f:
    X_test = pickle.load(f)
with open(gdrivePath+os.sep+f"Revision/data/pbmc3k_y_train.pkl", "rb") as f:
    y_train = pickle.load(f)
with open(gdrivePath+os.sep+f"Revision/data/pbmc3k_y_test.pkl", "rb") as f:
    y_test = pickle.load(f)

PBMC68K

In [None]:
pip install anndata


Collecting anndata
  Downloading anndata-0.12.1-py3-none-any.whl.metadata (9.6 kB)
Collecting array-api-compat>=1.7.1 (from anndata)
  Downloading array_api_compat-1.12.0-py3-none-any.whl.metadata (2.5 kB)
Collecting legacy-api-wrap (from anndata)
  Downloading legacy_api_wrap-1.4.1-py3-none-any.whl.metadata (2.1 kB)
Collecting zarr!=3.0.*,>=2.18.7 (from anndata)
  Downloading zarr-3.1.0-py3-none-any.whl.metadata (10 kB)
Collecting donfig>=0.8 (from zarr!=3.0.*,>=2.18.7->anndata)
  Downloading donfig-0.8.1.post1-py3-none-any.whl.metadata (5.0 kB)
Collecting numcodecs>=0.14 (from numcodecs[crc32c]>=0.14->zarr!=3.0.*,>=2.18.7->anndata)
  Downloading numcodecs-0.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)
Collecting crc32c>=2.7 (from numcodecs[crc32c]>=0.14->zarr!=3.0.*,>=2.18.7->anndata)
  Downloading crc32c-2.7.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.3 kB)
Downloading anndata-0.

In [None]:
import anndata

# Load the h5ad file
adata = anndata.read_h5ad("/data/68kPBMC_preprocessed.h5ad")

# Print basic info
print(adata)


AnnData object with n_obs × n_vars = 68579 × 17789
    obs: 'cluster', 'n_genes', 'n_counts', 'split'
    var: 'n_cells'


In [None]:
train_mask = adata.obs['split'] == "train"
test_mask = adata.obs['split'] == "test"

# Select training data
X_train = adata.X[train_mask.values]
y_train = adata.obs['cluster'][train_mask.values]

# Select test data
X_test = adata.X[test_mask.values]
y_test = adata.obs['cluster'][test_mask.values]

HCA

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"HCA/5CV/fold_skf_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)


In [None]:
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((8000, 3000), (2000, 3000), (8000,), (2000,))

In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 2:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 3:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 4:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 5:
X_train shape: (8000, 3000)
y_train shape: (8000,)


In [None]:
from sklearn.preprocessing import StandardScaler
import torch
import torch.optim as optim

def FB_Oversampler(num_synthetic_samples, X_min,num_layers,hidden_features,learning_rate):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  scaler = StandardScaler()
  X_min = scaler.fit_transform(X_min)

  data = torch.tensor(X_min, dtype=torch.float32).to(device)

  # Define the number of features and hidden features
  num_features = X_min.shape[1]  # Example number of features
  #hidden_features = 128  # Example number of hidden units in the neural network

  # Create a sequence of masked affine autoregressive transforms
  #num_layers = 1

  # Base distribution: standard normal
  base_distribution = StandardNormal([num_features])

  # Define the sequence of transformations
  transforms = []
  for _ in range(num_layers):
      transforms.append(MaskedAffineAutoregressiveTransform(features=num_features, hidden_features=hidden_features))
  transform = CompositeTransform(transforms)

  # Define the flow model
  flow = Flow(transform, base_distribution).to(device)

  # Train the model
  optimizer = optim.Adam(flow.parameters(), lr=learning_rate)
  num_epochs = 100
  batch_size = 128 #128


  data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

  for epoch in range(num_epochs):
      for batch in data_loader:
          batch = batch.to(device)
          optimizer.zero_grad()
          loss = -flow.log_prob(batch).mean()
          loss.backward()
          optimizer.step()
      print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

  # Generate new samples
  with torch.no_grad():
      samples = flow.sample(num_synthetic_samples).to("cpu")
  generated_data_np = samples.numpy()

   # Rescale the generated samples back to the original scale
  generated_data_rescaled = scaler.inverse_transform(generated_data_np)

  return generated_data_rescaled

PBMC68K

In [None]:
X_train = pbmc68k['X_train']
X_test = pbmc68k['X_test']

In [None]:
synthetic_samples = FB_Oversampler(int(X_test.shape[0]), X_train,1, 1024, 1e-6)

In [None]:
synthetic_samples.shape

(6991, 3000)

In [None]:
import pickle
with open(f"results/pbmc68k_MAF-FB_generated.pkl", "wb") as f:
   pickle.dump(synthetic_samples,f)

5CV RUN

In [None]:

def FB_CV():
  gen_dict = []
  for k, fold in enumerate(all_folds, start=1):
      X_train = fold['X_train']
      X_val = fold['X_val']
      y_train = fold['y_train']
      y_val = fold['y_val']


      # Get unique values and their counts
      unique_values, counts = np.unique(y_train, return_counts=True)
      classlabel_counts = dict(zip(unique_values, counts))
      Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
      max_count = Q3 #np.max(counts)

      i=1
      for label, count in classlabel_counts.items():
          #print(x.shape, y.shape)
          print("label, count, max_count",label,count,max_count)
          X_minority = X_train[y_train == label]
          if count < max_count:
              #print("\n")
              #print(f"Value {label} appears {count} times.")
              num_synthetic_samples = max_count - count
              synthetic_samples = FB_Oversampler(int(num_synthetic_samples), X_minority,1, 1024, 1e-6)

              X_minority = np.array(X_minority)
              synthetic_samples = np.array(synthetic_samples)
              if i==1:
                X_train_gen = X_minority
                X_train_gen = np.vstack([X_train_gen, synthetic_samples])

                y_train_gen = np.full(max_count, label)

                y_train_indexes = np.full(count, 1)
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
              else:
                tmp = np.vstack([X_minority, synthetic_samples])
                X_train_gen = np.vstack([X_train_gen, tmp])
                y_train_gen = np.concatenate([y_train_gen, np.full(max_count, label)])

                y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
          else:
              X_train_gen = np.vstack([X_train_gen, X_minority])
              y_train_gen = np.concatenate([y_train_gen, np.full(count, label)])
              y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
          i=i+1

      syn = {
          'X_train_gen': X_train_gen,
          'y_train_gen': y_train_gen
      }
      gen_dict.append(syn)

      with open("HCA/5CV" + os.sep + f'MAF-FB_skf_fold_'+str(k)+'.pkl', 'wb') as f:
          pickle.dump(syn, f)
  return gen_dict


In [None]:
gen_dict = FB_CV()