Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3401fab
type hinting for dds.py
illumination-k Feb 2, 2023
ae00ca1
type hinting for ds.py
illumination-k Feb 2, 2023
d9b0e71
type hinting for grid_search.py
illumination-k Feb 2, 2023
34445c4
type hinting for utils.py
illumination-k Feb 2, 2023
6579aca
precommit
illumination-k Feb 2, 2023
af128e0
fix typo
illumination-k Feb 2, 2023
7e428c7
run precommit script manually and fix
illumination-k Feb 2, 2023
4cbcad7
update docs and conf
illumination-k Feb 3, 2023
1d46f73
remove auto_example
illumination-k Feb 9, 2023
7534f47
pull upstream and merge confilict, removing api/docstring
illumination-k Feb 14, 2023
d98e99e
fix for python v3.8
illumination-k Feb 14, 2023
fda48d7
add mypy to precommit
maikia Feb 14, 2023
20a9c0f
merge with main
maikia Feb 14, 2023
3213af6
removing bugs
maikia Feb 15, 2023
2b82daa
merge with main
maikia Feb 15, 2023
b40b88a
resolve merge conflicts
Feb 16, 2023
5f44d01
resolve merge conflicts
Feb 16, 2023
1866117
resolve merge conflicts
Feb 16, 2023
a9ebaef
Revert "resolve merge conflicts"
Feb 17, 2023
7a8c230
Revert "resolve merge conflicts"
Feb 17, 2023
07ef994
Revert "resolve merge conflicts"
Feb 17, 2023
512e26d
fix: resolve merge conflicts
Feb 17, 2023
b37c7fd
fix: handle mypy . --check-untyped-def errors
Feb 17, 2023
55194ba
fix: black linting
Feb 17, 2023
a3476c2
docs: add .rst files since these are not generated automatically
Feb 17, 2023
d534824
fix: typo
Feb 17, 2023
5f278a6
fix: typo
BorisMuzellec Feb 17, 2023
ad1c948
fix: remove unneeded # type: ignore
Feb 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ repos:
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
hooks:
- id: mypy
exclude: ^(tests/|docs/source/conf.py)
- repo: https://github.com/nbQA-dev/nbQA #black and isort for Jupyter notebooks
rev: 1.4.0
hooks:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@
# The following elements are the link that auto doc were not able to do
nitpick_ignore = [
("py:class", "pd.Series"),
# ("py:class", "anndata.AnnData"),
# ("py:class", "anndata._core.anndata.AnnData"),
("py:class", "pd.DataFrame"),
("py:class", "ndarray"),
("py:class", "numpy._typing._generic_alias.ScalarType"),
("py:class", "pydantic.main.BaseModel"),
("py:class", "torch.nn.modules.module.Module"),
("py:class", "torch.nn.modules.loss._Loss"),
Expand Down
91 changes: 52 additions & 39 deletions pydeseq2/dds.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import time
import warnings
from typing import List
from typing import Optional
from typing import Union
from typing import cast

import anndata as ad
import anndata as ad # type: ignore
import numpy as np
import numpy.typing as npt
import pandas as pd
import statsmodels.api as sm
from joblib import Parallel
import statsmodels.api as sm # type: ignore
from joblib import Parallel # type: ignore
from joblib import delayed
from joblib import parallel_backend
from scipy.special import polygamma
from scipy.stats import f
from scipy.special import polygamma # type: ignore
from scipy.stats import f # type: ignore
from scipy.stats import norm
from statsmodels.tools.sm_exceptions import DomainWarning
from statsmodels.tools.sm_exceptions import DomainWarning # type: ignore

from pydeseq2.preprocessing import deseq2_norm
from pydeseq2.utils import build_design_matrix
Expand Down Expand Up @@ -146,20 +151,20 @@ class DeseqDataSet(ad.AnnData):

def __init__(
self,
counts,
clinical,
design_factors="condition",
reference_level=None,
min_mu=0.5,
min_disp=1e-8,
max_disp=10.0,
refit_cooks=True,
min_replicates=7,
beta_tol=1e-8,
n_cpus=None,
batch_size=128,
joblib_verbosity=0,
):
counts: pd.DataFrame,
clinical: pd.DataFrame,
design_factors: Union[str, List[str]] = "condition",
reference_level: Optional[str] = None,
min_mu: float = 0.5,
min_disp: float = 1e-8,
max_disp: float = 10.0,
refit_cooks: bool = True,
min_replicates: int = 7,
beta_tol: float = 1e-8,
n_cpus: Optional[int] = None,
batch_size: int = 128,
joblib_verbosity: int = 0,
) -> None:

