diff --git a/pydeseq2/ds.py b/pydeseq2/ds.py index ef0af973..250d087c 100644 --- a/pydeseq2/ds.py +++ b/pydeseq2/ds.py @@ -188,14 +188,15 @@ def __init__( # If the `refit_cooks` attribute of the dds object is True, check that outliers # were actually refitted. - try: - dds.replaced - except AttributeError: - raise AttributeError( - "dds has 'refit_cooks' set to True but Cooks outliers have not been " - "refitted. Please run 'dds.refit()' first or set 'dds.refit_cooks' " - "to False." - ) + if self.dds.refit_cooks: + try: + dds.replaced + except AttributeError: + raise AttributeError( + "dds has 'refit_cooks' set to True but Cooks outliers have not been " + "refitted. Please run 'dds.refit()' first or set 'dds.refit_cooks' " + "to False." + ) def summary(self): """Run the statistical analysis. diff --git a/tests/test_pydeseq2.py b/tests/test_pydeseq2.py index 8cf7186b..3c797c95 100644 --- a/tests/test_pydeseq2.py +++ b/tests/test_pydeseq2.py @@ -54,6 +54,52 @@ def test_deseq(tol=0.02): assert (abs(r_res.padj - res_df.padj) / r_res.padj).max() < tol +def test_deseq_no_refit_cooks(tol=0.02): + """Test that the outputs of the DESeq2 function *without cooks refit* + match those of the original R package, up to a tolerance in relative error. + Note: this is just to check that the workflow runs bug-free, as we expect no outliers + in the synthetic dataset. + """ + + test_path = str(Path(os.path.realpath(tests.__file__)).parent.resolve()) + + counts_df = load_example_data( + modality="raw_counts", + dataset="synthetic", + debug=False, + ) + + clinical_df = load_example_data( + modality="clinical", + dataset="synthetic", + debug=False, + ) + + r_res = pd.read_csv( + os.path.join(test_path, "data/single_factor/r_test_res.csv"), index_col=0 + ) + + dds = DeseqDataSet( + counts_df, clinical_df, design_factors="condition", refit_cooks=False + ) + dds.deseq2() + + res = DeseqStats(dds) + res.summary() + res_df = res.results_df + + # check that the same p-values are NaN + assert (res_df.pvalue.isna() == r_res.pvalue.isna()).all() + assert (res_df.padj.isna() == r_res.padj.isna()).all() + + # Check that the same LFC, p-values and adjusted p-values are found (up to tol) + assert ( + abs(r_res.log2FoldChange - res_df.log2FoldChange) / abs(r_res.log2FoldChange) + ).max() < tol + assert (abs(r_res.pvalue - res_df.pvalue) / r_res.pvalue).max() < tol + assert (abs(r_res.padj - res_df.padj) / r_res.padj).max() < tol + + def test_lfc_shrinkage(tol=0.02): """Test that the outputs of the lfc_shrink function match those of the original R package (starting from the same inputs), up to a tolerance in relative error.