/
__init__.py
224 lines (183 loc) · 8.54 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from typing import Any, Callable, List, Optional, Set
from typeguard import check_type
import sys
import typing
import inspect
import wrapt
def check_all_contracts(*args, decorate_main=True) -> None:
"""Automatically check contracts for all functions and classes in the given module.
When called with no arguments, the current module's functions and classes are checked.
"""
modules = []
if decorate_main:
modules.append(sys.modules["__main__"])
for module_name in args:
modules.append(sys.modules.get(module_name, None))
for module in modules:
if not module:
# Module name was passed in incorrectly.
continue
for name, value in inspect.getmembers(module):
if inspect.isfunction(value):
module.__dict__[name] = check_contracts(value)
elif inspect.isclass(value):
add_class_invariants(value)
@wrapt.decorator
def check_contracts(wrapped, instance, args, kwargs):
"""A decorator for automatically checking preconditions and type contracts for a function."""
try:
if instance and inspect.isclass(instance):
# This is a class method, so there is no instance.
return _check_function_contracts(wrapped, None, args, kwargs)
else:
return _check_function_contracts(wrapped, instance, args, kwargs)
except AssertionError as e:
raise AssertionError(str(e)) from None
def add_class_invariants(klass: type) -> None:
"""Modify the given class to check representation invariants and method contracts."""
if '__representation_invariants__' in klass.__dict__:
# This means the class has already been decorated
return
# Update representation invariants from this class' docstring and those of its superclasses.
rep_invariants = set()
# Iterate over all inherited classes except object
for cls in klass.__mro__[:-1]:
if '__representation_invariants__' in cls.__dict__:
rep_invariants = rep_invariants.union(cls.__representation_invariants__)
else:
rep_invariants.update(parse_assertions(cls.__doc__ or '', parse_token='Representation Invariant'))
setattr(klass, '__representation_invariants__', rep_invariants)
def new_setattr(self: klass, name: str, value: Any) -> None:
"""Set the value of the given attribute on self to the given value.
Check representation invariants for this class when not within an instance method of the class.
"""
cls_annotations = typing.get_type_hints(klass)
if name in cls_annotations:
try:
check_type(name, value, cls_annotations[name])
except TypeError:
raise AssertionError(
f'{repr(value)} did not match type annotation for attribute "{name}: {cls_annotations[name]}"')
super(klass, self).__setattr__(name, value)
curframe = inspect.currentframe()
callframe = inspect.getouterframes(curframe, 2)
frame_locals = callframe[1].frame.f_locals
if self is not frame_locals.get('self'):
# Only validating if the attribute is not being set in a instance/class method
init = getattr(klass, '__init__')
try:
_check_invariants(self, rep_invariants, init.__globals__)
except AssertionError as e:
raise AssertionError(str(e)) from None
for attr, value in klass.__dict__.items():
if inspect.isroutine(value):
if isinstance(value, (staticmethod, classmethod)):
# Don't check rep invariants for staticmethod and classmethod
setattr(klass, attr, check_contracts(value))
else:
setattr(klass, attr, _instance_method_wrapper(value, rep_invariants))
klass.__setattr__ = new_setattr
def _check_function_contracts(wrapped, instance, args, kwargs):
params = wrapped.__code__.co_varnames[:wrapped.__code__.co_argcount]
annotations = typing.get_type_hints(wrapped)
args_with_self = (instance,) + args if instance else args
# Check function parameter types
for arg, param in zip(args_with_self, params):
if param in annotations:
try:
check_type(param, arg, annotations[param])
except TypeError:
raise AssertionError(
f'{wrapped.__name__} argument {repr(arg)} did not match type annotation for parameter \
"{param}: {annotations[param]}"')
# Check function preconditions
preconditions = parse_assertions(wrapped.__doc__ or '')
function_locals = dict(zip(params, args_with_self))
_check_assertions(wrapped, function_locals, preconditions)
# Check return type
r = wrapped(*args, **kwargs)
if 'return' in annotations:
return_type = annotations['return']
try:
check_type('return', r, return_type)
except TypeError:
raise AssertionError(
f'{wrapped.__name__} return value {r} does not match annotated return type {return_type}')
return r
def _instance_method_wrapper(wrapped, rep_invariants=None):
if rep_invariants is None:
return check_contracts
@wrapt.decorator
def wrapper(wrapped, instance, args, kwargs):
init = getattr(instance, '__init__')
try:
r = _check_function_contracts(wrapped, instance, args, kwargs)
_check_invariants(instance, rep_invariants, init.__globals__)
_check_class_type_annotations(instance)
except AssertionError as e:
raise AssertionError(str(e)) from None
else:
return r
return wrapper(wrapped)
def _check_class_type_annotations(instance: Any) -> None:
"""Check that the type annotations for the class still hold.
"""
klass = instance.__class__
cls_annotations = typing.get_type_hints(klass)
for attr, annotation in cls_annotations.items():
value = getattr(instance, attr)
try:
check_type(attr, value, annotation)
except TypeError:
raise AssertionError(
f'{repr(value)} did not match type annotation for attribute "{attr}: {annotation}"')
def _check_invariants(instance, rep_invariants: Set[str], global_scope: dict) -> None:
"""Check that the representation invariants for the instance are satisfied.
"""
for invariant in rep_invariants:
try:
check = eval(invariant, global_scope, {'self': instance})
except:
print(f'[python_ta] Warning: could not evaluate invariant: {invariant}', file=sys.stderr)
else:
assert check,\
f'Representation invariant "{invariant}" violated.'
def _check_assertions(wrapped: Callable[..., Any], function_locals: dict, assertions: List[str]) -> None:
"""Check that the given assertions are still satisfied.
"""
for assertion in assertions:
try:
check = eval(assertion, wrapped.__globals__, function_locals)
except:
print(f'[python_ta] Warning: could not evaluate invariant: {assertion}', file=sys.stderr)
else:
assert check,\
f'{wrapped.__name__} precondition "{assertion}" violated for arguments {function_locals}.'
def parse_assertions(docstring: str, parse_token: str = 'Precondition') -> List[str]:
"""Return a list of preconditions/representation invariants parsed from the given docstring.
Uses parse_token to determine what to look for. parse_token defaults to Precondition.
Currently only supports two forms:
1. A single line of the form "<parse_token>: <cond>"
2. A group of lines starting with "<parse_token>s:", where each subsequent
line is of the form "- <cond>". Each line is considered a separate condition.
The lines can be separated by blank lines, but no other text.
"""
lines = [line.strip() for line in docstring.split('\n')]
assertion_lines = [i
for i, line in enumerate(lines)
if line.lower().startswith(parse_token.lower())]
if assertion_lines == []:
return []
first = assertion_lines[0]
if lines[first].startswith(parse_token + ':'):
return [lines[first][len(parse_token + ':'):].strip()]
elif lines[first].startswith(parse_token + 's:'):
assertions = []
for line in lines[first + 1:]:
if line.startswith('-'):
assertions.append(line[1:].strip())
elif line != '':
break
return assertions
else:
return []