# Test counts before going further
test_valid_counts(counts)
Expand Down Expand Up @@ -195,7 +200,7 @@ def __init__(
self.batch_size = batch_size
self.joblib_verbosity = joblib_verbosity

def deseq2(self):
def deseq2(self) -> None:
"""Perform dispersion and log fold-change (LFC) estimation.

Wrapper for the first part of the PyDESeq2 pipeline.
Expand All @@ -222,7 +227,7 @@ def deseq2(self):
# for genes that had outliers replaced
self.refit()

def fit_size_factors(self):
def fit_size_factors(self) -> None:
"""Fit sample-wise deseq2 normalization (size) factors.

Uses the median-of-ratios method.
Expand All @@ -233,7 +238,7 @@ def fit_size_factors(self):
end = time.time()
print(f"... done in {end - start:.2f} seconds.\n")

def fit_genewise_dispersions(self):
def fit_genewise_dispersions(self) -> None:
"""Fit gene-wise dispersion estimates.

Fits a negative binomial per gene, independently.
Expand All @@ -247,6 +252,9 @@ def fit_genewise_dispersions(self):
self.non_zero_idx = np.arange(self.n_vars)[self.varm["non_zero"]]
self.non_zero_genes = self.var_names[self.varm["non_zero"]]

if isinstance(self.non_zero_genes, pd.MultiIndex):
raise ValueError("non_zero_genes should not be a MultiIndex")

# Fit "method of moments" dispersion estimates
self._fit_MoM_dispersions()

Expand Down Expand Up @@ -333,7 +341,7 @@ def fit_genewise_dispersions(self):
self.varm["_genewise_converged"] = np.full(self.n_vars, np.NaN)
self.varm["_genewise_converged"][self.varm["non_zero"]] = l_bfgs_b_converged_

def fit_dispersion_trend(self):
def fit_dispersion_trend(self) -> None:
r"""Fit the dispersion trend coefficients.

.. math:: f(\mu) = \alpha_1/\mu + a_0.
Expand Down Expand Up @@ -374,7 +382,6 @@ def fit_dispersion_trend(self):
coeffs = pd.Series([1.0, 1.0])

while (np.log(np.abs(coeffs / old_coeffs)) ** 2).sum() >= 1e-6:

glm_gamma = sm.GLM(
targets.values,
covariates.values,
Expand Down Expand Up @@ -411,7 +418,7 @@ def fit_dispersion_trend(self):
self.uns["trend_coeffs"],
)

def fit_dispersion_prior(self):
def fit_dispersion_prior(self) -> None:
"""Fit dispersion variance priors and standard deviation of log-residuals.

The computation is based on genes whose dispersions are above 100 * min_disp.
Expand All @@ -435,6 +442,7 @@ def fit_dispersion_prior(self):
above_min_disp = self[:, self.non_zero_genes].varm["genewise_dispersions"] >= (
100 * self.min_disp
)

self.uns["_squared_logres"] = np.median(
np.abs(disp_residuals[above_min_disp])
) ** 2 / norm.ppf(0.75)
Expand All @@ -443,7 +451,7 @@ def fit_dispersion_prior(self):
0.25,
)

def fit_MAP_dispersions(self):
def fit_MAP_dispersions(self) -> None:
"""Fit Maximum a Posteriori dispersion estimates.

After MAP dispersions are fit, filter genes for which we don't apply shrinkage.
Expand Down Expand Up @@ -499,7 +507,7 @@ def fit_MAP_dispersions(self):
"genewise_dispersions"
][self.varm["_outlier_genes"]]

