MOE-FB Runner (Flow-based, Mixture-of-Experts extension of MAF-FB)

This script trains a MOE-FB flow-based model for single-cell data.
MOE-FB extends MAF-FB with learnable feature masking and a Mixture-of-Experts (MOE) multi-head attention mechanism.

Default hyperparameters

epochs: 100

batch size: 128

hidden features: 1024

learning rate: 1e-6

MOE: num_heads=10, experts=4

ActNorm: configurable (enable/disable via user flag)

Model modules (brief)

Learnable Masking (feature gating):
A lightweight MLP learns a continuous mask ∈ [0,1] for each feature (context-dependent). The masked representation is passed to the flow, improving robustness to noisy/sparse genes.

MOE Attention (multi-head, gated experts):
A multi-head attention layer (heads=10) produces context vectors that a gating network (softmax) uses to mix E=4 expert subnets. The resulting expert-weighted context parametrizes the MAF affine transforms (scale/shift), allowing sample- and feature-adaptive flows.

MAF Backbone:
Masked Affine Autoregressive Transform stacks provide tractable log-likelihood and efficient sampling; MOE outputs modulate the transform parameters.

ActNorm (optional):
Per-channel affine normalization initialized with data-dependent stats; can be toggled by a user parameter.

Supported study scenarios

PBMC3K

Train MOE-FB on real train samples (100 epc, bs=128, hidden=1024, lr=1e-6, heads=10, experts=4, ActNorm as set).

Generate synthetic samples equal to the TEST sample size.

Saves: pbmc3k_MOE-FB_100epc.pkl

HCA-BM10K (5-fold CV)

Integrated Pancreatic Dataset (5-fold CV)

For EACH FOLD:

Fit MOE-FB on TRAIN ONLY.

Per-class synthesis to Q3 (75th percentile) of the training-fold cell-type count distribution (if --label-col provided; labels assigned accordingly).

Augment TRAIN with synthetic data. VALIDATION/TEST are NEVER touched.

Saves per-fold files under: {output}/folds/fold_{i}/

a dictionary per fold containing: train_gen, y_train_gen

Notes

