Skip to content

Commit

Permalink
Separated the concerns of unit testing descriptor stack trace behavio…
Browse files Browse the repository at this point in the history
…r; expanded rpyc.service support to decorators
  • Loading branch information
comrumino committed Jul 2, 2022
1 parent a36ca14 commit 9213f6a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 55 deletions.
7 changes: 5 additions & 2 deletions rpyc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

def service(cls):
"""find and rename exposed decorated attributes"""
# NOTE: inspect.getmembers invokes getattr for each attribute-name. Descriptors may raise AttributeError.
# Only the AttributeError exception is caught when raised. This decorator will if a descriptor raises
# any exception other than AttributeError when getattr is called.
for attr_name, attr_obj in inspect.getmembers(cls): # rebind exposed decorated attributes
exposed_prefix = getattr(attr_obj, '__exposed__', False)
if exposed_prefix and not inspect.iscode(attr_obj): # exclude the implementation
Expand All @@ -26,11 +29,11 @@ def exposed(arg):
# When the arg is a string (i.e. `@rpyc.exposed("customPrefix_")`) the prefix
# is partially evaluated into the wrapper. The function returned is "frozen" and used as a decorator.
return functools.partial(_wrapper, arg)
elif hasattr(arg, '__call__'):
elif hasattr(arg, '__call__') or hasattr(arg, '__get__'):
# When the arg is callable (i.e. `@rpyc.exposed`) then use default prefix and invoke
return _wrapper(exposed_prefix, arg)
else:
raise TypeError('rpyc.exposed expects a callable object or a string')
raise TypeError('rpyc.exposed expects a callable object, descriptor, or string')


def _wrapper(exposed_prefix, exposed_obj):
Expand Down
99 changes: 46 additions & 53 deletions tests/test_attr_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,6 @@ def _rpyc_getattr(_, name):
SVC_RESTRICTED = ["exposed_foobar", "__add__", "_privy", "foo", "bar"]


class MyDescriptor1(object):
def __get__(self, instance, owner=None):
raise AttributeError("abcd")


class MyDescriptor2(object):
def __get__(self, instance, owner=None):
raise RuntimeError("efgh")


class MyService(rpyc.Service):
exposed_MyClass = MyClass

Expand All @@ -96,10 +86,6 @@ def exposed_get_two(self):
protector.register(YourClass, ["lala", "baba"])
return protector.wrap(YourClass())

exposed_desc_1 = MyDescriptor1()

exposed_desc_2 = MyDescriptor2()


class TestRestricted(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -166,11 +152,6 @@ def test_default_config(self):
self.assertRaises(AttributeError, lambda: obj.foo)
self.assertRaises(AttributeError, lambda: obj.bar)
self.assertRaises(AttributeError, lambda: obj.spam)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_all(self):
self._reset_cfg()
Expand All @@ -181,11 +162,6 @@ def test_allow_all(self):
self.assertEqual(obj._privy(), "privy")
self.assertEqual(obj.foobar(), "Fee Fie Foe Foo")
self.assertEqual(obj.exposed_foobar(), "Fee Fie Foe Foo")
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_exposed(self):
self._reset_cfg()
Expand All @@ -196,11 +172,6 @@ def test_allow_exposed(self):
except Exception:
passed = True
self.assertEqual(passed, True)
root = self.conn.root
self.assertRaises(AttributeError, lambda: root.exposed_desc_1)
self.assertRaises(AttributeError, lambda: root.desc_1)
self.assertRaises(AttributeError, lambda: root.exposed_desc_2)
self.assertRaises(AttributeError, lambda: root.desc_2)

def test_allow_safe_attrs(self):
self._reset_cfg()
Expand All @@ -213,11 +184,6 @@ def test_allow_safe_attrs(self):
self.assertRaises(AttributeError, lambda: obj.foo)
self.assertRaises(AttributeError, lambda: obj.bar)
self.assertRaises(AttributeError, lambda: obj.spam)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

def test_allow_public_attrs(self):
self._reset_cfg()
Expand All @@ -229,31 +195,58 @@ def test_allow_public_attrs(self):
self.assertEqual(obj.foobar(), "Fee Fie Foe Foo")
self.assertEqual(obj.exposed_foobar(), "Fee Fie Foe Foo")
self.assertRaises(AttributeError, lambda: obj._privy)


class MyDescriptor1(object):
def __get__(self, instance, owner=None):
raise AttributeError("abcd")


class MyDescriptor2(object):
def __get__(self, instance, owner=None):
if instance is None:
return self
else:
raise RuntimeError("efgh")


@rpyc.service
class MyDecoratedService(rpyc.Service):
desc_1 = rpyc.exposed(MyDescriptor1())
exposed_desc_2 = MyDescriptor2()


class TestDescriptorErrors(unittest.TestCase):
def setUp(self):
self.cfg = copy.copy(rpyc.core.protocol.DEFAULT_CONFIG)
self.server = ThreadedServer(MyDecoratedService(), port=0)
self.thd = self.server._start_in_thread()
self.conn = rpyc.connect("localhost", self.server.port)

def tearDown(self):
self.conn.close()
while self.server.clients:
pass
self.server.close()
self.thd.join()

def test_default_config(self):
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)

# def test_type_protector(self):
# obj = self.conn.root.get_two()
# assert obj.baba() == "baba"
# try:
# obj.gaga()
# except AttributeError:
# pass
# else:
# assert False, "expected an attribute error!"
# obj2 = obj.lala()
# assert obj2.foo() == "foo"
# assert obj2.spam() == "spam"
# try:
# obj.bar()
# except AttributeError:
# pass
# else:
# assert False, "expected an attribute error!"
#
def test_allow_all(self):
self.cfg['allow_all_attrs'] = True
self.conn.close()
self.server.protocol_config.update(self.cfg)
self.conn = rpyc.connect("localhost", self.server.port)
root = self.conn.root
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.exposed_desc_1)
self.assertRaisesRegex(AttributeError, "abcd", lambda: root.desc_1)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.exposed_desc_2)
self.assertRaisesRegex(RuntimeError, "efgh", lambda: root.desc_2)


if __name__ == "__main__":
Expand Down

0 comments on commit 9213f6a

Please sign in to comment.