Skip to content

Commit

Permalink
Small fixes to SliceDict and SliceDataset (#871)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan committed Jul 15, 2022
1 parent b889de9 commit 764a1a0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 7 deletions.
15 changes: 8 additions & 7 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, **kwargs):
else:
self._len = lengths[0]

super(SliceDict, self).__init__(**kwargs)
super().__init__(**kwargs)

def __len__(self):
return self._len
Expand All @@ -66,8 +66,8 @@ def __getitem__(self, sl):
# lengths and shapes.
raise ValueError("SliceDict cannot be indexed by integers.")
if isinstance(sl, str):
return super(SliceDict, self).__getitem__(sl)
return SliceDict(**{k: v[sl] for k, v in self.items()})
return super().__getitem__(sl)
return type(self)(**{k: v[sl] for k, v in self.items()})

def __setitem__(self, key, value):
if not isinstance(key, str):
Expand All @@ -82,14 +82,14 @@ def __setitem__(self, key, value):
"Cannot set array with shape[0] != {}"
"".format(self._len))

super(SliceDict, self).__setitem__(key, value)
super().__setitem__(key, value)

def update(self, kwargs):
for key, value in kwargs.items():
self.__setitem__(key, value)

def __repr__(self):
out = super(SliceDict, self).__repr__()
out = super().__repr__()
return "SliceDict(**{})".format(out)

@property
Expand Down Expand Up @@ -234,8 +234,9 @@ def __getitem__(self, i):
Xi = self._select_item(Xn)
return self.transform(Xi)

cls = type(self)
if isinstance(i, slice):
return SliceDataset(self.dataset, idx=self.idx, indices=self.indices_[i])
return cls(self.dataset, idx=self.idx, indices=self.indices_[i])

if isinstance(i, np.ndarray):
if i.ndim != 1:
Expand All @@ -245,7 +246,7 @@ def __getitem__(self, i):
if i.dtype == np.bool:
i = np.flatnonzero(i)

return SliceDataset(self.dataset, idx=self.idx, indices=self.indices_[i])
return cls(self.dataset, idx=self.idx, indices=self.indices_[i])

def __array__(self, dtype=None):
# This method is invoked when calling np.asarray(X)
Expand Down
18 changes: 18 additions & 0 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,14 @@ def test_equals_different_keys(self, sldict_cls):
)
assert sldict0 != sldict1

def test_subclass_getitem_returns_instance_of_itself(self, sldict_cls):
class MySliceDict(sldict_cls):
pass

sldict = MySliceDict(a=np.zeros(3))
sliced = sldict[:2]
assert isinstance(sliced, MySliceDict)


class TestSliceDataset:
@pytest.fixture(scope='class', params=['numpy', 'torch'])
Expand Down Expand Up @@ -521,6 +529,16 @@ def test_slicedataset_asarray(self, slds_cls, custom_ds, n, dtype):
expected_dtype = torch_to_numpy_dtype_dict.get(expected.dtype, expected.dtype)
assert array.dtype == expected_dtype

@pytest.mark.parametrize('sl', [slice(0, 2), np.arange(3)])
def test_subclass_getitem_returns_instance_of_itself(self, slds_cls, custom_ds, sl):
class MySliceDataset(slds_cls):
pass

slds = MySliceDataset(custom_ds, idx=0)
sliced = slds[sl]

assert isinstance(sliced, MySliceDataset)


class TestPredefinedSplit():

Expand Down

0 comments on commit 764a1a0

Please sign in to comment.