Skip to content

Commit

Permalink
鉁煏碉笍 Fix mypy (#1387)
Browse files Browse the repository at this point in the history
* [x] Fix an issue which came up after using the most recent `mypy`
version
* [x] Re-use `OneOrSequence` type annotation consistently in this method
* [x] Re-use `upgrade_to_sequence` utility
* [x] get rid of custom PyTorch Geometric installation for tests - this
should also fix the issue with CI breaking after each torch release;
citing from the
[README](https://github.com/pyg-team/pytorch_geometric?tab=readme-ov-file#pypi)
> From PyG 2.3 onwards, you can install and use PyG without any external
library required except for PyTorch. For this, simply run
> 
> ```
> pip install torch_geometric
> ```
> ### Additional Libraries
> If you want to utilize the full set of features from PyG, there exists
several additional libraries you may want to install: [...]
  • Loading branch information
mberr committed May 4, 2024
1 parent 9b9b59f commit ac85bef
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 33 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ docs =
sphinx-autodoc-typehints<=1.13.1
sphinx_automodapi
texext
pyg =
# from 2.3 onwards, you can install this without pre-compiled dependencies
# for training, you may still want to have those, cf. https://github.com/pyg-team/pytorch_geometric?tab=readme-ov-file#pypi
torch_geometric

[options.entry_points]
console_scripts =
Expand Down
67 changes: 38 additions & 29 deletions src/pykeen/ablation/ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..training import SLCWATrainingLoop, training_loop_resolver
from ..typing import OneOrSequence
from ..utils import normalize_path, normalize_string
from ..utils import normalize_path, normalize_string, upgrade_to_sequence

__all__ = [
"ablation_pipeline",
Expand Down Expand Up @@ -238,7 +238,7 @@ def _run_ablation_experiments(
def iter_unique_ids(disable: bool = False) -> Iterable[str]:
"""Iterate unique id to append to a path."""
if disable:
return []
return
datetime = time.strftime("%Y-%m-%d-%H-%M")
yield f"{datetime}_{uuid4()}"

Expand Down Expand Up @@ -337,15 +337,15 @@ def path_to_str(x: object) -> str:

def prepare_ablation( # noqa:C901
datasets: OneOrSequence[str | SplitToPathDict],
models: Union[str, List[str]],
losses: Union[str, List[str]],
optimizers: Union[str, List[str]],
training_loops: Union[str, List[str]],
models: OneOrSequence[str],
losses: OneOrSequence[str],
optimizers: OneOrSequence[str],
training_loops: OneOrSequence[str],
directory: Union[str, pathlib.Path],
*,
create_inverse_triples: OneOrSequence[bool] = False,
regularizers: OneOrSequence[None | str] = None,
epochs: Optional[int] = None,
create_inverse_triples: Union[bool, List[bool]] = False,
regularizers: Union[None, str, List[str], List[None]] = None,
negative_sampler: Optional[str] = None,
evaluator: Optional[str] = None,
model_to_model_kwargs: Optional[Mapping2D] = None,
Expand Down Expand Up @@ -444,24 +444,33 @@ def prepare_ablation( # noqa:C901
the paths to the training, testing, and validation data.
"""
directory = normalize_path(path=directory)
if isinstance(datasets, (str, dict)):
datasets = [datasets]
if isinstance(create_inverse_triples, bool):
create_inverse_triples = [create_inverse_triples]
if isinstance(models, str):
models = [models]
if isinstance(losses, str):
losses = [losses]
if isinstance(optimizers, str):
optimizers = [optimizers]
if isinstance(training_loops, str):
training_loops = [training_loops]
if isinstance(regularizers, str):
regularizers = [regularizers]
elif regularizers is None:
regularizers = [None]

it: Iterable[tuple[str | SplitToPathDict, bool, str, str, str | None, str, str]] = itt.product(
datasets = upgrade_to_sequence(datasets)
create_inverse_triples = upgrade_to_sequence(create_inverse_triples)
models = upgrade_to_sequence(models)
losses = upgrade_to_sequence(losses)
optimizers = upgrade_to_sequence(optimizers)
training_loops = upgrade_to_sequence(training_loops)
regularizers = upgrade_to_sequence(regularizers)

# note: for some reason, mypy does not properly recognize the tuple[T1, T2, T3] notation,
# but rather uses tuple[T1 | T2 | T3, ...]
it: Iterable[
tuple[
# dataset
str | SplitToPathDict,
# create inverse triples
bool,
# models, losses
str,
str,
# regularizers
str | None,
# optimizers, training loops
str,
str,
]
]
it = itt.product( # type: ignore
datasets,
create_inverse_triples,
models,
Expand All @@ -479,7 +488,7 @@ def prepare_ablation( # noqa:C901
directories = []
for counter, (
dataset,
create_inverse_triples,
this_create_inverse_triples,
model,
loss,
regularizer,
Expand Down Expand Up @@ -538,8 +547,8 @@ def _set_arguments(config: Optional[Mapping3D], key: str, value: str) -> None:
"the paths to the training, testing, and validation data.",
)
logger.info(f"Dataset: {dataset}")
hpo_config["dataset_kwargs"] = dict(create_inverse_triples=create_inverse_triples)
logger.info(f"Add inverse triples: {create_inverse_triples}")
hpo_config["dataset_kwargs"] = dict(create_inverse_triples=this_create_inverse_triples)
logger.info(f"Add inverse triples: {this_create_inverse_triples}")

hpo_config["model"] = model
hpo_config["model_kwargs"] = model_to_model_kwargs.get(model, {})
Expand Down
5 changes: 1 addition & 4 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ envlist =
#coverage-report

[testenv]
commands =
# custom installation for PyG, cf. https://github.com/rusty1s/pytorch_scatter/pull/268
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric --find-links https://data.pyg.org/whl/torch-2.2.0+cpu.html
coverage run -p -m pytest --durations=20 {posargs:tests} -m "not slow"
# ensure we use the CPU-only version of torch
setenv =
PIP_EXTRA_INDEX_URL = https://download.pytorch.org/whl/cpu
Expand All @@ -45,6 +41,7 @@ extras =
tests
transformers
lightning
pyg
# biomedicine # pyobo is too slow without caching
allowlist_externals =
/bin/cat
Expand Down

0 comments on commit ac85bef

Please sign in to comment.