Synthesis strictly uses train-only information (no leakage).


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install nflows --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/45.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.8/45.8 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m117.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m93.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m52.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.8 MB/s[0m eta [3

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pylab import savefig
from scipy.io import arff
import ntpath
import glob
import os
import math
from sklearn import preprocessing

import argparse

import torch
from torch import nn, optim
from nflows.flows import Flow
from nflows.distributions import StandardNormal
from nflows.transforms import CompositeTransform, MaskedAffineAutoregressiveTransform

from sklearn import manifold
import string


Pbmc3k

In [None]:
import pickle
with open(f"data/pbmc3k_train.pkl", "rb") as f:
    X_train = pickle.load(f)
with open(f"data/pbmc3k_test.pkl", "rb") as f:
    X_test = pickle.load(f)
with open(f"data/pbmc3k_y_train.pkl", "rb") as f:
    y_train = pickle.load(f)
with open(f"data/pbmc3k_y_test.pkl", "rb") as f:
    y_test = pickle.load(f)

5CV Data

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"data/5CV_woTest/fold_skf_3000_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)



In [None]:
unique_values, counts = np.unique(all_folds[0]['y_val'], return_counts=True)
display(dict(zip(unique_values, counts)),np.max(counts))

{'PSC': np.int64(11),
 'acinar': np.int64(272),
 'activated_stellate': np.int64(57),
 'alpha': np.int64(974),
 'beta': np.int64(738),
 'delta': np.int64(190),
 'ductal': np.int64(340),
 'endothelial': np.int64(58),
 'epsilon': np.int64(4),
 'gamma': np.int64(85),
 'macrophage': np.int64(11),
 'mast': np.int64(5),
 'mesenchymal': np.int64(16),
 'pp': np.int64(37),
 'quiescent_stellate': np.int64(35),
 'schwann': np.int64(2)}

np.int64(974)

In [None]:
Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
print("Q1",Q1,"\nQ2",Q2,"\nQ3", Q3)


Q1 44 
Q2 227 
Q3 759


In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 2:
X_train shape: (11337, 3000)
y_train shape: (11337,)
Fold 3:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 4:
X_train shape: (11338, 3000)
y_train shape: (11338,)
Fold 5:
X_train shape: (11338, 3000)
y_train shape: (11338,)


HCA

In [None]:
import pickle

all_folds = []

for fold in range(1, 6):
    with open(f"HCA/5CV/fold_skf_{fold}.pkl", "rb") as f:
        fold_data = pickle.load(f)
        all_folds.append(fold_data)


In [None]:
for i,f in enumerate(all_folds, start=1):
  print(f"Fold {i}:")
  print("X_train shape:", f['X_train'].shape)
  print("y_train shape:", f['y_train'].shape)

Fold 1:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 2:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 3:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 4:
X_train shape: (8000, 3000)
y_train shape: (8000,)
Fold 5:
X_train shape: (8000, 3000)
y_train shape: (8000,)


In [None]:
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import torch.optim as optim
import pickle

from nflows.flows import Flow
from nflows.distributions import StandardNormal
from nflows.transforms import CompositeTransform
from nflows.transforms.base import Transform
from nflows.transforms.normalization import ActNorm
from nflows.transforms.permutations import RandomPermutation
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -------------------------
# Contextual Feature Masking
class FeatureMasking(nn.Module):
    def __init__(self, num_features, hidden_features):
        super().__init__()
        self.context_extractor = nn.Sequential(
            nn.Linear(num_features, hidden_features),
            nn.ReLU()
        )
        self.mask_gen = nn.Sequential(
            nn.Linear(hidden_features, num_features),
            nn.Sigmoid()
        )

    def forward(self, x):
        context = self.context_extractor(x)
        mask = 0.5 * self.mask_gen(context)
        x_mean = x.mean(dim=1, keepdim=True)
        x_masked = x * (1 - mask) + x_mean * mask
        return x_masked, mask


class ConditionalActNorm(nn.Module):
    def __init__(self, features, sparsity_threshold=0.8):
        super().__init__()
        self.actnorm = ActNorm(features)
        self.sparsity_threshold = sparsity_threshold
        self.enabled = True

    def forward(self, x, context=None):
        if self.enabled:
            return self.actnorm(x, context)
        else:
            return x, torch.zeros(x.shape[0], device=x.device)

    def inverse(self, z, context=None):
        if self.enabled:
            return self.actnorm.inverse(z, context)
        else:
            return z, torch.zeros(z.shape[0], device=z.device)

# -------------------------
class MoEAttention(nn.Module):
    def __init__(self, features, hidden_features, num_experts=4, num_heads=20):
        super().__init__()
        self.num_experts = num_experts
        if num_heads is None:
            num_heads = 20
        self.attn_experts = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=features, num_heads=num_heads, batch_first=True)
            for _ in range(num_experts)
        ])
        self.gate = nn.Sequential(
            nn.Linear(features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, num_experts)
        )

    def forward(self, x):
        x_unsq = x.unsqueeze(1)
        gate_scores = self.gate(x)
        expert_outputs = []
        for expert in self.attn_experts:
            out, _ = expert(x_unsq, x_unsq, x_unsq)
            expert_outputs.append(out.squeeze(1))
        expert_outputs = torch.stack(expert_outputs, dim=1)
        weights = torch.softmax(gate_scores, dim=1).unsqueeze(2)
        output = (expert_outputs * weights).sum(dim=1)
        return output

# -------------------------
class MoEMaskedAffineTransform(Transform):
    def __init__(self, features, hidden_features, mask, num_experts=4, num_heads=20):
        super().__init__()
        self.features = features
        self.register_buffer("mask", mask.float())
        self.attention = MoEAttention(features, hidden_features, num_experts, num_heads)
        self.scale_net = nn.Sequential(
            nn.Linear(features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, features),
            nn.Tanh()
        )
        self.shift_net = nn.Sequential(
            nn.Linear(features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, features)
        )

    def forward(self, x, context=None):
        xa = x * self.mask
        h = self.attention(xa)
        scale = self.scale_net(h) * (1.0 - self.mask)
        shift = self.shift_net(h) * (1.0 - self.mask)
        zb = (x * (1.0 - self.mask)) * torch.exp(scale) + shift
        z = xa + zb
        logabsdet = scale.sum(dim=1)
        return z, logabsdet

    def inverse(self, z, context=None):
        za = z * self.mask
        h = self.attention(za)
        scale = self.scale_net(h) * (1.0 - self.mask)
        shift = self.shift_net(h) * (1.0 - self.mask)
        xb = (z * (1.0 - self.mask) - shift) * torch.exp(-scale)
        x = za + xb
        logabsdet = -scale.sum(dim=1)
        return x, logabsdet

