Skip to content

Commit

Permalink
Merge pull request #9 from tools4origins/feat/ExpressionTypes
Browse files Browse the repository at this point in the history
Feat/expression types
  • Loading branch information
tools4origins committed Oct 31, 2021
2 parents 2025eba + 507f4f7 commit 8b801ac
Show file tree
Hide file tree
Showing 16 changed files with 717 additions and 42 deletions.
24 changes: 16 additions & 8 deletions pysparkling/sql/column.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .expressions.expressions import Expression
from .expressions.expressions import Expression, RegisteredExpressions
from .expressions.fields import find_position_in_schema
from .expressions.literals import Literal
from .expressions.mappers import CaseWhen, StarOperator
Expand Down Expand Up @@ -630,7 +630,7 @@ def output_fields(self, schema):
return self.expr.output_fields(schema)
return [StructField(
name=self.col_name,
dataType=self.data_type,
dataType=self.data_type(schema),
nullable=self.is_nullable
)]

Expand All @@ -647,7 +647,6 @@ def mergeStats(self, row, schema):
def initialize(self, partition_index):
if isinstance(self.expr, Expression):
self.expr.recursive_initialize(partition_index)
return self

def with_pre_evaluation_schema(self, pre_evaluation_schema):
if isinstance(self.expr, Expression):
Expand Down Expand Up @@ -682,11 +681,20 @@ def __nonzero__(self):

__bool__ = __nonzero__

@property
def data_type(self):
# pylint: disable=W0511
# todo: be more specific
return DataType()
def data_type(self, schema):
if isinstance(self.expr, (Expression, RegisteredExpressions)):
return self.expr.data_type(schema)
if isinstance(self.expr, str):
try:
return schema[self.expr].dataType
except KeyError:
# pylint: disable=raise-missing-from
raise AnalysisException(
f"cannot resolve '`{self.expr}`' given input columns: {schema.fields};"
)
raise AnalysisException(
f"cannot resolve '`{self.expr}`' type, expecting str or Expression but got {type(self.expr)};"
)

@property
def is_nullable(self):
Expand Down
122 changes: 115 additions & 7 deletions pysparkling/sql/expressions/arrays.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ..column import Column
from ..types import ArrayType, BooleanType, IntegerType, MapType, NullType, StringType, StructType
from ..utils import AnalysisException
from .expressions import BinaryOperation, Expression, UnaryExpression

Expand All @@ -15,6 +17,9 @@ def eval(self, row, schema):
return None
return False

def data_type(self, schema):
return BooleanType()


class ArrayContains(Expression):
pretty_name = "array_contains"
Expand All @@ -36,6 +41,9 @@ def args(self):
self.value
)

def data_type(self, schema):
return BooleanType()


class ArrayColumn(Expression):
pretty_name = "array"
Expand All @@ -50,13 +58,23 @@ def eval(self, row, schema):
def args(self):
return self.columns

def data_type(self, schema):
if not self.columns:
return ArrayType(elementType=NullType)
return ArrayType(elementType=self.columns[0].data_type(schema))


class MapColumn(Expression):
pretty_name = "map"

def __init__(self, *columns):
super().__init__(columns)
self.columns = columns
if len(columns) % 2 != 0:
raise AnalysisException(
f"Cannot resolve '{self}' due to data type mismatch: "
f"map expects a positive even number of arguments."
)
self.keys = columns[::2]
self.values = columns[1::2]

Expand All @@ -69,6 +87,14 @@ def eval(self, row, schema):
def args(self):
return self.columns

def data_type(self, schema):
if not self.columns:
return MapType(keyType=NullType, valueType=NullType)
return MapType(
keyType=self.keys[0].data_type(schema),
valueType=self.values[0].data_type(schema),
)


class MapFromArraysColumn(Expression):
pretty_name = "map_from_arrays"
Expand All @@ -79,16 +105,28 @@ def __init__(self, keys, values):
self.values = values

def eval(self, row, schema):
return dict(
zip(self.keys.eval(row, schema), self.values.eval(row, schema))
)
keys = self.keys.eval(row, schema)
values = self.values.eval(row, schema)
if len(keys) != len(values):
raise AnalysisException(
f"Error in '{self}': The key array and value array of MapData must have the same length."
)
return dict(zip(keys, values))

def args(self):
return (
self.keys,
self.values
)

def data_type(self, schema):
if not isinstance(self.keys, Column) and not self.keys:
return MapType(keyType=NullType, valueType=NullType)
return MapType(
keyType=self.keys[0].data_type(schema),
valueType=self.values[0].data_type(schema),
)


class Size(UnaryExpression):
pretty_name = "size"
Expand All @@ -101,26 +139,52 @@ def eval(self, row, schema):
f"{self.column} value should be an array or a map, got {type(column_value)}"
)

def data_type(self, schema):
return IntegerType()


class ArraySort(UnaryExpression):
pretty_name = "array_sort"

def eval(self, row, schema):
return sorted(self.column.eval(row, schema))

def data_type(self, schema):
return self.column.data_type(schema)


class ArrayMin(UnaryExpression):
pretty_name = "array_min"

