Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add oseries models accessor #49

Merged
merged 1 commit into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions pastastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def add_model(self, ml: Union[ps.Model, dict],
raise ItemInLibraryException(f"Model with name '{name}' "
"already in 'models' library!")
self._clear_cache("_modelnames_cache")
self.oseries_models._add_model(mldict["oseries"]["name"], name)

@staticmethod
def _parse_series_input(series: Union[FrameorSeriesUnion, ps.TimeSeries],
Expand Down Expand Up @@ -520,7 +521,10 @@ def del_models(self, names: Union[list, str]) -> None:
name(s) of the model to delete
"""
for n in self._parse_names(names, libname="models"):
mldict = self.get_models(n, return_dict=True)
oname = mldict["oseries"]["name"]
self._del_item("models", n)
self.oseries_models._del_model(oname, n)
self._clear_cache("_modelnames_cache")

def del_oseries(self, names: Union[list, str]):
Expand Down Expand Up @@ -1371,7 +1375,7 @@ def __repr__(self):
"""Representation of the object is a list of modelnames."""
return self.conn._modelnames_cache.__repr__()

def __getitem__(self, name):
def __getitem__(self, name: str):
"""Get model from store with model name as key.

Parameters
Expand All @@ -1381,7 +1385,7 @@ def __getitem__(self, name):
"""
return self.conn.get_models(name)

def __setitem__(self, name, ml):
def __setitem__(self, name: str, ml: ps.Model):
"""Set item.

Parameters
Expand All @@ -1403,3 +1407,105 @@ def __iter__(self):
model
"""
yield from self.conn.iter_models()


class OseriesModelsAccessor:
"""Object for getting list of model names per oseries.

Provides dict-like access for obtaining models for a certain
location/oseries (i.e. PastaStore.oseries_models["oseries1"]). On
initialization of a Connector this dictionary is built by running
through all models and storing a list of model names per oseries name.
"""

def __init__(self, conn):
"""Initialize oseries models accessor.

Parameters
----------
conn : pastastore.*Connector
pastastore Connector object
"""
self.conn = conn
self.oseries_models_dict = {}
self._build_dict()

def __repr__(self):
"""String represenation.

Returns
-------
str
string representation of oseries_models dictionary.
"""
return self.oseries_models_dict.__repr__()

def __getitem__(self, name: str):
"""Get list of model names with oseries name as key.

Parameters
----------
name : str
name of oseries

Returns
-------
list
list of model names (str)
"""
return self.oseries_models_dict[name]

def __setitem__(self, oseries_name: str, model_name: str):
"""Add model name to oseries_models dictionary.

Parameters
----------
oseries_name : str
name of oseries
model_name : str
name of model
"""
self._add_model(oseries_name, model_name)

def _build_dict(self):
"""Build dictionary with list of model names per oseries.
"""
if self.conn.n_models > 0:
for mlnam in tqdm(self.conn._modelnames_cache,
desc="Build oseries_model dictionary"):
ml = self.conn.get_models(mlnam, return_dict=True)
onam = ml["oseries"]["name"]
if onam in self.oseries_models_dict:
self.oseries_models_dict[onam].append(mlnam)
else:
self.oseries_models_dict[onam] = [mlnam]

def _add_model(self, oseries_name: str, model_name: str):
"""Add model name to list for oseries.

Parameters
----------
oseries_name : str
name of oseries
model_name : str
name of models
"""
if oseries_name in self.oseries_models_dict:
if model_name not in self.oseries_models_dict[oseries_name]:
self.oseries_models_dict[oseries_name].append(model_name)
else:
self.oseries_models_dict[oseries_name] = [model_name]

def _del_model(self, oseries_name: str, model_name: str):
"""Delete model name from list for oseries.

Parameters
----------
oseries_name : str
name of oseries
model_name : str
name of model
"""
self.oseries_models_dict[oseries_name].remove(model_name)
if len(self.oseries_models_dict[oseries_name]) == 0:
del self.oseries_models_dict[oseries_name]
7 changes: 6 additions & 1 deletion pastastore/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pandas as pd
from pastas.io.pas import PastasEncoder, pastas_hook

from .base import BaseConnector, ConnectorUtil, ModelAccessor
from .base import (BaseConnector, ConnectorUtil,
ModelAccessor, OseriesModelsAccessor)
from .util import _custom_warning

FrameorSeriesUnion = Union[pd.DataFrame, pd.Series]
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, name: str, connstr: str):
self.arc = arctic.Arctic(connstr)
self._initialize()
self.models = ModelAccessor(self)
self.oseries_models = OseriesModelsAccessor(self)

def _initialize(self) -> None:
"""Internal method to initalize the libraries."""
Expand Down Expand Up @@ -212,6 +214,7 @@ def __init__(self, name: str, path: str):
self.libs: dict = {}
self._initialize()
self.models = ModelAccessor(self)
self.oseries_models = OseriesModelsAccessor(self)

def _initialize(self) -> None:
"""Internal method to initalize the libraries (stores)."""
Expand Down Expand Up @@ -405,6 +408,7 @@ def __init__(self, name: str):
for val in self._default_library_names:
setattr(self, "lib_" + val, {})
self.models = ModelAccessor(self)
self.oseries_models = OseriesModelsAccessor(self)

def _get_library(self, libname: str):
"""Get reference to dictionary holding data.
Expand Down Expand Up @@ -539,6 +543,7 @@ def __init__(self, name: str, path: str):
self.path = os.path.abspath(path)
self._initialize()
self.models = ModelAccessor(self)
self.oseries_models = OseriesModelsAccessor(self)

def _initialize(self) -> None:
"""Internal method to initialize the libraries."""
Expand Down
40 changes: 4 additions & 36 deletions pastastore/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def n_stresses(self):
def n_models(self):
return self.conn.n_models

@property
def oseries_models(self):
return self.conn.oseries_models

def __repr__(self):
"""Representation string of the object."""
return f"<PastaStore> {self.name}: \n - " + self.conn.__str__()
Expand Down Expand Up @@ -869,39 +873,3 @@ def get_model_timeseries_names(
return structure.dropna(how="all", axis=1)
else:
return structure

def get_oseries_model_list(self,
modelnames: Optional[Union[list, str]] = None,
progressbar: bool = True) -> Dict:
"""Get a list of model names for each oseries.

Parameters
----------
modelnames : Optional[Union[list, str]], optional
list of modelnames to consider, by default None, which
defaults to all models
progressbar : bool, optional
show progressbar, by default False

Returns
-------
oseries_model_dict : dict
dictionary with oseries names as keys, and list of model names
as values
"""

modelnames = self.conn._parse_names(modelnames, libname="models")

oseries_model_dict = {}

for mlnam in (tqdm(modelnames, desc="Get model oseries names")
if progressbar else modelnames):
iml = self.get_models(mlnam, returndict=True)
oname = iml["oseries"]["name"]
if oname in oseries_model_dict:
oseries_model_dict[oname] = oseries_model_dict[oname].append(
oname)
else:
oseries_model_dict[oname] = [oname]

return oseries_model_dict
30 changes: 26 additions & 4 deletions tests/test_003_pastastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,37 @@ def test_model_accessor(request, pstore):
# getter
ml = pstore.models["oseries1"]
# setter
pstore.models["oseries2"] = ml
pstore.models["oseries1_2"] = ml
# iter
mnames = [ml.name for ml in pstore.models]
try:
assert len(mnames) == 2
assert mnames[0] in ["oseries1", "oseries2"]
assert mnames[1] in ["oseries1", "oseries2"]
assert mnames[0] in ["oseries1", "oseries1_2"]
assert mnames[1] in ["oseries1", "oseries1_2"]
finally:
pstore.del_models("oseries2")
pstore.del_models("oseries1_2")
return


@pytest.mark.dependency()
def test_oseries_model_accessor(request, pstore):
depends(request, [f"test_store_model[{pstore.type}]"])
# repr
pstore.oseries_models.__repr__()
# get model names
ml = pstore.models["oseries1"]
ml_list1 = pstore.oseries_models["oseries1"]
assert len(ml_list1) == 1

# add model
pstore.models["oseries1_2"] = ml
ml_list2 = pstore.oseries_models["oseries1"]
assert len(ml_list2) == 2

# delete model
pstore.del_models("oseries1_2")
ml_list3 = pstore.oseries_models["oseries1"]
assert len(ml_list3) == 1
return


Expand Down