Skip to content

Commit

Permalink
Merge pull request #479 from gschaffner/exposed-descriptor-AttributeE…
Browse files Browse the repository at this point in the history
…rror

Fix propagation of AttributeErrors raised by exposed descriptors
  • Loading branch information
comrumino committed Jul 2, 2022
2 parents 9903738 + 146c663 commit dc71f0d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
12 changes: 11 additions & 1 deletion rpyc/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time # noqa: F401
import gc # noqa: F401

from inspect import getattr_static
from threading import Lock, Condition, RLock
from rpyc.lib import spawn, Timeout, get_methods, get_id_pack
from rpyc.lib.compat import pickle, next, maxint, select_error, acquire_lock # noqa: F401
Expand Down Expand Up @@ -524,6 +525,15 @@ def root(self): # serving
return self._remote_root

def _check_attr(self, obj, name, perm): # attribute access
def hasattr_static(obj, name):
try:
getattr_static(obj, name)
except AttributeError:
return False
else:
return True


config = self._config
if not config[perm]:
raise AttributeError(f"cannot access {name!r}")
Expand All @@ -532,7 +542,7 @@ def _check_attr(self, obj, name, perm): # attribute access
plain |= config["allow_exposed_attrs"] and name.startswith(prefix)
plain |= config["allow_safe_attrs"] and name in config["safe_attrs"]
plain |= config["allow_public_attrs"] and not name.startswith("_")
has_exposed = prefix and hasattr(obj, prefix + name)
has_exposed = prefix and (hasattr(obj, prefix + name) or hasattr_static(obj, prefix + name))
if plain and (not has_exposed or hasattr(obj, name)):
return name
if has_exposed:
Expand Down
39 changes: 39 additions & 0 deletions tests/test_attr_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def _rpyc_getattr(_, name):
SVC_RESTRICTED = ["exposed_foobar", "__add__", "_privy", "foo", "bar"]


class MyDescriptor1(object):
def __get__(self, instance, owner=None):
raise AttributeError("abcd")


class MyDescriptor2(object):
def __get__(self, instance, owner=None):
raise RuntimeError("efgh")


class MyService(rpyc.Service):
exposed_MyClass = MyClass

Expand All @@ -86,6 +96,10 @@ def exposed_get_two(self):
protector.register(YourClass, ["lala", "baba"])
return protector.wrap(YourClass())

exposed_desc_1 = MyDescriptor1()

exposed_desc_2 = MyDescriptor2()


class TestRestricted(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -152,6 +166,11 @@ def test_default_config(self):
self.assertRaises(AttributeError, lambda: obj.foo)
self.assertRaises(AttributeError, lambda: obj.bar)
self.assertRaises(AttributeError, lambda: obj.spam)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_all(self):
self._reset_cfg()
Expand All @@ -162,6 +181,11 @@ def test_allow_all(self):
self.assertEqual(obj._privy(), "privy")
self.assertEqual(obj.foobar(), "Fee Fie Foe Foo")
self.assertEqual(obj.exposed_foobar(), "Fee Fie Foe Foo")
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_exposed(self):
self._reset_cfg()
Expand All @@ -172,6 +196,11 @@ def test_allow_exposed(self):
except Exception:
passed = True
self.assertEqual(passed, True)
root = self.conn.root
self.assertRaises(AttributeError, lambda: root.exposed_desc_1)
self.assertRaises(AttributeError, lambda: root.desc_1)
self.assertRaises(AttributeError, lambda: root.exposed_desc_2)
self.assertRaises(AttributeError, lambda: root.desc_2)

def test_allow_safe_attrs(self):
self._reset_cfg()
Expand All @@ -184,6 +213,11 @@ def test_allow_safe_attrs(self):
self.assertRaises(AttributeError, lambda: obj.foo)
self.assertRaises(AttributeError, lambda: obj.bar)
self.assertRaises(AttributeError, lambda: obj.spam)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_public_attrs(self):
self._reset_cfg()
Expand All @@ -195,6 +229,11 @@ def test_allow_public_attrs(self):
self.assertEqual(obj.foobar(), "Fee Fie Foe Foo")
self.assertEqual(obj.exposed_foobar(), "Fee Fie Foe Foo")
self.assertRaises(AttributeError, lambda: obj._privy)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

# def test_type_protector(self):
# obj = self.conn.root.get_two()
Expand Down

0 comments on commit dc71f0d

Please sign in to comment.