diff --git a/gssapi/_utils.py b/gssapi/_utils.py index 67e1cf68..e5ddfb8e 100644 --- a/gssapi/_utils.py +++ b/gssapi/_utils.py @@ -1,6 +1,10 @@ import sys +import types import six +import decorator as deco + +from gssapi.raw.misc import GSSError def import_gssapi_extension(name): @@ -58,3 +62,70 @@ def enc(x): return x return dict((enc(k), enc(v)) for k, v in six.iteritems(d)) + + +# in case of Python 3, just use exception chaining +@deco.decorator +def catch_and_return_token(func, self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except GSSError as e: + if e.token is not None and self.__DEFER_STEP_ERRORS__: + self._last_err = e + # skip the "return func" line above in the traceback + if six.PY2: + self._last_tb = sys.exc_info()[2].tb_next.tb_next + else: + self._last_err.__traceback__ = e.__traceback__.tb_next + + return e.token + else: + raise + + +@deco.decorator +def check_last_err(func, self, *args, **kwargs): + if self._last_err is not None: + try: + if six.PY2: + six.reraise(type(self._last_err), self._last_err, + self._last_tb) + else: + # NB(directxman12): not using six.reraise in Python 3 leads + # to cleaner tracebacks, and raise x is valid + # syntax in Python 3 (unlike raise x, y, z) + raise self._last_err + finally: + if six.PY2: + del self._last_tb # in case of cycles, break glass + + self._last_err = None + else: + return func(self, *args, **kwargs) + + @deco.decorator + def check_last_err(func, self, *args, **kwargs): + if self._last_err is not None: + try: + raise self._last_err + finally: + self._last_err = None + else: + return func(self, *args, **kwargs) + + +class CheckLastError(type): + def __new__(cls, name, parents, attrs): + attrs['__DEFER_STEP_ERRORS__'] = True + + for attr_name in attrs: + attr = attrs[attr_name] + + # wrap only methods + if not isinstance(attr, types.FunctionType): + continue + + if attr_name[0] != '_': + attrs[attr_name] = check_last_err(attr) + + return super(CheckLastError, cls).__new__(cls, name, parents, attrs) diff --git a/gssapi/raw/chan_bindings.pyx b/gssapi/raw/chan_bindings.pyx index f18e6200..8db51388 100644 --- a/gssapi/raw/chan_bindings.pyx +++ b/gssapi/raw/chan_bindings.pyx @@ -1,4 +1,4 @@ -from libc.stdlib cimport malloc, free +from libc.stdlib cimport calloc, free from gssapi.raw.cython_types cimport * @@ -25,34 +25,24 @@ cdef class ChannelBindings: cdef gss_channel_bindings_t __cvalue__(ChannelBindings self) except NULL: cdef gss_channel_bindings_t res - res = malloc(sizeof(res[0])) + res = calloc(1, sizeof(res[0])) - if self.initiator_address_type is None: - res.initiator_addrtype = GSS_C_AF_NULLADDR - else: + if self.initiator_address_type is not None: res.initiator_addrtype = self.initiator_address_type if self.initiator_address is not None: res.initiator_address.value = self.initiator_address res.initiator_address.length = len(self.initiator_address) - else: - res.initiator_address.length = 0 - if self.acceptor_address_type is None: - res.acceptor_addrtype = GSS_C_AF_NULLADDR - else: + if self.acceptor_address_type is not None: res.acceptor_addrtype = self.acceptor_address_type if self.acceptor_address is not None: res.acceptor_address.value = self.acceptor_address res.acceptor_address.length = len(self.acceptor_address) - else: - res.acceptor_address.length = 0 if self.application_data is not None: res.application_data.value = self.application_data res.application_data.length = len(self.application_data) - else: - res.application_data.length = 0 return res diff --git a/gssapi/sec_contexts.py b/gssapi/sec_contexts.py index afa3925e..c6a30332 100644 --- a/gssapi/sec_contexts.py +++ b/gssapi/sec_contexts.py @@ -1,3 +1,5 @@ +import six + from gssapi.raw import sec_contexts as rsec_contexts from gssapi.raw import message as rmessage from gssapi.raw import named_tuples as tuples @@ -9,6 +11,7 @@ from gssapi.creds import Credentials +@six.add_metaclass(_utils.CheckLastError) class SecurityContext(rsec_contexts.SecurityContext): # TODO(directxman12): do we want to use __slots__ here? def __new__(cls, base=None, token=None, @@ -24,6 +27,9 @@ def __init__(self, base=None, token=None, name=None, creds=None, desired_lifetime=None, flags=None, mech_type=None, channel_bindings=None, usage=None): + # NB(directxman12): _last_err must be set first + self._last_err = None + # determine the usage ('initiate' vs 'accept') if base is None and token is None: # this will be a new context @@ -74,6 +80,9 @@ def __init__(self, base=None, token=None, else: self.usage = 'accept' + # NB(directxman12): DO NOT ADD AN __del__ TO THIS CLASS -- it screws up + # the garbage collector if _last_tb is still defined + # TODO(directxman12): implement flag properties def get_signature(self, message): @@ -125,6 +134,7 @@ def export(self): INQUIRE_ARGS = ('initiator_name', 'target_name', 'lifetime', 'mech_type', 'flags', 'locally_init', 'complete') + @_utils.check_last_err def _inquire(self, **kwargs): if not kwargs: default_val = True @@ -166,12 +176,14 @@ def lifetime(self): locally_initiated = _utils.inquire_property('locally_init') @property + @_utils.check_last_err def complete(self): if self._started: return self._inquire(complete=True).complete else: return False + @_utils.catch_and_return_token def step(self, token=None): if self.usage == 'accept': return self._acceptor_step(token=token) diff --git a/gssapi/tests/test_high_level.py b/gssapi/tests/test_high_level.py index 36db9edb..a7534d63 100644 --- a/gssapi/tests/test_high_level.py +++ b/gssapi/tests/test_high_level.py @@ -22,6 +22,9 @@ FQDN = socket.getfqdn().encode('utf-8') SERVICE_PRINCIPAL = TARGET_SERVICE_NAME + b'/' + FQDN +# disable error deferring to catch errors immediately +gssctx.SecurityContext.__DEFER_STEP_ERRORS__ = False + class _GSSAPIKerberosTestCase(kt.KerberosTestCase): @classmethod @@ -385,6 +388,7 @@ def test_copy(self): class SecurityContextTestCase(_GSSAPIKerberosTestCase): def setUp(self): super(SecurityContextTestCase, self).setUp() + gssctx.SecurityContext.__DEFER_STEP_ERRORS__ = False self.client_name = gssnames.Name(self.USER_PRINC) self.client_creds = gsscreds.Credentials(desired_name=None, usage='initiate') @@ -577,3 +581,37 @@ def test_verify_signature_raise(self): server_ctx.verify_signature.should_raise(gb.GSSError, b'other message', mic_token) + + def test_defer_step_error_on_method(self): + gssctx.SecurityContext.__DEFER_STEP_ERRORS__ = True + bdgs = gb.ChannelBindings(application_data=b'abcxyz') + client_ctx = self._create_client_ctx(desired_lifetime=400, + channel_bindings=bdgs) + + client_token = client_ctx.step() + client_token.should_be_a(bytes) + + bdgs.application_data = b'defuvw' + server_ctx = gssctx.SecurityContext(creds=self.server_creds, + channel_bindings=bdgs) + server_ctx.step(client_token).should_be_a(bytes) + server_ctx.encrypt.should_raise(gb.BadChannelBindingsError, b'test') + + def test_defer_step_error_on_complete_property_access(self): + gssctx.SecurityContext.__DEFER_STEP_ERRORS__ = True + bdgs = gb.ChannelBindings(application_data=b'abcxyz') + client_ctx = self._create_client_ctx(desired_lifetime=400, + channel_bindings=bdgs) + + client_token = client_ctx.step() + client_token.should_be_a(bytes) + + bdgs.application_data = b'defuvw' + server_ctx = gssctx.SecurityContext(creds=self.server_creds, + channel_bindings=bdgs) + server_ctx.step(client_token).should_be_a(bytes) + + def check_complete(): + return server_ctx.complete + + check_complete.should_raise(gb.BadChannelBindingsError) diff --git a/setup.py b/setup.py index 82af3731..bc25640d 100755 --- a/setup.py +++ b/setup.py @@ -149,7 +149,8 @@ def gssapi_modules(lst): extension_file('rfc5588', 'gss_store_cred'), ]), install_requires=[ - 'enum34' + 'enum34', + 'decorator' ], tests_require=[ 'tox'