Skip to content

Commit

Permalink
history: Fix regressions in __getitem__ (#776)
Browse files Browse the repository at this point in the history
* Index epochs first so IndexError is always raised for bad epoch
  indices, and so jagged batch counts don't raise a false IndexError.

* Only filter epochs with non-matching keys once, and do it after
  indexing epochs to avoid returning values from the wrong epoch(s).

* Never compare arbitrary values to empty lists or tuples - instead,
  generate _none if no batches matched the keys to exclude the epoch.

* Don't raise KeyError if all batches are requested but there are no
  epochs - an empty list is probably more useful.

* Add new history tests
  • Loading branch information
cebtenzzre committed May 26, 2021
1 parent 1d37ef3 commit 54796f1
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 61 deletions.
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed a few bugs in the `net.history` implementation (#776)

## [0.10.0] - 2021-03-23

### Added
Expand Down
102 changes: 42 additions & 60 deletions skorch/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,9 @@ def _not_none(items):
return all(item is not _none for item in items)


def _filter_none(items):
"""Filter special placeholder value, preserves sequence type."""
type_ = list if isinstance(items, list) else tuple
return type_(filter(_not_none, items))


def _getitem_list_list(items, keys):
def _getitem_list_list(items, keys, tuple_=False):
"""Ugly but efficient extraction of multiple values from a list of
items.
Keys are contained in a list.
"""
filtered = []
for item in items:
Expand All @@ -40,30 +31,14 @@ def _getitem_list_list(items, keys):
break
else: # no break
if row:
filtered.append(row)
filtered.append(tuple(row) if tuple_ else row)
if items and not filtered:
return _none
return filtered


def _getitem_list_tuple(items, keys):
"""Ugly but efficient extraction of multiple values from a list of
items.
Keys are contained in a tuple.
"""
filtered = []
for item in items:
row = ()
do_append = True
for key in keys:
try:
row += (item[key],)
except KeyError:
do_append = False
break
if row and do_append:
filtered.append(row)
return filtered
return _getitem_list_list(items, keys, tuple_=True)


def _getitem_list_str(items, key):
Expand All @@ -73,6 +48,8 @@ def _getitem_list_str(items, key):
filtered.append(item[key])
except KeyError:
continue
if items and not filtered:
return _none
return filtered


Expand Down Expand Up @@ -274,48 +251,53 @@ def __getitem__(self, i):
# i_e: index epoch, k_e: key epoch
# i_b: index batch, k_b: key batch
i_e, k_e, i_b, k_b = _unpack_index(i)
keyerror_msg = "Key '{}' was not found in history."
keyerror_msg = "Key {!r} was not found in history."

if i_b is not None and k_e != 'batches':
raise KeyError("History indexing beyond the 2nd level is "
"only possible if key 'batches' is used, "
"found key '{}'.".format(k_e))
"found key {!r}.".format(k_e))

items = self.to_list()

# extract the epochs
# handles: history[i_e]
if i_e is not None:
items = items[i_e]
if isinstance(i_e, int):
items = [items]

# extract indices of batches
# handles: history[..., k_e, i_b]
if i_b is not None:
items = [row[k_e][i_b] for row in items]

# extract keys of batches
# handles: history[..., k_e, i_b][k_b]
if items and (k_b is not None):
extract = _get_getitem_method(items[0], k_b)
items = [extract(batches, k_b) for batches in items]
items = [b for b in items if b not in (_none, [], ())]
if not _filter_none(items):
# all rows contained _none or were empty
raise KeyError(keyerror_msg.format(k_b))

# extract epoch-level values, but only if not already done
# extract keys of epochs or batches
# handles: history[..., k_e]
if (k_e is not None) and (i_b is None):
if not items:
raise KeyError(keyerror_msg.format(k_e))

extract = _get_getitem_method(items[0], k_e)
items = [extract(item, k_e) for item in items]
if not _filter_none(items):
raise KeyError(keyerror_msg.format(k_e))

# extract the epochs
# handles: history[i_b, ..., ..., ...]
if i_e is not None:
items = items[i_e]
if isinstance(i_e, slice):
items = _filter_none(items)
if items is _none:
raise KeyError(keyerror_msg.format(k_e))
# handles: history[..., ..., ..., k_b]
if k_e is not None and (i_b is None or k_b is not None):
key = k_e if k_b is None else k_b

if items:
extract = _get_getitem_method(items[0], key)
items = [extract(item, key) for item in items]

# filter out epochs with missing keys
items = list(filter(_not_none, items))

if not items and not (k_e == 'batches' and i_b is None):
# none of the epochs matched
raise KeyError(keyerror_msg.format(key))

if (
isinstance(i_b, slice)
and k_b is not None
and not any(batches for batches in items)
):
# none of the batches matched
raise KeyError(keyerror_msg.format(key))

if isinstance(i_e, int):
items, = items

return items
108 changes: 107 additions & 1 deletion skorch/tests/test_history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for history.py."""

import numpy as np
import pytest

from skorch.history import History
Expand Down Expand Up @@ -189,7 +190,6 @@ def test_history_with_invalid_epoch_key(self, history):
expected = "Key 'not-batches' was not found in history."
assert msg == expected


def test_history_too_many_indices(self, history):
with pytest.raises(KeyError) as exc:
# pylint: disable=pointless-statement
Expand Down Expand Up @@ -218,3 +218,109 @@ def test_history_save_load_cycle_file_path(self, history, tmpdir):
new_history = History.from_file(str(history_f))

assert history == new_history

@pytest.mark.parametrize('type_', [list, tuple])
def test_history_multiple_keys(self, history, type_):
dur_loss = history[-1, type_(['duration', 'total_loss'])]
# pylint: disable=unidiomatic-typecheck
assert type(dur_loss) is type_ and len(dur_loss) == 2

loss_loss = history[-1, 'batches', -1, type_(['loss', 'loss'])]
# pylint: disable=unidiomatic-typecheck
assert type(loss_loss) is type_ and len(loss_loss) == 2

def test_history_key_in_other_epoch(self):
h = History()
for has_valid in (True, False):
h.new_epoch()
h.new_batch()
h.record_batch('train_loss', 1)
if has_valid:
h.new_batch()
h.record_batch('valid_loss', 2)

with pytest.raises(KeyError):
# pylint: disable=pointless-statement
h[-1, 'batches', -1, 'valid_loss']

def test_history_no_epochs_index(self):
h = History()
with pytest.raises(IndexError):
# pylint: disable=pointless-statement
h[-1, 'batches']

def test_history_jagged_batches(self):
h = History()
for num_batch in (1, 2):
h.new_epoch()
for _ in range(num_batch):
h.new_batch()
# Make sure we can access this batch
assert h[-1, 'batches', 1] == {}

@pytest.mark.parametrize('value, check_warn', [
([], False),
(np.array([]), True),
])
def test_history_retrieve_empty_list(self, value, check_warn, recwarn):
h = History()
h.new_epoch()
h.record('foo', value)
h.new_batch()
h.record_batch('batch_foo', value)

# Make sure we can access our object
assert h[-1, 'foo'] is value
assert h[-1, 'batches', -1, 'batch_foo'] is value

# There should be no warning about comparison to an empty ndarray
if check_warn:
assert not recwarn.list

@pytest.mark.parametrize('has_epoch, epoch_slice', [
(False, slice(None)),
(True, slice(1, None)),
])
def test_history_no_epochs_key(self, has_epoch, epoch_slice):
h = History()
if has_epoch:
h.new_epoch()

# Expect KeyError since the key was not found in any epochs
with pytest.raises(KeyError):
# pylint: disable=pointless-statement
h[epoch_slice, 'foo']
with pytest.raises(KeyError):
# pylint: disable=pointless-statement
h[epoch_slice, ['foo', 'bar']]

@pytest.mark.parametrize('has_batch, batch_slice', [
(False, slice(None)),
(True, slice(1, None)),
])
def test_history_no_batches_key(self, has_batch, batch_slice):
h = History()
h.new_epoch()
if has_batch:
h.new_batch()

# Expect KeyError since the key was not found in any batches
with pytest.raises(KeyError):
# pylint: disable=pointless-statement
h[-1, 'batches', batch_slice, 'foo']
with pytest.raises(KeyError):
# pylint: disable=pointless-statement
h[-1, 'batches', batch_slice, ['foo', 'bar']]

@pytest.mark.parametrize('has_epoch, epoch_slice', [
(False, slice(None)),
(True, slice(1, None)),
])
def test_history_no_epochs_batches(self, has_epoch, epoch_slice):
h = History()
if has_epoch:
h.new_epoch()

# Expect a list of zero epochs since 'batches' always exists
assert h[epoch_slice, 'batches'] == []
assert h[epoch_slice, 'batches', -1] == []

0 comments on commit 54796f1

Please sign in to comment.