Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

馃 馃 Expose interactions and representations via pipeline #163

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
67 changes: 57 additions & 10 deletions src/pykeen/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import logging
import math
from abc import ABC
from typing import Any, Callable, Generic, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Collection, Generic, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch import FloatTensor, nn

from . import functional as pkf
from ..typing import HeadRepresentation, RelationRepresentation, Representation, TailRepresentation
from ..utils import check_shapes
from ..utils import check_shapes, get_cls, get_subclasses, normalize_string, upgrade_to_sequence

__all__ = [
# Base Classes
Expand Down Expand Up @@ -44,12 +44,8 @@
logger = logging.getLogger(__name__)


def _upgrade_to_sequence(x: Union[FloatTensor, Sequence[FloatTensor]]) -> Sequence[FloatTensor]:
return x if isinstance(x, Sequence) else (x,)


def _ensure_tuple(*x: Union[Representation, Sequence[Representation]]) -> Sequence[Sequence[Representation]]:
return tuple(_upgrade_to_sequence(xx) for xx in x)
return tuple(upgrade_to_sequence(xx) for xx in x)


def _unpack_singletons(*xs: Tuple) -> Sequence[Tuple]:
Expand Down Expand Up @@ -278,9 +274,9 @@ def _score(
slice_dim: Optional[str] = slice_dims[0] if len(slice_dims) == 1 else None

# FIXME typing does not work well for this
h = _upgrade_to_sequence(h)
r = _upgrade_to_sequence(r)
t = _upgrade_to_sequence(t)
h = upgrade_to_sequence(h)
r = upgrade_to_sequence(r)
t = upgrade_to_sequence(t)
assert self._check_shapes(h=h, r=r, t=t, h_prefix=h_prefix, r_prefix=r_prefix, t_prefix=t_prefix)

# prepare input to generic score function: bh*, br*, bt*
Expand Down Expand Up @@ -1097,3 +1093,54 @@ def _prepare_hrt_for_functional(
t: TailRepresentation,
) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102
return dict(h=h, w_r=r[0], d_r=r[1], t=t)


_INTERACTIONS: Collection[Type[Interaction]] = set(get_subclasses(cls=Interaction)).difference({
# Abstract class
TranslationalInteraction,
})

#: A mapping of interaction names to their implementations
interactions: Mapping[str, Type[Interaction]] = {
normalize_string(cls.__name__): cls
for cls in _INTERACTIONS
}


def get_interaction_cls(query: Union[str, Type[Interaction]]) -> Type[Interaction]:
"""Look up an interaction class by name (case/punctuation insensitive) in :data:`pykeen.nn.modules`.

:param query:
The name of the interaction (case insensitive, punctuation insensitive).

:return:
The interaction class.
"""
return get_cls(
query,
base=Interaction, # type: ignore
lookup_dict=interactions,
)


def resolve_interaction(
interaction: Union[str, Type[Interaction], Interaction],
interaction_kwargs: Optional[Mapping[str, Any]],
) -> Interaction:
"""
Resolve an interaction.

:param interaction:
The interaction. Can either be an already instantiated interaction, an interaction class, or a string of the class name.
:param interaction_kwargs:
Key-word based arguments passed to the constructor. Not effective, if an already instantiated interaction is passed.

:return:
A interaction instance.
"""
# already instantiated
if isinstance(interaction, Interaction):
return interaction
if isinstance(interaction, str):
interaction = get_interaction_cls(interaction)
return interaction(**(interaction_kwargs or dict()))
106 changes: 92 additions & 14 deletions src/pykeen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
import os
import time
from dataclasses import dataclass, field
from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Set, Type, Union
from typing import Any, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Type, Union

import pandas as pd
import torch
Expand All @@ -180,19 +180,21 @@
from .datasets.base import DataSet
from .evaluation import Evaluator, MetricResults, get_evaluator_cls
from .losses import Loss, _LOSS_SUFFIX, get_loss_cls
from .models import get_model_cls
from .models import ERModel
from .models.base import Model
from .nn import EmbeddingSpecification, Interaction, RepresentationModule
from .nn.modules import resolve_interaction
from .optimizers import get_optimizer_cls
from .regularizers import Regularizer
from .sampling import NegativeSampler, get_negative_sampler_cls
from .stoppers import EarlyStopper, Stopper, get_stopper_cls
from .trackers import ResultTracker, get_result_tracker_cls
from .training import SLCWATrainingLoop, TrainingLoop, get_training_loop_cls
from .triples import TriplesFactory
from .typing import OneOrMany
from .utils import (
Result, ensure_ftp_directory, fix_dataclass_init_docs, get_json_bytes_io, get_model_io, normalize_string,
random_non_negative_int, resolve_device, set_random_seed,
)
random_non_negative_int, resolve_device, set_random_seed, )
from .version import get_git_hash, get_version

__all__ = [
Expand Down Expand Up @@ -699,6 +701,49 @@ def pipeline_from_config(
)


def _resolve_representations(
num_representations: int,
shape: Sequence[str],
representations: OneOrMany[Union[None, RepresentationModule, EmbeddingSpecification]],
dimensions: Mapping[str, int],
) -> Sequence[RepresentationModule]:
if not isinstance(representations, Sequence):
if isinstance(representations, RepresentationModule):
raise ValueError
representations = [representations] * len(shape)

dimensions = dict(**dimensions)
result = []
for symbolic_shape, representation in zip(shape, representations):
if representation is None:
representation = EmbeddingSpecification()
if isinstance(representation, RepresentationModule):
if representation.max_id < num_representations:
raise ValueError
elif representation.max_id > num_representations:
logger.warning(
f"{representation} does provide representations for more than the requested {num_representations}."
f"While this is not necessarily an error, be aware that these representations will not be trained."
)
for name, size in zip(symbolic_shape, representation.shape):
expected_dimension = dimensions.get(name)
if expected_dimension is not None and expected_dimension != size:
raise ValueError
dimensions[name] = size
elif isinstance(representation, EmbeddingSpecification):
actual_shape = tuple([dimensions[name] for name in symbolic_shape])
representation = representation.make(
num_embeddings=num_representations,
embedding_dim=None,
shape=actual_shape,
)
else:
raise AssertionError
result.append(representation)

return result


def pipeline( # noqa: C901
*,
# 1. Dataset
Expand All @@ -709,6 +754,12 @@ def pipeline( # noqa: C901
validation: Union[None, TriplesFactory, str] = None,
evaluation_entity_whitelist: Optional[Collection[str]] = None,
evaluation_relation_whitelist: Optional[Collection[str]] = None,
# Interaction
interaction: Union[str, Type[Interaction], Interaction],
interaction_kwargs: Optional[Mapping[str, Any]] = None,
# Representations
entity_representations: OneOrMany[Union[None, RepresentationModule, EmbeddingSpecification]],
relation_representations: OneOrMany[Union[None, RepresentationModule, EmbeddingSpecification]],
# 2. Model
model: Union[str, Type[Model]],
model_kwargs: Optional[Mapping[str, Any]] = None,
Expand Down Expand Up @@ -865,11 +916,40 @@ def pipeline( # noqa: C901
relations=evaluation_relation_whitelist,
)

# Resolve interaction
interaction = resolve_interaction(
interaction=interaction,
interaction_kwargs=interaction_kwargs,
)

# Given an interaction, we know which representations are needed
# interaction.entity_shape, interaction.relation_shape
entity_representations = _resolve_representations(
num_representations=training.num_entities,
shape=interaction.entity_shape,
representations=entity_representations,
dimensions=dict(d=...), # TODO
)
relation_representations = _resolve_representations(
num_representations=training.num_relations,
shape=interaction.relation_shape,
representations=relation_representations,
dimensions=dict(), # TODO
)

if model_kwargs is None:
model_kwargs = {}
model_kwargs.update(preferred_device=device)
model_kwargs.setdefault('random_seed', random_seed)

if loss is not None:
if 'loss' in model_kwargs: # FIXME
logger.warning('duplicate loss in kwargs and model_kwargs. removing from model_kwargs')
del model_kwargs['loss']
loss_cls = get_loss_cls(loss)
_loss = loss_cls(**(loss_kwargs or {}))
model_kwargs.setdefault('loss', _loss)

if regularizer is not None:
logger.warning('Specification of the regularizer from the pipeline() is currently under maitenance')
# FIXME this should never happen.
Expand All @@ -879,19 +959,17 @@ def pipeline( # noqa: C901
# regularizer_cls: Type[Regularizer] = get_regularizer_cls(regularizer)
# model_kwargs['regularizer'] = regularizer_cls(**(regularizer_kwargs or {}))

if loss is not None:
if 'loss' in model_kwargs: # FIXME
logger.warning('duplicate loss in kwargs and model_kwargs. removing from model_kwargs')
del model_kwargs['loss']
loss_cls = get_loss_cls(loss)
_loss = loss_cls(**(loss_kwargs or {}))
model_kwargs.setdefault('loss', _loss)

model = get_model_cls(model)
model_instance: Model = model(
# Compose model
# TODO: What about custom models, not subclassing from ERModel
model_instance = ERModel(
triples_factory=training,
interaction=interaction,
loss=loss,
entity_representations=entity_representations,
relation_representations=relation_representations,
**model_kwargs,
)

# Log model parameters
result_tracker.log_params(params=dict(cls=model.__name__, kwargs=model_kwargs), prefix='model')

Expand Down
4 changes: 4 additions & 0 deletions src/pykeen/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
'GaussianDistribution',
]

# utility type
X = TypeVar("X")
OneOrMany = Union[X, Sequence[X]]

LabeledTriples = np.ndarray
MappedTriples = torch.LongTensor
EntityMapping = Mapping[str, int]
Expand Down
15 changes: 14 additions & 1 deletion src/pykeen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random
from abc import ABC, abstractmethod
from io import BytesIO
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -621,3 +621,16 @@ def pop_only(elements: Iterable[X]) -> X:
def strip_dim(*x):
"""Strip the last dimension."""
return [xx.view(xx.shape[2:]) for xx in x]


def upgrade_to_sequence(x: Union[X, Sequence[X]]) -> Sequence[X]:
"""
Ensure sequence, by wrapping a non-sequence into a 1-element tuple.

:param x:
The input.

:return:
A sequence.
"""
return x if isinstance(x, Sequence) else (x,)