Skip to content

Commit

Permalink
Merge pull request #36 from pganssle/fix_groupby_problem
Browse files Browse the repository at this point in the history
Fixed issue with groupByKey and reduceByKey throwing errors with numpy arrays
  • Loading branch information
svenkreiss committed Mar 1, 2016
2 parents 6f0cacd + 7a4dda7 commit d9da014
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
5 changes: 2 additions & 3 deletions pysparkling/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,9 +669,8 @@ def groupBy(self, f, numPartitions=None):
>>> from pysparkling import Context
>>> my_rdd = Context().parallelize([4, 7, 2])
>>> my_rdd.groupBy(lambda x: x % 2).collect()
>>> my_rdd.groupBy(lambda x: x % 2).mapValues(sorted).collect()
[(0, [2, 4]), (1, [7])]
"""

return self.keyBy(f).groupByKey(numPartitions)
Expand All @@ -691,7 +690,7 @@ def groupByKey(self, numPartitions=None):

return self.context.parallelize((
(k, [gg[1] for gg in g]) for k, g in itertools.groupby(
sorted(self.collect()),
sorted(self.collect(), key=itemgetter(0)),
lambda e: e[0],
)
), numPartitions)
Expand Down
63 changes: 62 additions & 1 deletion tests/test_rdd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pysparkling import Context
from operator import add

import unittest


class RDDTest(unittest.TestCase):
""" Tests for the resilient distributed databases """

Expand Down Expand Up @@ -160,6 +160,67 @@ def test_sampleByKey(self):
self.assertTrue(max(sample["a"]) <= 999 and min(sample["a"]) >= 0)
self.assertTrue(max(sample["b"]) <= 999 and min(sample["b"]) >= 0)

def test_groupByKey(self):
# This will fail if the values of the RDD need to be compared
class IncomparableValue(object):
def __init__(self, value):
self.value = value

def __eq__(self, other):
return self.value == other.value

def __lt__(self, other):
raise NotImplementedError("This object cannot be compared")

keys = (0, 1, 2, 0, 1, 2)
r = [IncomparableValue(i) for i in range(len(keys))]

k_rdd = self.context.parallelize(zip(keys, r))
actual_group = k_rdd.groupByKey().collect()

expected_group = ((0, r[0::3]),
(1, r[1::3]),
(2, r[2::3]))

grouped_dict = {k: v for k, v in actual_group}

for k, v in expected_group:
self.assertIn(k, grouped_dict)

for vv in v:
self.assertIn(vv, grouped_dict[k])

def test_reduceByKey(self):
# This will fail if the values of the RDD need to be compared
class IncomparableValueAddable(object):
def __init__(self, value):
self.value = value

def __eq__(self, other):
return self.value == other.value

def __add__(self, other):
return self.__class__(self.value + other.value)

def __lt__(self, other):
raise NotImplementedError("This object cannot be compared")

keys = (0, 1, 2, 0, 1, 2)
r = [IncomparableValueAddable(i) for i in range(len(keys))]

k_rdd = self.context.parallelize(zip(keys, r))
actual_group = k_rdd.reduceByKey(add).collect()

expected_group = ((0, IncomparableValueAddable(3)),
(1, IncomparableValueAddable(5)),
(2, IncomparableValueAddable(7)))

grouped_dict = {k: v for k, v in actual_group}

# Keep this order-agnostic
for k, v in expected_group:
self.assertEqual(grouped_dict[k], v)


if __name__ == "__main__":
unittest.main()

0 comments on commit d9da014

Please sign in to comment.