Skip to content

Commit

Permalink
fix memory leak in sql
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Feb 27, 2015
1 parent fbc4694 commit d9ae973
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 98 deletions.
90 changes: 1 addition & 89 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

import warnings
import json
from array import array
from itertools import imap

from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter

from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.sql.types import StringType, StructType, _verify_type, \
from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame

Expand Down Expand Up @@ -620,93 +619,6 @@ def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())


def _create_row(fields, values):
row = Row(*values)
row.__FIELDS__ = fields
return row


class Row(tuple):

"""
A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
>>> row = Row(name="Alice", age=11)
>>> row
Row(age=11, name='Alice')
>>> row.name, row.age
('Alice', 11)
Row also can be used to create another Row like class, then it
could be used to create Row objects, such as
>>> Person = Row("name", "age")
>>> Person
<Row(name, age)>
>>> Person("Alice", 11)
Row(name='Alice', age=11)
"""

def __new__(self, *args, **kwargs):
if args and kwargs:
raise ValueError("Can not use both args "
"and kwargs to create Row")
if args:
# create row class or objects
return tuple.__new__(self, args)

elif kwargs:
# create row objects
names = sorted(kwargs.keys())
values = tuple(kwargs[n] for n in names)
row = tuple.__new__(self, values)
row.__FIELDS__ = names
return row

else:
raise ValueError("No args or kwargs")

def asDict(self):
"""
Return as an dict
"""
if not hasattr(self, "__FIELDS__"):
raise TypeError("Cannot convert a Row class into dict")
return dict(zip(self.__FIELDS__, self))

# let obect acs like class
def __call__(self, *args):
"""create new Row object"""
return _create_row(self, args)

def __getattr__(self, item):
if item.startswith("__"):
raise AttributeError(item)
try:
# it will be slow when it has many fields,
# but this will not be used in normal cases
idx = self.__FIELDS__.index(item)
return self[idx]
except IndexError:
raise AttributeError(item)

def __reduce__(self):
if hasattr(self, "__FIELDS__"):
return (_create_row, (self.__FIELDS__, tuple(self)))
else:
return tuple.__reduce__(self)

def __repr__(self):
if hasattr(self, "__FIELDS__"):
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
for k, v in zip(self.__FIELDS__, self))
else:
return "<Row(%s)>" % ", ".join(self)


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,10 +1025,12 @@ def cast(self, dataType):
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
else:
raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)

def __repr__(self):
return 'Column<%s>' % self._jdf.toString().encode('utf8')
return 'Column<%s>' % self._jc.toString().encode('utf8')


def _test():
Expand Down
37 changes: 29 additions & 8 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import warnings
import json
import re
import weakref
from array import array
from operator import itemgetter

Expand All @@ -42,8 +43,7 @@ def __hash__(self):
return hash(str(self))

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
return isinstance(other, self.__class__) and self.jsonValue() == other.jsonValue()

def __ne__(self, other):
return not self.__eq__(other)
Expand Down Expand Up @@ -786,8 +786,24 @@ def _merge_type(a, b):
return a


def _need_converter(dataType):
if isinstance(dataType, StructType):
return True
elif isinstance(dataType, ArrayType):
return _need_converter(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
elif isinstance(dataType, NullType):
return True
else:
return False


def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if not _need_converter(dataType):
return lambda x: x

if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
Expand All @@ -806,13 +822,17 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)

def convert_struct(obj):
if obj is None:
return

if isinstance(obj, (tuple, list)):
return tuple(conv(v) for v, conv in zip(obj, converters))
if convert_fields:
return tuple(conv(v) for v, conv in zip(obj, converters))
else:
return tuple(obj)

if isinstance(obj, dict):
d = obj
Expand All @@ -821,7 +841,10 @@ def convert_struct(obj):
else:
raise ValueError("Unexpected obj: %s" % obj)

return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])

return convert_struct

Expand Down Expand Up @@ -1037,8 +1060,7 @@ def _verify_type(obj, dataType):
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)


_cached_cls = {}
_cached_cls = weakref.WeakValueDictionary()


def _restore_object(dataType, obj):
Expand Down Expand Up @@ -1233,8 +1255,7 @@ def __new__(self, *args, **kwargs):
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
values = tuple(kwargs[n] for n in names)
row = tuple.__new__(self, values)
row = tuple.__new__(self, [kwargs[n] for n in names])
row.__FIELDS__ = names
return row

Expand Down

0 comments on commit d9ae973

Please sign in to comment.