-
Notifications
You must be signed in to change notification settings - Fork 44
/
schema_utils.py
72 lines (57 loc) · 2.69 KB
/
schema_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from functools import reduce
from pysparkling.sql.internal_utils.joins import INNER_JOIN, CROSS_JOIN, LEFT_JOIN,\
LEFT_ANTI_JOIN, LEFT_SEMI_JOIN, RIGHT_JOIN, FULL_JOIN
from pysparkling.sql.types import _infer_schema, _has_nulltype, _merge_type, \
StructType, StructField, _get_null_fields
from pysparkling.sql.utils import IllegalArgumentException
def infer_schema_from_rdd(rdd):
return infer_schema_from_list(rdd.takeSample(withReplacement=False, num=200))
def infer_schema_from_list(data, names=None):
"""
Infer schema from list of Row or tuple.
:param data: list of Row or tuple
:param names: list of column names
:return: :class:`pysparkling.sql.types.StructType`
"""
if not data:
raise ValueError("can not infer schema from empty dataset")
first = data[0]
if isinstance(first, dict):
raise NotImplementedError(
"Inferring schema from dict is deprecated in Spark "
"and not implemented in pysparkling. "
"Please use .sql.Row instead"
)
schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
if _has_nulltype(schema):
raise ValueError(
"Type(s) of the following field(s) cannot be determined after inferring: '{0}'".format(
"', '".join(_get_null_fields(schema))
)
)
return schema
def merge_schemas(left_schema, right_schema, how, on=None):
if on is None:
on = []
left_on_fields, right_on_fields = get_on_fields(left_schema, right_schema, on)
other_left_fields = [field for field in left_schema.fields if field not in left_on_fields]
other_right_fields = [field for field in right_schema.fields if field not in right_on_fields]
if how in (INNER_JOIN, CROSS_JOIN, LEFT_JOIN, LEFT_ANTI_JOIN, LEFT_SEMI_JOIN):
on_fields = left_on_fields
elif how == RIGHT_JOIN:
on_fields = right_on_fields
elif how == FULL_JOIN:
on_fields = [StructField(field.name, field.dataType, nullable=True)
for field in left_on_fields]
else:
raise IllegalArgumentException("Invalid how argument in join: {0}".format(how))
return StructType(fields=on_fields + other_left_fields + other_right_fields)
def get_on_fields(left_schema, right_schema, on):
left_on_fields = [next(field for field in left_schema if field.name == c) for c in on]
right_on_fields = [next(field for field in right_schema if field.name == c) for c in on]
return left_on_fields, right_on_fields
def get_schema_from_cols(cols, current_schema):
new_schema = StructType(fields=[
field for col in cols for field in col.find_fields_in_schema(current_schema)
])
return new_schema