In [16]:
import numpy as np
from scipy import linalg
from sklearn.preprocessing import StandardScaler

def generate_synthetic_perturbations(control_data, perturbed_data):
    """
    Generate synthetic perturbed samples from control data while preserving the 
    mean expression and covariance structure of the true perturbed distribution.
    
    Parameters:
    -----------
    control_data : np.ndarray
        Control population gene expression matrix (n_cells × n_genes)
    perturbed_data : np.ndarray
        Perturbed population gene expression matrix (m_cells × n_genes)
    
    Returns:
    --------
    synthetic_perturbed : np.ndarray
        Synthetic perturbed samples with preserved statistics
    transform_matrix : np.ndarray
        The transformation matrix used to generate synthetic samples
    """
    
    # Calculate means and covariances
    control_mean = np.mean(control_data, axis=0)
    perturbed_mean = np.mean(perturbed_data, axis=0)
    
    control_cov = np.cov(control_data.T)
    perturbed_cov = np.cov(perturbed_data.T)
    
    # Ensure numerical stability
    control_cov += 1e-10 * np.eye(control_cov.shape[0])
    perturbed_cov += 1e-10 * np.eye(perturbed_cov.shape[0])
    
    # Calculate the transformation matrix using Cholesky decomposition
    # This preserves the covariance structure
    chol_control = linalg.cholesky(control_cov, lower=True)
    chol_perturbed = linalg.cholesky(perturbed_cov, lower=True)
    
    # Calculate the transformation matrix
    transform = np.dot(chol_perturbed, linalg.inv(chol_control))
    
    # Generate synthetic samples
    centered_control = control_data - control_mean
    synthetic_centered = np.dot(centered_control, transform.T)
    synthetic_perturbed = synthetic_centered + perturbed_mean
    
    return synthetic_perturbed, transform

def validate_synthetic_data(original_perturbed, synthetic_perturbed):
    """
    Validate that the synthetic data preserves the statistical properties
    of the original perturbed data.
    
    Parameters:
    -----------
    original_perturbed : np.ndarray
        Original perturbed population data
    synthetic_perturbed : np.ndarray
        Generated synthetic perturbed data
    
    Returns:
    --------
    dict
        Dictionary containing validation metrics
    """
    # Calculate validation metrics
    orig_mean = np.mean(original_perturbed, axis=0)
    synth_mean = np.mean(synthetic_perturbed, axis=0)
    mean_diff = np.abs(orig_mean - synth_mean)
    
    orig_cov = np.cov(original_perturbed.T)
    synth_cov = np.cov(synthetic_perturbed.T)
    cov_diff = np.abs(orig_cov - synth_cov)
    
    metrics = {
        'mean_difference_max': np.max(mean_diff),
        'mean_difference_avg': np.mean(mean_diff),
        'covariance_difference_max': np.max(cov_diff),
        'covariance_difference_avg': np.mean(cov_diff)
    }
    
    return metrics

# Example usage:


In [None]:
# Generate some example data
n_genes = 100
n_control_cells = 1000
n_perturbed_cells = 800

# Simulate control and perturbed data
np.random.seed(42)
control_data = np.random.normal(0, 1, (n_control_cells, n_genes))

# Create perturbed data with different mean and covariance
random_transform = np.random.normal(0, 0.1, (n_genes, n_genes))
perturbed_data = np.dot(control_data, random_transform.T) + np.random.normal(1, 0.5, n_genes)

print(f"Perturbed data shape {perturbed_data.shape}, perturbed data type {perturbed_data.dtype}")
print(f"Control data shape {control_data.shape}, control data type {control_data.dtype}")

# Generate synthetic perturbed samples
synthetic_perturbed, transform = generate_synthetic_perturbations(control_data, perturbed_data)

