Skip to content

Commit

Permalink
PT free dataset resources in main proc for DataLoader mp
Browse files Browse the repository at this point in the history
Fix #1443
  • Loading branch information
albertz committed Jan 18, 2024
1 parent 61d54b2 commit ca90f08
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
self.rnd_seq_drop = Random(self._get_random_seed_for_epoch(epoch=epoch))
return False

def finish_epoch(self):
def finish_epoch(self, *, free_resources: bool = False):
"""
This would get called at the end of the epoch (currently optional only).
After this, further calls to :func:`get_data` or :func:`load_seqs` are invalid,
Expand Down
6 changes: 3 additions & 3 deletions returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,14 @@ def get_total_num_seqs(self) -> int:
"""
return self.num_total_seqs

def finish_epoch(self):
def finish_epoch(self, *, free_resources: bool = False):
"""
This would get called at the end of the epoch.
"""
super(MetaDataset, self).finish_epoch()
super(MetaDataset, self).finish_epoch(free_resources=free_resources)
for _, dataset in self.datasets.items():
assert isinstance(dataset, Dataset)
dataset.finish_epoch()
dataset.finish_epoch(free_resources=free_resources)

def _load_seqs(self, start, end):
"""
Expand Down
16 changes: 16 additions & 0 deletions returnn/datasets/multi_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations
from typing import Optional, Any, Dict, List
import sys
import gc
from .basic import init_dataset, Dataset, DatasetSeq
from .cached2 import CachedDataset2
from returnn.util.basic import try_run
Expand Down Expand Up @@ -274,6 +275,15 @@ def _get(seq_idx: int) -> Optional[DatasetSeq]:
got_init_seq_order = True
next_seq_idx = 0
cache[:] = []
elif msg == "finish_epoch":
got_init_seq_order = False
next_seq_idx = 0
cache[:] = []
if dataset:
dataset.finish_epoch(**kwargs)
if kwargs["free_resources"]:
dataset = None
gc.collect()
else:
raise Exception(f"unknown msg {msg!r}")
except KeyboardInterrupt: # when parent dies
Expand Down Expand Up @@ -329,6 +339,12 @@ def get_total_num_seqs(self) -> int:
return self._total_num_seqs
raise NotImplementedError

def finish_epoch(self, *, free_resources: bool = False):
"""finish epoch"""
super().finish_epoch(free_resources=free_resources)
for worker_parent_conn in self._worker_parent_conns:
worker_parent_conn.send(("finish_epoch", {"free_resources": free_resources}))


class _SetupProcPreInit:
def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions returnn/datasets/sprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _sprintConfigStr(self):
def _partitionEpoch(self):
return self.partition_epoch

def finish_epoch(self):
def finish_epoch(self, *, free_resources: bool = False):
"""
Called at the end of the epoch.
"""
Expand All @@ -724,7 +724,7 @@ def finish_epoch(self):
super(ExternSprintDataset, self).init_seq_order(epoch=None, seq_list=None)
# Exit child, before we overwrite anything, such as new epoch or seq_list.
self._exit_child(wait_thread=True)
super(ExternSprintDataset, self).finish_epoch()
super(ExternSprintDataset, self).finish_epoch(free_resources=free_resources)

def _exit_handler(self):
"""
Expand Down
7 changes: 7 additions & 0 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,13 @@ def _create_data_loader(self, dataset: Dataset) -> DataLoader:

loader_opts = self.config.typed_value("torch_dataloader_opts") or {}
assert isinstance(loader_opts, dict), f"config torch_dataloader_opts, expected dict, got {type(loader_opts)}"

if loader_opts.get("num_workers"):
# We are not using the dataset anymore here in the main proc,
# so free all resources as much as we can.
# https://github.com/rwth-i6/returnn/issues/1443
dataset.finish_epoch(free_resources=True)

return data_pipeline.create_data_loader_from_batches(batches_dataset, loader_opts)

def _run_step(
Expand Down

0 comments on commit ca90f08

Please sign in to comment.