# Job Dependency Chain on HPC

### Step 1. Get the parent JobID

In [None]:
jid=$(sbatch --parsable --array=... my_array_script.sh)

	•	--parsable makes sbatch print only one token: the parent JobID of the array job (e.g. 123456).  
	•	All array tasks will be named 123456_1, 123456_2, …, but you only need the parent ID to set dependencies.


### Step 2.	Submit the next job with a dependency

In [None]:
sbatch --dependency=afterok:${jid} next_step.sh

	•	afterok:<JobID> tells Slurm: “start this job only after the referenced job and all its array tasks exit with state COMPLETED.”
	•	If any array task fails (FAILED, CANCELLED…), the dependency is not satisfied and the next job stays pending.


### Step 3.	Chain as many steps as you like

Store each parent JobID in a variable and pass it to the next submission:

In [None]:
prev=""                          # previous step
for step in 1 2 3; do
    jid=$(sbatch --parsable $([[ -n $prev ]] && echo "--dependency=afterok:$prev") step_$step.sh)
    prev=$jid
done

Slurm now enforces a dependency chain (aka chained jobs / workflow):  
step 1 → step 2 → step 3, automatically handling all waiting and ordering.

## Here is an example:

Recently, I have been running a Wilcoxon test on a very large dataset, which did not finish even after 30 days. Therefore, I decided to split the entire dataset into smaller subsets and run the Wilcoxon tests on them in parallel. However, loading the dataset requires a node with 750 GB of memory, and such nodes are limited. Hence, I first need to perform the splitting step on a large-memory node and save the resulting subsets. After that, I can request hundreds of regular-memory nodes and run many Wilcoxon tests in parallel.

Here, however, I encountered another problem. Each subset is about 6 GB in size, and there are around 16,000 subsets in total, which obviously exceeds the available storage capacity. Therefore, I decided to process them in groups of 200 subsets at a time. For each round, I extract 200 subsets from the original dataset, run the Wilcoxon tests on them, and then delete those 200 subsets to free up storage before proceeding to the next round. 

Even though I only process 200 subsets at a time, there are still more than 80 rounds in total. For each round, I need to first submit a split job to generate the 200 subsets, wait for it to finish and save them, and then submit another Wilcoxon job to analyze those subsets. This process is quite time-consuming.

If I submit all the jobs to Slurm at once, there’s another problem — the next split job might start before the previous Wilcoxon job finishes, which would delete the subsets that are still being used by the unfinished Wilcoxon tasks. So I must wait until all 200 Wilcoxon tests from the previous round are completely finished before deleting those subsets and creating the next batch of subsets.

To achieve this, I learned about the Job Dependency Chain method on GPT, which allows me to link multiple Slurm jobs together so that each step automatically starts only after the previous one has successfully completed.

### This script is a Slurm automation (driver) script designed to sequentially submit and manage multiple rounds of data processing tasks on a computing cluster.

In [None]:
#!/bin/bash

cd /dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/

split_script=/dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/huang_split_qiruiz_auto.sh
wilcoxon_script=/dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/huang_wilcoxon_qiruiz.sh

prev_dep=""

for x in {1..5}; do
  if [[ -n "$prev_dep" ]]; then
      dep_opt="--dependency=afterok:${prev_dep}"
  else
      dep_opt=""
  fi

  jid_split=$(sbatch --parsable $dep_opt \
                     --export=IDX=$x "$split_script")
  echo "[+] Split $x  →  JobID $jid_split"

  beg=$((200*(x-1)+1))
  end=$((200*x))

  jid_wil=$(sbatch --parsable --array=${beg}-${end} \
                   --dependency=afterok:${jid_split} \
                   "$wilcoxon_script")
  echo "    ↳ Wilcoxon ${beg}-${end}  →  JobID $jid_wil"

  prev_dep=$jid_wil
done

### Here is the split_script:

/dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/huang_split_qiruiz_auto.sh

In [None]:
#!/bin/bash

#SBATCH --job-name=zqr_huang_split
#SBATCH --output=/dartfs/rc/lab/S/Szhao/qiruiz/perturb-predict-apply/task/result/huang_split_%A_%x_%j.out
#SBATCH --time=06:00:00             
#SBATCH --partition=standard
#SBATCH --account=nccc
#SBATCH --ntasks=1                     
#SBATCH --cpus-per-task=1  
#SBATCH --mem=750G 
#SBATCH --mail-user=f0070pp@dartmouth.edu
#SBATCH --mail-type=BEGIN,END,FAIL          

