Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add drop_last option to data loader in training loop #217

Merged
merged 13 commits into from Dec 15, 2020
20 changes: 2 additions & 18 deletions src/pykeen/models/base.py
Expand Up @@ -20,7 +20,7 @@
from ..regularizers import NoRegularizer, Regularizer
from ..triples import TriplesFactory
from ..typing import Constrainer, DeviceHint, Initializer, MappedTriples, Normalizer
from ..utils import NoRandomSeedNecessary, resolve_device, set_random_seed
from ..utils import NoRandomSeedNecessary, get_batchnorm_modules, resolve_device, set_random_seed

__all__ = [
'Model',
Expand All @@ -31,13 +31,6 @@

logger = logging.getLogger(__name__)

UNSUPPORTED_FOR_SUBBATCHING = ( # must be a tuple
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.SyncBatchNorm,
)


def _extend_batch(
batch: MappedTriples,
Expand Down Expand Up @@ -321,16 +314,7 @@ def can_slice_t(self) -> bool:
@property
def modules_not_supporting_sub_batching(self) -> Collection[nn.Module]:
"""Return all modules not supporting sub-batching."""
return [
module
for module in self.modules()
if isinstance(module, UNSUPPORTED_FOR_SUBBATCHING)
]

@property
def supports_subbatching(self) -> bool: # noqa: D400, D401
"""Does this model support sub-batching?"""
return len(self.modules_not_supporting_sub_batching) == 0
return get_batchnorm_modules(module=self)

@abstractmethod
def _reset_parameters_(self): # noqa: D401
Expand Down
30 changes: 27 additions & 3 deletions src/pykeen/training/training_loop.py
Expand Up @@ -26,7 +26,10 @@
from ..training.schlichtkrull_sampler import GraphSampler
from ..triples import Instances, TriplesFactory
from ..typing import MappedTriples
from ..utils import is_cuda_oom_error, is_cudnn_error, normalize_string
from ..utils import (
format_relative_comparison, get_batchnorm_modules, is_cuda_oom_error, is_cudnn_error,
normalize_string,
)

__all__ = [
'TrainingLoop',
Expand Down Expand Up @@ -167,6 +170,7 @@ def train(
checkpoint_name: Optional[str] = None,
checkpoint_frequency: Optional[int] = None,
checkpoint_on_failure: bool = False,
drop_last: Optional[bool] = None,
) -> List[float]:
"""Train the KGE model.

Expand Down Expand Up @@ -218,6 +222,9 @@ def train(
which might cause problems with regards to the reproducibility of that specific training loop. Therefore,
these checkpoints are saved with a distinct checkpoint name, which will be
``PyKEEN_just_saved_my_day_{datetime}.pt`` in the given checkpoint_root.
:param drop_last:
Whether to drop the last batch in each epoch to prevent smaller batches. Defaults to False, except if the
model contains batch normalization layers. Can be provided explicitly to override.

:return:
The losses per epoch.
Expand Down Expand Up @@ -297,6 +304,7 @@ def train(
checkpoint_path=checkpoint_path,
checkpoint_frequency=checkpoint_frequency,
checkpoint_on_failure_file_path=checkpoint_on_failure_file_path,
drop_last=drop_last,
)

# Ensure the release of memory
Expand Down Expand Up @@ -328,6 +336,7 @@ def _train( # noqa: C901
checkpoint_path: Union[None, str, pathlib.Path] = None,
checkpoint_frequency: Optional[int] = None,
checkpoint_on_failure_file_path: Optional[str] = None,
drop_last: Optional[bool] = None,
) -> List[float]:
"""Train the KGE model.

Expand Down Expand Up @@ -370,6 +379,9 @@ def _train( # noqa: C901
The frequency of saving checkpoints in minutes. Setting it to 0 will save a checkpoint after every epoch.
:param checkpoint_on_failure_file_path:
The full filepath for saving checkpoints on failure.
:param drop_last:
Whether to drop the last batch in each epoch to prevent smaller batches. Defaults to False, except if the
model contains batch normalization layers. Can be provided explicitly to override.

:return:
The losses per epoch.
Expand Down Expand Up @@ -398,9 +410,20 @@ def _train( # noqa: C901

if sub_batch_size is None or sub_batch_size == batch_size: # by default do not split batches in sub-batches
sub_batch_size = batch_size
elif not self.model.supports_subbatching:
elif self.model.modules_not_supporting_sub_batching:
raise SubBatchingNotSupportedError(self.model)

model_contains_batch_norm = bool(get_batchnorm_modules(self.model))
if batch_size == 1 and model_contains_batch_norm:
raise ValueError("Cannot train a model with batch_size=1 containing BatchNorm layers.")
if drop_last is None:
drop_last = model_contains_batch_norm
mberr marked this conversation as resolved.
Show resolved Hide resolved
if drop_last and not only_size_probing:
logger.info(
"Dropping last (incomplete) batch each epoch (%s batches).",
format_relative_comparison(part=1, total=len(self.training_instances)),
)

# Sanity check
if self.model.is_mr_loss and label_smoothing > 0.:
raise RuntimeError('Label smoothing can not be used with margin ranking loss.')
Expand Down Expand Up @@ -459,6 +482,7 @@ def _train( # noqa: C901
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=drop_last,
)

# Save the time to track when the saved point was available
Expand Down Expand Up @@ -762,7 +786,7 @@ def _sub_batch_size_search(self, batch_size: int) -> Tuple[int, bool, bool]:

if not finished_search:
logger.info('Starting sub_batch_size search for training now...')
if not self.model.supports_subbatching:
if self.model.modules_not_supporting_sub_batching:
logger.info('This model does not support sub-batching.')
supports_sub_batching = False
sub_batch_size = batch_size
Expand Down
10 changes: 10 additions & 0 deletions src/pykeen/utils.py
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import torch
import torch.nn
import torch.nn.modules.batchnorm

from .constants import PYKEEN_BENCHMARKS
from .typing import DeviceHint, RandomHint, TorchRandomHint
Expand Down Expand Up @@ -519,3 +520,12 @@ def format_relative_comparison(
) -> str:
"""Format a relative comparison."""
return f"{part}/{total} ({part / total:2.2%})"


def get_batchnorm_modules(module: torch.nn.Module) -> List[torch.nn.Module]:
"""Return all submodules which are batch normalization layers."""
return [
submodule
for submodule in module.modules()
if isinstance(submodule, torch.nn.modules.batchnorm._BatchNorm)
]