From 3745b8d489ba173ecdc9c3041ba31e7e9ec8558d Mon Sep 17 00:00:00 2001 From: dbrakenhoff Date: Mon, 6 Dec 2021 14:20:58 +0100 Subject: [PATCH] add oseries models accessor - keep a list of model names per oseries - list is updated when models are added deleted - accessor has dict-like access - initialized when instantiating connector (2 ms per model for arctic) - add test --- pastastore/base.py | 110 ++++++++++++++++++++++++++++++++++- pastastore/connectors.py | 7 ++- pastastore/store.py | 40 ++----------- tests/test_003_pastastore.py | 30 ++++++++-- 4 files changed, 144 insertions(+), 43 deletions(-) diff --git a/pastastore/base.py b/pastastore/base.py index 6c7a2cd0..0a436cbe 100644 --- a/pastastore/base.py +++ b/pastastore/base.py @@ -405,6 +405,7 @@ def add_model(self, ml: Union[ps.Model, dict], raise ItemInLibraryException(f"Model with name '{name}' " "already in 'models' library!") self._clear_cache("_modelnames_cache") + self.oseries_models._add_model(mldict["oseries"]["name"], name) @staticmethod def _parse_series_input(series: Union[FrameorSeriesUnion, ps.TimeSeries], @@ -520,7 +521,10 @@ def del_models(self, names: Union[list, str]) -> None: name(s) of the model to delete """ for n in self._parse_names(names, libname="models"): + mldict = self.get_models(n, return_dict=True) + oname = mldict["oseries"]["name"] self._del_item("models", n) + self.oseries_models._del_model(oname, n) self._clear_cache("_modelnames_cache") def del_oseries(self, names: Union[list, str]): @@ -1371,7 +1375,7 @@ def __repr__(self): """Representation of the object is a list of modelnames.""" return self.conn._modelnames_cache.__repr__() - def __getitem__(self, name): + def __getitem__(self, name: str): """Get model from store with model name as key. Parameters @@ -1381,7 +1385,7 @@ def __getitem__(self, name): """ return self.conn.get_models(name) - def __setitem__(self, name, ml): + def __setitem__(self, name: str, ml: ps.Model): """Set item. Parameters @@ -1403,3 +1407,105 @@ def __iter__(self): model """ yield from self.conn.iter_models() + + +class OseriesModelsAccessor: + """Object for getting list of model names per oseries. + + Provides dict-like access for obtaining models for a certain + location/oseries (i.e. PastaStore.oseries_models["oseries1"]). On + initialization of a Connector this dictionary is built by running + through all models and storing a list of model names per oseries name. + """ + + def __init__(self, conn): + """Initialize oseries models accessor. + + Parameters + ---------- + conn : pastastore.*Connector + pastastore Connector object + """ + self.conn = conn + self.oseries_models_dict = {} + self._build_dict() + + def __repr__(self): + """String represenation. + + Returns + ------- + str + string representation of oseries_models dictionary. + """ + return self.oseries_models_dict.__repr__() + + def __getitem__(self, name: str): + """Get list of model names with oseries name as key. + + Parameters + ---------- + name : str + name of oseries + + Returns + ------- + list + list of model names (str) + """ + return self.oseries_models_dict[name] + + def __setitem__(self, oseries_name: str, model_name: str): + """Add model name to oseries_models dictionary. + + Parameters + ---------- + oseries_name : str + name of oseries + model_name : str + name of model + """ + self._add_model(oseries_name, model_name) + + def _build_dict(self): + """Build dictionary with list of model names per oseries. + """ + if self.conn.n_models > 0: + for mlnam in tqdm(self.conn._modelnames_cache, + desc="Build oseries_model dictionary"): + ml = self.conn.get_models(mlnam, return_dict=True) + onam = ml["oseries"]["name"] + if onam in self.oseries_models_dict: + self.oseries_models_dict[onam].append(mlnam) + else: + self.oseries_models_dict[onam] = [mlnam] + + def _add_model(self, oseries_name: str, model_name: str): + """Add model name to list for oseries. + + Parameters + ---------- + oseries_name : str + name of oseries + model_name : str + name of models + """ + if oseries_name in self.oseries_models_dict: + if model_name not in self.oseries_models_dict[oseries_name]: + self.oseries_models_dict[oseries_name].append(model_name) + else: + self.oseries_models_dict[oseries_name] = [model_name] + + def _del_model(self, oseries_name: str, model_name: str): + """Delete model name from list for oseries. + + Parameters + ---------- + oseries_name : str + name of oseries + model_name : str + name of model + """ + self.oseries_models_dict[oseries_name].remove(model_name) + if len(self.oseries_models_dict[oseries_name]) == 0: + del self.oseries_models_dict[oseries_name] diff --git a/pastastore/connectors.py b/pastastore/connectors.py index 99bd4d04..3c2a2a36 100644 --- a/pastastore/connectors.py +++ b/pastastore/connectors.py @@ -8,7 +8,8 @@ import pandas as pd from pastas.io.pas import PastasEncoder, pastas_hook -from .base import BaseConnector, ConnectorUtil, ModelAccessor +from .base import (BaseConnector, ConnectorUtil, + ModelAccessor, OseriesModelsAccessor) from .util import _custom_warning FrameorSeriesUnion = Union[pd.DataFrame, pd.Series] @@ -46,6 +47,7 @@ def __init__(self, name: str, connstr: str): self.arc = arctic.Arctic(connstr) self._initialize() self.models = ModelAccessor(self) + self.oseries_models = OseriesModelsAccessor(self) def _initialize(self) -> None: """Internal method to initalize the libraries.""" @@ -212,6 +214,7 @@ def __init__(self, name: str, path: str): self.libs: dict = {} self._initialize() self.models = ModelAccessor(self) + self.oseries_models = OseriesModelsAccessor(self) def _initialize(self) -> None: """Internal method to initalize the libraries (stores).""" @@ -405,6 +408,7 @@ def __init__(self, name: str): for val in self._default_library_names: setattr(self, "lib_" + val, {}) self.models = ModelAccessor(self) + self.oseries_models = OseriesModelsAccessor(self) def _get_library(self, libname: str): """Get reference to dictionary holding data. @@ -539,6 +543,7 @@ def __init__(self, name: str, path: str): self.path = os.path.abspath(path) self._initialize() self.models = ModelAccessor(self) + self.oseries_models = OseriesModelsAccessor(self) def _initialize(self) -> None: """Internal method to initialize the libraries.""" diff --git a/pastastore/store.py b/pastastore/store.py index e6c18a98..c48579f3 100644 --- a/pastastore/store.py +++ b/pastastore/store.py @@ -97,6 +97,10 @@ def n_stresses(self): def n_models(self): return self.conn.n_models + @property + def oseries_models(self): + return self.conn.oseries_models + def __repr__(self): """Representation string of the object.""" return f" {self.name}: \n - " + self.conn.__str__() @@ -869,39 +873,3 @@ def get_model_timeseries_names( return structure.dropna(how="all", axis=1) else: return structure - - def get_oseries_model_list(self, - modelnames: Optional[Union[list, str]] = None, - progressbar: bool = True) -> Dict: - """Get a list of model names for each oseries. - - Parameters - ---------- - modelnames : Optional[Union[list, str]], optional - list of modelnames to consider, by default None, which - defaults to all models - progressbar : bool, optional - show progressbar, by default False - - Returns - ------- - oseries_model_dict : dict - dictionary with oseries names as keys, and list of model names - as values - """ - - modelnames = self.conn._parse_names(modelnames, libname="models") - - oseries_model_dict = {} - - for mlnam in (tqdm(modelnames, desc="Get model oseries names") - if progressbar else modelnames): - iml = self.get_models(mlnam, returndict=True) - oname = iml["oseries"]["name"] - if oname in oseries_model_dict: - oseries_model_dict[oname] = oseries_model_dict[oname].append( - oname) - else: - oseries_model_dict[oname] = [oname] - - return oseries_model_dict diff --git a/tests/test_003_pastastore.py b/tests/test_003_pastastore.py index 84397974..d118d4d8 100644 --- a/tests/test_003_pastastore.py +++ b/tests/test_003_pastastore.py @@ -82,15 +82,37 @@ def test_model_accessor(request, pstore): # getter ml = pstore.models["oseries1"] # setter - pstore.models["oseries2"] = ml + pstore.models["oseries1_2"] = ml # iter mnames = [ml.name for ml in pstore.models] try: assert len(mnames) == 2 - assert mnames[0] in ["oseries1", "oseries2"] - assert mnames[1] in ["oseries1", "oseries2"] + assert mnames[0] in ["oseries1", "oseries1_2"] + assert mnames[1] in ["oseries1", "oseries1_2"] finally: - pstore.del_models("oseries2") + pstore.del_models("oseries1_2") + return + + +@pytest.mark.dependency() +def test_oseries_model_accessor(request, pstore): + depends(request, [f"test_store_model[{pstore.type}]"]) + # repr + pstore.oseries_models.__repr__() + # get model names + ml = pstore.models["oseries1"] + ml_list1 = pstore.oseries_models["oseries1"] + assert len(ml_list1) == 1 + + # add model + pstore.models["oseries1_2"] = ml + ml_list2 = pstore.oseries_models["oseries1"] + assert len(ml_list2) == 2 + + # delete model + pstore.del_models("oseries1_2") + ml_list3 = pstore.oseries_models["oseries1"] + assert len(ml_list3) == 1 return