Skip to content

Commit

Permalink
Merge pull request #33 from jcollado/fix-sort-ascending-param
Browse files Browse the repository at this point in the history
Fix sortBy/sortByKey ascending param
  • Loading branch information
wdm0006 committed Mar 22, 2018
2 parents b6f3a38 + fde2854 commit d66c304
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
4 changes: 2 additions & 2 deletions dummy_spark/rdd.py
Expand Up @@ -294,7 +294,7 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
:param keyfunc:
:return:
"""
data = sorted(self._jrdd, key=keyfunc, reverse=ascending)
data = sorted(self._jrdd, key=keyfunc, reverse=not ascending)
return RDD(data, self.ctx)

def sortBy(self, keyfunc, ascending=True, numPartitions=None):
Expand All @@ -305,7 +305,7 @@ def sortBy(self, keyfunc, ascending=True, numPartitions=None):
:param numPartitions:
:return:
"""
data = sorted(self._jrdd, key=keyfunc, reverse=ascending)
data = sorted(self._jrdd, key=keyfunc, reverse=not ascending)
return RDD(data, self.ctx)

def glom(self):
Expand Down
62 changes: 62 additions & 0 deletions tests/unit/test_rdd.py
Expand Up @@ -323,6 +323,68 @@ def merge_combiners(a, b):
[('A', [1, 6]), ('B', [2, 3]), ('C', [4, 5])],
)

def test_sortByKey_ascending(self):
sc = SparkContext(master='', conf=SparkConf())
rdd = (
sc.parallelize([
('e', 5),
('d', 4),
('c', 3),
('b', 2),
('a', 1),
])
.sortByKey(ascending=True)
)
self.assertListEqual(
rdd.collect(),
[
('a', 1),
('b', 2),
('c', 3),
('d', 4),
('e', 5),
],
)

def test_sortByKey_descending(self):
sc = SparkContext(master='', conf=SparkConf())
rdd = (
sc.parallelize([
('a', 1),
('b', 2),
('c', 3),
('d', 4),
('e', 5),
])
.sortByKey(ascending=False)
)
self.assertListEqual(
rdd.collect(),
[
('e', 5),
('d', 4),
('c', 3),
('b', 2),
('a', 1),
],
)

def test_sortBy_ascending(self):
sc = SparkContext(master='', conf=SparkConf())
rdd = (
sc.parallelize([5, 4, 3, 2, 1])
.sortBy(lambda x: x, ascending=True)
)
self.assertListEqual(rdd.collect(), [1, 2, 3, 4, 5])

def test_sortBy_descending(self):
sc = SparkContext(master='', conf=SparkConf())
rdd = (
sc.parallelize([1, 2, 3, 4, 5])
.sortBy(lambda x: x, ascending=False)
)
self.assertListEqual(rdd.collect(), [5, 4, 3, 2, 1])

def test_subtractByKey(self):
sc = SparkContext(master='', conf=SparkConf())
rdd1 = sc.parallelize([('A', 1), ('B', 2), ('C', 3)])
Expand Down

0 comments on commit d66c304

Please sign in to comment.