# Validate results
metrics = validate_synthetic_data(perturbed_data, synthetic_perturbed)
print("Validation metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Perturbed data shape (1000, 100), perturbed data type float64
Control data shape (1000, 100), control data type float64
Validation metrics:
mean_difference_max: 4.884981308350689e-15
mean_difference_avg: 8.369867299240497e-16
covariance_difference_max: 4.513955875751208e-11
covariance_difference_avg: 9.131090768908179e-12


In [3]:
import scanpy as sc
adata = sc.read('/orcd/data/omarabu/001/Omnicell_datasets/repogle_k562_essential_raw/K562_essential_raw_singlecell_01.h5ad')

In [35]:
control_data = adata[adata.obs["gene"] == "non-targeting"].X.toarray().astype(np.float64)
perturbed_data = adata[adata.obs["gene"] == "PCF11"].X.toarray().astype(np.float64)

In [17]:
synthetic_perturbed, transform = generate_synthetic_perturbations(control_data, perturbed_data)



In [22]:
print(synthetic_perturbed.shape)
print(control_data.shape)


(10691, 8563)
(10691, 8563)


In [19]:
metrics = validate_synthetic_data(perturbed_data, synthetic_perturbed)
print("Validation metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Validation metrics:
mean_difference_max: 7.673861546209082e-13
mean_difference_avg: 3.183176453944053e-15
covariance_difference_max: 2.7816604415420443e-06
covariance_difference_avg: 2.1535969161422817e-11


In [None]:
control_adata = adata[adata.obs["gene"] == "non-targeting"]
gt_adata = adata[adata.obs["gene"] == "PCF11"]

synthetic_perturbed_adata = control_adata.copy()
synthetic_perturbed_adata.X = synthetic_perturbed


In [29]:
synthetic_perturbed_adata.X.mean(axis=0)


array([0.16833667, 0.19438878, 1.52905812, ..., 0.25250501, 0.29058116,
       0.2745491 ])

In [30]:
gt_adata.X.mean(axis=0)


matrix([[0.16833669, 0.19438884, 1.5290581 , ..., 0.25250503, 0.29058114,
         0.27454913]], dtype=float32)

In [31]:
from omnicell.evaluation.utils import get_DEGs, get_DEGs_overlaps

sc.pp.log1p(control_adata)
pred_DEGs_df = get_DEGs(control_adata, synthetic_perturbed_adata)
true_DEGs_df = get_DEGs(control_adata, gt_adata)

DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], 0.05, None)


  view_to_actual(adata)


  utils.warn_names_duplicates("obs")


In [32]:
DEGs_overlaps

{'Overlap_in_top_3677_DEGs': 2946,
 'Overlap_in_top_100_DEGs': 88,
 'Overlap_in_top_50_DEGs': 37,
 'Overlap_in_top_20_DEGs': 14,
 'Jaccard': 0.33774834437086093}

In [36]:
import numpy as np
from scipy import stats
from sklearn.neighbors import NearestNeighbors

def generate_count_preserving_perturbations(control_data, perturbed_data, n_neighbors=20):
    """
    Generate synthetic perturbed samples from control data while preserving
    count structure, sparsity patterns, and statistical properties of the
    true perturbed distribution.
    
    Parameters:
    -----------
    control_data : np.ndarray
        Control population sparse count matrix (n_cells × n_genes)
    perturbed_data : np.ndarray
        Perturbed population sparse count matrix (m_cells × n_genes)
    n_neighbors : int
        Number of nearest neighbors to consider for sampling
    
    Returns:
    --------
    synthetic_perturbed : np.ndarray
        Synthetic perturbed samples with preserved count structure
    """
    
    # Calculate statistical properties of perturbed data
    perturbed_means = np.mean(perturbed_data, axis=0)
    perturbed_vars = np.var(perturbed_data, axis=0)
    
    # Calculate sparsity patterns
    perturbed_sparsity = (perturbed_data == 0).mean(axis=0)
    
    # Initialize synthetic data matrix
    n_control = control_data.shape[0]
    n_genes = control_data.shape[1]
    synthetic_perturbed = np.zeros_like(control_data)
    
    # For each gene, transform the control counts while preserving structure
    for gene_idx in range(n_genes):
        control_gene = control_data[:, gene_idx]
        perturbed_gene = perturbed_data[:, gene_idx]
        
        # Preserve sparsity pattern
        control_nonzero = control_gene != 0
        target_nonzero = np.random.random(n_control) > perturbed_sparsity[gene_idx]
        
        if np.sum(target_nonzero) > 0:
            # For non-zero values, sample from empirical distribution
            nonzero_counts = perturbed_gene[perturbed_gene != 0]
            
            if len(nonzero_counts) > 0:
                # Sample counts from the empirical distribution
                synthetic_counts = np.random.choice(
                    nonzero_counts, 
                    size=np.sum(target_nonzero),
                    replace=True
                )
                
                # Assign sampled counts to non-zero positions
                synthetic_perturbed[target_nonzero, gene_idx] = synthetic_counts
    
    # Adjust for local structure preservation
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, metric='cosine')
    nbrs.fit(perturbed_data)
    
    # Refine synthetic data using local structure
    for cell_idx in range(n_control):
        if cell_idx % 1000 == 0:  # Progress tracking
            print(f"Processing cell {cell_idx}/{n_control}")
            
        cell_vector = synthetic_perturbed[cell_idx]
        
        # Find similar cells in perturbed data
        distances, indices = nbrs.kneighbors(cell_vector.reshape(1, -1))
        
        # Sample from nearest neighbors for refinement
        sampled_neighbor = perturbed_data[
            indices[0][np.random.randint(0, len(indices[0]))]
        ]
        
        # Adjust counts while preserving sparsity
        nonzero_mask = synthetic_perturbed[cell_idx] != 0
        if np.sum(nonzero_mask) > 0:
            synthetic_perturbed[cell_idx, nonzero_mask] = np.minimum(
                synthetic_perturbed[cell_idx, nonzero_mask],
                sampled_neighbor[nonzero_mask]
            )
    
    return synthetic_perturbed