# -------------------------
def make_alternating_mask(features, even=True, device="cpu"):
    mask = torch.zeros(features, device=device)
    if even:
        mask[::2] = 1.
    else:
        mask[1::2] = 1.
    return mask

# -------------------------
def build_flow_with_glow_blocks(num_features, hidden_features, num_layers, num_experts=4, num_maf_blocks=1, device='cuda'):
    transforms = []
    for i in range(num_layers):
        mask = make_alternating_mask(num_features, even=(i % 2 == 0), device=device)

        cond_actnorm = ConditionalActNorm(num_features)
        transforms.append(cond_actnorm)

        transforms.append(
            MoEMaskedAffineTransform(
                features=num_features,
                hidden_features=hidden_features,
                mask=mask,
                num_experts=num_experts
            )
        )

        for _ in range(num_maf_blocks):
            transforms.append(ConditionalActNorm(num_features))
            transforms.append(
                MaskedAffineAutoregressiveTransform(
                    features=num_features,
                    hidden_features=hidden_features,
                )
            )

    transform = CompositeTransform(transforms)
    base_distribution = StandardNormal([num_features])
    flow = Flow(transform, base_distribution).to(device)
    return flow

# -------------------------
def FB_Oversampler(num_synthetic_samples, X_min, num_layers, hidden_features, learning_rate, num_experts=4, num_epochs=100, batch_size=128,act_norm_enabled=0):

    X_min_tensor = torch.tensor(X_min, dtype=torch.float32).to(device)
    # Keep raw for sparsity check
    raw_data_for_sparsity = X_min_tensor.clone()

    # Scale after sparsity check
    scaler = StandardScaler()
    X_min_scaled = scaler.fit_transform(X_min_tensor.cpu().numpy())
    data = torch.tensor(X_min_scaled, dtype=torch.float32).to(device)

    num_features = X_min.shape[1]
    masking_module = FeatureMasking(num_features, hidden_features).to(device)

    flow = build_flow_with_glow_blocks(num_features, hidden_features, num_layers, num_experts, num_maf_blocks=1, device=device)

    optimizer = optim.Adam(list(flow.parameters()) + list(masking_module.parameters()), lr=learning_rate)
    loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=False)

    # Identify conditional actnorm layers
    conditional_actnorm_layers = [m for m in flow._transform._transforms if isinstance(m, ConditionalActNorm)]

    flow.train()
    masking_module.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for i,batch in enumerate(loader):
            batch = batch.to(device)
            for idx, m in enumerate(conditional_actnorm_layers):
                m.enabled = act_norm_enabled
                #status = "ENABLED" if m.enabled else "SKIPPED"
                #print(f"  ActNorm Layer {idx + 1}: {status}")

            optimizer.zero_grad()
            masked_batch, mask_values = masking_module(batch)
            nll = -flow.log_prob(masked_batch).mean()
            mask_penalty = mask_values.mean()
            loss = nll + 0.01 * mask_penalty
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch.size(0)

        avg_loss = total_loss / len(data)
        print(f"Epoch {epoch+1}/{num_epochs} - loss: {avg_loss:.4f}")

    flow.eval()
    with torch.no_grad():
        samples = flow.sample(num_synthetic_samples).to(device)

    generated_data_np = samples.cpu().numpy()
    generated_data_rescaled = scaler.inverse_transform(generated_data_np)
    return generated_data_rescaled


Using device: cuda


Generate without class

In [None]:
synthetic_samples = FB_Oversampler(int(X_test.shape[0]), X_train,1, 1024, 1e-6)

