-
Notifications
You must be signed in to change notification settings - Fork 126
/
pampy.py
194 lines (164 loc) · 6.83 KB
/
pampy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from collections import Iterable
from itertools import zip_longest
from typing import Tuple, List
from typing import Pattern as RegexPattern
from pampy.helpers import *
_ = ANY = UnderscoreType()
HEAD = HeadType()
REST = TAIL = TailType()
def run(action, var):
if callable(action):
if isinstance(var, Iterable):
try:
return action(*var)
except TypeError as err:
raise MatchError(get_lambda_args_error_msg(action, var, err))
elif isinstance(var, BoxedArgs):
return action(var.get())
else:
return action(var)
else:
return action
def match_value(pattern, value) -> Tuple[bool, List]:
if value is PaddedValue:
return False, []
elif isinstance(pattern, (int, float, str, bool)):
eq = pattern == value
type_eq = type(pattern) == type(value)
return eq and type_eq, []
elif pattern is None:
return value is None, []
elif isinstance(pattern, type):
if isinstance(value, pattern):
return True, [value]
elif isinstance(pattern, (list, tuple)):
return match_iterable(pattern, value)
elif isinstance(pattern, dict):
return match_dict(pattern, value)
elif callable(pattern):
return_value = pattern(value)
if isinstance(return_value, bool):
return return_value, [value]
elif isinstance(return_value, tuple) and len(return_value) == 2 \
and isinstance(return_value[0], bool) and isinstance(return_value[1], list):
return return_value
else:
raise MatchError("Warning! pattern function %s is not returning a boolean "
"nor a tuple of (boolean, list), but instead %s" %
(pattern, return_value))
elif isinstance(pattern, RegexPattern):
rematch = pattern.search(value)
if rematch is not None:
return True, list(rematch.groups())
elif pattern is _:
return True, [value]
elif pattern is HEAD or pattern is TAIL:
raise MatchError("HEAD or TAIL should only be used inside an Iterable (list or tuple).")
elif is_dataclass(pattern) and pattern.__class__ == value.__class__:
return match_dict(pattern.__dict__, value.__dict__)
return False, []
def match_dict(pattern, value) -> Tuple[bool, List]:
if not isinstance(value, dict) or not isinstance(pattern, dict):
return False, []
total_extracted = []
still_usable_value_keys = set(value.keys())
still_usable_pattern_keys = set(pattern.keys())
for pkey, pval in pattern.items():
if pkey not in still_usable_pattern_keys:
continue
matched_left_and_right = False
for vkey, vval in value.items():
if vkey not in still_usable_value_keys:
continue
if pkey not in still_usable_pattern_keys:
continue
key_matched, key_extracted = match_value(pkey, vkey)
if key_matched:
value_matched, value_extracted = match_value(pval, vval)
if value_matched:
total_extracted += key_extracted + value_extracted
matched_left_and_right = True
still_usable_pattern_keys.remove(pkey)
still_usable_value_keys.remove(vkey)
break
if not matched_left_and_right:
return False, []
return True, total_extracted
def only_padded_values_follow(padded_pairs, i):
i += 1
while i < len(padded_pairs):
pattern, value = padded_pairs[i]
if pattern is not PaddedValue:
return False
i += 1
return True
def match_iterable(patterns, values) -> Tuple[bool, List]:
if not isinstance(patterns, Iterable) or not isinstance(values, Iterable):
return False, []
total_extracted = []
padded_pairs = list(zip_longest(patterns, values, fillvalue=PaddedValue))
for i, (pattern, value) in enumerate(padded_pairs):
if pattern is HEAD:
if i != 0:
raise MatchError("HEAD can only be in first position of a pattern.")
else:
if value is PaddedValue:
return False, []
else:
total_extracted += [value]
elif pattern is TAIL:
if not only_padded_values_follow(padded_pairs, i):
raise MatchError("TAIL must me in last position of the pattern.")
else:
tail = [value for (pattern, value) in padded_pairs[i:] if value is not PaddedValue]
total_extracted.append(tail)
break
else:
matched, extracted = match_value(pattern, value)
if not matched:
return False, []
else:
total_extracted += extracted
return True, total_extracted
def match(var, *args, default=NoDefault, strict=True):
"""
Match `var` against a number of potential patterns.
Example usage:
```
match(x,
3, "this matches the number 3",
int, "matches any integer",
(str, int), lambda a, b: "a tuple (a, b) you can use in a function",
[1, 2, _], "any list of 3 elements that begins with [1, 2]",
{'x': _}, "any dict with a key 'x' and any value associated",
_, "anything else"
)
```
:param var: The variable to test patterns against.
:param args: Alternating patterns and actions. There must be an action for every pattern specified.
Patterns can take many forms, see README.md for examples.
Actions can be either a literal value or a callable which will be called with the arguments that were
matched in corresponding pattern.
:param default: If `default` is specified then it will be returned if none of the patterns match.
If `default` is unspecified then a `MatchError` will be thrown instead.
:return: The result of the action which corresponds to the first matching pattern.
"""
if len(args) % 2 != 0:
raise MatchError("Every guard must have an action.")
if default is NoDefault and strict is False:
default = False
pairs = list(pairwise(args))
patterns = [patt for (patt, action) in pairs]
for patt, action in pairs:
matched_as_value, args = match_value(patt, var)
if matched_as_value:
lambda_args = args if len(args) > 0 else BoxedArgs(var)
return run(action, lambda_args)
if default is NoDefault:
if _ not in patterns:
raise MatchError("'_' not provided. This case is not handled:\n%s" % str(var))
else:
return default
class MatchError(Exception):
def __init__(self, msg):
super().__init__(msg)