From 169f287381d3ffba63d1bc1564fffbdf5b4c3e6a Mon Sep 17 00:00:00 2001 From: PythonFZ Date: Thu, 29 Jun 2023 11:04:17 +0200 Subject: [PATCH 1/2] add custom exceptions --- tests/test_err_msg.py | 20 ++++++++++++++++++++ znflow/__init__.py | 2 ++ znflow/base.py | 10 ++++++++++ znflow/exceptions.py | 4 ++++ 4 files changed, 36 insertions(+) create mode 100644 tests/test_err_msg.py create mode 100644 znflow/exceptions.py diff --git a/tests/test_err_msg.py b/tests/test_err_msg.py new file mode 100644 index 0000000..4fad7ed --- /dev/null +++ b/tests/test_err_msg.py @@ -0,0 +1,20 @@ +import znflow +import dataclasses +import pytest + +@dataclasses.dataclass +class ComputeMean(znflow.Node): + x: float + y: float + + results: float = None + + def run(self): + self.results = (self.x + self.y) / 2 + + +def test_attribute_access(): + with znflow.DiGraph() as graph: + n1 = ComputeMean(2, 8) + with pytest.raises(znflow.exceptions.ConnectionAttributeError): + n1.x.data diff --git a/znflow/__init__.py b/znflow/__init__.py index 533d87e..2c621cc 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -16,6 +16,7 @@ from znflow.graph import DiGraph from znflow.node import Node, nodify from znflow.visualize import draw +from znflow import exceptions __version__ = importlib.metadata.version(__name__) @@ -31,6 +32,7 @@ "Property", "CombinedConnections", "combine", + "exceptions", ] with contextlib.suppress(ImportError): diff --git a/znflow/base.py b/znflow/base.py index 5f36dc3..590284f 100644 --- a/znflow/base.py +++ b/znflow/base.py @@ -3,7 +3,9 @@ import contextlib import dataclasses import typing +from typing import Any from uuid import UUID +from znflow import exceptions @contextlib.contextmanager @@ -177,6 +179,14 @@ def result(self): else: result = self.instance return result[self.item] if self.item else result + + def __getattribute__(self, __name: str) -> Any: + try: + return super().__getattribute__(__name) + except AttributeError as e: + raise exceptions.ConnectionAttributeError( + "Connection does not support further attributes to its result." + ) from e @dataclasses.dataclass(frozen=True) diff --git a/znflow/exceptions.py b/znflow/exceptions.py new file mode 100644 index 0000000..e486076 --- /dev/null +++ b/znflow/exceptions.py @@ -0,0 +1,4 @@ +"""ZnFlow exceptions.""" + +class ConnectionAttributeError(AttributeError): + """Raised when a connection attribute is not found.""" From f409e7377c6514394d702a4a847164a3229aa931 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 12 Jul 2023 12:19:22 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_err_msg.py | 11 +++++++---- znflow/__init__.py | 2 +- znflow/base.py | 3 ++- znflow/exceptions.py | 1 + 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_err_msg.py b/tests/test_err_msg.py index 4fad7ed..57115b6 100644 --- a/tests/test_err_msg.py +++ b/tests/test_err_msg.py @@ -1,20 +1,23 @@ -import znflow import dataclasses + import pytest +import znflow + + @dataclasses.dataclass class ComputeMean(znflow.Node): x: float y: float - + results: float = None - + def run(self): self.results = (self.x + self.y) / 2 def test_attribute_access(): - with znflow.DiGraph() as graph: + with znflow.DiGraph(): n1 = ComputeMean(2, 8) with pytest.raises(znflow.exceptions.ConnectionAttributeError): n1.x.data diff --git a/znflow/__init__.py b/znflow/__init__.py index 2c621cc..5d3c8f8 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -4,6 +4,7 @@ import logging import sys +from znflow import exceptions from znflow.base import ( CombinedConnections, Connection, @@ -16,7 +17,6 @@ from znflow.graph import DiGraph from znflow.node import Node, nodify from znflow.visualize import draw -from znflow import exceptions __version__ = importlib.metadata.version(__name__) diff --git a/znflow/base.py b/znflow/base.py index 4776a16..6631965 100644 --- a/znflow/base.py +++ b/znflow/base.py @@ -5,6 +5,7 @@ import typing from typing import Any from uuid import UUID + from znflow import exceptions @@ -184,7 +185,7 @@ def result(self): else: result = self.instance return result[self.item] if self.item else result - + def __getattribute__(self, __name: str) -> Any: try: return super().__getattribute__(__name) diff --git a/znflow/exceptions.py b/znflow/exceptions.py index e486076..227fc47 100644 --- a/znflow/exceptions.py +++ b/znflow/exceptions.py @@ -1,4 +1,5 @@ """ZnFlow exceptions.""" + class ConnectionAttributeError(AttributeError): """Raised when a connection attribute is not found."""