Skip to content

Commit

Permalink
fix __eq__ of Singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 27, 2015
1 parent 534ac90 commit 3da44fc
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return isinstance(other, self.__class__) and self.jsonValue() == other.jsonValue()
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__

def __ne__(self, other):
return not self.__eq__(other)
Expand All @@ -64,6 +64,8 @@ def json(self):
sort_keys=True)


# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
class PrimitiveTypeSingleton(type):

"""Metaclass for PrimitiveType"""
Expand All @@ -82,10 +84,6 @@ class PrimitiveType(DataType):

__metaclass__ = PrimitiveTypeSingleton

def __eq__(self, other):
# because they should be the same object
return self is other


class NullType(PrimitiveType):

Expand Down Expand Up @@ -510,6 +508,9 @@ def __eq__(self, other):

def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
>>> LongType() == pickle.loads(pickle.dumps(LongType()))
True
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
Expand Down Expand Up @@ -899,20 +900,20 @@ def _parse_field_abstract(s):
Parse a field in schema abstract
>>> _parse_field_abstract("a")
StructField(a,None,true)
StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
StructField(b,StructType(...c,None,true),StructField(d...
StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
StructField(a,ArrayType(None,true),true)
StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
StructField(a,MapType(None,ArrayType(None,true),true),true)
StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
return StructField(s, None, True)
return StructField(s, NullType(), True)


def _parse_schema_abstract(s):
Expand All @@ -926,11 +927,11 @@ def _parse_schema_abstract(s):
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
StructField(b,StructType(List(StructField(t,None,true))),true)
StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
return
return NullType()

elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
Expand All @@ -939,7 +940,7 @@ def _parse_schema_abstract(s):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)

elif s.startswith('{'):
return MapType(None, _parse_schema_abstract(s[1:-1]))
return MapType(NullType(), _parse_schema_abstract(s[1:-1]))

parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
Expand All @@ -959,7 +960,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if dataType is None:
if dataType is NullType():
return _infer_type(obj)

if not obj:
Expand Down

0 comments on commit 3da44fc

Please sign in to comment.