Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ check_lint:
interrogate .

doctest:
pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
python -m pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py

test:
pytest
python -m pytest

uml:
pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests
Expand Down
3 changes: 0 additions & 3 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import arviz as az

import causalpy.pymc_models as pymc_models
import causalpy.skl_models as skl_models
Expand All @@ -28,8 +27,6 @@
from .experiments.regression_kink import RegressionKink
from .experiments.synthetic_control import SyntheticControl

az.style.use("arviz-darkgrid")

__all__ = [
"__version__",
"DifferenceInDifferences",
Expand Down
16 changes: 10 additions & 6 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from abc import abstractmethod

import arviz as az
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.base import RegressorMixin

Expand Down Expand Up @@ -63,12 +65,14 @@ def plot(self, *args, **kwargs) -> tuple:
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
depending on the model type.
"""
if isinstance(self.model, PyMCModel):
return self._bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self._ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")
# Apply arviz-darkgrid style only during plotting, then revert
with plt.style.context(az.style.library["arviz-darkgrid"]):
if isinstance(self.model, PyMCModel):
return self._bayesian_plot(*args, **kwargs)
elif isinstance(self.model, RegressorMixin):
return self._ols_plot(*args, **kwargs)
else:
raise ValueError("Unsupported model type")

@abstractmethod
def _bayesian_plot(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ dependencies:
- statsmodels
- xarray>=v2022.11.0
- pymc-extras>=0.3.0
- python>=3.11