Skip to content

Commit

Permalink
More EpochResultStore refactors! 🎉 (#5522)
Browse files Browse the repository at this point in the history
Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 11, 2021
1 parent 253e57c commit 9f12ca0
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 226 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import torch

Expand Down Expand Up @@ -128,50 +128,28 @@ def get_epoch_log_metrics(self, *_, **__) -> List[Dict]:
def get_forked_metrics(self, *_, **__) -> List[Dict]:
return self.get_epoch_from_func_name("get_forked_metrics")

@staticmethod
def _append_to_structure(primary_dict, opt_idx, batch_idx, result) -> None:
primary_dict.setdefault(opt_idx, {})
primary_dict[opt_idx].setdefault(batch_idx, [])
primary_dict[opt_idx][batch_idx].append(result)
def append(self, result: Result, info: Dict) -> None:
dataloader_idx = info["dataloader_idx"]
self._internal_type = info["type"]
opt_idx = info["opt_idx"]

def append(self, result, dataloader_idx: Optional[int] = None, extra_info: Optional[dict] = None) -> None:
if not isinstance(result, Result):
raise TypeError(f'{result} must be Result')

if dataloader_idx is None:
dataloader_idx = 0

if extra_info is None:
extra_info = {}

# [dataloader_idx][optimizer_idx][training_step_idx] is a list
if len(extra_info) > 0:
self._internal_type = ResultStoreType.INSIDE_BATCH_TRAIN_LOOP
# initialize dictionary
if self._internal_type == ResultStoreType.INSIDE_BATCH_TRAIN_LOOP:
if dataloader_idx not in self._internals:
self._internals[dataloader_idx] = {}
self._internals_reduced[dataloader_idx] = defaultdict(dict)
self._latest_ref[dataloader_idx] = {}
self._internals.setdefault(dataloader_idx, {})

# extract infos
opt_idx = extra_info["opt_idx"]
batch_idx = extra_info["batch_idx"]

self._append_to_structure(self._internals[dataloader_idx], opt_idx, batch_idx, result)

self._latest_ref[dataloader_idx][opt_idx] = result

# [dataloader_idx] is a list
batch_idx = info["batch_idx"]
self._internals[dataloader_idx].setdefault(opt_idx, {})
self._internals[dataloader_idx][opt_idx].setdefault(batch_idx, [])
self._internals[dataloader_idx][opt_idx][batch_idx].append(result)
else:
self._internal_type = ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
self._internals.setdefault(dataloader_idx, [])
self._internals[dataloader_idx].append(result)
self._latest_ref.setdefault(dataloader_idx, {})

if dataloader_idx not in self._latest_ref:
self._latest_ref[dataloader_idx] = {}
self._latest_ref[dataloader_idx][0] = {}

self._latest_ref[dataloader_idx][0] = result
self._latest_ref[dataloader_idx].setdefault(opt_idx, {})
self._latest_ref[dataloader_idx][opt_idx] = result

def auto_reduce_results_on_epoch_end(self) -> None:
"""
Expand All @@ -188,36 +166,32 @@ def auto_reduce_results_on_epoch_end(self) -> None:
for opt_idx in list(epoch_metrics):
# TODO: Figure out to reduce memory
# TODO: How to start training in middle of epoch
opt_outputs = epoch_metrics[opt_idx]

outputs = epoch_metrics[opt_idx]
# reduce across time first
time_reduced_outputs = []
for batch_idx in opt_outputs.keys():
tbptt_outs = opt_outputs[batch_idx]
tbptt_outs = tbptt_outs[0].__class__.reduce_across_time(tbptt_outs)
if len(tbptt_outs) > 1:
time_reduced_outputs.append(tbptt_outs)
for tbptt_outputs in outputs.values():
tbptt_outputs = type(tbptt_outputs[0]).reduce_across_time(tbptt_outputs)
if len(tbptt_outputs) > 1:
time_reduced_outputs.append(tbptt_outputs)

if len(time_reduced_outputs) == 0:
continue

# reduce across training steps
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)
outputs = type(time_reduced_outputs[0]).reduce_on_epoch_end(time_reduced_outputs)

# with manual opt need 1 + metrics because meta is always there
if opt_outputs.minimize is not None:
opt_outputs.minimize = opt_outputs.minimize.mean()
if outputs.minimize is not None:
outputs.minimize = outputs.minimize.mean()

self._internals_reduced[dl_idx][opt_idx] = opt_outputs
self._internals_reduced[dl_idx][opt_idx] = outputs

# free memory
del self._internals[dl_idx][opt_idx]
else:
# no need to reduce as called only once
if len(epoch_metrics) == 1:
reduced_epoch_metrics = epoch_metrics[0]
else:
reduced_epoch_metrics = epoch_metrics[0].__class__.reduce_on_epoch_end(epoch_metrics)
reduced_epoch_metrics = epoch_metrics[0]
if len(epoch_metrics) != 1:
reduced_epoch_metrics = type(reduced_epoch_metrics).reduce_on_epoch_end(epoch_metrics)

self._internals_reduced[dl_idx] = reduced_epoch_metrics

Expand Down Expand Up @@ -257,18 +231,22 @@ def __getitem__(self, key: str) -> Any:
return self._internals.get(key, None)

@property
def has_split_and_opt_idx(self):
"""
This function informs if we are running within training batch loop
"""
return self._split_idx is not None and self._opt_idx is not None

@property
def extra_info(self):
def info(self):
"""
This function provides necessary parameters to properly configure HookResultStore obj
"""
return {"batch_idx": self.trainer.batch_idx, "split_idx": self._split_idx, "opt_idx": self._opt_idx}
model_ref = self.trainer.get_model()
return {
"batch_idx": self.trainer.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
"dataloader_idx": model_ref._current_dataloader_idx or 0,
"opt_idx": self._opt_idx or 0,
"split_idx": self._split_idx or 0,
"type": (
ResultStoreType.INSIDE_BATCH_TRAIN_LOOP if self._opt_idx is not None and self._split_idx is not None
else ResultStoreType.OUTSIDE_BATCH_TRAIN_LOOP
)
}

def reset_model(self):
"""
Expand All @@ -279,17 +257,6 @@ def reset_model(self):
model_ref._current_hook_fx_name = None
model_ref._current_fx_name = ''

def current_model_info(self):
"""
This function is used to extract
information related to current function scoping `self.log` call.
"""
model_ref = self.trainer.get_model()
# extract hook information
fx_name = model_ref._current_hook_fx_name or model_ref._current_fx_name
dataloader_idx = model_ref._current_dataloader_idx
return fx_name, dataloader_idx

def cache_result(self) -> None:
"""
This function is called after every hook
Expand All @@ -306,13 +273,11 @@ def cache_result(self) -> None:
model_ref._current_fx_name = ''
return

# extract model information
fx_name, dataloader_idx = self.current_model_info()
info = self.info
fx_name = info["fx_name"]

self._internals.setdefault(fx_name, HookResultStore(fx_name))

extra_info = self.extra_info if self.has_split_and_opt_idx else {}

# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)

Expand All @@ -322,16 +287,15 @@ def cache_result(self) -> None:
elif self.trainer._distrib_type == DistributedType.DP:
hook_result.to(torch.device("cuda", self.trainer.root_gpu))

self._internals[fx_name].append(hook_result, dataloader_idx=dataloader_idx, extra_info=extra_info)
self._internals[fx_name].append(hook_result, info)

# update logged_metrics, progress_bar_metrics, callback_metrics

if "epoch_end" in fx_name:
self.update_logger_connector()

self.reset_model()

def update_logger_connector(self) -> None:
def update_logger_connector(self) -> Tuple[Dict, Dict]:
"""
This function is called every time we capture a hook
It automatically updates the logger_connector followings:
Expand Down Expand Up @@ -483,24 +447,24 @@ def __call__(
Example::
result: Result = self(fx_name="training_step", dl_idx="0", opt_idx="0", reduced=True)
result: Result = self(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)
result['train_loss_epoch'] # aggregated train_loss over one epoch.
Args:
fx_name: Hook name from ModelHooks or Callback. Example: `training_step`
fx_name: Hook name from ModelHooks or Callback. Example: ``"training_step"``
dl_idx: Dataloader idx in short. It starts from 0 to num_dataloaders - 1
dl_idx: Dataloader index in short. From ``0`` to ``num_dataloaders - 1``
opt_idx: Optimizer idx in short. It starts from 0 to num_optimizers - 1
opt_idx: Optimizer index in short. From ``0`` to ``num_optimizers - 1``
batch_idx: Index of batch idx seen during batch training or evaluation.
Works only with reduced=False
batch_idx: Batch index seen during batch training or evaluation.
Works only with ``reduced=False``
split_idx: Index of split idx in training loop when ttbt is used.
reduced: Data are being aggregated on on_epoch_end.
Indicates if we want to access aggregated Result or not.
Indicates if we want to access the aggregated Result or not.
"""
hook_result = self[fx_name]
internal_type = hook_result._internal_type
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def call_hook(self, hook_name, *args, **kwargs):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)

# if the PL module doesn't have the hook then call the accelator
# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
Expand Down
121 changes: 0 additions & 121 deletions tests/trainer/dynamic_args/test_multiple_optimizers.py

This file was deleted.

0 comments on commit 9f12ca0

Please sign in to comment.