Skip to content

Commit

Permalink
Fix observable trajectories on HDF5 load (#502)
Browse files Browse the repository at this point in the history
* Save Observable species, coeffs in JSON export when `include_netgen` is True
* Recompute observables and their trajectories when loading old JSON exports
  • Loading branch information
alubbock committed Jun 4, 2020
1 parent dc33305 commit 7391df8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
7 changes: 7 additions & 0 deletions pysb/export/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ def encode_reaction(cls, rxn):
else str(rxn['rate'])
return rxn

@classmethod
def encode_observable(cls, obs):
o = super(PySBJSONWithNetworkEncoder, cls).encode_observable(obs)
o['species'] = obs.species
o['coefficients'] = obs.coefficients
return o

@classmethod
def encode_model(cls, model):
d = super(PySBJSONWithNetworkEncoder, cls).encode_model(model)
Expand Down
32 changes: 31 additions & 1 deletion pysb/importers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from pysb.core import RuleExpression, ReactionPattern, ComplexPattern, \
MonomerPattern, MultiState, ANY, WILD, Parameter, Expression
from pysb.annotation import Annotation
from pysb.pattern import SpeciesPatternMatcher
import sympy
import collections
from collections.abc import Mapping
import json
import re
import warnings
from sympy.parsing.sympy_parser import parse_expr
try:
basestring
Expand Down Expand Up @@ -105,11 +107,16 @@ def decode_derived_expression(self, expr):
return self.decode_expression(expr, derived=True)

def decode_observable(self, obs):
self.b.observable(
o = self.b.observable(
obs['name'],
self.decode_reaction_pattern(obs['reaction_pattern']),
obs['match']
)
try:
o.coefficients = obs['coefficients']
o.species = obs['species']
except KeyError:
pass

def decode_monomer_pattern(self, mp):
mon = self._modelget(mp['monomer'])
Expand Down Expand Up @@ -234,6 +241,29 @@ def decode(self, s):
for component in res.get(component_type, []):
decoder(component)

if self.b.model.reactions and self.b.model.observables \
and 'species' not in res['observables'][0]:

# We have network, need to regenerate Observable species and coeffs
warnings.warn(
'This SimulationResult file is missing Observable species and '
'coefficients data. These will be generated now - we recommend '
'you re-save your SimulationResult file to avoid this warning.'
)

for obs in self.b.model.observables:
if obs.match in ('molecules', 'species'):
obs_matches = SpeciesPatternMatcher(self.b.model).match(
obs.reaction_pattern, index=True, counts=True)
sp, vals = zip(*sorted(obs_matches.items()))
obs.species = list(sp)
if obs.match == 'molecules':
obs.coefficients = list(vals)
else:
obs.coefficients = [1] * len(obs_matches.values())
else:
raise ValueError(f'Unknown obs.match value: {obs.match}')

return self.b.model


Expand Down
6 changes: 6 additions & 0 deletions pysb/tests/test_exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def check_convert(model, format):
# issues
check_model_against_component_list(
m, model.all_components())
# Check observable generation
for obs in model.observables:
assert obs.coefficients == \
m.observables[obs.name].coefficients
assert obs.species == \
m.observables[obs.name].species
elif format == 'bngl':
if model.name.endswith('tutorial_b') or \
model.name.endswith('tutorial_c'):
Expand Down
2 changes: 1 addition & 1 deletion pysb/tests/test_simulationresult.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _check_resultsets_equal(res1, res2):
for k, v in res1.initials.items():
assert np.allclose(res1.initials[k], v)

assert np.allclose(res1._yobs_view, res1._yobs_view)
assert np.allclose(res1._yobs_view, res2._yobs_view)
if res1._model.expressions_dynamic():
assert np.allclose(res1._yexpr_view, res2._yexpr_view)

Expand Down

0 comments on commit 7391df8

Please sign in to comment.