def eval(self, row, schema):
return min(self.column.eval(row, schema))
column_type = self.column.data_type(schema)
column_value = self.column.eval(row, schema)
if not column_type == ArrayType:
raise AnalysisException(
f"Cannot resolve '{self}' due to data type mismatch: argument 1 requires array type, "
f"however, '{column_value}' is of {column_type} type."
)
return min(column_value)

def data_type(self, schema):
return self.column.data_type(schema).elementType()


class ArrayMax(UnaryExpression):
pretty_name = "array_max"

def eval(self, row, schema):
return max(self.column.eval(row, schema))
column_type = self.column.data_type(schema)
column_value = self.column.eval(row, schema)
if not column_type == ArrayType:
raise AnalysisException(
f"Cannot resolve '{self}' due to data type mismatch: argument 1 requires array type, "
f"however, '{column_value}' is of {column_type} type."
)
return max(column_value)

def data_type(self, schema):
return self.column.data_type(schema).elementType()


class Slice(Expression):
Expand All @@ -142,6 +206,9 @@ def args(self):
self.length
)

def data_type(self, schema):
return self.column.data_type(schema)


class ArrayRepeat(Expression):
pretty_name = "array_repeat"
Expand All @@ -161,9 +228,12 @@ def args(self):
self.count
)

def data_type(self, schema):
return ArrayType(self.col.data_type(schema))


class Sequence(Expression):
pretty_name = "array_join"
pretty_name = "sequence"

def __init__(self, start, stop, step):
super().__init__(start, stop, step)
Expand Down Expand Up @@ -201,6 +271,9 @@ def args(self):
self.step
)

def data_type(self, schema):
return ArrayType(self.start.data_type(schema))


class ArrayJoin(Expression):
pretty_name = "array_join"
Expand Down Expand Up @@ -231,6 +304,9 @@ def args(self):
self.nullReplacement
)

def data_type(self, schema):
return StringType()


class SortArray(Expression):
pretty_name = "sort_array"
Expand All @@ -249,6 +325,9 @@ def args(self):
self.asc
)

def data_type(self, schema):
return self.col.data_type(schema)


class ArraysZip(Expression):
pretty_name = "arrays_zip"
Expand All @@ -259,7 +338,7 @@ def __init__(self, columns):

def eval(self, row, schema):
return [
list(combination)
dict(enumerate(combination))
for combination in zip(
*(c.eval(row, schema) for c in self.columns)
)
Expand All @@ -268,6 +347,11 @@ def eval(self, row, schema):
def args(self):
return self.columns

def data_type(self, schema):
return ArrayType(StructType([
col.data_type(schema) for col in self.columns
]))


class Flatten(UnaryExpression):
pretty_name = "flatten"
Expand All @@ -279,6 +363,9 @@ def eval(self, row, schema):
for value in array
]

def data_type(self, schema):
return self.column.data_type(schema).elementType


class ArrayPosition(Expression):
pretty_name = "array_position"
Expand All @@ -303,6 +390,9 @@ def args(self):
self.value
)

def data_type(self, schema):
return IntegerType()


class ElementAt(Expression):
pretty_name = "element_at"
Expand All @@ -324,6 +414,9 @@ def args(self):
self.extraction
)

def data_type(self, schema):
return self.col.data_type(schema).elementType


class ArrayRemove(Expression):
pretty_name = "array_remove"
Expand All @@ -343,34 +436,49 @@ def args(self):
self.element
)

def data_type(self, schema):
return self.col.data_type(schema)


class ArrayDistinct(UnaryExpression):
pretty_name = "array_distinct"

def eval(self, row, schema):
return list(set(self.column.eval(row, schema)))

def data_type(self, schema):
return self.column.data_type(schema)


class ArrayIntersect(BinaryOperation):
pretty_name = "array_intersect"

def eval(self, row, schema):
return list(set(self.arg1.eval(row, schema)) & set(self.arg2.eval(row, schema)))

def data_type(self, schema):
return self.arg1.data_type(schema)


class ArrayUnion(BinaryOperation):
pretty_name = "array_union"

def eval(self, row, schema):
return list(set(self.arg1.eval(row, schema)) | set(self.arg2.eval(row, schema)))

def data_type(self, schema):
return self.arg1.data_type(schema)


class ArrayExcept(BinaryOperation):
pretty_name = "array_except"

def eval(self, row, schema):
return list(set(self.arg1.eval(row, schema)) - set(self.arg2.eval(row, schema)))

def data_type(self, schema):
return self.arg1.data_type(schema)


__all__ = [
"ArraysZip", "ArrayRepeat", "Flatten", "ArrayMax", "ArrayMin", "SortArray", "Size",
Expand Down
4 changes: 4 additions & 0 deletions pysparkling/sql/expressions/csvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ..internal_utils.options import Options
from ..internal_utils.readers.csvreader import csv_record_to_row, CSVReader
from ..internal_utils.readers.utils import guess_schema_from_strings
from ..types import StringType
from ..utils import AnalysisException
from .expressions import Expression

Expand Down Expand Up @@ -33,3 +34,6 @@ def eval(self, row, schema):

def args(self):
return (self.column,)

def data_type(self, schema):
return StringType()

0 comments on commit 8b801ac

Please sign in to comment.