diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5d7aeb664cadf..795ef0dbc4c47 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -17,7 +17,6 @@ import warnings import json -from array import array from itertools import imap from py4j.protocol import Py4JError @@ -25,7 +24,7 @@ from pyspark.rdd import RDD, _prepare_for_python_RDD from pyspark.serializers import AutoBatchedSerializer, PickleSerializer -from pyspark.sql.types import StringType, StructType, _verify_type, \ +from pyspark.sql.types import Row, StringType, StructType, _verify_type, \ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame @@ -620,93 +619,6 @@ def _get_hive_ctx(self): return self._jvm.HiveContext(self._jsc.sc()) -def _create_row(fields, values): - row = Row(*values) - row.__FIELDS__ = fields - return row - - -class Row(tuple): - - """ - A row in L{DataFrame}. The fields in it can be accessed like attributes. - - Row can be used to create a row object by using named arguments, - the fields will be sorted by names. - - >>> row = Row(name="Alice", age=11) - >>> row - Row(age=11, name='Alice') - >>> row.name, row.age - ('Alice', 11) - - Row also can be used to create another Row like class, then it - could be used to create Row objects, such as - - >>> Person = Row("name", "age") - >>> Person - - >>> Person("Alice", 11) - Row(name='Alice', age=11) - """ - - def __new__(self, *args, **kwargs): - if args and kwargs: - raise ValueError("Can not use both args " - "and kwargs to create Row") - if args: - # create row class or objects - return tuple.__new__(self, args) - - elif kwargs: - # create row objects - names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) - row.__FIELDS__ = names - return row - - else: - raise ValueError("No args or kwargs") - - def asDict(self): - """ - Return as an dict - """ - if not hasattr(self, "__FIELDS__"): - raise TypeError("Cannot convert a Row class into dict") - return dict(zip(self.__FIELDS__, self)) - - # let obect acs like class - def __call__(self, *args): - """create new Row object""" - return _create_row(self, args) - - def __getattr__(self, item): - if item.startswith("__"): - raise AttributeError(item) - try: - # it will be slow when it has many fields, - # but this will not be used in normal cases - idx = self.__FIELDS__.index(item) - return self[idx] - except IndexError: - raise AttributeError(item) - - def __reduce__(self): - if hasattr(self, "__FIELDS__"): - return (_create_row, (self.__FIELDS__, tuple(self))) - else: - return tuple.__reduce__(self) - - def __repr__(self): - if hasattr(self, "__FIELDS__"): - return "Row(%s)" % ", ".join("%s=%r" % (k, v) - for k, v in zip(self.__FIELDS__, self)) - else: - return "" % ", ".join(self) - - def _test(): import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index aec99017fbdc1..5c3b7377c33b5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1025,10 +1025,12 @@ def cast(self, dataType): ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(dataType.json()) jc = self._jc.cast(jdt) + else: + raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) def __repr__(self): - return 'Column<%s>' % self._jdf.toString().encode('utf8') + return 'Column<%s>' % self._jc.toString().encode('utf8') def _test(): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83899ad4b1b12..2720439416682 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -24,6 +24,7 @@ import pydoc import shutil import tempfile +import pickle import py4j @@ -88,6 +89,14 @@ def __eq__(self, other): other.x == self.x and other.y == self.y +class DataTypeTests(unittest.TestCase): + # regression test for SPARK-6055 + def test_data_type_eq(self): + lt = LongType() + lt2 = pickle.loads(pickle.dumps(LongType())) + self.assertEquals(lt, lt2) + + class SQLTests(ReusedPySparkTestCase): @classmethod diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0f5dc2be6dab8..31a861e1feb46 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -21,6 +21,7 @@ import warnings import json import re +import weakref from array import array from operator import itemgetter @@ -42,8 +43,7 @@ def __hash__(self): return hash(str(self)) def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -64,6 +64,8 @@ def json(self): sort_keys=True) +# This singleton pattern does not work with pickle, you will get +# another object after pickle and unpickle class PrimitiveTypeSingleton(type): """Metaclass for PrimitiveType""" @@ -82,10 +84,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class NullType(PrimitiveType): @@ -242,11 +240,12 @@ def __init__(self, elementType, containsNull=True): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, True) + >>> ArrayType(StringType()) == ArrayType(StringType(), True) True - >>> ArrayType(StringType, False) == ArrayType(StringType) + >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ + assert isinstance(elementType, DataType), "elementType should be DataType" self.elementType = elementType self.containsNull = containsNull @@ -292,13 +291,15 @@ def __init__(self, keyType, valueType, valueContainsNull=True): :param valueContainsNull: indicates whether values contains null values. - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) False """ + assert isinstance(keyType, DataType), "keyType should be DataType" + assert isinstance(valueType, DataType), "valueType should be DataType" self.keyType = keyType self.valueType = valueType self.valueContainsNull = valueContainsNull @@ -348,13 +349,14 @@ def __init__(self, name, dataType, nullable=True, metadata=None): to simple type that can be serialized to JSON automatically - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) False """ + assert isinstance(dataType, DataType), "dataType should be DataType" self.name = name self.dataType = dataType self.nullable = nullable @@ -393,16 +395,17 @@ class StructType(DataType): def __init__(self, fields): """Creates a StructType - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ + assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" self.fields = fields def simpleString(self): @@ -505,20 +508,24 @@ def __eq__(self, other): def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. + >>> import pickle >>> def check_datatype(datatype): + ... pickled = pickle.loads(pickle.dumps(datatype)) + ... assert datatype == pickled ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) - ... return datatype == python_datatype - >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) - True + ... assert datatype == python_datatype + >>> for cls in _all_primitive_types.values(): + ... check_datatype(cls()) + >>> # Simple ArrayType. >>> simple_arraytype = ArrayType(StringType(), True) >>> check_datatype(simple_arraytype) - True + >>> # Simple MapType. >>> simple_maptype = MapType(StringType(), LongType()) >>> check_datatype(simple_maptype) - True + >>> # Simple StructType. >>> simple_structtype = StructType([ ... StructField("a", DecimalType(), False), @@ -526,7 +533,7 @@ def _parse_datatype_json_string(json_string): ... StructField("c", LongType(), True), ... StructField("d", BinaryType(), False)]) >>> check_datatype(simple_structtype) - True + >>> # Complex StructType. >>> complex_structtype = StructType([ ... StructField("simpleArray", simple_arraytype, True), @@ -535,22 +542,20 @@ def _parse_datatype_json_string(json_string): ... StructField("boolean", BooleanType(), False), ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) - True + >>> # Complex ArrayType. >>> complex_arraytype = ArrayType(complex_structtype, True) >>> check_datatype(complex_arraytype) - True + >>> # Complex MapType. >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - True + >>> check_datatype(ExamplePointUDT()) - True >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), ... StructField("point", ExamplePointUDT(), False)]) >>> check_datatype(structtype_with_udt) - True """ return _parse_datatype_json_value(json.loads(json_string)) @@ -786,8 +791,24 @@ def _merge_type(a, b): return a +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ + if not _need_converter(dataType): + return lambda x: x + if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) @@ -806,13 +827,17 @@ def _create_converter(dataType): # dataType must be StructType names = [f.name for f in dataType.fields] converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) def convert_struct(obj): if obj is None: return if isinstance(obj, (tuple, list)): - return tuple(conv(v) for v, conv in zip(obj, converters)) + if convert_fields: + return tuple(conv(v) for v, conv in zip(obj, converters)) + else: + return tuple(obj) if isinstance(obj, dict): d = obj @@ -821,7 +846,10 @@ def convert_struct(obj): else: raise ValueError("Unexpected obj: %s" % obj) - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) return convert_struct @@ -871,20 +899,20 @@ def _parse_field_abstract(s): Parse a field in schema abstract >>> _parse_field_abstract("a") - StructField(a,None,true) + StructField(a,NullType,true) >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... + StructField(b,StructType(...c,NullType,true),StructField(d... >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) + StructField(a,ArrayType(NullType,true),true) >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) + StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) """ if set(_BRACKETS.keys()) & set(s): idx = min((s.index(c) for c in _BRACKETS if c in s)) name = s[:idx] return StructField(name, _parse_schema_abstract(s[idx:]), True) else: - return StructField(s, None, True) + return StructField(s, NullType(), True) def _parse_schema_abstract(s): @@ -898,11 +926,11 @@ def _parse_schema_abstract(s): >>> _parse_schema_abstract("c{} d{a b}") StructType...c,MapType...d,MapType...a...b... >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) + StructField(b,StructType(List(StructField(t,NullType,true))),true) """ s = s.strip() if not s: - return + return NullType() elif s.startswith('('): return _parse_schema_abstract(s[1:-1]) @@ -911,7 +939,7 @@ def _parse_schema_abstract(s): return ArrayType(_parse_schema_abstract(s[1:-1]), True) elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) + return MapType(NullType(), _parse_schema_abstract(s[1:-1])) parts = _split_schema_abstract(s) fields = [_parse_field_abstract(p) for p in parts] @@ -931,7 +959,7 @@ def _infer_schema_type(obj, dataType): >>> _infer_schema_type(row, schema) StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ - if dataType is None: + if dataType is NullType(): return _infer_type(obj) if not obj: @@ -1037,8 +1065,7 @@ def _verify_type(obj, dataType): for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) - -_cached_cls = {} +_cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): @@ -1233,8 +1260,7 @@ def __new__(self, *args, **kwargs): elif kwargs: # create row objects names = sorted(kwargs.keys()) - values = tuple(kwargs[n] for n in names) - row = tuple.__new__(self, values) + row = tuple.__new__(self, [kwargs[n] for n in names]) row.__FIELDS__ = names return row