-
Notifications
You must be signed in to change notification settings - Fork 44
/
fields.py
73 lines (55 loc) · 2.29 KB
/
fields.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from pysparkling.sql.types import StructField
from pysparkling.sql.expressions.expressions import Expression
from pysparkling.sql.utils import AnalysisException
class FieldAsExpression(Expression):
def __init__(self, field):
super(FieldAsExpression, self).__init__()
self.field = field
def eval(self, row, schema):
return row[find_position_in_schema(schema, self.field)]
def __str__(self):
return self.field.name
def output_fields(self, schema):
return [self.field]
def find_position_in_schema(schema, expr):
if isinstance(expr, str):
show_id = False
field_name = expr
matches = set(i for i, field in enumerate(schema.fields) if field_name == field.name)
elif isinstance(expr, FieldAsExpression):
return find_position_in_schema(schema, expr.field)
elif isinstance(expr, StructField) and hasattr(expr, "id"):
show_id = True
field_name = format_field(expr, show_id=show_id)
matches = set(i for i, field in enumerate(schema.fields) if expr.id == field.id)
else:
if isinstance(expr, StructField):
expression = "Unbound field {0}".format(expr.name)
else:
expression = "Expression type '{0}'".format(type(expr))
raise NotImplementedError(
"{0} is not supported. "
"As a user you should not see this error, feel free to report a bug at "
"https://github.com/svenkreiss/pysparkling/issues".format(expression)
)
return get_checked_matches(matches, field_name, schema, show_id)
def get_checked_matches(matches, field_name, schema, show_id):
if not matches:
raise AnalysisException("Unable to find the column '{0}' among {1}".format(
field_name,
format_schema(schema, show_id)
))
if len(matches) > 1:
raise AnalysisException(
"Reference '{0}' is ambiguous, found {1} columns matching it.".format(
field_name,
len(matches)
)
)
return matches.pop()
def format_schema(schema, show_id):
return [format_field(field, show_id=show_id) for field in schema.fields]
def format_field(field, show_id):
if show_id:
return "{0}#{1}".format(field.name, field.id)
return field.name