Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🥔 ✈️ Make mock model importable #691

Merged
merged 24 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/pykeen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .hpo.samplers import sampler_resolver
from .losses import loss_resolver
from .lr_schedulers import lr_scheduler_resolver
from .models import model_resolver
from .models import MockModel, model_resolver
from .models.cli import build_cli_from_cls
from .optimizers import optimizer_resolver
from .regularizers import regularizer_resolver
Expand All @@ -43,6 +43,7 @@
from .version import env_table

HERE = Path(__file__).resolve().parent
SKIP_MODELS = {MockModel}


@click.group()
Expand Down Expand Up @@ -84,6 +85,8 @@ def _help_models(tablefmt: str, link_fmt: Optional[str] = None):

def _get_model_lines(tablefmt: str, link_fmt: Optional[str] = None):
for _, model in sorted(model_resolver.lookup_dict.items()):
if model in SKIP_MODELS:
continue
reference = f"pykeen.models.{model.__name__}"
docdata = getattr(model, "__docdata__", None)
if docdata is not None:
Expand Down Expand Up @@ -504,7 +507,7 @@ def get_readme() -> str:
tablefmt = "github"
return readme_template.render(
models=_help_models(tablefmt, link_fmt="https://pykeen.readthedocs.io/en/latest/api/{}.html"),
n_models=len(model_resolver.lookup_dict),
n_models=len(model_resolver.lookup_dict) - len(SKIP_MODELS),
regularizers=_help_regularizers(tablefmt, link_fmt="https://pykeen.readthedocs.io/en/latest/api/{}.html"),
n_regularizers=len(regularizer_resolver.lookup_dict),
losses=_help_losses(tablefmt, link_fmt="https://pykeen.readthedocs.io/en/latest/api/{}.html"),
Expand Down
2 changes: 2 additions & 0 deletions src/pykeen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .base import EntityRelationEmbeddingModel, Model, _OldAbstractModel
from .baseline import EvaluationOnlyModel, MarginalDistributionBaseline
from .mocks import MockModel
from .multimodal import ComplExLiteral, DistMultLiteral, DistMultLiteralGated, LiteralModel
from .nbase import ERModel, _NewAbstractModel
from .resolve import make_model, make_model_cls
Expand Down Expand Up @@ -74,6 +75,7 @@
"ERMLPE",
"HolE",
"KG2E",
"MockModel",
"MuRE",
"NodePiece",
"NTN",
Expand Down
51 changes: 51 additions & 0 deletions src/pykeen/models/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-

"""Mock models that return random results.

These are useful for baselines.
"""

import torch

from .base import EntityRelationEmbeddingModel, Model
from ..nn import EmbeddingSpecification
from ..triples import CoreTriplesFactory

__all__ = [
"MockModel",
]


class MockModel(EntityRelationEmbeddingModel):
cthoyt marked this conversation as resolved.
Show resolved Hide resolved
"""A mock model returning fake scores."""

hpo_default = {}

def __init__(self, *, triples_factory: CoreTriplesFactory, embedding_dim: int = 50, **_kwargs):
super().__init__(
triples_factory=triples_factory,
entity_representations=EmbeddingSpecification(embedding_dim=embedding_dim),
relation_representations=EmbeddingSpecification(embedding_dim=embedding_dim),
)
num_entities = self.num_entities
self.scores = torch.arange(num_entities, dtype=torch.float, requires_grad=True)
self.num_backward_propagations = 0

def _generate_fake_scores(self, batch: torch.LongTensor) -> torch.FloatTensor:
"""Generate fake scores s[b, i] = i of size (batch_size, num_entities)."""
batch_size = batch.shape[0]
batch_scores = self.scores.view(1, -1).repeat(batch_size, 1)
assert batch_scores.shape == (batch_size, self.num_entities)
return batch_scores

def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self.scores[torch.randint(high=self.num_entities, size=hrt_batch.shape[:-1])]
cthoyt marked this conversation as resolved.
Show resolved Hide resolved

def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self._generate_fake_scores(batch=hr_batch)

def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self._generate_fake_scores(batch=rt_batch)

def reset_parameters_(self) -> Model: # noqa: D102
pass # Not needed for unittest
7 changes: 2 additions & 5 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,10 +890,7 @@ def test_reset_parameters_(self):

# check that the parameters where modified
num_equal_weights_after_re_init = sum(1 for np in new_params if (np.data == old_content[id(np)]).all())
assert num_equal_weights_after_re_init == self.num_constant_init, (
num_equal_weights_after_re_init,
self.num_constant_init,
)
self.assertEqual(num_equal_weights_after_re_init, self.num_constant_init)

def _check_scores(self, batch, scores) -> None:
"""Check the scores produced by a forward function."""
Expand All @@ -918,7 +915,7 @@ def test_score_hrt(self) -> None:
self.skipTest(str(e))
else:
raise e
assert scores.shape == (self.batch_size, 1)
self.assertEqual(scores.shape, (self.batch_size, 1))
self._check_scores(batch, scores)

def test_score_t(self) -> None:
Expand Down
38 changes: 1 addition & 37 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@

from pykeen.evaluation import Evaluator, MetricResults, RankBasedMetricResults
from pykeen.evaluation.rank_based_evaluator import RANK_REALISTIC, RANK_TYPES, SIDES
from pykeen.models import EntityRelationEmbeddingModel, Model
from pykeen.nn.emb import EmbeddingSpecification, RepresentationModule
from pykeen.triples import CoreTriplesFactory
from pykeen.nn.emb import RepresentationModule
from pykeen.typing import MappedTriples

__all__ = [
"CustomRepresentations",
"MockModel",
]


Expand All @@ -32,39 +29,6 @@ def forward(self, indices: Optional[torch.LongTensor] = None) -> torch.FloatTens
return self.x.unsqueeze(dim=0).repeat(n, *(1 for _ in self.shape))


class MockModel(EntityRelationEmbeddingModel):
"""A mock model returning fake scores."""

def __init__(self, *, triples_factory: CoreTriplesFactory):
super().__init__(
triples_factory=triples_factory,
entity_representations=EmbeddingSpecification(embedding_dim=50),
relation_representations=EmbeddingSpecification(embedding_dim=50),
)
num_entities = self.num_entities
self.scores = torch.arange(num_entities, dtype=torch.float, requires_grad=True)
self.num_backward_propagations = 0

def _generate_fake_scores(self, batch: torch.LongTensor) -> torch.FloatTensor:
"""Generate fake scores s[b, i] = i of size (batch_size, num_entities)."""
batch_size = batch.shape[0]
batch_scores = self.scores.view(1, -1).repeat(batch_size, 1)
assert batch_scores.shape == (batch_size, self.num_entities)
return batch_scores

def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self.scores[torch.randint(high=self.num_entities, size=hrt_batch.shape[:-1])]

def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self._generate_fake_scores(batch=hr_batch)

def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self._generate_fake_scores(batch=rt_batch)

def reset_parameters_(self) -> Model: # noqa: D102
pass # Not needed for unittest


class MockEvaluator(Evaluator):
"""A mock evaluator for testing early stopping."""

Expand Down
4 changes: 2 additions & 2 deletions tests/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from pykeen.datasets import Nations
from pykeen.evaluation import RankBasedEvaluator
from pykeen.models import Model, TransE
from pykeen.models import MockModel, Model, TransE
from pykeen.stoppers.early_stopping import EarlyStopper, is_improvement
from pykeen.trackers import MLFlowResultTracker
from pykeen.training import SLCWATrainingLoop
from tests.mocks import MockEvaluator, MockModel
from tests.mocks import MockEvaluator

try:
import mlflow
Expand Down
3 changes: 1 addition & 2 deletions tests/test_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@
resolve_metric_name,
)
from pykeen.evaluation.sklearn import SklearnEvaluator, SklearnMetricResults
from pykeen.models import Model, TransE
from pykeen.models import MockModel, Model, TransE
from pykeen.triples import TriplesFactory
from pykeen.typing import MappedTriples
from tests.mocks import MockModel

