From b5b3f3453af125af56e0c5930416427b428f38e7 Mon Sep 17 00:00:00 2001 From: anguisterrenis Date: Sun, 19 Apr 2020 13:50:19 +0200 Subject: [PATCH 01/28] modified file i/o methods so they can be extended in the future --- .gitignore | 3 ++- pyrates/ir/circuit.py | 21 ++++++++++++------- pyrates/utility/filestorage.py | 8 ++++++++ tests/test_file_io.py | 37 ++++++++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 tests/test_file_io.py diff --git a/.gitignore b/.gitignore index 26d2cae6..e8ed1dff 100755 --- a/.gitignore +++ b/.gitignore @@ -125,4 +125,5 @@ venv.bak/ type_info.json tests/output/ -pyrates_build/ \ No newline at end of file +pyrates_build/ +tests/resources \ No newline at end of file diff --git a/pyrates/ir/circuit.py b/pyrates/ir/circuit.py index a70b2b7c..d8f8f724 100755 --- a/pyrates/ir/circuit.py +++ b/pyrates/ir/circuit.py @@ -1486,25 +1486,32 @@ def to_pyauto(self, dir: str): f'system.' ) - def to_file(self, filename: str) -> None: + def to_file(self, filename: str, mode: str = "pickle") -> None: """Save continuation results on disc. Parameters ---------- filename + mode Returns ------- None """ - import pickle + if mode == "pickle": - data = {key: getattr(self, key) for key in self.__slots__ if key != '_backend'} - try: - pickle.dump(data, open(filename, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) - except (FileExistsError, TypeError): - pickle.dump(data, open(filename, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) + import pickle + + data = {key: getattr(self, key) for key in self.__slots__ if key != '_backend'} + try: + pickle.dump(data, open(filename, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) + except (FileExistsError, TypeError): + pickle.dump(data, open(filename, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) + + else: + from pyrates.utility.filestorage import FILEIOMODES + ValueError(f"Unknown file format to save to. Allowed modes: {FILEIOMODES}") @classmethod def from_file(cls, filename: str): diff --git a/pyrates/utility/filestorage.py b/pyrates/utility/filestorage.py index b338462b..52cdf5ad 100755 --- a/pyrates/utility/filestorage.py +++ b/pyrates/utility/filestorage.py @@ -44,6 +44,8 @@ __author__ = "Daniel Rose" __status__ = "Development" +FILEIOMODES = ["pickle"] + # TODO: Update documentations & clean functions from unnecessary comments (i.e. silent code) @@ -316,3 +318,9 @@ def read_simulation_data_from_file(dirname: str, path="", filenames: list = None raise return data + + +def to_pickle(obj, filename): + """Conserve a PyRates object as pickle.""" + + diff --git a/tests/test_file_io.py b/tests/test_file_io.py new file mode 100644 index 00000000..ee95f739 --- /dev/null +++ b/tests/test_file_io.py @@ -0,0 +1,37 @@ +"""Test suit for saving PyRates objects to or loading them from file.""" + +__author__ = "Daniel Rose" +__status__ = "Development" + +import pytest + + +def setup_module(): + print("\n") + print("=======================") + print("| Test Suite File I/O |") + print("=======================") + + +@pytest.mark.xfail +def test_save_to_pickle(): + pass + + path = "model_templates.jansen_rit.circuit.JansenRitCircuit" + from pyrates.frontend.template.circuit import CircuitTemplate + from pyrates.ir.circuit import CircuitIR + + template = CircuitTemplate.from_yaml(path) + + circuit = template.apply() # type: CircuitIR + # circuit.to_file(filename="resources/jansen_rit.p", mode="pickle") + circuit.optimize_graph_in_place() + # circuit.to_file(filename="resources/jansen_rit_vectorized.p", mode="pickle") + + import pickle + pickle.dump(circuit, open("resources/jansen_rit.p", "wb")) + + +@pytest.mark.skip +def test_load_from_pickle(): + pass \ No newline at end of file From fd70c4d50a60f75af1db560773a8962fb7b63852 Mon Sep 17 00:00:00 2001 From: anguisterrenis Date: Sun, 19 Apr 2020 15:46:49 +0200 Subject: [PATCH 02/28] moved function to load template from yaml from template base class to template module and added "from_file" function to template module --- pyrates/frontend/__init__.py | 2 +- pyrates/frontend/file.py | 2 +- pyrates/frontend/fileio/__init__.py | 0 pyrates/frontend/{ => fileio}/yaml.py | 5 - pyrates/frontend/nxgraph.py | 271 ------------------- pyrates/frontend/template/__init__.py | 92 ++++++- pyrates/frontend/template/_io.py | 11 + pyrates/frontend/template/abc.py | 51 +--- pyrates/frontend/template/circuit/circuit.py | 8 +- pyrates/frontend/template/operator_graph.py | 3 +- tests/test_frontend_yaml_parser.py | 98 ++----- 11 files changed, 133 insertions(+), 410 deletions(-) create mode 100644 pyrates/frontend/fileio/__init__.py rename pyrates/frontend/{ => fileio}/yaml.py (96%) mode change 100755 => 100644 delete mode 100755 pyrates/frontend/nxgraph.py create mode 100644 pyrates/frontend/template/_io.py diff --git a/pyrates/frontend/__init__.py b/pyrates/frontend/__init__.py index 2070c7df..709f3f22 100755 --- a/pyrates/frontend/__init__.py +++ b/pyrates/frontend/__init__.py @@ -36,7 +36,7 @@ # template-based interface from pyrates.frontend import template # from pyrates.frontend import dict as dict_ -from pyrates.frontend import yaml +from pyrates.frontend.fileio import yaml # from pyrates.frontend import nxgraph from pyrates.frontend.template import CircuitTemplate, NodeTemplate, EdgeTemplate, OperatorTemplate diff --git a/pyrates/frontend/file.py b/pyrates/frontend/file.py index a8dc3270..a06de11e 100755 --- a/pyrates/frontend/file.py +++ b/pyrates/frontend/file.py @@ -31,7 +31,7 @@ import importlib from pyrates import PyRatesException -from pyrates.frontend import yaml as _yaml +from pyrates.frontend.fileio import yaml as _yaml __author__ = "Daniel Rose" __status__ = "Development" diff --git a/pyrates/frontend/fileio/__init__.py b/pyrates/frontend/fileio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyrates/frontend/yaml.py b/pyrates/frontend/fileio/yaml.py old mode 100755 new mode 100644 similarity index 96% rename from pyrates/frontend/yaml.py rename to pyrates/frontend/fileio/yaml.py index 731bbea0..287bab77 --- a/pyrates/frontend/yaml.py +++ b/pyrates/frontend/fileio/yaml.py @@ -94,8 +94,3 @@ def from_circuit(circuit, path: str, name: str): from pathlib import Path path = Path(path) yaml.dump(dict_repr, path) - - -@register_interface -def to_template(path: str, template_cls): - return template_cls.from_yaml(path) diff --git a/pyrates/frontend/nxgraph.py b/pyrates/frontend/nxgraph.py deleted file mode 100755 index ec1d832b..00000000 --- a/pyrates/frontend/nxgraph.py +++ /dev/null @@ -1,271 +0,0 @@ - -# -*- coding: utf-8 -*- -# -# -# PyRates software framework for flexible implementation of neural -# network models and simulations. See also: -# https://github.com/pyrates-neuroscience/PyRates -# -# Copyright (C) 2017-2018 the original authors (Richard Gast and -# Daniel Rose), the Max-Planck-Institute for Human Cognitive Brain -# Sciences ("MPI CBS") and contributors -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see -# -# CITATION: -# -# Richard Gast and Daniel Rose et. al. in preparation -""" -""" -from copy import deepcopy - -import networkx as nx -from networkx import MultiDiGraph, DiGraph, find_cycle, NetworkXNoCycle - -from pyrates import PyRatesException -from pyrates.ir.edge import EdgeIR -from pyrates.ir.circuit import CircuitIR -from pyrates.frontend.dict import to_node -from pyrates.frontend._registry import register_interface - - -__author__ = "Daniel Rose" -__status__ = "Development" - - -# @register_interface -# def to_circuit(graph: nx.MultiDiGraph, label="circuit", -# node_creator=to_node): -# """Create a CircuitIR instance out of a networkx.MultiDiGraph""" -# -# circuit = CircuitIR(label) -# -# for name, data in graph.nodes(data=True): -# -# circuit.add_node(name, node=node_creator(data)) -# -# required_keys = ["source_var", "target_var", "weight", "delay"] -# for source, target, data in graph.edges(data=True): -# -# if all([key in data for key in required_keys]): -# if "edge_ir" not in data: -# data["edge_ir"] = EdgeIR() -# source_var = data.pop("source_var") -# target_var = data.pop("target_var") -# circuit.add_edge(f"{source}/{source_var}", f"{target}/{target_var}", **data) -# else: -# raise KeyError(f"Missing a key out of {required_keys} in an edge with source `{source}` and target" -# f"`{target}`") -# -# return circuit -# -# -# def from_circuit(circuit, revert_node_names=False): -# """Old implementation that transforms all information in a circuit to a networkx.MultiDiGraph with a few additional -# transformations that the old backend needed.""" -# return Circuit2NetDef.network_def(circuit, revert_node_names) -# -# -# class Circuit2NetDef: -# # label_counter = {} # type: Dict[str, int] -# -# @classmethod -# def network_def(cls, circuit: CircuitIR, revert_node_names=False): -# """A bit of a workaround to connect interfaces of frontend and backend. -# TODO: Remove BackendIRFormatter and adapt corresponding tests""" -# # import re -# -# network_def = MultiDiGraph() -# -# edge_list = [] -# node_dict = {} -# -# # reorganize node to conform with backend API -# ############################################# -# for node_key, data in circuit.graph.nodes(data=True): -# node = data["node"] -# # reformat all node internals into operators + operator_args -# if revert_node_names: -# names = node_key.split("/") -# node_key = ".".join(reversed(names)) -# node_dict[node_key] = {} # type: Dict[str, Union[list, dict]] -# node_dict[node_key] = dict(cls._nd_reformat_operators(node.op_graph)) -# op_order = cls._nd_get_operator_order(node.op_graph) # type: list -# # noinspection PyTypeChecker -# node_dict[node_key]["operator_order"] = op_order -# -# # reorganize edge to conform with backend API -# ############################################# -# for source, target, data in circuit.graph.edges(data=True): -# # move edge operators to node -# if revert_node_names: -# source = ".".join(reversed(source.split("/"))) -# target = ".".join(reversed(target.split("/"))) -# node_dict[target], edge = cls._move_edge_ops_to_node(target, node_dict[target], data) -# -# edge_list.append((source, target, dict(**edge))) -# -# # network_def.add_nodes_from(node_dict) -# for key, node in node_dict.items(): -# network_def.add_node(key, **node) -# network_def.add_edges_from(edge_list) -# -# return network_def # return MultiDiGraph as needed by ComputeGraph class -# -# @staticmethod -# def _nd_reformat_operators(op_graph: DiGraph): -# operator_args = dict() -# operators = dict() -# -# for op_key, op_dict in op_graph.nodes(data=True): -# op_cp = deepcopy(op_dict) # duplicate operator info -# var_dict = op_cp["operator"].variables -# for var_key, var_props in var_dict.items(): -# operator_args[f"{op_key}/{var_key}"] = var_props -# -# op_cp["equations"] = op_cp["operator"].equations -# op_cp["inputs"] = op_cp["operator"].inputs -# op_cp["output"] = op_cp["operator"].output -# # op_cp.pop("values", None) -# op_cp.pop("operator", None) -# operators[op_key] = op_cp -# -# reformatted = dict(operator_args=operator_args, -# operators=operators, -# inputs={}) -# return reformatted -# -# @staticmethod -# def _nd_get_operator_order(op_graph: DiGraph) -> list: -# """ -# -# Parameters -# ---------- -# op_graph -# -# Returns -# ------- -# op_order -# """ -# # check, if cycles are present in operator graph (which would be problematic -# try: -# find_cycle(op_graph) -# except NetworkXNoCycle: -# pass -# else: -# raise PyRatesException("Found cyclic operator graph. Cycles are not allowed for operators within one node.") -# -# op_order = [] -# graph = op_graph.copy() # type: DiGraph -# while graph.nodes: -# # noinspection PyTypeChecker -# primary_nodes = [node for node, in_degree in graph.in_degree if in_degree == 0] -# op_order.extend(primary_nodes) -# graph.remove_nodes_from(primary_nodes) -# -# return op_order -# -# @classmethod -# def _move_edge_ops_to_node(cls, target, node_dict: dict, edge_dict: dict) -> (dict, dict): -# """ -# -# Parameters -# ---------- -# target -# Key identifying target node in backend graph -# node_dict -# Dictionary of target node (to move operators into) -# edge_dict -# Dictionary with edge properties (to move operators from) -# Returns -# ------- -# node_dict -# Updated dictionary of target node -# edge_dict -# Dictionary of reformatted edge -# """ -# # grab all edge variables -# edge = edge_dict["edge_ir"] # type: EdgeIR -# source_var = edge_dict["source_var"] -# target_var = edge_dict["target_var"] -# weight = edge_dict["weight"] -# delay = edge_dict["delay"] -# input_var = edge.input -# output_var = edge.output -# -# if len(edge.op_graph) > 0: -# # reformat all edge internals into operators + operator_args -# op_data = cls._nd_reformat_operators(edge.op_graph) # type: dict -# op_order = cls._nd_get_operator_order(edge.op_graph) # type: List[str] -# operators = op_data["operators"] -# operator_args = op_data["operator_args"] -# -# # operator keys refer to a unique combination of template names and changed values -# -# # add operators to target node in reverse order, so they can be safely prepended -# added_ops = False -# for op_name in reversed(op_order): -# # check if operator name is already known in target node -# if op_name in node_dict["operators"]: -# pass -# else: -# added_ops = True -# # this should all go smoothly, because operator should not be known yet -# # add operator dict to target node operators -# node_dict["operators"][op_name] = operators[op_name] -# # prepend operator to op_order -# node_dict["operator_order"].insert(0, op_name) -# # ToDo: consider using collections.deque instead -# # add operator args to target node -# node_dict["operator_args"].update(operator_args) -# -# out_op = op_order[-1] -# out_var = operators[out_op]['output'] -# if added_ops: -# # append operator output to target operator sources -# # assume that only last operator in edge operator_order gives the output -# # for op_name in node_dict["operators"]: -# # if out_var in node_dict["operators"][op_name]["inputs"]: -# # if out_var_long not in node_dict["operators"][op_name]["inputs"][out_var]: -# # # add reference to source operator that was previously in an edge -# # node_dict["operators"][op_name]["inputs"][out_var].append(output_var) -# -# # shortcut, since target_var and output_var are known: -# target_op, target_vname = target_var.split("/") -# if output_var not in node_dict["operators"][target_op]["inputs"][target_vname]["sources"]: -# node_dict["operators"][target_op]["inputs"][target_vname]["sources"].append(out_op) -# -# # simplify edges and save into edge_list -# # op_graph = edge.op_graph -# # in_ops = [op for op, in_degree in op_graph.in_degree if in_degree == 0] -# # if len(in_ops) == 1: -# # # simple case: only one input operator? then it's the first in the operator order. -# # target_op = op_order[0] -# # target_inputs = operators[target_op]["inputs"] -# # if len(target_var) != 1: -# # raise PyRatesException("Either too many or too few input variables detected. " -# # "Needs to be exactly one.") -# # target_var = list(target_inputs.keys())[0] -# # target_var = f"{target_op}/{target_var}" -# # else: -# # raise NotImplementedError("Transforming an edge with multiple input operators is not yet handled.") -# -# # shortcut to new target war: -# target_var = input_var -# edge_dict = {"source_var": source_var, -# "target_var": target_var, -# "weight": weight, -# "delay": delay} -# # set target_var to singular input of last operator added -# return node_dict, edge_dict \ No newline at end of file diff --git a/pyrates/frontend/template/__init__.py b/pyrates/frontend/template/__init__.py index 0f9b07c8..e4b31f30 100755 --- a/pyrates/frontend/template/__init__.py +++ b/pyrates/frontend/template/__init__.py @@ -25,13 +25,103 @@ # CITATION: # # Richard Gast and Daniel Rose et. al. in preparation - +from ._io import _complete_template_path from .node import NodeTemplate from .operator import OperatorTemplate from .edge import EdgeTemplate from .circuit import CircuitTemplate from pyrates.frontend._registry import register_interface +known_template_classes = dict() + +template_cache = dict() + + +def register_template_class(name, cls): + """Register a given template class to the module attribute `_known_template_classes`. This way new template classes + can be registered by users. Could also be used to overwrite existing template classes.""" + + if name in known_template_classes: + raise UserWarning(f"Overwriting existing map from name `{name}` to template class `{cls}`.") + + known_template_classes[name] = cls + + +register_template_class("OperatorTemplate", OperatorTemplate) +register_template_class("NodeTemplate", NodeTemplate) +register_template_class("EdgeTemplate", EdgeTemplate) +register_template_class("CircuitTemplate", CircuitTemplate) + + +@register_interface +def from_file(path: str, mode: str = "yaml"): + """Generic file loader function that looks for correct template class""" + + if mode == "yaml": + loader = from_yaml + + else: + raise ValueError(f"Unknown file loading mode '{mode}'.") + + return loader(path) + + +@register_interface +def from_yaml(path): + """Load template from yaml file. Templates are cached by path. Depending on the 'base' key of the yaml template, + either a template class is instantiated or the function recursively loads base templates until it hits a known + template class. + + Parameters: + ----------- + path + Path to template in YAML file of form 'directories.file.template' + """ + + if path in template_cache: + # if we have loaded this template in the past, return what has been cached + template = template_cache[path] + else: + # if it has not been cached yet, load the file and parse into dict + from pyrates.frontend.fileio.yaml import to_dict + template_dict = to_dict(path) + + try: + base = template_dict.pop("base") + except KeyError: + raise KeyError(f"No 'base' defined for template {path}. Please define a " + f"base to derive the template from.") + + # figure out which type of template this is by analysing the "base" key + try: + # If the base key coincides with any known template class name, fetch the class + cls = known_template_classes[base] + + except KeyError: + # class not known, so the base must refer to a parent template. Then let's recursively load that one until + # we hit a known template class. + base = _complete_template_path(base, path) + + base_template = from_yaml(base) + template = base_template.update_template(**template_dict) + # may fail if "base" is present but empty + else: + # instantiate template class + template = cls(**template_dict) + + template_cache[path] = template + + return template + + +def clear_cache(): + """Shorthand to clear template cache for whatever reason.""" + template_cache.clear() + + +def _select_template_class(): + pass + # module-lvl functions for template conversion # writing them out explicitly diff --git a/pyrates/frontend/template/_io.py b/pyrates/frontend/template/_io.py new file mode 100644 index 00000000..161b00b9 --- /dev/null +++ b/pyrates/frontend/template/_io.py @@ -0,0 +1,11 @@ +def _complete_template_path(target_path: str, source_path: str) -> str: + """Check if path contains a folder structure and prepend own path, if it doesn't""" + + if "." not in target_path: + if "/" in source_path or "\\" in source_path: + import os + basedir, _ = os.path.split(source_path) + target_path = os.path.normpath(os.path.join(basedir, target_path)) + else: + target_path = ".".join((*source_path.split('.')[:-1], target_path)) + return target_path \ No newline at end of file diff --git a/pyrates/frontend/template/abc.py b/pyrates/frontend/template/abc.py index 217951b9..d20bb916 100755 --- a/pyrates/frontend/template/abc.py +++ b/pyrates/frontend/template/abc.py @@ -29,8 +29,6 @@ """ Abstract base classes """ -from pyrates.frontend.yaml import to_dict as dict_from_yaml - __author__ = "Daniel Rose" __status__ = "Development" @@ -50,52 +48,10 @@ def __repr__(self): return f"<{self.__class__.__name__} '{self.path}'>" @staticmethod - def _complete_template_path(target_path: str, source_path: str) -> str: - """Check if path contains a folder structure and prepend own path, if it doesn't""" - - if "." not in target_path: - if "/" in source_path or "\\" in source_path: - import os - basedir, _ = os.path.split(source_path) - target_path = os.path.normpath(os.path.join(basedir, target_path)) - else: - target_path = ".".join((*source_path.split('.')[:-1], target_path)) - return target_path - - @classmethod - def from_yaml(cls, path): - """Convenience method that looks for a loader class for the template type and applies it, assuming - the class naming convention '