idx=${IDX:-$1}
if [[ -z "$idx" ]]; then
  echo "ERROR: IDX (split index) not provided; exiting."
  exit 1
fi

echo "$SLURM_JOB_ID  starting split $idx  at $(date) on $(hostname)"

source /optnfs/common/miniconda3/etc/profile.d/conda.sh
conda activate gears

rm -rf /dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/sc/Huang-HCT116/*

cd /dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data
python -u huang_parallel_split_adata.py "$idx"

echo "Split $idx completed"
echo "Finished execution at: $(date)"

### Here is the wilcoxon_script:

/dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/huang_wilcoxon_qiruiz.sh


In [None]:
#!/bin/bash

#SBATCH --job-name=zqr_wilcoxon
#SBATCH --output=/dartfs/rc/lab/S/Szhao/qiruiz/perturb-predict-apply/task/result/huang_wilcoxon_%A_%a.out
#SBATCH --time=02:00:00             
#SBATCH --partition=standard
#SBATCH --account=nccc        
#SBATCH --ntasks=1                     
#SBATCH --cpus-per-task=1  
#SBATCH --mem=24G 
#SBATCH --mail-type=BEGIN,END,FAIL
#SBATCH --mail-user=f0070pp@dartmouth.edu
          

echo $SLURM_JOB_ID starting execution `date` on `hostname`

source /software/python-anaconda-2022.05-el8-x86_64/etc/profile.d/conda.sh
conda activate gears_cpu

cd /dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data

python -u huang_parallel_huang_wilcoxon.py ${SLURM_ARRAY_TASK_ID}
echo "Finished execution at: `date`"

It automatically controls the execution order of two stages:  
	1.	Split stage  
Runs huang_split_qiruiz_auto.sh on a large-memory node to divide the original big dataset into smaller subsets (200 per round).  
	2.	Wilcoxon stage  
Runs huang_wilcoxon_qiruiz.sh on regular-memory nodes to perform Wilcoxon tests in parallel on those 200 subsets.  

Using a Slurm Job Dependency Chain, the script ensures that the jobs run strictly in the correct sequence:

split(1) → wilcoxon(1–200)  
       ↓  
split(2) → wilcoxon(201–400)  
       ↓  
split(3) ...

### To submit the jobs, you only need to run the master.sh script on the terminal of the login node (the first script among the three).
### bash master.sh

Then it will appear:

In [None]:
[f0070pp@discovery-01 ~]$ bash /dartfs/rc/lab/S/Szhao/perturbation_prediction/SeqExpDesign/vcc/data/huang_master.sh  
[+] Split 1  →  JobID 6292666  
    ↳ Wilcoxon 1-200  →  JobID 6292667  
[+] Split 2  →  JobID 6292668  
    ↳ Wilcoxon 201-400  →  JobID 6292669  
[f0070pp@discovery-01 ~]$ squeue -u f0070pp  
             JOBID PARTITION     NAME     USER ST       TIME  NODES NODELIST(REASON)  
 6292669_[201-400]  standard zqr_wilc  f0070pp PD       0:00      1 (Dependency)  
           6292668  standard zqr_huan  f0070pp PD       0:00      1 (Dependency)  
   6292667_[1-200]  standard zqr_wilc  f0070pp PD       0:00      1 (Dependency)  
           6292666  standard zqr_huan  f0070pp PD       0:00      1 (Priority)  


## Notes:
	1.	On Discovery, the largest available node has 750 GB of memory, but the pending time is very long and the node is unstable — sometimes the job freezes halfway without error.  
	2.	The total number of jobs that can be submitted at once through a Job Dependency Chain is also limited. On RCC, I encountered a limit of around 1,000 jobs (5 loops × 200 array tasks). However, even with this limitation, it is still much more convenient than submitting each round manually. I haven’t tested the single-submission limit on Discovery yet.

#### Appendix

In [None]:
import os
import sys
os.environ["OMP_NUM_THREADS"] = "11"
os.environ["OPENBLAS_NUM_THREADS"] = "8" # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = "11" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "8" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "11" # export NUMEXPR_NUM_THREADS=6
os.environ["NUMBA_CACHE_DIR"]='/tmp/numba_cache'
import numpy as np
import pandas as pd
import scipy as sp
import scipy.sparse

import pickle

import anndata as ad
import scanpy as sc

import numba
import multiprocessing
# njobs = max(1, multiprocessing.cpu_count())
njobs = min(1, multiprocessing.cpu_count())
numba.set_num_threads(njobs)

import gc
from tqdm import tqdm
print(pd.__version__, sc.__version__, ad.__version__)


def summarize(adata):
    N_C, N_G = adata.shape
    N_P = adata.obs['perturbation'].nunique()
    N_P_2 = adata.obs['perturbation'][adata.obs['perturbation'].str.contains('_')].nunique()
    return N_C, N_G, N_P, N_P - N_P_2, N_P_2

def comp_bulk_expressions(adata, key='perturbation'):
    '''
    Calculate the bulk expressions (in the log-scale) in the given DataFrame.
    This function groups the DataFrame by the 'perturbation' column and then 
    applies a transformation to calculate the average effect. The transformation 
    involves taking logarithm of the mean of the exponentiated values minus one for each group.

    Parameters
    ----------
    adata : anndata.AnnData or pandas.DataFrame
        An AnnData or a DataFrame containing the data with a column 'key'.
    key : str, optional
        The column name to group by, default is 'condition'.

    Returns
    -------
    adata_bulk : anndata.AnnData or pandas.DataFrame
        An AnnData or a DataFrame with the average effect of each perturbation.
    '''
    if isinstance(adata, ad.AnnData):
        df = adata.to_df()
        key_pert = adata.obs[key]        
    else:
        df = adata
        key_pert = key
    obs = adata.obs.drop_duplicates(subset=[key]).set_index(key).sort_values(key)
    obs['n_cells'] = adata.obs.groupby(key, observed=True).size()
    var = adata.var
    # save memory by removing the original adata
    # del adata
    gc.collect()
    df = df.astype(np.float32)  # Ensure consistent data type for calculations
    grouped = df.groupby(key_pert, observed=True)
    sums = grouped.sum().astype(np.float32)
    sizes = grouped.size().astype(np.float32)
    means = sums / sizes.values[:, None]
    stds = grouped.std().astype(np.float32)
    adata_bulk = means  # You can return or use both means and stds as needed
    del df
    gc.collect()
    adata_bulk = ad.AnnData(
        X=sp.sparse.csr_matrix(adata_bulk.values), 
        obs=obs, 
        var=var
        )
    return adata_bulk, stds

def comp_bulk_expressions_batch(adata, key='perturbation', group_size=1000):
    """
    Wrapper function to calculate bulk expressions in smaller groups to save memory.

    Parameters
    ----------
    adata : anndata.AnnData
        An AnnData object containing the data with a column 'key'.
    key : str, optional
        The column name to group by, default is 'perturbation'.
    group_size : int, optional
        Number of perturbation conditions to process in each group, default is 10.

    Returns
    -------
    adata_bulk : anndata.AnnData
        An AnnData object with the average effect of each perturbation.
    """
    perturbations = np.sort(np.array(adata.obs[key].unique()))
    grouped_bulk = []
    stds = []
    for i in tqdm(range(0, len(perturbations), group_size), desc='Processing perturbations'):
        subset_perturbations = perturbations[i:(i+group_size)]
        sub_adata = adata[adata.obs[key].isin(subset_perturbations)]
        if sub_adata.shape[0] == 0:
            continue
        sub_bulk, std = comp_bulk_expressions(sub_adata, key=key)
        grouped_bulk.append(sub_bulk.copy())        
        stds.append(std)
    # Concatenate all sub-bulk results
    adata_bulk = ad.concat(grouped_bulk, merge='same')
    stds = pd.concat(stds, axis=0)
    adata_bulk.uns['std'] = stds.copy()
    return adata_bulk



datasets = [
    # single perturbation
    "Adamson", "Frangieh",    
    "Replogle-GW-k562", "Replogle-E-k562", "Replogle-E-rpe1",
    "Tian-crispra", "Tian-crispri",
    "Jiang-IFNB", "Jiang-IFNG", "Jiang-INS", "Jiang-TGFB", "Jiang-TNFA",
    "Huang-HCT116", "Huang-HEK293T",

    "Nadig-HEPG2", "Nadig-JURKAT",

    # double perturbation
    "Norman", "Wessels",
]

path_origin = 'origin/'
path_sc = 'sc/';os.makedirs(path_sc, exist_ok=True)
path_bulk = 'bulk/';os.makedirs(path_bulk, exist_ok=True)
path_bulk_log = 'bulk_log/';os.makedirs(path_bulk_log, exist_ok=True)

dataset = "Huang-HCT116"
print(f'Processing {dataset}...')

adata = sc.read_h5ad(f'{path_origin}{dataset}.h5ad')
print(adata)
if dataset == 'Adamson':
    adata.obs.rename({'gene':'perturbation'}, axis=1, inplace=True)
    adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'CTRL':'control'}).values
    adata = adata[adata.obs['perturbation']!='None']
elif dataset.startswith('Huang'):
    adata.obs.rename({'gene_target':'perturbation'}, axis=1, inplace=True)
    adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'Non-Targeting':'control'}).values
elif dataset.startswith('Nadig'):
    adata.obs.rename({'gene':'perturbation'}, axis=1, inplace=True)
    adata.obs['perturbation'] = adata.obs['perturbation'].astype(str)
    adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'non-targeting':'control'}).values
    adata.var.set_index('gene_name', inplace=True)
elif dataset.startswith('Jiang'):
    raise NotImplementedError("Currently not supported for Jiang dataset.")

print(summarize(adata))

# filter out cells
if dataset.startswith('Huang'):
    print('Skipping filtering cells for Huang datasets')
    print('Min genes by counts:', adata.obs['n_genes_by_counts'].min())
else:
    sc.pp.filter_cells(adata, min_genes=100)
print(summarize(adata))

# filter perturbation condition
ncells_pert = adata.obs.groupby('perturbation', observed=True).size()
min_cells = 25 if dataset.startswith('Nadig') else 50
valid_pert = np.array(ncells_pert[ncells_pert >= min_cells].index)
valid_pert = valid_pert[np.isin(valid_pert, adata.var.index)]
# TODO: filter inefficient perturbations
adata = adata[adata.obs['perturbation'].isin(np.append(valid_pert, ['control']))]
print(summarize(adata))


# # filter cells by perturbation quantile effect
# adata = filter_cells_by_pert_effect(adata)
# print(summarize(adata))

if not sp.sparse.issparse(adata.X):
    adata.X = scipy.sparse.csr_matrix(adata.X)

# Filter genes with less than 100 cells in the control groups, but keep those in adata.obs.index
gene_filter = (np.sum(adata[adata.obs['perturbation'] == 'control'].X > 0, axis=0) >= 100) | adata.var.index.isin(valid_pert)
adata = adata[:, gene_filter]

duplicate_var_names = adata.var_names[adata.var_names.duplicated()]
print(f"Duplicate var names: {duplicate_var_names}")
adata = adata[:, ~adata.var_names.duplicated()]
# sc.pp.filter_genes(adata, min_cells=100)
# sc.pp.filter_genes(adata, max_counts=10) # this seems to be too strict
print(summarize(adata))


# Add information
adata.obs.loc[:,'dataset'] = dataset
# TODO: add celltype/pathway information
# if 'pathway' not in adata.obs:
if dataset.startswith('Huang') or dataset.startswith('Nadig'):
    adata.obs.loc[:,'celltype'] = dataset.split('-')[1]    
elif not dataset.startswith('Jiang'):
    dict_cts = {
        'Adamson': 'K562',
        'Replogle-GW-k562': 'K562',
        'Replogle-E-k562': 'K562',
        'Replogle-E-rpe1': 'RPE1',
        'Frangieh':'melanoma',
        "Tian-crispra": 'iPSC', 
        "Tian-crispri": 'iPSC'
        }
    adata.obs.loc[:,'celltype'] = dict_cts[dataset]

    adata.X = scipy.sparse.csr_matrix(adata.X)

print(summarize(adata))


col = adata.obs['perturbation'].astype(str)
perts = sorted(p for p in col.unique() if p != 'control')
# pert_path = "/project/jingshuw/SeqExpDesign/vcc/data/pert_Huang-HCT116.csv"
# pd.Series(perts, name="perturbation").to_csv(pert_path, index=False)
print("The number of perturbations:", len(perts))

batch_size = 200
batch_idx = int(sys.argv[1])
start = (batch_idx - 1) * batch_size
end = min(batch_idx * batch_size, len(perts))

batch_perts = perts[start:end]
print(f"Batch {batch_idx}: {len(batch_perts)} perturbations ({start+1}-{end})")

for pert in batch_perts:
    adata_split = adata[col.isin([pert, 'control'])].copy()
    print(f"[PAIR] pert={pert} | cells={adata_split.n_obs} | genes={adata_split.n_vars}")
    out_path = os.path.join(path_sc, dataset, f'{pert}.h5ad')
    adata_split.write_h5ad(out_path)

# # adata = sc.read_h5ad(f'{path_sc}{dataset}_overlap_vcc.h5ad')

# ##############################################################################
# #
# # Save the bulk expression data
# #
# ##############################################################################
# # print('Save the bulk expression data')

# # For pseudo bulk aggregation
# # Aggregate counts of adata.X according to perturbation
# # adata_bulk = comp_bulk_expressions_batch(adata, key='perturbation')

# # # calculate the library size
# # adata_bulk.obs['n_counts'] = adata_bulk.X.sum(axis=1)

# # print(adata_bulk)
# # adata_bulk.write_h5ad(f'{path_bulk}{dataset}_overlap_vcc.h5ad')



# ##############################################################################
# #
# # Save the bulk expression data (after log transformantion)
# #
# ##############################################################################
# # print('Save the bulk expression data (after log transformantion)')


# # For mean aggreagation
# sc.pp.normalize_total(adata)
# sc.pp.log1p(adata)

# # # Aggregate counts of adata.X according to perturbation
# # adata_bulk = comp_bulk_expressions_batch(adata, key='perturbation')

# # # calculate the library size
# # adata_bulk.obs['n_counts'] = adata_bulk.X.sum(axis=1)

# # print(adata_bulk)
# # adata_bulk.write_h5ad(f'{path_bulk_log}{dataset}_overlap_vcc.h5ad')



# ##############################################################################
# #
# # Calculate DE genes from single-cell data
# #
# ##############################################################################
# print('Calculate DE genes from single-cell data')

# path_de = f'de/';os.makedirs(path_de, exist_ok=True)
# data_dir = 'sc/'

# def calculate_de_genes_group(adata, rankby_abs, test_name='t-test'):
#     perturbations = adata.obs['perturbation'].unique()
#     tie_correct = 'tie' if test_name == 'wilcoxon' else 't-test'
#     sc.tl.rank_genes_groups(
#         adata, groupby='perturbation', reference='control',
#         n_genes=adata.shape[1], rankby_abs=rankby_abs=='abs',
#         method=test_name, tie_correct=tie_correct=='tie',
#         use_raw=False, n_jobs=njobs)
#     df = pd.DataFrame(adata.uns['rank_genes_groups']['names']).T
#     df_rank = df.apply(lambda x: pd.Series(x.index, index=x.values)[adata.var.index], axis=1).astype(int)
#     df_rank = df_rank[sorted(df_rank.columns)]
#     de_result = {"perturbations": perturbations, "rank_genes_groups": adata.uns['rank_genes_groups'], "df_rank": df_rank}
#     return de_result


# def calculate_de_genes(adata, path_de, dataset, test_name='t-test'):
#     """
#     Calculate DE genes from single-cell data using different statistical tests.
    
#     Parameters
#     ----------
#     adata : anndata.AnnData
#         The input AnnData object containing single-cell data.
#     path_de : str
#         The path where DE results will be saved.
    
#     Returns
#     -------
#     None
#         Results are saved to disk as pickle files.
#     """
#     cell_types = adata.obs['celltype'].unique()
#     per = [p for p in adata.obs['perturbation'].astype(str).unique() if p != 'control']; assert len(per)==1; pert_name = per[0]
#     for rankby_abs in ['abs']: #['abs', 'noabs']
#         path = f'{path_de}{test_name}/{rankby_abs}/{dataset}/';os.makedirs(path, exist_ok=True)
#         assert len(cell_types) == 1
#         de_genes = calculate_de_genes_group(adata, rankby_abs, test_name=test_name)
            
#         with open(f'{path}{pert_name}.pkl', 'wb') as f:
#             pickle.dump(de_genes, f)



# for test_name in ['wilcoxon']:
#     calculate_de_genes(adata, path_de, dataset, test_name=test_name)

In [None]:
import os
import sys
os.environ["OMP_NUM_THREADS"] = "11"
os.environ["OPENBLAS_NUM_THREADS"] = "8" # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = "11" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "8" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "11" # export NUMEXPR_NUM_THREADS=6
os.environ["NUMBA_CACHE_DIR"]='/tmp/numba_cache'
import numpy as np
import pandas as pd
import scipy as sp
import scipy.sparse

import pickle

import anndata as ad
import scanpy as sc

import numba
import multiprocessing
# njobs = max(1, multiprocessing.cpu_count())
njobs = min(1, multiprocessing.cpu_count())
numba.set_num_threads(njobs)

import gc
from tqdm import tqdm
print(pd.__version__, sc.__version__, ad.__version__)


def summarize(adata):
    N_C, N_G = adata.shape
    N_P = adata.obs['perturbation'].nunique()
    N_P_2 = adata.obs['perturbation'][adata.obs['perturbation'].str.contains('_')].nunique()
    return N_C, N_G, N_P, N_P - N_P_2, N_P_2

def comp_bulk_expressions(adata, key='perturbation'):
    '''
    Calculate the bulk expressions (in the log-scale) in the given DataFrame.
    This function groups the DataFrame by the 'perturbation' column and then 
    applies a transformation to calculate the average effect. The transformation 
    involves taking logarithm of the mean of the exponentiated values minus one for each group.

    Parameters
    ----------
    adata : anndata.AnnData or pandas.DataFrame
        An AnnData or a DataFrame containing the data with a column 'key'.
    key : str, optional
        The column name to group by, default is 'condition'.

    Returns
    -------
    adata_bulk : anndata.AnnData or pandas.DataFrame
        An AnnData or a DataFrame with the average effect of each perturbation.
    '''
    if isinstance(adata, ad.AnnData):
        df = adata.to_df()
        key_pert = adata.obs[key]        
    else:
        df = adata
        key_pert = key
    obs = adata.obs.drop_duplicates(subset=[key]).set_index(key).sort_values(key)
    obs['n_cells'] = adata.obs.groupby(key, observed=True).size()
    var = adata.var
    # save memory by removing the original adata
    # del adata
    gc.collect()
    df = df.astype(np.float32)  # Ensure consistent data type for calculations
    grouped = df.groupby(key_pert, observed=True)
    sums = grouped.sum().astype(np.float32)
    sizes = grouped.size().astype(np.float32)
    means = sums / sizes.values[:, None]
    stds = grouped.std().astype(np.float32)
    adata_bulk = means  # You can return or use both means and stds as needed
    del df
    gc.collect()
    adata_bulk = ad.AnnData(
        X=sp.sparse.csr_matrix(adata_bulk.values), 
        obs=obs, 
        var=var
        )
    return adata_bulk, stds

def comp_bulk_expressions_batch(adata, key='perturbation', group_size=1000):
    """
    Wrapper function to calculate bulk expressions in smaller groups to save memory.

    Parameters
    ----------
    adata : anndata.AnnData
        An AnnData object containing the data with a column 'key'.
    key : str, optional
        The column name to group by, default is 'perturbation'.
    group_size : int, optional
        Number of perturbation conditions to process in each group, default is 10.

    Returns
    -------
    adata_bulk : anndata.AnnData
        An AnnData object with the average effect of each perturbation.
    """
    perturbations = np.sort(np.array(adata.obs[key].unique()))
    grouped_bulk = []
    stds = []
    for i in tqdm(range(0, len(perturbations), group_size), desc='Processing perturbations'):
        subset_perturbations = perturbations[i:(i+group_size)]
        sub_adata = adata[adata.obs[key].isin(subset_perturbations)]
        if sub_adata.shape[0] == 0:
            continue
        sub_bulk, std = comp_bulk_expressions(sub_adata, key=key)
        grouped_bulk.append(sub_bulk.copy())        
        stds.append(std)
    # Concatenate all sub-bulk results
    adata_bulk = ad.concat(grouped_bulk, merge='same')
    stds = pd.concat(stds, axis=0)
    adata_bulk.uns['std'] = stds.copy()
    return adata_bulk



datasets = [
    # single perturbation
    "Adamson", "Frangieh",    
    "Replogle-GW-k562", "Replogle-E-k562", "Replogle-E-rpe1",
    "Tian-crispra", "Tian-crispri",
    "Jiang-IFNB", "Jiang-IFNG", "Jiang-INS", "Jiang-TGFB", "Jiang-TNFA",
    "Huang-HCT116", "Huang-HEK293T",

    "Nadig-HEPG2", "Nadig-JURKAT",

    # double perturbation
    "Norman", "Wessels",
]

path_origin = 'origin/'
path_sc = 'sc/';os.makedirs(path_sc, exist_ok=True)
path_bulk = 'bulk/';os.makedirs(path_bulk, exist_ok=True)
path_bulk_log = 'bulk_log/';os.makedirs(path_bulk_log, exist_ok=True)

dataset = "Huang-HCT116"
print(f'Processing {dataset}...')

# adata = sc.read_h5ad(f'{path_origin}{dataset}.h5ad')
# print(adata)
# if dataset == 'Adamson':
#     adata.obs.rename({'gene':'perturbation'}, axis=1, inplace=True)
#     adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'CTRL':'control'}).values
#     adata = adata[adata.obs['perturbation']!='None']
# elif dataset.startswith('Huang'):
#     adata.obs.rename({'gene_target':'perturbation'}, axis=1, inplace=True)
#     adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'Non-Targeting':'control'}).values
# elif dataset.startswith('Nadig'):
#     adata.obs.rename({'gene':'perturbation'}, axis=1, inplace=True)
#     adata.obs['perturbation'] = adata.obs['perturbation'].astype(str)
#     adata.obs.loc[:,'perturbation'] = adata.obs['perturbation'].astype(str).replace({'non-targeting':'control'}).values
#     adata.var.set_index('gene_name', inplace=True)
# elif dataset.startswith('Jiang'):
#     raise NotImplementedError("Currently not supported for Jiang dataset.")

# print(summarize(adata))

# # filter out cells
# if dataset.startswith('Huang'):
#     print('Skipping filtering cells for Huang datasets')
#     print('Min genes by counts:', adata.obs['n_genes_by_counts'].min())
# else:
#     sc.pp.filter_cells(adata, min_genes=100)
# print(summarize(adata))

# # filter perturbation condition
# ncells_pert = adata.obs.groupby('perturbation', observed=True).size()
# min_cells = 25 if dataset.startswith('Nadig') else 50
# valid_pert = np.array(ncells_pert[ncells_pert >= min_cells].index)
# valid_pert = valid_pert[np.isin(valid_pert, adata.var.index)]
# # TODO: filter inefficient perturbations
# adata = adata[adata.obs['perturbation'].isin(np.append(valid_pert, ['control']))]
# print(summarize(adata))


# # # filter cells by perturbation quantile effect
# # adata = filter_cells_by_pert_effect(adata)
# # print(summarize(adata))

# if not sp.sparse.issparse(adata.X):
#     adata.X = scipy.sparse.csr_matrix(adata.X)

# # Filter genes with less than 100 cells in the control groups, but keep those in adata.obs.index
# gene_filter = (np.sum(adata[adata.obs['perturbation'] == 'control'].X > 0, axis=0) >= 100) | adata.var.index.isin(valid_pert)
# adata = adata[:, gene_filter]

# duplicate_var_names = adata.var_names[adata.var_names.duplicated()]
# print(f"Duplicate var names: {duplicate_var_names}")
# adata = adata[:, ~adata.var_names.duplicated()]
# # sc.pp.filter_genes(adata, min_cells=100)
# # sc.pp.filter_genes(adata, max_counts=10) # this seems to be too strict
# print(summarize(adata))


# # Add information
# adata.obs.loc[:,'dataset'] = dataset
# # TODO: add celltype/pathway information
# # if 'pathway' not in adata.obs:
# if dataset.startswith('Huang') or dataset.startswith('Nadig'):
#     adata.obs.loc[:,'celltype'] = dataset.split('-')[1]    
# elif not dataset.startswith('Jiang'):
#     dict_cts = {
#         'Adamson': 'K562',
#         'Replogle-GW-k562': 'K562',
#         'Replogle-E-k562': 'K562',
#         'Replogle-E-rpe1': 'RPE1',
#         'Frangieh':'melanoma',
#         "Tian-crispra": 'iPSC', 
#         "Tian-crispri": 'iPSC'
#         }
#     adata.obs.loc[:,'celltype'] = dict_cts[dataset]

#     adata.X = scipy.sparse.csr_matrix(adata.X)

# print(summarize(adata))

# col = adata.obs['perturbation'].astype(str)
# perts = sorted(p for p in col.unique() if p != 'control')
# print("The number of perturbations:", len(perts))

# batch_size = 100
# batch_idx = int(sys.argv[1])
# start = (batch_idx - 1) * batch_size
# end = min(batch_idx * batch_size, len(perts))

# batch_perts = perts[start:end]
# print(f"Batch {batch_idx}: {len(batch_perts)} perturbations ({start+1}-{end})")

# for pert in batch_perts:
#     adata_split = adata[col.isin([pert, 'control'])].copy()
#     print(f"[PAIR] pert={pert} | cells={adata_split.n_obs} | genes={adata_split.n_vars}")
#     out_path = os.path.join(path_sc, dataset, f'{pert}.h5ad')
#     adata_split.write_h5ad(out_path)

# adata = sc.read_h5ad(f'{path_sc}{dataset}_overlap_vcc.h5ad')

pert_path = f"./pert_{dataset}.csv"
perts = pd.read_csv(pert_path)["perturbation"].tolist()
batch_idx = int(sys.argv[1])
pert = perts[batch_idx - 1]
print(f"Batch {batch_idx}: loading perturbation '{pert}'")

adata_path = os.path.join(path_sc, dataset, f"{pert}.h5ad")

print(f"Batch {batch_idx}: loading {adata_path}")
adata = sc.read_h5ad(adata_path)
print(f"[LOAD] pert={pert} | cells={adata.n_obs} | genes={adata.n_vars}")

##############################################################################
#
# Save the bulk expression data
#
##############################################################################
# print('Save the bulk expression data')

# For pseudo bulk aggregation
# Aggregate counts of adata.X according to perturbation
# adata_bulk = comp_bulk_expressions_batch(adata, key='perturbation')

# # calculate the library size
# adata_bulk.obs['n_counts'] = adata_bulk.X.sum(axis=1)

# print(adata_bulk)
# adata_bulk.write_h5ad(f'{path_bulk}{dataset}_overlap_vcc.h5ad')



##############################################################################
#
# Save the bulk expression data (after log transformantion)
#
##############################################################################
# print('Save the bulk expression data (after log transformantion)')


# For mean aggreagation
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# # Aggregate counts of adata.X according to perturbation
# adata_bulk = comp_bulk_expressions_batch(adata, key='perturbation')

# # calculate the library size
# adata_bulk.obs['n_counts'] = adata_bulk.X.sum(axis=1)

# print(adata_bulk)
# adata_bulk.write_h5ad(f'{path_bulk_log}{dataset}_overlap_vcc.h5ad')



##############################################################################
#
# Calculate DE genes from single-cell data
#
##############################################################################
print('Calculate DE genes from single-cell data')

path_de = f'de/';os.makedirs(path_de, exist_ok=True)
data_dir = 'sc/'

def calculate_de_genes_group(adata, rankby_abs, test_name='t-test'):
    perturbations = adata.obs['perturbation'].unique()
    tie_correct = 'tie' if test_name == 'wilcoxon' else 't-test'
    sc.tl.rank_genes_groups(
        adata, groupby='perturbation', reference='control',
        n_genes=adata.shape[1], rankby_abs=rankby_abs=='abs',
        method=test_name, tie_correct=tie_correct=='tie',
        use_raw=False, n_jobs=njobs)
    df = pd.DataFrame(adata.uns['rank_genes_groups']['names']).T
    df_rank = df.apply(lambda x: pd.Series(x.index, index=x.values)[adata.var.index], axis=1).astype(int)
    df_rank = df_rank[sorted(df_rank.columns)]
    de_result = {"perturbations": perturbations, "rank_genes_groups": adata.uns['rank_genes_groups'], "df_rank": df_rank}
    return de_result


def calculate_de_genes(adata, path_de, dataset, test_name='t-test'):
    """
    Calculate DE genes from single-cell data using different statistical tests.
    
    Parameters
    ----------
    adata : anndata.AnnData
        The input AnnData object containing single-cell data.
    path_de : str
        The path where DE results will be saved.
    
    Returns
    -------
    None
        Results are saved to disk as pickle files.
    """
    cell_types = adata.obs['celltype'].unique()
    per = [p for p in adata.obs['perturbation'].astype(str).unique() if p != 'control']; assert len(per)==1; pert_name = per[0]
    for rankby_abs in ['abs']: #['abs', 'noabs']
        path = f'{path_de}{test_name}/{rankby_abs}/{dataset}/';os.makedirs(path, exist_ok=True)
        assert len(cell_types) == 1
        de_genes = calculate_de_genes_group(adata, rankby_abs, test_name=test_name)
            
        with open(f'{path}{pert_name}.pkl', 'wb') as f:
            pickle.dump(de_genes, f)



for test_name in ['wilcoxon']:
    calculate_de_genes(adata, path_de, dataset, test_name=test_name)