diff --git a/pydeseq2/dds.py b/pydeseq2/dds.py index 1ba6c279..a9b0ea12 100644 --- a/pydeseq2/dds.py +++ b/pydeseq2/dds.py @@ -38,7 +38,6 @@ # Ignore AnnData's FutureWarning about implicit data conversion. warnings.simplefilter("ignore", FutureWarning) - # TODO: implement a quiet (non-verbose) mode ? @@ -764,32 +763,32 @@ def _replace_outliers(self) -> None: idx = self.layers["cooks"] > cooks_cutoff self.varm["replaced"] = idx.any(axis=0) - # Compute replacement counts: trimmed means * size_factors - self.counts_to_refit = self[:, self.varm["replaced"]].copy() - - trim_base_mean = pd.DataFrame( - cast( - np.ndarray, - trimmed_mean( - self.counts_to_refit.X / self.obsm["size_factors"][:, None], - trim=0.2, - axis=0, + if sum(self.varm["replaced"] > 0): + # Compute replacement counts: trimmed means * size_factors + self.counts_to_refit = self[:, self.varm["replaced"]].copy() + + trim_base_mean = pd.DataFrame( + cast( + np.ndarray, + trimmed_mean( + self.counts_to_refit.X / self.obsm["size_factors"][:, None], + trim=0.2, + axis=0, + ), ), - ), - index=self.counts_to_refit.var_names, - ) - - replacement_counts = ( - pd.DataFrame( - trim_base_mean.values * self.obsm["size_factors"], index=self.counts_to_refit.var_names, - columns=self.counts_to_refit.obs_names, ) - .astype(int) - .T - ) - if sum(self.varm["replaced"] > 0): + replacement_counts = ( + pd.DataFrame( + trim_base_mean.values * self.obsm["size_factors"], + index=self.counts_to_refit.var_names, + columns=self.counts_to_refit.obs_names, + ) + .astype(int) + .T + ) + self.counts_to_refit.X[ self.obsm["replaceable"][:, None] & idx[:, self.varm["replaced"]] ] = replacement_counts.values[ @@ -804,15 +803,17 @@ def _refit_without_outliers( self.refit_cooks ), "Trying to refit Cooks outliers but the 'refit_cooks' flag is set to False" - # Check that replaced counts are available. If not, compute them. - if not hasattr(self, "counts_to_refit"): + # Check that _replace_outliers() was previously run. + if "replaced" not in self.varm: self._replace_outliers() # Only refit genes for which replacing outliers hasn't resulted in all zeroes new_all_zeroes = (self.counts_to_refit.X == 0).all(axis=0) self.new_all_zeroes_genes = self.counts_to_refit.var_names[new_all_zeroes] - self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy() + if (~new_all_zeroes).sum() == 0: # if no gene can be refit, we can skip + return + self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy() if isinstance(self.new_all_zeroes_genes, pd.MultiIndex): raise ValueError @@ -875,10 +876,13 @@ def _refit_without_outliers( self.layers["replace_cooks"] = replace_cooks # Take into account new all-zero genes - self[:, self.new_all_zeroes_genes].varm["_normed_means"] = np.zeros( - new_all_zeroes.sum() - ) - self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros(new_all_zeroes.sum()) + if (new_all_zeroes).sum() > 0: + self[:, self.new_all_zeroes_genes].varm["_normed_means"] = np.zeros( + new_all_zeroes.sum() + ) + self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros( + new_all_zeroes.sum() + ) def _fit_iterate_size_factors(self, niter: int = 10, quant: float = 0.95) -> None: """ diff --git a/pydeseq2/utils.py b/pydeseq2/utils.py index fb256899..0fd653bd 100644 --- a/pydeseq2/utils.py +++ b/pydeseq2/utils.py @@ -233,7 +233,7 @@ def build_design_matrix( if intercept: design_matrix.insert(0, "intercept", 1) - return design_matrix + return design_matrix.astype("int") def dispersion_trend(