def validate_count_structure(original_perturbed, synthetic_perturbed):
    """
    Validate that the synthetic data preserves count structure and sparsity.
    
    Parameters:
    -----------
    original_perturbed : np.ndarray
        Original perturbed population data
    synthetic_perturbed : np.ndarray
        Generated synthetic perturbed data
    
    Returns:
    --------
    dict
        Dictionary containing validation metrics
    """
    metrics = {
        'original_sparsity': (original_perturbed == 0).mean(),
        'synthetic_sparsity': (synthetic_perturbed == 0).mean(),
        'original_integer': np.all(np.equal(np.mod(original_perturbed, 1), 0)),
        'synthetic_integer': np.all(np.equal(np.mod(synthetic_perturbed, 1), 0)),
        'mean_difference': np.mean(np.abs(
            np.mean(original_perturbed, axis=0) - 
            np.mean(synthetic_perturbed, axis=0)
        )),
        'unique_values': {
            'original': len(np.unique(original_perturbed)),
            'synthetic': len(np.unique(synthetic_perturbed))
        }
    }
    
    return metrics
# Generate synthetic perturbed samples

synthetic_perturbed = generate_count_preserving_perturbations(
    control_data, 
    perturbed_data
)


# Validate results
metrics = validate_count_structure(perturbed_data, synthetic_perturbed)
print("Validation metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value}")

    # Validate results
metrics = validate_synthetic_data(perturbed_data, synthetic_perturbed)
print("Validation metrics:")
for metric, value in metrics.items():
    print(f"{metric}: {value}")

Processing cell 0/10691
Processing cell 1000/10691
Processing cell 2000/10691
Processing cell 3000/10691
Processing cell 4000/10691
Processing cell 5000/10691
Processing cell 6000/10691
Processing cell 7000/10691
Processing cell 8000/10691
Processing cell 9000/10691
Processing cell 10000/10691
Validation metrics:
original_sparsity: 0.6219567477826141
synthetic_sparsity: 0.7712065228809764
original_integer: True
synthetic_integer: True
mean_difference: 0.47022226477160395
unique_values: {'original': 527, 'synthetic': 499}
Validation metrics:
mean_difference_max: 34.66173278181094
mean_difference_avg: 0.47022226477160395
covariance_difference_max: 14225.801850900989
covariance_difference_avg: 0.35706564419917314


In [37]:
synthetic_perturbed_sparsity_adata = control_adata.copy()
synthetic_perturbed_sparsity_adata.X = synthetic_perturbed


In [48]:
print(synthetic_perturbed_sparsity_adata.X.var(axis=0))
print(control_adata.X.toarray().var(axis=0))
from omnicell.evaluation.utils import get_eval

