diff --git a/gae/preprocessing.py b/gae/preprocessing.py index 9738da1..e81ca79 100644 --- a/gae/preprocessing.py +++ b/gae/preprocessing.py @@ -57,8 +57,7 @@ def mask_test_edges(adj): def ismember(a, b, tol=5): rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1) - return (np.all(np.any(rows_close, axis=-1), axis=-1) and - np.all(np.any(rows_close, axis=0), axis=0)) + return np.any(rows_close) test_edges_false = [] while len(test_edges_false) < len(test_edges): @@ -110,5 +109,3 @@ def ismember(a, b, tol=5): # NOTE: these edge lists only contain single direction of edge! return adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false - -