def fit_LFC(self):
def fit_LFC(self) -> None:
"""Fit log fold change (LFC) coefficients.

In the 2-level setting, the intercept corresponds to the base mean,
Expand Down Expand Up @@ -561,7 +569,7 @@ def fit_LFC(self):
self.varm["_LFC_converged"] = np.full(self.n_vars, np.NaN)
self.varm["_LFC_converged"][self.varm["non_zero"]] = converged_

def calculate_cooks(self):
def calculate_cooks(self) -> None:
"""Compute Cook's distance for outlier detection.

Measures the contribution of a single entry to the output of LFC estimation.
Expand All @@ -572,16 +580,15 @@ def calculate_cooks(self):
self.fit_MAP_dispersions()

num_vars = self.obsm["design_matrix"].shape[-1]
nonzero_data = self[:, self.non_zero_genes]

# Keep only non-zero genes
nonzero_data = self[:, self.non_zero_genes]
normed_counts = pd.DataFrame(
nonzero_data.X / self.obsm["size_factors"][:, None],
index=self.obs_names,
columns=self.non_zero_genes,
)

# dispersions = pd.Series(np.NaN, index=self.var_names)
dispersions = robust_method_of_moments_disp(
normed_counts, self.obsm["design_matrix"]
)
Expand All @@ -601,7 +608,7 @@ def calculate_cooks(self):
squared_pearson_res / num_vars * diag_mul
)

def refit(self):
def refit(self) -> None:
"""Refit Cook outliers.

Replace values that are filtered out based on the Cooks distance with imputed
Expand All @@ -615,7 +622,7 @@ def refit(self):
# Refit dispersions and LFCs for genes that had outliers replaced
self._refit_without_outliers()

def _fit_MoM_dispersions(self):
def _fit_MoM_dispersions(self) -> None:
""" "Rough method of moments" initial dispersions fit.
Estimates are the max of "robust" and "method of moments" estimates.
"""
Expand All @@ -637,7 +644,7 @@ def _fit_MoM_dispersions(self):
alpha_hat, self.min_disp, self.max_disp
)

def _replace_outliers(self):
def _replace_outliers(self) -> None:
"""Replace values that are filtered out based
on the Cooks distance with imputed values.
"""
Expand Down Expand Up @@ -677,10 +684,13 @@ def _replace_outliers(self):
self.counts_to_refit = self[:, self.varm["replaced"]].copy()

trim_base_mean = pd.DataFrame(
trimmed_mean(
self.counts_to_refit.X / self.obsm["size_factors"][:, None],
trim=0.2,
axis=0,
cast(
npt.NDArray,
trimmed_mean(
self.counts_to_refit.X / self.obsm["size_factors"][:, None],
trim=0.2,
axis=0,
),
),
index=self.counts_to_refit.var_names,
)
Expand All @@ -703,7 +713,7 @@ def _replace_outliers(self):

def _refit_without_outliers(
self,
):
) -> None:
"""Re-run the whole DESeq2 pipeline with replaced outliers."""
assert (
self.refit_cooks
Expand All @@ -718,6 +728,9 @@ def _refit_without_outliers(
self.new_all_zeroes_genes = self.counts_to_refit.var_names[new_all_zeroes]
self.counts_to_refit = self.counts_to_refit[:, ~new_all_zeroes].copy()

if isinstance(self.new_all_zero_genes, pd.MultiIndex):
raise ValueError

sub_dds = DeseqDataSet(
counts=pd.DataFrame(
self.counts_to_refit.X,
Expand Down
Loading