From 3da44fc88dbbff0bd2b12e5950a13d1f92129dbd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 27 Feb 2015 12:16:30 -0800 Subject: [PATCH] fix __eq__ of Singleton --- python/pyspark/sql/types.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 617c60dca7893..3e7c1256f3595 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -43,7 +43,7 @@ def __hash__(self): return hash(str(self)) def __eq__(self, other): - return isinstance(other, self.__class__) and self.jsonValue() == other.jsonValue() + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -64,6 +64,8 @@ def json(self): sort_keys=True) +# This singleton pattern does not work with pickle, you will get +# another object after pickle and unpickle class PrimitiveTypeSingleton(type): """Metaclass for PrimitiveType""" @@ -82,10 +84,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class NullType(PrimitiveType): @@ -510,6 +508,9 @@ def __eq__(self, other): def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. + >>> import pickle + >>> LongType() == pickle.loads(pickle.dumps(LongType())) + True >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) @@ -899,20 +900,20 @@ def _parse_field_abstract(s): Parse a field in schema abstract >>> _parse_field_abstract("a") - StructField(a,None,true) + StructField(a,NullType,true) >>> _parse_field_abstract("b(c d)") - StructField(b,StructType(...c,None,true),StructField(d... + StructField(b,StructType(...c,NullType,true),StructField(d... >>> _parse_field_abstract("a[]") - StructField(a,ArrayType(None,true),true) + StructField(a,ArrayType(NullType,true),true) >>> _parse_field_abstract("a{[]}") - StructField(a,MapType(None,ArrayType(None,true),true),true) + StructField(a,MapType(NullType,ArrayType(NullType,true),true),true) """ if set(_BRACKETS.keys()) & set(s): idx = min((s.index(c) for c in _BRACKETS if c in s)) name = s[:idx] return StructField(name, _parse_schema_abstract(s[idx:]), True) else: - return StructField(s, None, True) + return StructField(s, NullType(), True) def _parse_schema_abstract(s): @@ -926,11 +927,11 @@ def _parse_schema_abstract(s): >>> _parse_schema_abstract("c{} d{a b}") StructType...c,MapType...d,MapType...a...b... >>> _parse_schema_abstract("a b(t)").fields[1] - StructField(b,StructType(List(StructField(t,None,true))),true) + StructField(b,StructType(List(StructField(t,NullType,true))),true) """ s = s.strip() if not s: - return + return NullType() elif s.startswith('('): return _parse_schema_abstract(s[1:-1]) @@ -939,7 +940,7 @@ def _parse_schema_abstract(s): return ArrayType(_parse_schema_abstract(s[1:-1]), True) elif s.startswith('{'): - return MapType(None, _parse_schema_abstract(s[1:-1])) + return MapType(NullType(), _parse_schema_abstract(s[1:-1])) parts = _split_schema_abstract(s) fields = [_parse_field_abstract(p) for p in parts] @@ -959,7 +960,7 @@ def _infer_schema_type(obj, dataType): >>> _infer_schema_type(row, schema) StructType...a,ArrayType...b,MapType(StringType,...c,LongType... """ - if dataType is None: + if dataType is NullType(): return _infer_type(obj) if not obj: