Skip to content

Commit

Permalink
Allow comparing Pairs to tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
rafguns committed May 2, 2017
1 parent ae3df05 commit a7121f8
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 8 deletions.
31 changes: 26 additions & 5 deletions linkpred/evaluation/scoresheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,40 @@ def __init__(self, *args):
"__init__() takes 1 or 2 arguments in addition to self")
# For link prediction, a and b are two different nodes
assert a != b, "Predicted link (%s, %s) is a self-loop!" % (a, b)
self.elements = (a, b) if a > b else (b, a)
self.elements = self._sorted_tuple((a, b))

@staticmethod
def _sorted_tuple(t):
a, b = t
return (a, b) if a > b else (b, a)

def __eq__(self, other):
return self.elements == other.elements
try:
return self.elements == other.elements
except AttributeError:
return self.elements == self._sorted_tuple(other)

def __ne__(self, other):
return self.elements != other.elements
return not self == other

def __lt__(self, other):
return self.elements < other.elements
try:
return self.elements < other.elements
except AttributeError:

return self.elements < self._sorted_tuple(other)

def __gt__(self, other):
return self.elements > other.elements
try:
return self.elements > other.elements
except AttributeError:
return self.elements > self._sorted_tuple(other)

def __le__(self, other):
return self < other or self == other

def __ge__(self, other):
return self > other or self == other

def __getitem__(self, idx):
return self.elements[idx]
Expand Down
5 changes: 4 additions & 1 deletion linkpred/linkpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def for_comparison(G, exclude=None):
In practice this means we return it as a set of Pairs.
"""
exclude = set(Pair(u, v) for u, v in exclude) if exclude else set()
if not exclude:
return set(G.edges_iter())

exclude = set(Pair(u, v) for u, v in exclude)
return set(Pair(u, v) for u, v in G.edges_iter()) - exclude


Expand Down
5 changes: 3 additions & 2 deletions tests/test_linkpred.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ def test_for_comparison():
from linkpred.evaluation import Pair

G = nx.path_graph(10)
expected = set(Pair(x) for x in [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5),
(5, 6), (6, 7), (7, 8), (8, 9)])
expected = {(0, 1), (1, 2), (2, 3), (3, 4), (4, 5),
(5, 6), (6, 7), (7, 8), (8, 9)}
assert_equal(for_comparison(G), expected)

to_delete = [Pair(2, 3), Pair(8, 9)]
expected = {Pair(t) for t in expected}
expected = expected.difference(to_delete)
assert_equal(for_comparison(G, exclude=to_delete), expected)

Expand Down
1 change: 1 addition & 0 deletions tests/test_scoresheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def test_pair():
t = ('a', 'b')
pair = Pair(t)
assert_equal(pair, Pair(*t))
assert_equal(pair, t)
assert_equal(pair, Pair('b', 'a'))
assert_equal(pair, eval(repr(pair)))
assert_equal(u(pair), "b - a")
Expand Down

0 comments on commit a7121f8

Please sign in to comment.