diff --git a/src/pykeen/datasets/base.py b/src/pykeen/datasets/base.py index 3e3e9174d9..8749cabf96 100644 --- a/src/pykeen/datasets/base.py +++ b/src/pykeen/datasets/base.py @@ -20,7 +20,7 @@ from ..constants import PYKEEN_DATASETS from ..triples import TriplesFactory -from ..typing import RandomHint +from ..typing import TorchRandomHint from ..utils import normalize_string __all__ = [ @@ -520,7 +520,7 @@ def __init__( eager: bool = False, create_inverse_triples: bool = False, delimiter: Optional[str] = None, - random_state: RandomHint = None, + random_state: TorchRandomHint = None, randomize_cleanup: bool = False, ): """Initialize dataset. @@ -596,7 +596,7 @@ def __init__( cache_root: Optional[str] = None, eager: bool = False, create_inverse_triples: bool = False, - random_state: RandomHint = None, + random_state: TorchRandomHint = None, ): """Initialize dataset. @@ -654,7 +654,7 @@ def __init__( cache_root: Optional[str] = None, eager: bool = False, create_inverse_triples: bool = False, - random_state: RandomHint = None, + random_state: TorchRandomHint = None, read_csv_kwargs: Optional[Mapping[str, Any]] = None, ): """Initialize dataset. diff --git a/src/pykeen/datasets/ckg.py b/src/pykeen/datasets/ckg.py index 51a294583d..d1c172c07a 100644 --- a/src/pykeen/datasets/ckg.py +++ b/src/pykeen/datasets/ckg.py @@ -11,7 +11,7 @@ import pandas as pd from .base import TabbedDataset -from ..typing import RandomHint +from ..typing import TorchRandomHint __all__ = [ 'CKG', @@ -34,7 +34,7 @@ def __init__( self, eager: bool = False, create_inverse_triples: bool = False, - random_state: RandomHint = 0, + random_state: TorchRandomHint = 0, cache_root: Optional[str] = None, ): super().__init__( diff --git a/src/pykeen/datasets/conceptnet.py b/src/pykeen/datasets/conceptnet.py index bb397714fa..3e4e4967d0 100644 --- a/src/pykeen/datasets/conceptnet.py +++ b/src/pykeen/datasets/conceptnet.py @@ -9,7 +9,7 @@ from more_click import verbose_option from .base import SingleTabbedDataset -from ..typing import RandomHint +from ..typing import TorchRandomHint URL = 'https://s3.amazonaws.com/conceptnet/downloads/2019/edges/conceptnet-assertions-5.7.0.csv.gz' @@ -28,7 +28,7 @@ class ConceptNet(SingleTabbedDataset): def __init__( self, create_inverse_triples: bool = False, - random_state: RandomHint = 0, + random_state: TorchRandomHint = 0, **kwargs, ): super().__init__( diff --git a/src/pykeen/datasets/drkg.py b/src/pykeen/datasets/drkg.py index 3daf141680..2d46421314 100644 --- a/src/pykeen/datasets/drkg.py +++ b/src/pykeen/datasets/drkg.py @@ -8,7 +8,7 @@ import logging from .base import TarFileSingleDataset -from ..typing import RandomHint +from ..typing import TorchRandomHint __all__ = [ 'DRKG', @@ -29,7 +29,7 @@ class DRKG(TarFileSingleDataset): def __init__( self, create_inverse_triples: bool = False, - random_state: RandomHint = 0, + random_state: TorchRandomHint = 0, **kwargs, ): super().__init__( diff --git a/src/pykeen/datasets/hetionet.py b/src/pykeen/datasets/hetionet.py index 141ce24b56..57284fdc0b 100644 --- a/src/pykeen/datasets/hetionet.py +++ b/src/pykeen/datasets/hetionet.py @@ -10,7 +10,7 @@ import click from .base import SingleTabbedDataset -from ..typing import RandomHint +from ..typing import TorchRandomHint __all__ = [ 'Hetionet', @@ -40,7 +40,7 @@ def __init__( self, create_inverse_triples: bool = False, eager: bool = False, - random_state: RandomHint = 0, + random_state: TorchRandomHint = 0, ): super().__init__( url=URL, diff --git a/src/pykeen/typing.py b/src/pykeen/typing.py index 7043736aed..56ae0b9df9 100644 --- a/src/pykeen/typing.py +++ b/src/pykeen/typing.py @@ -17,7 +17,6 @@ 'Constrainer', 'InteractionFunction', 'DeviceHint', - 'RandomHint', 'TorchRandomHint', ] @@ -34,5 +33,4 @@ Constrainer = Callable[[TensorType], TensorType] DeviceHint = Union[None, str, torch.device] -RandomHint = Union[None, int, np.random.RandomState] TorchRandomHint = Union[None, int, torch.Generator] diff --git a/src/pykeen/utils.py b/src/pykeen/utils.py index 255a8f53fd..ddae07fdb9 100644 --- a/src/pykeen/utils.py +++ b/src/pykeen/utils.py @@ -21,7 +21,7 @@ import torch.nn.modules.batchnorm from .constants import PYKEEN_BENCHMARKS -from .typing import DeviceHint, RandomHint, TorchRandomHint +from .typing import DeviceHint, TorchRandomHint from .version import get_git_hash __all__ = [ @@ -48,7 +48,6 @@ 'NoRandomSeedNecessary', 'Result', 'fix_dataclass_init_docs', - 'ensure_random_state', 'get_benchmark', ] @@ -406,18 +405,6 @@ def random_non_negative_int() -> int: return int(sq.generate_state(1)[0]) -def ensure_random_state(random_state: RandomHint) -> np.random.RandomState: - """Prepare a random state.""" - if random_state is None: - random_state = random_non_negative_int() - logger.warning(f'using automatically assigned random_state={random_state}') - if isinstance(random_state, int): - random_state = np.random.RandomState(random_state) - if not isinstance(random_state, np.random.RandomState): - raise TypeError - return random_state - - def ensure_torch_random_state(random_state: TorchRandomHint) -> torch.Generator: """Prepare a random state for PyTorch.""" if random_state is None: