-
-
Notifications
You must be signed in to change notification settings - Fork 657
Closed
Labels
Description
The following code wont work:
import torch
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
return iter(range(self.start, self.end))
ds = MyIterableDataset(0, 1000)
data_loader = torch.utils.data.DataLoader(ds, num_workers=2)
from ignite.engine import Engine
def foo(e, b):
print("{}-{}: {}".format(e.state.epoch, e.state.iteration, b))
engine = Engine(foo)
engine.run(data_loader, epoch_length=10)
and gives the error
ValueErrorTraceback (most recent call last)
<ipython-input-19-1c8004fbf46e> in <module>
21
22 engine = Engine(foo)
---> 23 engine.run(data_loader, epoch_length=10)
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in run(self, data, max_epochs, epoch_length, seed)
848
849 self.state.dataloader = data
--> 850 return self._internal_run()
851
852 def _setup_engine(self):
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in _internal_run(self)
950 self._dataloader_iter = self._dataloader_len = None
951 self.logger.error("Engine run is terminating due to exception: %s.", str(e))
--> 952 self._handle_exception(e)
953
954 self._dataloader_iter = self._dataloader_len = None
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in _handle_exception(self, e)
714 self._fire_event(Events.EXCEPTION_RAISED, e)
715 else:
--> 716 raise e
717
718 def state_dict(self):
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in _internal_run(self)
933
934 if self._dataloader_iter is None:
--> 935 self._setup_engine()
936
937 hours, mins, secs = self._run_once_on_dataset()
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in _setup_engine(self)
873 if not isinstance(batch_sampler, ReproducibleBatchSampler):
874 self.state.dataloader = _update_dataloader(self.state.dataloader,
--> 875 ReproducibleBatchSampler(batch_sampler))
876
877 iteration = self.state.iteration
/opt/conda/lib/python3.7/site-packages/ignite/engine/engine.py in _update_dataloader(dataloader, new_batch_sampler)
963 params = {k: getattr(dataloader, k) for k in params_keys}
964 params['batch_sampler'] = new_batch_sampler
--> 965 return torch.utils.data.DataLoader(**params)
966
967
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __init__(self, dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context)
182 raise ValueError(
183 "DataLoader with IterableDataset: expected unspecified "
--> 184 "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
185 else:
186 self._dataset_kind = _DatasetKind.Map
ValueError: DataLoader with IterableDataset: expected unspecified batch_sampler option, but got batch_sampler=<ignite.engine.engine.ReproducibleBatchSampler object at 0x7f6000e3fad0>
Patching it like here can be a workaround:
engine.run(map(lambda x: x, data_loader), epoch_length=10)