-
Notifications
You must be signed in to change notification settings - Fork 44
/
expressions.py
261 lines (204 loc) · 8.46 KB
/
expressions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from pysparkling.sql.casts import get_caster
from pysparkling.sql.types import StructField, DataType, \
INTERNAL_TYPE_ORDER, python_to_spark_type
from pysparkling.sql.utils import AnalysisException
class Expression(object):
def __init__(self, *children):
self.children = children
self.pre_evaluation_schema = None
def eval(self, row, schema):
raise NotImplementedError
def __str__(self):
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__
def output_fields(self, schema):
return [StructField(
name=str(self),
dataType=self.data_type,
nullable=self.is_nullable
)]
@property
def data_type(self):
# pylint: disable=W0511
# todo: be more specific
return DataType()
@property
def is_nullable(self):
return True
@property
def may_output_multiple_cols(self):
return False
@property
def may_output_multiple_rows(self):
return False
@property
def is_an_aggregation(self):
return False
def merge(self, row, schema):
pass
def recursive_merge(self, row, schema):
self.merge(row, schema)
self.children_merge(self.children, row, schema)
@staticmethod
def children_merge(children, row, schema):
for child in children:
if isinstance(child, Expression):
child.recursive_merge(row, schema)
elif hasattr(child, "expr") and isinstance(child.expr, Expression):
child.expr.recursive_merge(row, schema)
elif isinstance(child, (list, set, tuple)):
Expression.children_merge(child, row, schema)
def mergeStats(self, other, schema):
pass
def recursive_merge_stats(self, other, schema):
# Top level import would cause cyclic dependencies
# pylint: disable=import-outside-toplevel
from pysparkling.sql.expressions.operators import Alias
if isinstance(other.expr, Alias):
self.recursive_merge_stats(other.expr.expr, schema)
else:
self.mergeStats(other.expr, schema)
self.children_merge_stats(self.children, other, schema)
@staticmethod
def children_merge_stats(children, other, schema):
# Top level import would cause cyclic dependencies
# pylint: disable=import-outside-toplevel
from pysparkling.sql.column import Column
for child in children:
if isinstance(child, Expression):
child.recursive_merge_stats(other, schema)
elif isinstance(child, Column) and isinstance(child.expr, Expression):
child.expr.recursive_merge_stats(other, schema)
elif isinstance(child, (list, set, tuple)):
Expression.children_merge_stats(child, other, schema)
def recursive_initialize(self, partition_index):
"""
This methods adds once data to expressions that require it
e.g. for non-deterministic expression so that their result is constant
across several evaluations
"""
self.initialize(partition_index)
self.children_initialize(self.children, partition_index)
@staticmethod
def children_initialize(children, partition_index):
# Top level import would cause cyclic dependencies
# pylint: disable=import-outside-toplevel
from pysparkling.sql.column import Column
for child in children:
if isinstance(child, Expression):
child.recursive_initialize(partition_index)
elif isinstance(child, Column) and isinstance(child.expr, Expression):
child.expr.recursive_initialize(partition_index)
elif isinstance(child, (list, set, tuple)):
Expression.children_initialize(child, partition_index)
def initialize(self, partition_index):
pass
# Adding information about the schema that was defined in the step prior the evaluation
def with_pre_evaluation_schema(self, schema):
self.pre_evaluation_schema = schema
def recursive_pre_evaluation_schema(self, schema):
self.with_pre_evaluation_schema(schema)
self.children_pre_evaluation_schema(self.children, schema)
@staticmethod
def children_pre_evaluation_schema(children, schema):
# Top level import would cause cyclic dependencies
# pylint: disable=import-outside-toplevel
from pysparkling.sql.column import Column
for child in children:
if isinstance(child, Expression):
child.recursive_pre_evaluation_schema(schema)
elif isinstance(child, Column) and isinstance(child.expr, Expression):
child.expr.recursive_pre_evaluation_schema(schema)
elif isinstance(child, (list, set, tuple)):
Expression.children_pre_evaluation_schema(child, schema)
class UnaryExpression(Expression):
def __init__(self, column):
super(UnaryExpression, self).__init__(column)
self.column = column
def eval(self, row, schema):
raise NotImplementedError
def __str__(self):
raise NotImplementedError
class BinaryOperation(Expression):
"""
Perform a binary operation but return None if any value is None
"""
def __init__(self, arg1, arg2):
super(BinaryOperation, self).__init__(arg1, arg2)
self.arg1 = arg1
self.arg2 = arg2
def eval(self, row, schema):
raise NotImplementedError
def __str__(self):
raise NotImplementedError
class TypeSafeBinaryOperation(BinaryOperation):
"""
Perform a type and null-safe binary operation using *comparison* type cast rules:
It converts values if they are of different types following PySpark rules:
lit(datetime.date(2019, 1, 1))==lit("2019-01-01") is True
"""
def eval(self, row, schema):
value_1 = self.arg1.eval(row, schema)
value_2 = self.arg2.eval(row, schema)
if value_1 is None or value_2 is None:
return None
type_1 = value_1.__class__
type_2 = value_2.__class__
if type_1 == type_2:
return self.unsafe_operation(value_1, value_2)
try:
order_1 = INTERNAL_TYPE_ORDER.index(type_1)
order_2 = INTERNAL_TYPE_ORDER.index(type_2)
except ValueError as e:
raise AnalysisException("Unable to process type: {0}".format(e))
spark_type_1 = python_to_spark_type(type_1)
spark_type_2 = python_to_spark_type(type_2)
if order_1 > order_2:
caster = get_caster(from_type=spark_type_2, to_type=spark_type_1, options={})
value_2 = caster(value_2)
elif order_1 < order_2:
caster = get_caster(from_type=spark_type_1, to_type=spark_type_2, options={})
value_1 = caster(value_1)
return self.unsafe_operation(value_1, value_2)
def __str__(self):
raise NotImplementedError
def unsafe_operation(self, value_1, value_2):
raise NotImplementedError
class NullSafeBinaryOperation(BinaryOperation):
"""
Perform a null-safe binary operation
It does not converts values if they are of different types:
lit(datetime.date(2019, 1, 1)) - lit("2019-01-01") raises an error
"""
def eval(self, row, schema):
value_1 = self.arg1.eval(row, schema)
value_2 = self.arg2.eval(row, schema)
if value_1 is None or value_2 is None:
return None
type_1 = value_1.__class__
type_2 = value_2.__class__
if type_1 == type_2 or (
isinstance(value_1, (int, float)) and
isinstance(value_2, (int, float))
):
return self.unsafe_operation(value_1, value_2)
raise AnalysisException(
"Cannot resolve {0} due to data type mismatch, first value is {1}, second value is {2}."
"".format(self, type_1, type_2)
)
def __str__(self):
raise NotImplementedError
def unsafe_operation(self, value1, value2):
raise NotImplementedError
class NullSafeColumnOperation(Expression):
def __init__(self, column, *args):
super(NullSafeColumnOperation, self).__init__(column, *args)
self.column = column
def eval(self, row, schema):
value = self.column.eval(row, schema)
return self.unsafe_operation(value)
def __str__(self):
raise NotImplementedError
def unsafe_operation(self, value):
raise NotImplementedError