Skip to content

Commit

Permalink
Enable use of amo in pipeline and CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Nov 26, 2020
1 parent 710279d commit b973e7c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/pykeen/models/cli/builders.py
Expand Up @@ -101,6 +101,7 @@ def _decorate_model_kwargs(command: click.Command) -> click.Command:
@options.optimizer_option
@regularizer_option
@options.training_loop_option
@options.automatic_memory_optimization_option
@options.number_epochs_option
@options.batch_size_option
@options.learning_rate_option
Expand Down Expand Up @@ -128,6 +129,7 @@ def main(
mlflow_tracking_uri,
title,
dataset,
automatic_memory_optimization,
training_triples_factory,
testing_triples_factory,
validation_triples_factory,
Expand Down Expand Up @@ -180,6 +182,7 @@ def main(
title=title,
),
random_seed=random_seed,
automatic_memory_optimization=automatic_memory_optimization,
)

if not silent:
Expand Down
5 changes: 5 additions & 0 deletions src/pykeen/models/cli/options.py
Expand Up @@ -137,6 +137,11 @@ def triples_factory_callback(_, __, path: Optional[str]) -> Optional[TriplesFact
default=_get_default(get_training_loop_cls, suffix=_TRAINING_LOOP_SUFFIX),
show_default=True,
)
automatic_memory_optimization_option = click.option(
'--automatic-memory-optimization/--no-automatic-memory-optimization',
default=True,
show_default=True,
)
stopper_option = click.option(
'--stopper',
type=click.Choice(list(stoppers)),
Expand Down
14 changes: 11 additions & 3 deletions src/pykeen/pipeline.py
Expand Up @@ -738,6 +738,7 @@ def pipeline( # noqa: C901
result_tracker: Union[None, str, Type[ResultTracker]] = None,
result_tracker_kwargs: Optional[Mapping[str, Any]] = None,
# Misc
automatic_memory_optimization: bool = True,
metadata: Optional[Dict[str, Any]] = None,
device: Union[None, str, torch.device] = None,
random_seed: Optional[int] = None,
Expand Down Expand Up @@ -915,6 +916,7 @@ def pipeline( # noqa: C901
training_loop_instance: TrainingLoop = training_loop(
model=model_instance,
optimizer=optimizer_instance,
automatic_memory_optimization=automatic_memory_optimization,
)
elif training_loop is not SLCWATrainingLoop:
raise ValueError('Can not specify negative sampler with LCWA')
Expand All @@ -927,14 +929,20 @@ def pipeline( # noqa: C901
training_loop_instance: TrainingLoop = SLCWATrainingLoop(
model=model_instance,
optimizer=optimizer_instance,
automatic_memory_optimization=automatic_memory_optimization,
negative_sampler_cls=negative_sampler,
negative_sampler_kwargs=negative_sampler_kwargs,
)

evaluator = get_evaluator_cls(evaluator)
evaluator_instance: Evaluator = evaluator(
**(evaluator_kwargs or {}),
)
# TODO @mehdi is setting the automatic memory optimization as an attribute
# of the class appropriate, since it doesn't cause any state to be stored?
# I think it might be better to have this as an argument to the
# Evaluator.evaluate() function instead
if evaluation_kwargs is None:
evaluator_kwargs = {}
evaluator_kwargs.setdefault('automatic_memory_optimization', automatic_memory_optimization)
evaluator_instance: Evaluator = evaluator(**evaluator_kwargs)

if evaluation_kwargs is None:
evaluation_kwargs = {}
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models.py
Expand Up @@ -327,6 +327,16 @@ def cli_extras(self):
for k, v in kwargs.items():
extras.append('--' + k.replace('_', '-'))
extras.append(str(v))

# For the high/low memory test cases of NTN, SE, etc.
if self.training_loop_kwargs and 'automatic_memory_optimization' in self.training_loop_kwargs:
automatic_memory_optimization = self.training_loop_kwargs.get('automatic_memory_optimization')
if automatic_memory_optimization is True:
extras.append('--automatic-memory-optimization')
elif automatic_memory_optimization is False:
extras.append('--no-automatic-memory-optimization')
# else, leave to default

extras += [
'--number-epochs', self.train_num_epochs,
'--embedding-dim', self.embedding_dim,
Expand Down

0 comments on commit b973e7c

Please sign in to comment.