[0.02808893 0.03906319 1.07578913 ... 0.05606958 0.07404216 0.05969178]
[0.05651246 0.10574827 0.33283934 ... 0.09893356 0.10239583 0.14217657]


In [40]:

pred_DEGs_df = get_DEGs(control_adata, synthetic_perturbed_sparsity_adata)
DEGs_overlaps = get_DEGs_overlaps(true_DEGs_df, pred_DEGs_df, [100,50,20], 0.05, None)

DEGs_overlaps



  utils.warn_names_duplicates("obs")


{'Overlap_in_top_3677_DEGs': 1001,
 'Overlap_in_top_100_DEGs': 100,
 'Overlap_in_top_50_DEGs': 47,
 'Overlap_in_top_20_DEGs': 19,
 'Jaccard': 0.13524590163934427}

In [49]:
get_eval(control_adata, gt_adata, synthetic_perturbed_sparsity_adata, true_DEGs_df, [100,50,20], 0.05, None)


{'all_genes_mean_sub_diff_R': 0.9978958581455893,
 'all_genes_mean_sub_diff_R2': 0.9957961437041222,
 'all_genes_mean_sub_diff_MSE': 1.337547250628481,
 'all_genes_mean_fold_diff_R': 0.9864283683293209,
 'all_genes_mean_fold_diff_R2': 0.9730409258448464,
 'all_genes_mean_fold_diff_MSE': 0.9764326740817398,
 'all_genes_mean_R': 0.99861070275701,
 'all_genes_mean_R2': 0.9972233356608494,
 'all_genes_mean_MSE': 1.3375471037723647,
 'all_genes_var_R': 0.9972929237480384,
 'all_genes_var_R2': 0.9945931757579106,
 'all_genes_var_MSE': 21991.493741236867,
 'all_genes_corr_mtx_R': 0.7016891595993637,
 'all_genes_corr_mtx_R2': 0.49236767669926135,
 'all_genes_corr_mtx_MSE': 0.010289543702280723,
 'all_genes_cov_mtx_R': 0.6086382565513448,
 'all_genes_cov_mtx_R2': 0.37044052733786065,
 'all_genes_cov_mtx_MSE': 79.31398266675811,
 'Top_3677_DEGs_sub_diff_R': 0.9982916061354618,
 'Top_3677_DEGs_sub_diff_R2': 0.99658613088052,
 'Top_3677_DEGs_sub_diff_MSE': 3.0295415055729737,
 'Top_3677_DEGs_fold_

In [50]:

get_eval(control_adata, gt_adata, synthetic_perturbed_adata, true_DEGs_df, [100,50,20], 0.05, None)


{'all_genes_mean_sub_diff_R': 0.9999999999999994,
 'all_genes_mean_sub_diff_R2': 0.9999999999999989,
 'all_genes_mean_sub_diff_MSE': 6.506598652215286e-14,
 'all_genes_mean_fold_diff_R': 0.999999999999999,
 'all_genes_mean_fold_diff_R2': 0.999999999999998,
 'all_genes_mean_fold_diff_MSE': 4.9403403383518756e-15,
 'all_genes_mean_R': 0.9999999999999998,
 'all_genes_mean_R2': 0.9999999999999996,
 'all_genes_mean_MSE': 1.7246822772324387e-14,
 'all_genes_var_R': 0.9999999999999997,
 'all_genes_var_R2': 0.9999999999999993,
 'all_genes_var_MSE': 0.3735949021463976,
 'all_genes_corr_mtx_R': 1.0,
 'all_genes_corr_mtx_R2': 1.0,
 'all_genes_corr_mtx_MSE': 6.732789410295028e-22,
 'all_genes_cov_mtx_R': 1.0,
 'all_genes_cov_mtx_R2': 1.0,
 'all_genes_cov_mtx_MSE': 7.168730001909094e-19,
 'Top_3677_DEGs_sub_diff_R': 0.9999999999999996,
 'Top_3677_DEGs_sub_diff_R2': 0.9999999999999991,
 'Top_3677_DEGs_sub_diff_MSE': 1.5138807656079068e-13,
 'Top_3677_DEGs_fold_diff_R': 0.9999999999999991,
 'Top_3677