In [None]:
with open("results" + os.sep + f'pbmc3k_MOE-FB.pkl', 'wb') as f:
          pickle.dump(synthetic_samples, f)

5CV with class

In [None]:

def FB_CV():
  gen_dict = []
  for k, fold in enumerate(all_folds, start=1):
      X_train = fold['X_train']
      X_val = fold['X_val']
      y_train = fold['y_train']
      y_val = fold['y_val']


      # Get unique values and their counts
      unique_values, counts = np.unique(y_train, return_counts=True)
      classlabel_counts = dict(zip(unique_values, counts))
      Q1, Q2, Q3 = np.quantile(counts, [0.25, 0.5, 0.75], axis=0, method='nearest')
      max_count = Q3 #np.max(counts)

      i=1
      for label, count in classlabel_counts.items():
          #print(x.shape, y.shape)
          print("label, count, max_count",label,count,max_count)
          X_minority = X_train[y_train == label]
          if count < max_count:
              #print("\n")
              #print(f"Value {label} appears {count} times.")
              num_synthetic_samples = max_count - count
              num_experts=4,

              synthetic_samples = FB_Oversampler(int(num_synthetic_samples), X_minority.values,num_layers=1, hidden_features=1024,learning_rate= 1e-6)

              X_minority = np.array(X_minority)
              synthetic_samples = np.array(synthetic_samples)
              if i==1:
                X_train_gen = X_minority
                X_train_gen = np.vstack([X_train_gen, synthetic_samples])

                y_train_gen = np.full(max_count, label)

                y_train_indexes = np.full(count, 1)
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
              else:
                tmp = np.vstack([X_minority, synthetic_samples])
                X_train_gen = np.vstack([X_train_gen, tmp])
                y_train_gen = np.concatenate([y_train_gen, np.full(max_count, label)])

                y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
                y_train_indexes = np.concatenate([y_train_indexes, np.full(num_synthetic_samples, 2)])
          else:
              X_train_gen = np.vstack([X_train_gen, X_minority])
              y_train_gen = np.concatenate([y_train_gen, np.full(count, label)])
              y_train_indexes = np.concatenate([y_train_indexes, np.full(count, 1)])
          i=i+1

      syn = {
          'X_train_gen': X_train_gen,
          'y_train_gen': y_train_gen
      }
      gen_dict.append(syn)

      with open("HCA/5CV" + os.sep + f'MOE-FB_skf_fold'+str(k)+'.pkl', 'wb') as f:
          pickle.dump(syn, f)
  return gen_dict


In [None]:
gen_dict = FB_CV()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 70/100 - loss: 4225.7944
Epoch 71/100 - loss: 4225.3535
Epoch 72/100 - loss: 4224.9135
Epoch 73/100 - loss: 4224.4700
Epoch 74/100 - loss: 4224.0346
Epoch 75/100 - loss: 4223.5918
Epoch 76/100 - loss: 4223.1480
Epoch 77/100 - loss: 4222.7093
Epoch 78/100 - loss: 4222.2654
Epoch 79/100 - loss: 4221.8258
Epoch 80/100 - loss: 4221.3875
Epoch 81/100 - loss: 4220.9387
Epoch 82/100 - loss: 4220.4963
Epoch 83/100 - loss: 4220.0539
Epoch 84/100 - loss: 4219.6098
Epoch 85/100 - loss: 4219.1649
Epoch 86/100 - loss: 4218.7240
Epoch 87/100 - loss: 4218.2744
Epoch 88/100 - loss: 4217.8333
Epoch 89/100 - loss: 4217.3836
Epoch 90/100 - loss: 4216.9417
Epoch 91/100 - loss: 4216.4986
Epoch 92/100 - loss: 4216.0435
Epoch 93/100 - loss: 4215.5981
Epoch 94/100 - loss: 4215.1530
Epoch 95/100 - loss: 4214.7021
Epoch 96/100 - loss: 4214.2557
Epoch 97/100 - loss: 4213.8046
Epoch 98/100 - loss: 4213.3513
Epoch 99/100 - loss: 4212.8993
Epoch