Skip to content

Commit 19c8cf7

Browse files
authored
Fix the building reference index when ref_index is None in DISE (#264)
* Fix the return type when calculating medoid in `get_initial_selection` * Add test case for ref_index is None * Remove redundant checking of ref_index
1 parent 198e004 commit 19c8cf7

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

selector/methods/distance.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -507,8 +507,6 @@ def __init__(
507507
"""
508508
self.r0 = r0
509509
self.r = r0
510-
if ref_index is not None and ref_index < 0:
511-
raise ValueError(f"ref_index must be a non-negative integer, got {ref_index}.")
512510
self.ref_index = ref_index
513511
self.tol = tol
514512
self.n_iter = n_iter
@@ -677,7 +675,7 @@ def get_initial_selection(x=None, x_dist=None, ref_index=None, fun_dist=None):
677675
if x_dist is None:
678676
x_dist = fun_dist(x)
679677
# calculate the medoid center
680-
initial_selections = [np.argmin(np.sum(x_dist, axis=0))]
678+
initial_selections = [int(np.argmin(np.sum(x_dist, axis=0)))]
681679

682680
# the length of the distance matrix is the number of samples
683681
if x_dist is not None:

selector/methods/tests/test_distance.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,16 @@ def test_directed_sphere_same_number_of_pts():
301301
assert_equal(collector.r, 1)
302302

303303

304+
def test_directed_sphere_same_number_of_pts_None():
305+
"""Test DirectSphereExclusion with `size` = number of points in dataset with the ref_index None."""
306+
# None as the reference point
307+
x = np.array([[0, 0], [0, 1], [0, 2], [0, 3], [0, 4]])
308+
collector = DISE(r0=1, tol=0, ref_index=None)
309+
selected = collector.select(x, size=3)
310+
assert_equal(selected, [2, 0, 4])
311+
assert_equal(collector.r, 1)
312+
313+
304314
def test_directed_sphere_exclusion_select_more_number_of_pts():
305315
"""Test DirectSphereExclusion on points on the line with `size` < number of points in dataset."""
306316
# (0,0) as the reference point

0 commit comments

Comments
 (0)