Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
# Ignore AnnData's FutureWarning about implicit data conversion.
warnings.simplefilter("ignore", FutureWarning)


# TODO: implement a quiet (non-verbose) mode ?


Expand Down Expand Up @@ -764,32 +763,32 @@ def _replace_outliers(self) -> None:
idx = self.layers["cooks"] > cooks_cutoff
self.varm["replaced"] = idx.any(axis=0)

# Compute replacement counts: trimmed means * size_factors
self.counts_to_refit = self[:, self.varm["replaced"]].copy()

trim_base_mean = pd.DataFrame(
cast(
np.ndarray,
trimmed_mean(
self.counts_to_refit.X / self.obsm["size_factors"][:, None],
trim=0.2,
axis=0,
if sum(self.varm["replaced"] > 0):
# Compute replacement counts: trimmed means * size_factors
self.counts_to_refit = self[:, self.varm["replaced"]].copy()

trim_base_mean = pd.DataFrame(
cast(
np.ndarray,
trimmed_mean(
self.counts_to_refit.X / self.obsm["size_factors"][:, None],
trim=0.2,
axis=0,
),
),
),
index=self.counts_to_refit.var_names,
)

replacement_counts = (
pd.DataFrame(
trim_base_mean.values * self.obsm["size_factors"],
index=self.counts_to_refit.var_names,
columns=self.counts_to_refit.obs_names,
)
.astype(int)
.T
)

if sum(self.varm["replaced"] > 0):
replacement_counts = (
pd.DataFrame(
trim_base_mean.values * self.obsm["size_factors"],
index=self.counts_to_refit.var_names,
columns=self.counts_to_refit.obs_names,
)
.astype(int)
.T
)

self.counts_to_refit.X[
self.obsm["replaceable"][:, None] & idx[:, self.varm["replaced"]]
] = replacement_counts.values[
Expand All @@ -804,15 +803,17 @@ def _refit_without_outliers(
self.refit_cooks
), "Trying to refit Cooks outliers but the 'refit_cooks' flag is set to False"

# Check that replaced counts are available. If not, compute them.
if not hasattr(self, "counts_to_refit"):
# Check that _replace_outliers() was previously run.
if "replaced" not in self.varm:
self._replace_outliers()

# Only refit genes for which replacing outliers hasn't resulted in all zeroes
new_all_zeroes = (self.counts_to_refit.X == 0).all(axis=0)
self.new_all_zeroes_genes = self.counts_to_refit.var_names[new_all_zeroes]
self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy()
if (~new_all_zeroes).sum() == 0: # if no gene can be refit, we can skip
return

self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy()
if isinstance(self.new_all_zeroes_genes, pd.MultiIndex):
raise ValueError

Expand Down Expand Up @@ -875,10 +876,13 @@ def _refit_without_outliers(

self.layers["replace_cooks"] = replace_cooks
# Take into account new all-zero genes
self[:, self.new_all_zeroes_genes].varm["_normed_means"] = np.zeros(
new_all_zeroes.sum()
)
self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros(new_all_zeroes.sum())
if (new_all_zeroes).sum() > 0:
self[:, self.new_all_zeroes_genes].varm["_normed_means"] = np.zeros(
new_all_zeroes.sum()
)
self[:, self.new_all_zeroes_genes].varm["LFC"] = np.zeros(
new_all_zeroes.sum()
)

def _fit_iterate_size_factors(self, niter: int = 10, quant: float = 0.95) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion pydeseq2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def build_design_matrix(

if intercept:
design_matrix.insert(0, "intercept", 1)
return design_matrix
return design_matrix.astype("int")


def dispersion_trend(
Expand Down