diff --git a/pertpy/tools/_perturbation_space/_perturbation_space.py b/pertpy/tools/_perturbation_space/_perturbation_space.py index 026fdea0..4336623d 100644 --- a/pertpy/tools/_perturbation_space/_perturbation_space.py +++ b/pertpy/tools/_perturbation_space/_perturbation_space.py @@ -233,11 +233,15 @@ def add( new_perturbation.obs = new_obs for key, value in data["layers"].items(): - key_name = key.removesuffix("_control_diff") if key.endswith("_control_diff") else key + key_name = ( + key.removesuffix("_control_diff") if isinstance(key, str) and key.endswith("_control_diff") else key + ) new_perturbation.layers[key_name] = value for key, value in data["embeddings"].items(): - key_name = key.removesuffix("_control_diff") if key.endswith("_control_diff") else key + key_name = ( + key.removesuffix("_control_diff") if isinstance(key, str) and key.endswith("_control_diff") else key + ) new_perturbation.obsm[key_name] = value new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category") @@ -336,11 +340,15 @@ def subtract( new_perturbation.obs = new_obs for key, value in data["layers"].items(): - key_name = key.removesuffix("_control_diff") if key.endswith("_control_diff") else key + key_name = ( + key.removesuffix("_control_diff") if isinstance(key, str) and key.endswith("_control_diff") else key + ) new_perturbation.layers[key_name] = value for key, value in data["embeddings"].items(): - key_name = key.removesuffix("_control_diff") if key.endswith("_control_diff") else key + key_name = ( + key.removesuffix("_control_diff") if isinstance(key, str) and key.endswith("_control_diff") else key + ) new_perturbation.obsm[key_name] = value new_perturbation.obs[target_col] = new_perturbation.obs_names.astype("category") diff --git a/pertpy/tools/_perturbation_space/_simple.py b/pertpy/tools/_perturbation_space/_simple.py index 6db0374a..5f8f3d95 100644 --- a/pertpy/tools/_perturbation_space/_simple.py +++ b/pertpy/tools/_perturbation_space/_simple.py @@ -171,6 +171,9 @@ def compute( if mode in ps_adata.layers: ps_adata.X = ps_adata.layers[mode] + if None in ps_adata.layers: + del ps_adata.layers[None] + missing_cols = [col for col in original_obs.columns if col not in ps_adata.obs.columns] new_cols_data = {}