logger = logging.getLogger(__name__)

Expand Down
8 changes: 7 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
EntityRelationEmbeddingModel,
ERModel,
EvaluationOnlyModel,
MockModel,
Model,
_NewAbstractModel,
_OldAbstractModel,
Expand All @@ -31,7 +32,6 @@
from pykeen.utils import all_in_bounds, clamp_norm, extend_batch
from tests import cases
from tests.constants import EPSILON
from tests.mocks import MockModel
from tests.test_model_mode import SimpleInteractionModel

SKIP_MODULES = {
Expand All @@ -50,6 +50,12 @@
SKIP_MODULES.update(EvaluationOnlyModel.__subclasses__())


class TestMock(cases.ModelTestCase):
"""Test the mock model."""

cls = pykeen.models.MockModel


class TestCompGCN(cases.ModelTestCase):
"""Test the CompGCN model."""

Expand Down
3 changes: 1 addition & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pykeen.regularizers
from pykeen.datasets import EagerDataset, Nations
from pykeen.models import ERModel, Model
from pykeen.models import ERModel, MockModel, Model
from pykeen.models.predict import (
get_all_prediction_df,
get_head_prediction_df,
Expand All @@ -28,7 +28,6 @@
from pykeen.training import SLCWATrainingLoop
from pykeen.triples.generation import generate_triples_factory
from pykeen.utils import resolve_device
from tests.mocks import MockModel


class TestPipeline(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_training/test_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

from pykeen.datasets import Nations
from pykeen.models import Model
from pykeen.models.mocks import MockModel
from pykeen.stoppers.early_stopping import EarlyStopper
from pykeen.training import SLCWATrainingLoop
from pykeen.triples import TriplesFactory
from pykeen.typing import MappedTriples
from tests.mocks import MockEvaluator, MockModel
from tests.mocks import MockEvaluator


class DummyTrainingLoop(SLCWATrainingLoop):
Expand Down