diff --git a/benchmarking/interactions.py b/benchmarking/interactions.py new file mode 100644 index 0000000000..b3e8b24013 --- /dev/null +++ b/benchmarking/interactions.py @@ -0,0 +1,213 @@ +import random +from typing import Mapping, Optional, Tuple + +import click +import numpy +import pandas +import torch +from torch.utils.benchmark import Timer +from tqdm import tqdm + +from pykeen.nn import Interaction +from pykeen.nn.compute_kernel import _complex_native_complex, _complex_direct, _complex_broadcast_optimized, _complex_select, _complex_stacked, _complex_stacked_select +from pykeen.typing import HeadRepresentation, RelationRepresentation, TailRepresentation +from pykeen.utils import unpack_singletons +from pykeen.version import get_git_hash + + +def _use_case_to_shape( + use_case: str, + b: int, + n: int, + s: int, +) -> Tuple[ + Tuple[int, int], + Tuple[int, int], + Tuple[int, int], +]: + """ + Generate prefix shapes for various use cases. + + :param use_case: + The use case. + + - "hrt": score_hrt naive + - "hrt+": score_hrt fast SCLWA with tail corruption + - "h+rt": score_hrt fast SCLWA with head corruption + - "t": score_t + - "h": score_t + + :param b: + The batch size. + :param n: + The number of entities. + :param s: + The number of negative samples. + + :return: + A 3-tuple, (head_prefix, relation_prefix, tail_prefix), each a 2-tuple of integers. + """ + if use_case == "hrt": + b = b * s + return (b, 1), (b, 1), (b, 1) + elif use_case == "hrt+": + return (b, 1), (b, 1), (b, s) + elif use_case == "h+rt": + return (b, s), (b, 1), (b, 1) + elif use_case == "t": + return (b, 1), (b, 1), (1, n) + elif use_case == "h": + return (1, n), (b, 1), (b, 1) + else: + raise ValueError + + +def _resolve_shapes( + prefix_shapes: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]], + interaction: Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation], + dim: int, + additional_dims: Optional[Mapping[str, int]] = None, +) -> Tuple[Tuple[Tuple[int, ...], ...], ...]: + additional_dims = additional_dims or dict() + additional_dims.setdefault("d", dim) + return [ + tuple((*prefix_shape, *(additional_dims[s] for ss in s)) for s in suffix_shape) + for prefix_shape, suffix_shape in zip( + prefix_shapes, + ( + interaction.entity_shape, + interaction.relation_shape, + interaction.tail_entity_shape or interaction.entity_shape + ) + ) + ] + + +def _generate_hrt( + shapes: Tuple[Tuple[Tuple[int, ...], ...], ...], + device: torch.device, +) -> Tuple[HeadRepresentation, RelationRepresentation, TailRepresentation]: + return unpack_singletons(*( + [ + torch.rand(*shape, requires_grad=True, device=device) + for shape in single_shapes + ] + for single_shapes in shapes + )) + + +def _get_result_shape(prefix_shapes) -> Tuple[int, int, int, int]: + return (max(s[0] for s in prefix_shapes),) + tuple([s[1] for s in prefix_shapes]) + + +def _get_memory(interaction, shapes, device) -> int: + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.reset_peak_memory_stats() + h, r, t = _generate_hrt(shapes=shapes, device=device) + interaction(h=h, r=r, t=t) + stats = torch.cuda.memory_stats() + return stats["active_bytes.all.peak"] + + +@click.command() +@click.option('--fast/--no-fast', default=False) +@click.option('--shuffle/--no-shuffle', default=False) +@click.option('-m', '--max-result-elements-power', type=int, default=30, show_default=True) +@click.option('-n', '--max-num-entities-power', type=int, default=15, show_default=True) +@click.option('-b', '--max-batch-size-power', type=int, default=10, show_default=True) +@click.option('-d', '--max-vector-dimension-power', type=int, default=10, show_default=True) +@click.option('-s', '--max-sample-power', type=int, default=10, show_default=True) +def main( + fast: bool, + shuffle: bool, + max_result_elements_power: int, + max_num_entities_power: int, + max_batch_size_power: int, + max_vector_dimension_power: int, + max_sample_power: int, +): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + print(f"Running on {device}.") + variants = [ + _complex_select, + _complex_native_complex, + _complex_broadcast_optimized, + _complex_direct, + _complex_stacked, + _complex_stacked_select, + ] + use_case_labels = ["hrt", "hrt+", "h+rt", "t", "h"] + batch_sizes = [2 ** i for i in range(5, max_batch_size_power + 1)] + negative_samples = [2 ** i for i in range(5, max_sample_power + 1)] + num_entities = [2 ** i for i in range(7, max_num_entities_power)] + max_result_elements = 2 ** max_result_elements_power + vector_dimensions = [2 ** i for i in range(5, max_vector_dimension_power + 1)] + data = [] + tasks = [ + (v, b, s, n, d, ul, _use_case_to_shape(use_case=ul, b=b, n=n, s=s)) + for v in variants + for b in batch_sizes + for s in negative_samples + for n in num_entities + for d in vector_dimensions + for ul in use_case_labels + ] + if shuffle: + random.shuffle(tasks) + if fast: + tasks = tasks[:5] + progress = tqdm(tasks, unit="task") + for i, config in enumerate(progress, start=1): + v, b, s, n, d, ul, prefix_shapes = config + interaction = Interaction.from_func(v) + result_shape = _get_result_shape(prefix_shapes) + max_memory = median = iqr = float('nan') + if max_result_elements is not None and max_result_elements < numpy.prod(result_shape): + continue + shapes = _resolve_shapes( + prefix_shapes=prefix_shapes, + interaction=interaction, + dim=d, + ) + try: + timer = Timer( + stmt="interaction(h=h, r=r, t=t)", + globals=dict(interaction=interaction, shapes=shapes, device=device, _generate_hrt=_generate_hrt), + setup="h, r, t = _generate_hrt(shapes=shapes, device=device)" + ) + time = timer.blocked_autorange() + median = time.median + iqr = time.iqr + max_memory = _get_memory(interaction, shapes, device) + + except Exception as error: + progress.write(f"ERROR: {error} for {v}:{config}") + progress.set_postfix(dict(s=prefix_shapes, t=median, mem=max_memory)) + data.append((i, b, s, n, d, ul, prefix_shapes, v.__name__, median, iqr, max_memory)) + + git_hash = get_git_hash() + df = pandas.DataFrame(data=data, columns=[ + "experiment_number", + "batch_size", + "num_negative_samples", + "num_entities", + "dimension", + "use_case", + "prefix_shapes", + "variant", + "time_median", + "time_inter_quartile_range", + "max_memory", + ]) + df["device"] = device.type + df.to_csv(f"{git_hash}_measurement.tsv", sep="\t", index=False) + + df_agg = df.groupby( + by=["batch_size", "num_entities", "dimension", "use_case", "variant"] + ).agg({"time_median": "mean"}).unstack().reset_index()#.dropna() + df_agg.to_csv(f"{git_hash}_measurement_agg.tsv", sep="\t", index=False) + print(df_agg) + + +if __name__ == '__main__': + main() diff --git a/benchmarking/utils.py b/benchmarking/utils.py new file mode 100644 index 0000000000..8d00eb9fb9 --- /dev/null +++ b/benchmarking/utils.py @@ -0,0 +1,132 @@ +"""Benchmark utility methods.""" +import functools +import itertools +import operator +import timeit +from typing import Sequence, Tuple + +import click +import pandas +import torch +from tqdm.auto import tqdm + +from pykeen.utils import tensor_sum + + +def _generate_shapes( + batch_size: int, + num: int, + dim: int, + use_case: str, + shapes: Sequence[str], +) -> Sequence[Tuple[int, ...]]: + dims = dict(b=batch_size, h=num, r=num, t=num, d=dim) + canonical = "bhrtd" + return [ + tuple( + dims[c] if ( + c in use_case and c in shape + ) else 1 + for c in canonical + ) + for shape in shapes + ] + + +def _generate_tensors(*shapes: Tuple[int, ...]) -> Sequence[torch.FloatTensor]: + return [ + torch.rand(shape, requires_grad=True, dtype=torch.float32) + for shape in shapes + ] + + +def tqdm_itertools_product(*args, **kwargs): + return tqdm(itertools.product(*args), **kwargs, total=functools.reduce(operator.mul, map(len, args), 1)) + + +def _get_result_shape(shapes: Sequence[Tuple[int, ...]]) -> Tuple[int, ...]: + return tuple(max(ds) for ds in zip(*shapes)) + + +@click.command() +@click.option('-m', '--max-result-power', type=int, default=30, show_default=True) +@click.option('-b', '--max-batch-size-power', type=int, default=10, show_default=True) +@click.option('-n', '--max-num-power', type=int, default=14, show_default=True) +@click.option('-d', '--max-dim-power', type=int, default=10, show_default=True) +def main( + max_result_power: int, + max_batch_size_power: int, + max_dim_power: int, + max_num_power: int, +): + """Test whether tensor_sum actually offers any performance benefits.""" + max_size = 2 ** max_result_power + data = [] + progress = tqdm_itertools_product( + [2 ** i for i in range(5, max_batch_size_power + 1)], + [2 ** i for i in range(10, max_num_power + 1)], # 2**15 ~ 15k + [2 ** i for i in range(5, max_dim_power + 1)], + ["b", "bh", "br", "bt"], # score_hrt, score_h, score_t, score_t + ( + ("ConvKB/ERMLP", "", "bh", "br", "bt"), # conv_bias, h, r, t + ("NTN", "bhrt", "bhr", "bht", "br"), # h w t, vh h, vt t, b + ("ProjE", "bh", "br", ""), # h, r, b_c + # ("RotatE", "bhr", "bt"), # hr, -t, + # ("RotatE-inv", "bh", "brt"), # h, -(r_inv)t, + ("StructuredEmbedding", "bhr", "brt"), # r h, r t + ("TransE/TransD/KG2E", "bh", "br", "bt"), # h, r, t + ("TransH", "bh", "bhr", "br", "bt", "brt"), # h, - w_r, d_r, -t,, w_r + ("TransR", "bhr", "br", "brt"), # h m_r, r, -t m_r + # ("UnstructuredModel", "bh", "bt"), # h, r + ), + unit="configuration", + unit_scale=True, + ) + for batch_size, num, dim, use_case, (models, *shapes) in progress: + this_shapes = _generate_shapes( + batch_size=batch_size, + num=num, + dim=dim, + use_case=use_case, + shapes=shapes, + ) + result_shape = _get_result_shape(this_shapes) + num_result_elements = functools.reduce(operator.mul, result_shape, 1) + if num_result_elements > max_size: + continue + tensors = _generate_tensors(*this_shapes) + + # using normal sum + n_samples, time_baseline = timeit.Timer( + stmt="sum(tensors)", + globals=dict(tensors=tensors) + ).autorange() + time_baseline /= n_samples + + # use tensor_sum + n_samples, time = timeit.Timer( + setup="tensor_sum(*tensors)", + stmt="tensor_sum(*tensors)", + globals=dict(tensor_sum=tensor_sum, tensors=tensors) + ).autorange() + time /= n_samples + + data.append((batch_size, num, dim, use_case, shapes, time_baseline, time)) + progress.set_postfix(shape=result_shape, delta=time_baseline - time) + df = pandas.DataFrame(data=data, columns=[ + "batch_size", + "num", + "dim", + "use_case", + "shapes", + "time_sum", + "time_tensor_sum", + ]) + df.to_csv("tensor_sum.perf.tsv", sep="\t", index=False) + print("tensor_sum is better than sum in {percent:2.2%} of all cases.".format( + percent=(df["time_sum"] > df["time_tensor_sum"]).mean()) + ) + + +if __name__ == '__main__': + main() diff --git a/docs/source/extending/extending_interactors.rst b/docs/source/extending/extending_interactors.rst new file mode 100644 index 0000000000..4dce2d909f --- /dev/null +++ b/docs/source/extending/extending_interactors.rst @@ -0,0 +1,25 @@ +Extending the Interaction Functions +=================================== +In [ali2020]_, we argued that a knowledge graph embedding model (KGEM) consists of +several components: an interaction function, a loss function, a training approach, etc. + +Let's assume you have invented a new interaction model, +e.g. this variant of :class:`pykeen.models.DistMult` + +.. math:: + + f(h, r, t) = + +where :math:`h,r,t \in \mathbb{R}^d`, and :math:`\sigma` denotes the logistic sigmoid. + +.. code-block:: python + + from pykeen.nn import Interaction + + class ModifiedDistMultInteraction(Interaction): + def forward(self, h, r, t): + return h * r.sigmoid() * t + + +.. [ali2020] Mehdi, A., *et al.* (2020) `PyKEEN 1.0: A Python Library for Training and + Evaluating Knowledge Graph Embeddings `_ *arXiv*, 2007.14175. diff --git a/docs/source/extending/index.rst b/docs/source/extending/index.rst index 3f93f6814f..f8cf0bfef2 100644 --- a/docs/source/extending/index.rst +++ b/docs/source/extending/index.rst @@ -4,3 +4,4 @@ Extending PyKEEN :name: extending extending_models + extending_interactors diff --git a/docs/source/index.rst b/docs/source/index.rst index 157ee02afb..9445268586 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,6 +26,7 @@ PyKEEN :maxdepth: 2 reference/pipeline + reference/interactions reference/models reference/datasets reference/triples diff --git a/docs/source/reference/interactions.rst b/docs/source/reference/interactions.rst new file mode 100644 index 0000000000..d051f7e473 --- /dev/null +++ b/docs/source/reference/interactions.rst @@ -0,0 +1,13 @@ +Interactions +============ +Functional Interface +~~~~~~~~~~~~~~~~~~~~ +.. automodapi:: pykeen.nn.functional + :no-heading: + :headings: -- + +Module Interface +~~~~~~~~~~~~~~~~ +.. automodapi:: pykeen.nn.modules + :no-heading: + :headings: -- diff --git a/docs/source/reference/models.rst b/docs/source/reference/models.rst index f84366148f..73799cb6b4 100644 --- a/docs/source/reference/models.rst +++ b/docs/source/reference/models.rst @@ -20,3 +20,4 @@ Extra Modules ------------- .. automodule:: pykeen.nn :members: + :exclude-members: Interaction, StatelessInteraction diff --git a/src/pykeen/cli.py b/src/pykeen/cli.py index d1f58fb36f..8e8a17ba72 100644 --- a/src/pykeen/cli.py +++ b/src/pykeen/cli.py @@ -29,8 +29,8 @@ from .hpo.cli import optimize from .hpo.samplers import samplers as hpo_samplers_dict from .losses import losses as losses_dict -from .models import models as models_dict -from .models.base import EntityEmbeddingModel, EntityRelationEmbeddingModel, Model +from .models import ERModel, models as models_dict +from .models.base import Model from .models.cli import build_cli_from_cls from .optimizers import optimizers as optimizers_dict from .regularizers import regularizers as regularizers_dict @@ -100,8 +100,7 @@ def parameters(): base_parameters = set(chain( Model.__init__.__annotations__, - EntityEmbeddingModel.__init__.__annotations__, - EntityRelationEmbeddingModel.__init__.__annotations__, + ERModel.__init__.__annotations__, )) _hyperparameter_usage = sorted( (k, v) diff --git a/src/pykeen/evaluation/evaluator.py b/src/pykeen/evaluation/evaluator.py index d843a37fe8..b1c01a77e6 100644 --- a/src/pykeen/evaluation/evaluator.py +++ b/src/pykeen/evaluation/evaluator.py @@ -260,7 +260,6 @@ def _param_size_search( values_dict[key] = start_value values_dict['slice_size'] = None elif key == 'slice_size': - self._check_slicing_availability(model, batch_size=1) values_dict[key] = start_value values_dict['batch_size'] = 1 else: @@ -322,19 +321,6 @@ def _param_size_search( return values_dict[key], evaluated_once - @staticmethod - def _check_slicing_availability(model: Model, batch_size: int) -> None: - # Test if slicing is implemented for the required functions of this model - if model.triples_factory.create_inverse_triples: - if not model.can_slice_t: - raise MemoryError(f"The current model can't be evaluated on this hardware with these parameters, as " - f"evaluation batch_size={batch_size} is too big and slicing is not implemented for " - f"this model yet.") - elif not model.can_slice_t or not model.can_slice_h: - raise MemoryError(f"The current model can't be evaluated on this hardware with these parameters, as " - f"evaluation batch_size={batch_size} is too big and slicing is not implemented for this " - f"model yet.") - def create_sparse_positive_filter_( hrt_batch: MappedTriples, diff --git a/src/pykeen/experiments/inverse_stability.py b/src/pykeen/experiments/inverse_stability.py deleted file mode 100644 index 61194e033f..0000000000 --- a/src/pykeen/experiments/inverse_stability.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Inverse Stability Workflow. - -This experiment investigates the differences between - -""" - -import itertools as itt -import logging -from typing import Type - -import click -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns - -import pykeen.evaluation.evaluator -from pykeen.constants import PYKEEN_EXPERIMENTS -from pykeen.datasets import Dataset, get_dataset -from pykeen.models import Model, get_model_cls -from pykeen.pipeline import pipeline - -INVERSE_STABILITY = PYKEEN_EXPERIMENTS / 'inverse_stability' -INVERSE_STABILITY.mkdir(parents=True, exist_ok=True) - -pykeen.evaluation.evaluator.logger.setLevel(logging.CRITICAL) - - -@click.command() -@click.option('--force', is_flag=True) -@click.option('--clip', type=int, default=10) -def main(force: bool, clip: int): - """Run the inverse stability experiments.""" - results_path = INVERSE_STABILITY / 'results.tsv' - if results_path.exists() and not force: - df = pd.read_csv(results_path, sep='\t') - df['residuals'] = df['forward'] - df['inverse'] - df = df[(-clip < df['residuals']) & (df['residuals'] < clip)] - g = sns.FacetGrid(df, col='model', row='dataset', hue='training_loop', sharex=False, sharey=False) - g.map_dataframe(sns.histplot, x='residuals', stat="density") - g.add_legend() - g.savefig(INVERSE_STABILITY / 'results_residuals.png', dpi=300) - - else: - outer_dfs = [] - datasets = ['nations', 'kinships'] - models = ['rotate', 'complex', 'simple', 'transe', 'distmult'] - training_loops = ['lcwa', 'slcwa'] - for dataset, model, training_loop in itt.product(datasets, models, training_loops): - click.secho(f'{dataset} {model} {training_loop}', fg='cyan') - df = run_inverse_stability_workflow(dataset=dataset, model=model, training_loop=training_loop) - outer_dfs.append(df) - outer_df = pd.concat(outer_dfs) - outer_df.to_csv(INVERSE_STABILITY / 'results.tsv', sep='\t', index=False) - - -def run_inverse_stability_workflow(dataset: str, model: str, training_loop: str, random_seed=0, device='cpu'): - """Run an inverse stability experiment.""" - dataset: Dataset = get_dataset( - dataset=dataset, - dataset_kwargs=dict( - create_inverse_triples=True, - ), - ) - dataset_name = dataset.get_normalized_name() - model_cls: Type[Model] = get_model_cls(model) - model_name = model_cls.__name__.lower() - - dataset_dir = INVERSE_STABILITY / dataset_name - dataset_dir.mkdir(exist_ok=True, parents=True) - - pipeline_result = pipeline( - dataset=dataset, - model=model, - training_loop=training_loop, - training_kwargs=dict( - num_epochs=1000, - use_tqdm_batch=False, - ), - stopper='early', - stopper_kwargs=dict(patience=5, frequency=5), - random_seed=random_seed, - device=device, - ) - test_tf = dataset.testing - model = pipeline_result.model - # Score with original triples - scores_forward = model.score_hrt(test_tf.mapped_triples) - scores_forward_np = scores_forward.detach().numpy()[:, 0] - - # Score with inverse triples - scores_inverse = model.score_hrt_inverse(test_tf.mapped_triples) - scores_inverse_np = scores_inverse.detach().numpy()[:, 0] - - scores_path = dataset_dir / f'{model_name}_{training_loop}_scores.tsv' - df = pd.DataFrame( - list(zip( - itt.repeat(training_loop), - itt.repeat(dataset_name), - itt.repeat(model_name), - scores_forward_np, - scores_inverse_np, - )), - columns=['training_loop', 'dataset', 'model', 'forward', 'inverse'], - ) - df.to_csv(scores_path, sep='\t', index=False) - - fig, ax = plt.subplots(1, 1) - sns.histplot(data=df, x='forward', label='Forward', ax=ax, color='blue', stat="density") - sns.histplot(data=df, x='inverse', label='Inverse', ax=ax, color='orange', stat="density") - ax.set_title(f'{dataset_name} - {model_name} - {training_loop}') - ax.set_xlabel('Score') - plt.legend() - plt.savefig(dataset_dir / f'{model_name}_{training_loop}_overlay.png', dpi=300) - plt.close(fig) - - fig, ax = plt.subplots(1, 1) - sns.histplot(scores_forward_np - scores_inverse_np, ax=ax, stat="density") - ax.set_title(f'{dataset_name} - {model_name} - {training_loop}') - ax.set_xlabel('Forward - Inverse Score Difference') - plt.savefig(dataset_dir / f'{model_name}_{training_loop}_residuals.png', dpi=300) - plt.close(fig) - - return df - - -if __name__ == '__main__': - main() diff --git a/src/pykeen/hpo/hpo.py b/src/pykeen/hpo/hpo.py index f027cf507a..d90ecc116b 100644 --- a/src/pykeen/hpo/hpo.py +++ b/src/pykeen/hpo/hpo.py @@ -26,7 +26,7 @@ from ..models.base import Model from ..optimizers import Optimizer, get_optimizer_cls, optimizers_hpo_defaults from ..pipeline import pipeline, replicate_pipeline_from_config -from ..regularizers import Regularizer, get_regularizer_cls +from ..regularizers import Regularizer from ..sampling import NegativeSampler, get_negative_sampler_cls from ..stoppers import EarlyStopper, Stopper, get_stopper_cls from ..trackers import ResultTracker, get_result_tracker_cls @@ -57,7 +57,7 @@ class Objective: dataset: Union[None, str, Type[Dataset]] # 1. model: Type[Model] # 2. loss: Type[Loss] # 3. - regularizer: Type[Regularizer] # 4. + regularizer: Optional[Type[Regularizer]] # 4. optimizer: Type[Optimizer] # 5. training_loop: Type[TrainingLoop] # 6. evaluator: Type[Evaluator] # 8. @@ -149,13 +149,16 @@ def __call__(self, trial: Trial) -> Optional[float]: kwargs_ranges=self.loss_kwargs_ranges, ) # 4. Regularizer - _regularizer_kwargs = _get_kwargs( - trial=trial, - prefix='regularizer', - default_kwargs_ranges=self.regularizer.hpo_default, - kwargs=self.regularizer_kwargs, - kwargs_ranges=self.regularizer_kwargs_ranges, - ) + if self.regularizer is not None: + _regularizer_kwargs = _get_kwargs( + trial=trial, + prefix='regularizer', + default_kwargs_ranges=self.regularizer.hpo_default, + kwargs=self.regularizer_kwargs, + kwargs_ranges=self.regularizer_kwargs_ranges, + ) + else: + _regularizer_kwargs = None # 5. Optimizer _optimizer_kwargs = _get_kwargs( trial=trial, @@ -613,13 +616,15 @@ def hpo_pipeline( study.set_user_attr('loss', normalize_string(loss.__name__, suffix=_LOSS_SUFFIX)) logger.info(f'Using loss: {loss}') # 4. Regularizer - regularizer: Type[Regularizer] = ( - model.regularizer_default - if regularizer is None else - get_regularizer_cls(regularizer) - ) - study.set_user_attr('regularizer', regularizer.get_normalized_name()) - logger.info(f'Using regularizer: {regularizer}') + if regularizer is not None: + logger.warning('Usage of the regularizer with the HPO is currently under maitenance.') + # regularizer: Type[Regularizer] = ( + # model.regularizer_default + # if regularizer is None else + # get_regularizer_cls(regularizer) + # ) + # study.set_user_attr('regularizer', regularizer.get_normalized_name()) + # logger.info(f'Using regularizer: {regularizer}') # 5. Optimizer optimizer: Type[Optimizer] = get_optimizer_cls(optimizer) study.set_user_attr('optimizer', normalize_string(optimizer.__name__)) diff --git a/src/pykeen/models/__init__.py b/src/pykeen/models/__init__.py index 0de20d93cf..7189a695c8 100644 --- a/src/pykeen/models/__init__.py +++ b/src/pykeen/models/__init__.py @@ -8,8 +8,8 @@ from typing import Mapping, Set, Type, Union -from .base import EntityEmbeddingModel, EntityRelationEmbeddingModel, Model, MultimodalModel # noqa:F401 -from .multimodal import ComplExLiteral, DistMultLiteral +from .base import ERModel, Model # noqa:F401 +from .multimodal import ComplExLiteral, DistMultLiteral, LiteralModel # noqa:F401 from .unimodal import ( ComplEx, ConvE, diff --git a/src/pykeen/models/base.py b/src/pykeen/models/base.py index 946c62b215..8cafd51125 100644 --- a/src/pykeen/models/base.py +++ b/src/pykeen/models/base.py @@ -3,12 +3,15 @@ """Base module for all KGE models.""" import functools -import inspect import itertools as itt import logging from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, ClassVar, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union +from operator import itemgetter +from typing import ( + Any, ClassVar, Collection, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Set, TYPE_CHECKING, Tuple, + Type, Union, +) import numpy as np import pandas as pd @@ -16,17 +19,19 @@ from torch import nn from ..losses import Loss, MarginRankingLoss, NSSALoss -from ..nn import Embedding -from ..regularizers import NoRegularizer, Regularizer +from ..nn import EmbeddingSpecification, RepresentationModule +from ..nn.modules import Interaction +from ..regularizers import Regularizer, collect_regularization_terms from ..triples import TriplesFactory -from ..typing import Constrainer, DeviceHint, Initializer, MappedTriples, Normalizer -from ..utils import NoRandomSeedNecessary, get_batchnorm_modules, resolve_device, set_random_seed +from ..typing import DeviceHint, HeadRepresentation, MappedTriples, RelationRepresentation, TailRepresentation +from ..utils import NoRandomSeedNecessary, check_shapes, get_batchnorm_modules, resolve_device, set_random_seed + +if TYPE_CHECKING: + from ..typing import Representation # noqa __all__ = [ 'Model', - 'EntityEmbeddingModel', - 'EntityRelationEmbeddingModel', - 'MultimodalModel', + 'ERModel', ] logger = logging.getLogger(__name__) @@ -200,7 +205,21 @@ def _new_init(self, *args, **kwargs): class Model(nn.Module, ABC): - """A base module for all of the KGE models.""" + """An abstract class for knowledge graph embedding models (KGEMs). + + The only function that needs to be implemented for a given subclass is + :meth:`Model.forward`. The job of the :meth:`Model.forward` function, as + opposed to the completely general :meth:`torch.nn.Module.forward` is + to take indices for the head, relation, and tails' respective representation(s) + and to determine a score. + + Subclasses of Model can decide however they want on how to store entities' and + relations' representations, how they want to be looked up, and how they should + be scored. The :class:`ERModel` provides a commonly useful implementation + which allows for the specification of one or more entity representations and + one or more relation representations in the form of :class:`pykeen.nn.Embedding` + as well as a matching instance of a :class:`pykeen.nn.Interaction`. + """ #: A dictionary of hyper-parameters to the models that use them _hyperparameter_usage: ClassVar[Dict[str, Set[str]]] = defaultdict(set) @@ -216,14 +235,12 @@ class Model(nn.Module, ABC): #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Optional[Mapping[str, Any]]] = dict(margin=1.0, reduction='mean') #: The instance of the loss - loss: ClassVar[Loss] + loss: Loss #: The default regularizer class - regularizer_default: ClassVar[Type[Regularizer]] = NoRegularizer + regularizer_default: ClassVar[Optional[Type[Regularizer]]] = None #: The default parameters for the default regularizer class regularizer_default_kwargs: ClassVar[Optional[Mapping[str, Any]]] = None - #: The instance of the regularizer - regularizer: ClassVar[Regularizer] def __init__( self, @@ -232,7 +249,6 @@ def __init__( predict_with_sigmoid: bool = False, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: """Initialize the module. @@ -248,8 +264,6 @@ def __init__( The preferred device for model training and inference. :param random_seed: A random seed to use for initialising the model's weights. **Should** be set when aiming at reproducibility. - :param regularizer: - A regularizer to use for training. """ super().__init__() @@ -271,17 +285,8 @@ def __init__( self.loss = loss # TODO: Check loss functions that require 1 and -1 as label but only - self.is_mr_loss = isinstance(self.loss, MarginRankingLoss) - - # Regularizer - if regularizer is None: - regularizer = self.regularizer_default( - device=self.device, - **(self.regularizer_default_kwargs or {}), - ) - self.regularizer = regularizer - - self.is_nssa_loss = isinstance(self.loss, NSSALoss) + self.is_mr_loss: bool = isinstance(self.loss, MarginRankingLoss) + self.is_nssa_loss: bool = isinstance(self.loss, NSSALoss) # The triples factory facilitates access to the dataset. self.triples_factory = triples_factory @@ -298,30 +303,51 @@ def __init_subclass__(cls, autoreset: bool = True, **kwargs): # noqa:D105 _track_hyperparameters(cls) _add_post_reset_parameters(cls) - @property - def can_slice_h(self) -> bool: - """Whether score_h supports slicing.""" - return _can_slice(self.score_h) - - @property - def can_slice_r(self) -> bool: - """Whether score_r supports slicing.""" - return _can_slice(self.score_r) - - @property - def can_slice_t(self) -> bool: - """Whether score_t supports slicing.""" - return _can_slice(self.score_t) - @property def modules_not_supporting_sub_batching(self) -> Collection[nn.Module]: """Return all modules not supporting sub-batching.""" return get_batchnorm_modules(module=self) - @abstractmethod def _reset_parameters_(self): # noqa: D401 """Reset all parameters of the model in-place.""" - raise NotImplementedError + # cf. https://github.com/mberr/ea-sota-comparison/blob/6debd076f93a329753d819ff4d01567a23053720/src/kgm/utils/torch_utils.py#L317-L372 # noqa:E501 + # Make sure that all modules with parameters do have a reset_parameters method. + uninitialized_parameters = set(map(id, self.parameters())) + parents = defaultdict(list) + + # Recursively visit all sub-modules + task_list = [] + for name, module in self.named_modules(): + + # skip self + if module is self: + continue + + # Track parents for blaming + for p in module.parameters(): + parents[id(p)].append(module) + + # call reset_parameters if possible + if hasattr(module, 'reset_parameters'): + task_list.append((name.count('.'), module)) + + # initialize from bottom to top + # This ensures that specialized initializations will take priority over the default ones of its components. + for module in map(itemgetter(1), sorted(task_list, reverse=True, key=itemgetter(0))): + module.reset_parameters() + uninitialized_parameters.difference_update(map(id, module.parameters())) + + # emit warning if there where parameters which were not initialised by reset_parameters. + if len(uninitialized_parameters) > 0: + logger.warning( + 'reset_parameters() not found for all modules containing parameters. ' + '%d parameters where likely not initialized.', + len(uninitialized_parameters), + ) + + # Additional debug information + for i, p_id in enumerate(uninitialized_parameters, start=1): + logger.debug('[%3d] Parents to blame: %s', i, parents.get(p_id)) def reset_parameters_(self) -> 'Model': # noqa: D401 """Reset all parameters of the model and enforce model constraints.""" @@ -344,10 +370,23 @@ def _set_device(self, device: DeviceHint = None) -> None: """Set the Torch device to use.""" self.device = resolve_device(device=device) + def _instantiate_default_regularizer(self, **kwargs) -> Optional[Regularizer]: + """Instantiate the regularizer from this class's default settings. + + If the default regularizer is None, None is returned. + Handles the corner case when the default regularizer's keyword arguments are None + Additional keyword arguments can be passed through to the `__init__()` function + """ + if self.regularizer_default is None: + return None + + _kwargs = dict(self.regularizer_default_kwargs or {}) + _kwargs.update(kwargs) + return self.regularizer_default(**_kwargs) + def to_device_(self) -> 'Model': """Transfer model to device.""" self.to(self.device) - self.regularizer.to(self.device) torch.cuda.empty_cache() return self @@ -399,22 +438,10 @@ def predict_scores_all_tails( :return: shape: (batch_size, num_entities), dtype: float For each h-r pair, the scores for all possible tails. - - .. note:: - - We only expect the right side-side predictions, i.e., $(h,r,*)$ to change its - default behavior when the model has been trained with inverse relations - (mainly because of the behavior of the LCWA training approach). This is why - the :func:`predict_scores_all_heads()` has different behavior depending on - if inverse triples were used in training, and why this function has the same - behavior regardless of the use of inverse triples. """ # Enforce evaluation mode self.eval() - if slice_size is None: - scores = self.score_t(hr_batch) - else: - scores = self.score_t(hr_batch, slice_size=slice_size) + scores = self.score_t(hr_batch, slice_size=slice_size) # type: ignore if self.predict_with_sigmoid: scores = torch.sigmoid(scores) return scores @@ -550,10 +577,7 @@ def predict_scores_all_relations( """ # Enforce evaluation mode self.eval() - if slice_size is None: - scores = self.score_r(ht_batch) - else: - scores = self.score_r(ht_batch, slice_size=slice_size) + scores = self.score_r(ht_batch, slice_size=slice_size) # type: ignore if self.predict_with_sigmoid: scores = torch.sigmoid(scores) return scores @@ -567,12 +591,6 @@ def predict_scores_all_heads( This method calculates the score for all possible heads for each (relation, tail) pair. - .. note:: - - If the model has been trained with inverse relations, the task of predicting - the head entities becomes the task of predicting the tail entities of the - inverse triples, i.e., $f(*,r,t)$ is predicted by means of $f(t,r_{inv},*)$. - Additionally, the model is set to evaluation mode. :param rt_batch: shape: (batch_size, 2), dtype: long @@ -586,12 +604,36 @@ def predict_scores_all_heads( # Enforce evaluation mode self.eval() - if self.triples_factory.create_inverse_triples: - scores = self.score_h_inverse(rt_batch=rt_batch, slice_size=slice_size) - elif slice_size is None: - scores = self.score_h(rt_batch) - else: + ''' + In case the model was trained using inverse triples, the scoring of all heads is not handled by calculating + the scores for all heads based on a (relation, tail) pair, but instead all possible tails are calculated + for a (tail, inverse_relation) pair. + ''' + if not self.triples_factory.create_inverse_triples: scores = self.score_h(rt_batch, slice_size=slice_size) + if self.predict_with_sigmoid: + scores = torch.sigmoid(scores) + return scores + + ''' + The PyKEEN package handles _inverse relations_ by adding the number of relations to the indices of the + _native relation_. + Example: + The triples/knowledge graph used to train the model contained 100 relations. Due to using inverse relations, + the model now has an additional 100 inverse relations. If the _native relation_ has the index 3, the index + of the _inverse relation_ is 4 (id of relation + 1). + ''' + rt_batch_cloned = rt_batch.clone() + rt_batch_cloned.to(device=rt_batch.device) + + # The number of relations stored in the triples factory includes the number of inverse relations + # Id of inverse relation: relation + 1 + rt_batch_cloned[:, 0] = rt_batch_cloned[:, 0] + 1 + + # The score_t function requires (entity, relation) pairs instead of (relation, entity) pairs + rt_batch_cloned = rt_batch_cloned.flip(1) + scores = self.score_t(rt_batch_cloned, slice_size=slice_size) # type: ignore + if self.predict_with_sigmoid: scores = torch.sigmoid(scores) return scores @@ -787,15 +829,11 @@ def make_labeled_df( def post_parameter_update(self) -> None: """Has to be called after each parameter update.""" - self.regularizer.reset() - - def regularize_if_necessary(self, *tensors: torch.FloatTensor) -> None: - """Update the regularizer's term given some tensors, if regularization is requested. - - :param tensors: The tensors that should be passed to the regularizer to update its term. - """ - if self.training: - self.regularizer.update(*tensors) + for module in self.modules(): + if module is self: + continue + if hasattr(module, "post_parameter_update"): + module.post_parameter_update() def compute_mr_loss( self, @@ -819,7 +857,7 @@ def compute_mr_loss( ' losses. Please use the compute_loss method instead.', ) y = torch.ones_like(negative_scores, device=self.device) - return self.loss(positive_scores, negative_scores, y) + self.regularizer.term + return self.loss(positive_scores, negative_scores, y) + collect_regularization_terms(self) def compute_label_loss( self, @@ -882,25 +920,43 @@ def _compute_loss( 'The chosen loss does not allow the calculation of margin label' ' losses. Please use the compute_mr_loss method instead.', ) - return self.loss(tensor_1, tensor_2) + self.regularizer.term + return self.loss(tensor_1, tensor_2) + collect_regularization_terms(self) - def _prepare_inverse_batch(self, batch: torch.LongTensor, index_relation: int) -> torch.LongTensor: - if not self.triples_factory.create_inverse_triples: - raise ValueError( - "Your model is not configured to predict with inverse relations." - " Set ``create_inverse_triples=True`` when creating the dataset/triples factory" - " or using the pipeline().", - ) - batch_cloned = batch.clone() + @abstractmethod + def forward( + self, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, + ) -> torch.FloatTensor: + """Forward pass. - # The number of relations stored in the triples factory includes the number of inverse relations - # Id of inverse relation: relation + 1 - batch_cloned[:, index_relation] = batch_cloned[:, index_relation] + 1 + This method takes head, relation and tail indices and calculates the corresponding score. - # The score_t function requires (entity, relation) pairs instead of (relation, entity) pairs - return batch_cloned.flip(1) + .. note :: + All indices which are not None, have to be either 1-element or have the same shape, which is the batch size. + + .. note :: + If slicing is requested, the corresponding indices have to be None. + + :param h_indices: + The head indices. None indicates to use all. + :param r_indices: + The relation indices. None indicates to use all. + :param t_indices: + The tail indices. None indicates to use all. + :param slice_size: + The slice size. + :param slice_dim: + The dimension along which to slice. From {"h", "r", "t"}. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The score for each triple. + """ + raise NotImplementedError - @abstractmethod def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: """Forward pass. @@ -908,113 +964,78 @@ def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: :param hrt_batch: shape: (batch_size, 3), dtype: long The indices of (head, relation, tail) triples. - :raises NotImplementedError: - If the method was not implemented for this class. + :return: shape: (batch_size, 1), dtype: float The score for each triple. """ - raise NotImplementedError - - def score_hrt_inverse( - self, - hrt_batch: torch.LongTensor, - ) -> torch.FloatTensor: - r"""Score triples based on inverse triples, i.e., compute $f(h,r,t)$ based on $f(t,r_{inv},h)$. - - When training with inverse relations, the model produces two (different) scores for a triple $(h,r,t) \in K$. - The forward score is calculated from $f(h,r,t)$ and the inverse score is calculated from $f(t,r_{inv},h)$. - This function enables users to inspect the scores obtained by using the corresponding inverse triples. - """ - t_r_inv_h = self._prepare_inverse_batch(batch=hrt_batch, index_relation=1) - return self.score_hrt(hrt_batch=t_r_inv_h) + return self( + h_indices=hrt_batch[:, 0], + r_indices=hrt_batch[:, 1], + t_indices=hrt_batch[:, 2], + ).view(hrt_batch.shape[0], 1) - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: + def score_t(self, hr_batch: torch.LongTensor, slice_size: Optional[int] = None) -> torch.FloatTensor: """Forward pass using right side (tail) prediction. This method calculates the score for all possible tails for each (head, relation) pair. :param hr_batch: shape: (batch_size, 2), dtype: long The indices of (head, relation) pairs. + :param slice_size: + The slice size. :return: shape: (batch_size, num_entities), dtype: float For each h-r pair, the scores for all possible tails. """ - logger.warning( - 'Calculations will fall back to using the score_hrt method, since this model does not have a specific ' - 'score_t function. This might cause the calculations to take longer than necessary.', - ) - # Extend the hr_batch such that each (h, r) pair is combined with all possible tails - hrt_batch = _extend_batch(batch=hr_batch, all_ids=list(self.triples_factory.entity_to_id.values()), dim=2) - # Calculate the scores for each (h, r, t) triple using the generic interaction function - expanded_scores = self.score_hrt(hrt_batch=hrt_batch) - # Reshape the scores to match the pre-defined output shape of the score_t function. - scores = expanded_scores.view(hr_batch.shape[0], -1) - return scores - - def score_t_inverse(self, hr_batch: torch.LongTensor, slice_size: Optional[int] = None): - """Score all tails for a batch of (h,r)-pairs using the head predictions for the inverses $(*,r_{inv},h)$.""" - # TODO UNUSED - r_inv_h = self._prepare_inverse_batch(batch=hr_batch, index_relation=1) - - if slice_size is None: - return self.score_h(rt_batch=r_inv_h) - else: - return self.score_h(rt_batch=r_inv_h, slice_size=slice_size) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: + return self( + h_indices=hr_batch[:, 0], + r_indices=hr_batch[:, 1], + t_indices=None, + slice_size=slice_size, + slice_dim="h", + ).view(hr_batch.shape[0], self.num_entities) + + def score_h(self, rt_batch: torch.LongTensor, slice_size: Optional[int] = None) -> torch.FloatTensor: """Forward pass using left side (head) prediction. This method calculates the score for all possible heads for each (relation, tail) pair. :param rt_batch: shape: (batch_size, 2), dtype: long The indices of (relation, tail) pairs. + :param slice_size: + The slice size. :return: shape: (batch_size, num_entities), dtype: float For each r-t pair, the scores for all possible heads. """ - logger.warning( - 'Calculations will fall back to using the score_hrt method, since this model does not have a specific ' - 'score_h function. This might cause the calculations to take longer than necessary.', - ) - # Extend the rt_batch such that each (r, t) pair is combined with all possible heads - hrt_batch = _extend_batch(batch=rt_batch, all_ids=list(self.triples_factory.entity_to_id.values()), dim=0) - # Calculate the scores for each (h, r, t) triple using the generic interaction function - expanded_scores = self.score_hrt(hrt_batch=hrt_batch) - # Reshape the scores to match the pre-defined output shape of the score_h function. - scores = expanded_scores.view(rt_batch.shape[0], -1) - return scores - - def score_h_inverse(self, rt_batch: torch.LongTensor, slice_size: Optional[int] = None): - """Score all heads for a batch of (r,t)-pairs using the tail predictions for the inverses $(t,r_{inv},*)$.""" - t_r_inv = self._prepare_inverse_batch(batch=rt_batch, index_relation=0) - - if slice_size is None: - return self.score_t(hr_batch=t_r_inv) - else: - return self.score_t(hr_batch=t_r_inv, slice_size=slice_size) - - def score_r(self, ht_batch: torch.LongTensor) -> torch.FloatTensor: + return self( + h_indices=None, + r_indices=rt_batch[:, 0], + t_indices=rt_batch[:, 1], + slice_size=slice_size, + slice_dim="r", + ).view(rt_batch.shape[0], self.num_entities) + + def score_r(self, ht_batch: torch.LongTensor, slice_size: Optional[int] = None) -> torch.FloatTensor: """Forward pass using middle (relation) prediction. This method calculates the score for all possible relations for each (head, tail) pair. :param ht_batch: shape: (batch_size, 2), dtype: long The indices of (head, tail) pairs. + :param slice_size: + The slice size. :return: shape: (batch_size, num_relations), dtype: float For each h-t pair, the scores for all possible relations. """ - logger.warning( - 'Calculations will fall back to using the score_hrt method, since this model does not have a specific ' - 'score_r function. This might cause the calculations to take longer than necessary.', - ) - # Extend the ht_batch such that each (h, t) pair is combined with all possible relations - hrt_batch = _extend_batch(batch=ht_batch, all_ids=list(self.triples_factory.relation_to_id.values()), dim=1) - # Calculate the scores for each (h, r, t) triple using the generic interaction function - expanded_scores = self.score_hrt(hrt_batch=hrt_batch) - # Reshape the scores to match the pre-defined output shape of the score_r function. - scores = expanded_scores.view(ht_batch.shape[0], -1) - return scores + return self( + h_indices=ht_batch[:, 0], + r_indices=None, + t_indices=ht_batch[:, 1], + slice_size=slice_size, + slice_dim="t", + ).view(ht_batch.shape[0], self.num_relations) def get_grad_params(self) -> Iterable[nn.Parameter]: """Get the parameters that require gradients.""" @@ -1043,174 +1064,234 @@ def load_state(self, path: str) -> None: self.load_state_dict(torch.load(path, map_location=self.device)) -class EntityEmbeddingModel(Model, autoreset=False): - """A base module for most KGE models that have one embedding for entities.""" - - def __init__( - self, - triples_factory: TriplesFactory, - embedding_dim: int = 50, - loss: Optional[Loss] = None, - predict_with_sigmoid: bool = False, - preferred_device: DeviceHint = None, - random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, - entity_initializer: Optional[Initializer] = None, - entity_initializer_kwargs: Optional[Mapping[str, Any]] = None, - entity_normalizer: Optional[Normalizer] = None, - entity_normalizer_kwargs: Optional[Mapping[str, Any]] = None, - entity_constrainer: Optional[Constrainer] = None, - entity_constrainer_kwargs: Optional[Mapping[str, Any]] = None, - - ) -> None: - """Initialize the entity embedding model. - - :param embedding_dim: - The embedding dimensionality. Exact usages depends on the specific model subclass. - - .. seealso:: Constructor of the base class :class:`pykeen.models.Model` - """ - super().__init__( - triples_factory=triples_factory, - loss=loss, - preferred_device=preferred_device, - random_seed=random_seed, - regularizer=regularizer, - predict_with_sigmoid=predict_with_sigmoid, - ) - self.entity_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - device=self.device, - initializer=entity_initializer, - initializer_kwargs=entity_initializer_kwargs, - normalizer=entity_normalizer, - normalizer_kwargs=entity_normalizer_kwargs, - constrainer=entity_constrainer, - constrainer_kwargs=entity_constrainer_kwargs, +def _prepare_representation_module_list( + representations: Union[ + None, + EmbeddingSpecification, + RepresentationModule, + Sequence[Union[EmbeddingSpecification, RepresentationModule]], + ], + num_embeddings: int, + shapes: Sequence[str], + label: str, + skip_checks: bool = False, +) -> Sequence[RepresentationModule]: + """Normalize list of representations and wrap into nn.ModuleList.""" + # Important: use ModuleList to ensure that Pytorch correctly handles their devices and parameters + if representations is None: + representations = [] + if not isinstance(representations, Sequence): + representations = [representations] + if not skip_checks and len(representations) != len(shapes): + raise ValueError( + f"Interaction function requires {len(shapes)} {label} representations, but " + f"{len(representations)} were given.", ) + modules = [] + for r in representations: + if not isinstance(r, RepresentationModule): + assert isinstance(r, EmbeddingSpecification) + r = r.make(num_embeddings=num_embeddings) + if r.max_id < num_embeddings: + raise ValueError( + f"{r} only provides {r.max_id} {label} representations, but should provide {num_embeddings}.", + ) + elif r.max_id > num_embeddings: + logger.warning( + f"{r} provides {r.max_id} {label} representations, although only {num_embeddings} are needed." + f"While this is not necessarily wrong, it can indicate an error where the number of {label} " + f"representations was chosen wrong.", + ) + modules.append(r) + if not skip_checks: + check_shapes(*zip( + (r.shape for r in modules), + shapes, + ), raise_on_errors=True) + return nn.ModuleList(modules) - @property - def embedding_dim(self) -> int: # noqa:D401 - """The entity embedding dimension.""" - return self.entity_embeddings.embedding_dim - def _reset_parameters_(self): # noqa: D102 - self.entity_embeddings.reset_parameters() +class ERModel(Generic[HeadRepresentation, RelationRepresentation, TailRepresentation], Model, autoreset=False): + """A commonly useful base for KGEMs using embeddings and interaction modules.""" - def post_parameter_update(self) -> None: # noqa: D102 - # make sure to call this first, to reset regularizer state! - super().post_parameter_update() - self.entity_embeddings.post_parameter_update() + #: The entity representations + entity_representations: Sequence[RepresentationModule] + #: The relation representations + relation_representations: Sequence[RepresentationModule] -class EntityRelationEmbeddingModel(Model, autoreset=False): - """A base module for KGE models that have different embeddings for entities and relations.""" + #: The weight regularizers + weight_regularizers: List[Regularizer] def __init__( self, triples_factory: TriplesFactory, - embedding_dim: int = 50, - relation_dim: Optional[int] = None, + interaction: Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation], + entity_representations: Union[ + None, + EmbeddingSpecification, + RepresentationModule, + Sequence[Union[EmbeddingSpecification, RepresentationModule]], + ] = None, + relation_representations: Union[ + None, + EmbeddingSpecification, + RepresentationModule, + Sequence[Union[EmbeddingSpecification, RepresentationModule]], + ] = None, loss: Optional[Loss] = None, predict_with_sigmoid: bool = False, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, - entity_initializer: Optional[Initializer] = None, - entity_initializer_kwargs: Optional[Mapping[str, Any]] = None, - entity_normalizer: Optional[Normalizer] = None, - entity_normalizer_kwargs: Optional[Mapping[str, Any]] = None, - entity_constrainer: Optional[Constrainer] = None, - entity_constrainer_kwargs: Optional[Mapping[str, Any]] = None, - relation_initializer: Optional[Initializer] = None, - relation_initializer_kwargs: Optional[Mapping[str, Any]] = None, - relation_normalizer: Optional[Normalizer] = None, - relation_normalizer_kwargs: Optional[Mapping[str, Any]] = None, - relation_constrainer: Optional[Constrainer] = None, - relation_constrainer_kwargs: Optional[Mapping[str, Any]] = None, ) -> None: - """Initialize the entity embedding model. - - :param relation_dim: - The relation embedding dimensionality. If not given, defaults to same size as entity embedding - dimension. + """Initialize the module. - .. seealso:: Constructor of the base class :class:`pykeen.models.Model` - .. seealso:: Constructor of the base class :class:`pykeen.models.EntityEmbeddingModel` + :param triples_factory: + The triples factory facilitates access to the dataset. + :param loss: + The loss to use. If None is given, use the loss default specific to the model subclass. + :param predict_with_sigmoid: + Whether to apply sigmoid onto the scores when predicting scores. Applying sigmoid at prediction time may + lead to exactly equal scores for certain triples with very high, or very low score. When not trained with + applying sigmoid (or using BCEWithLogitsLoss), the scores are not calibrated to perform well with sigmoid. + :param preferred_device: + The preferred device for model training and inference. + :param random_seed: + A random seed to use for initialising the model's weights. **Should** be set when aiming at reproducibility. """ super().__init__( triples_factory=triples_factory, loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, predict_with_sigmoid=predict_with_sigmoid, ) - self.entity_embeddings = Embedding.init_with_device( + self.entity_representations = _prepare_representation_module_list( + representations=entity_representations, num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - device=self.device, - initializer=entity_initializer, - initializer_kwargs=entity_initializer_kwargs, - normalizer=entity_normalizer, - normalizer_kwargs=entity_normalizer_kwargs, - constrainer=entity_constrainer, - constrainer_kwargs=entity_constrainer_kwargs, + shapes=interaction.entity_shape, + label="entity", + skip_checks=interaction.tail_entity_shape is not None, ) - - # Default for relation dimensionality - if relation_dim is None: - relation_dim = embedding_dim - - self.relation_embeddings = Embedding.init_with_device( + self.relation_representations = _prepare_representation_module_list( + representations=relation_representations, num_embeddings=triples_factory.num_relations, - embedding_dim=relation_dim, - device=self.device, - initializer=relation_initializer, - initializer_kwargs=relation_initializer_kwargs, - normalizer=relation_normalizer, - normalizer_kwargs=relation_normalizer_kwargs, - constrainer=relation_constrainer, - constrainer_kwargs=relation_constrainer_kwargs, + shapes=interaction.relation_shape, + label="relation", ) + self.interaction = interaction + # Comment: it is important that the regularizers are stored in a module list, in order to appear in + # model.modules(). Thereby, we can collect them automatically. + self.weight_regularizers = nn.ModuleList() - @property - def embedding_dim(self) -> int: # noqa:D401 - """The entity embedding dimension.""" - return self.entity_embeddings.embedding_dim - - @property - def relation_dim(self): # noqa:D401 - """The relation embedding dimension.""" - return self.relation_embeddings.embedding_dim - - def _reset_parameters_(self): # noqa: D102 - self.entity_embeddings.reset_parameters() - self.relation_embeddings.reset_parameters() - - def post_parameter_update(self) -> None: # noqa: D102 - # make sure to call this first, to reset regularizer state! - super().post_parameter_update() - self.entity_embeddings.post_parameter_update() - self.relation_embeddings.post_parameter_update() + def append_weight_regularizer( + self, + parameter: Union[str, nn.Parameter, Iterable[Union[str, nn.Parameter]]], + regularizer: Regularizer, + ) -> None: + """Add a model weight to a regularizer's weight list, and register the regularizer with the model. + :param parameter: + The parameter, either as name, or as nn.Parameter object. A list of available parameter names is shown by + `sorted(dict(self.named_parameters()).keys())`. + :param regularizer: + The regularizer instance which will regularize the weights. + """ + # normalize input + if isinstance(parameter, (str, nn.Parameter)): + parameter = [parameter] + weights: Mapping[str, nn.Parameter] = dict(self.named_parameters()) + for param in parameter: + if isinstance(param, str): + if param not in weights: + raise KeyError(f"Invalid parameter_name={parameter}. Available are: {sorted(weights.keys())}.") + param: nn.Parameter = weights[param] # type: ignore + regularizer.add_parameter(parameter=param) + self.weight_regularizers.append(regularizer) + + def forward( + self, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, + ) -> torch.FloatTensor: + """Forward pass. -def _can_slice(fn) -> bool: - return 'slice_size' in inspect.getfullargspec(fn).args + This method takes head, relation and tail indices and calculates the corresponding score. + All indices which are not None, have to be either 1-element or have the same shape, which is the batch size. -class MultimodalModel(Model, autoreset=False): - """A multimodal KGE model.""" + :param h_indices: + The head indices. None indicates to use all. + :param r_indices: + The relation indices. None indicates to use all. + :param t_indices: + The tail indices. None indicates to use all. + :param slice_size: + The slice size. + :param slice_dim: + The dimension along which to slice. From {"h", "r", "t"} - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1) + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The score for each triple. + """ + h, r, t = self._get_representations(h_indices, r_indices, t_indices) + scores = self.interaction.score(h=h, r=r, t=t, slice_size=slice_size, slice_dim=slice_dim) + return self._repeat_scores_if_necessary(scores, h_indices, r_indices, t_indices) - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None) + def _repeat_scores_if_necessary( + self, + scores: torch.FloatTensor, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + ) -> torch.FloatTensor: + repeat_relations = len(self.relation_representations) == 0 + repeat_entities = len(self.entity_representations) == 0 + + if not (repeat_entities or repeat_relations): + return scores + + repeats = [1, 1, 1, 1] + + for i, (flag, ind, num) in enumerate(( + (repeat_entities, h_indices, self.num_entities), + (repeat_relations, r_indices, self.num_relations), + (repeat_entities, t_indices, self.num_entities), + ), start=1): + if flag: + if ind is None: + repeats[i] = num + else: + batch_size = len(ind) + if scores.shape[0] < batch_size: + repeats[0] = batch_size - def score_r(self, ht_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=ht_batch[:, 0], r_indices=None, t_indices=ht_batch[:, 1]) + return scores.repeat(*repeats) - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1]) + def _get_representations( + self, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + ) -> Tuple[ + Union[torch.FloatTensor, Sequence[torch.FloatTensor]], + Union[torch.FloatTensor, Sequence[torch.FloatTensor]], + Union[torch.FloatTensor, Sequence[torch.FloatTensor]], + ]: + h, r, t = [ + [ + representation.get_in_canonical_shape(dim=dim, indices=indices) + for representation in representations + ] + for dim, indices, representations in ( + ("h", h_indices, self.entity_representations), + ("r", r_indices, self.relation_representations), + ("t", t_indices, self.entity_representations), + ) + ] + # normalization + h, r, t = [x[0] if len(x) == 1 else x for x in (h, r, t)] + return h, r, t diff --git a/src/pykeen/models/multimodal/__init__.py b/src/pykeen/models/multimodal/__init__.py index c81a03b74b..e0122ad108 100644 --- a/src/pykeen/models/multimodal/__init__.py +++ b/src/pykeen/models/multimodal/__init__.py @@ -6,10 +6,12 @@ `_ arXiv preprint arXiv:1802.00934. """ +from .base import LiteralModel from .complex_literal import ComplExLiteral from .distmult_literal import DistMultLiteral __all__ = [ 'ComplExLiteral', 'DistMultLiteral', + 'LiteralModel', ] diff --git a/src/pykeen/models/multimodal/base.py b/src/pykeen/models/multimodal/base.py new file mode 100644 index 0000000000..496b8743ca --- /dev/null +++ b/src/pykeen/models/multimodal/base.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +"""Base classes for multi-modal models.""" + +from typing import Optional, TYPE_CHECKING, Tuple + +import torch +from torch import nn + +from ..base import ERModel +from ...losses import Loss +from ...nn import Embedding, EmbeddingSpecification, Interaction, LiteralRepresentations +from ...triples import TriplesNumericLiteralsFactory +from ...typing import DeviceHint, HeadRepresentation, RelationRepresentation, Representation, TailRepresentation + +__all__ = [ + "LiteralModel", +] + +if TYPE_CHECKING: + from ...typing import Representation # noqa + + +class LiteralInteraction( + Interaction[ + Tuple[Representation, Representation], + Representation, + Tuple[Representation, Representation], + ], +): + + def __init__( + self, + base: Interaction[Representation, Representation, Representation], + combination: nn.Module, + ): + super().__init__() + self.base = base + self.combination = combination + self.entity_shape = tuple(self.base.entity_shape) + ("e",) + + def forward( + self, + h: Tuple[Representation, Representation], + r: Representation, + t: Tuple[Representation, Representation], + ) -> torch.FloatTensor: + # combine entity embeddings + literals + h = torch.cat(h, dim=-1) + h = self.combination(h.view(-1, h.shape[-1])).view(*h.shape[:-1], -1) # type: ignore + t = torch.cat(t, dim=-1) + t = self.combination(t.view(-1, t.shape[-1])).view(*t.shape[:-1], -1) # type: ignore + return self.base(h=h, r=r, t=t) + + +class LiteralModel(ERModel[HeadRepresentation, RelationRepresentation, TailRepresentation], autoreset=False): + """Base class for models with entity literals.""" + + def __init__( + self, + triples_factory: TriplesNumericLiteralsFactory, + interaction: LiteralInteraction, + entity_specification: Optional[EmbeddingSpecification] = None, + relation_specification: Optional[EmbeddingSpecification] = None, + loss: Optional[Loss] = None, + predict_with_sigmoid: bool = False, + preferred_device: DeviceHint = None, + random_seed: Optional[int] = None, + ): + super().__init__( + triples_factory=triples_factory, + interaction=interaction, + loss=loss, + predict_with_sigmoid=predict_with_sigmoid, + preferred_device=preferred_device, + random_seed=random_seed, + entity_representations=[ + # entity embeddings + Embedding.from_specification( + num_embeddings=triples_factory.num_entities, + specification=entity_specification, + ), + # Entity literals + LiteralRepresentations( + numeric_literals=torch.as_tensor(triples_factory.numeric_literals, dtype=torch.float32), + ), + ], + relation_representations=Embedding.from_specification( + num_embeddings=triples_factory.num_relations, + specification=relation_specification, + ), + ) diff --git a/src/pykeen/models/multimodal/complex_literal.py b/src/pykeen/models/multimodal/complex_literal.py index d1f7dff957..d1324f1620 100644 --- a/src/pykeen/models/multimodal/complex_literal.py +++ b/src/pykeen/models/multimodal/complex_literal.py @@ -1,24 +1,63 @@ # -*- coding: utf-8 -*- -"""Implementation of the ComplexLiteral model based on the local closed world assumption (LCWA) training approach.""" +"""Implementation of the ComplexLiteral model.""" from typing import Any, ClassVar, Mapping, Optional, Type import torch import torch.nn as nn -from torch.nn.init import xavier_normal_ -from ..base import MultimodalModel -from ..unimodal.complex import ComplEx +from .base import LiteralInteraction, LiteralModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import BCEWithLogitsLoss, Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification +from ...nn.modules import ComplExInteraction from ...triples import TriplesNumericLiteralsFactory from ...typing import DeviceHint -from ...utils import split_complex +from ...utils import combine_complex, split_complex +__all__ = [ + 'ComplExLiteral', +] -class ComplExLiteral(ComplEx, MultimodalModel): + +class ComplExLiteralCombination(nn.Module): + """Separately transform real and imaginary part.""" + + def __init__( + self, + embedding_dim: int, + num_of_literals: int, + dropout: float = 0.0, + ): + super().__init__() + self.real = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(embedding_dim + num_of_literals, embedding_dim), + torch.nn.Tanh(), + ) + self.imag = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(embedding_dim + num_of_literals, embedding_dim), + torch.nn.Tanh(), + ) + # TODO: Determine this automatically + self.embedding_dim = 2 * embedding_dim + + def forward( + self, + x: torch.FloatTensor, + ) -> torch.FloatTensor: + x, literal = x[..., :self.embedding_dim], x[..., self.embedding_dim:] + x_re, x_im = split_complex(x) + x_re = self.real(torch.cat([x_re, literal], dim=-1)) + x_im = self.imag(torch.cat([x_im, literal], dim=-1)) + return combine_complex(x_re=x_re, x_im=x_im) + + +# TODO: Check entire build of the model +# TODO: There are no tests. +class ComplExLiteral(LiteralModel): """An implementation of ComplexLiteral from [agustinus2018]_ based on the LCWA training approach.""" #: The default strategy for optimizing the model's hyper-parameters @@ -37,72 +76,33 @@ def __init__( embedding_dim: int = 50, input_dropout: float = 0.2, loss: Optional[Loss] = None, + predict_with_sigmoid: bool = False, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, ) -> None: """Initialize the model.""" super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=LiteralInteraction( + base=ComplExInteraction(), + combination=ComplExLiteralCombination( + embedding_dim=embedding_dim, + num_of_literals=triples_factory.numeric_literals.shape[-1], + dropout=input_dropout, + ), + ), + entity_specification=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.xavier_normal_, + dtype=torch.complex64, + ), + relation_specification=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.xavier_normal_, + dtype=torch.complex64, + ), loss=loss, + predict_with_sigmoid=predict_with_sigmoid, preferred_device=preferred_device, random_seed=random_seed, - entity_initializer=xavier_normal_, - relation_initializer=xavier_normal_, ) - - # Literal - # num_ent x num_lit - self.numeric_literals = Embedding( - num_embeddings=triples_factory.num_entities, - embedding_dim=triples_factory.numeric_literals.shape[-1], - initializer=lambda x: triples_factory.numeric_literals, - ) - # Number of columns corresponds to number of literals - self.num_of_literals = self.numeric_literals.embedding_dim - - self.real_non_lin_transf = torch.nn.Sequential( - nn.Linear(self.embedding_dim // 2 + self.num_of_literals, self.embedding_dim // 2), - torch.nn.Tanh(), - ) - - self.img_non_lin_transf = torch.nn.Sequential( - nn.Linear(self.embedding_dim // 2 + self.num_of_literals, self.embedding_dim // 2), - torch.nn.Tanh(), - ) - - self.inp_drop = torch.nn.Dropout(input_dropout) - - def _get_entity_representations( - self, - idx: torch.LongTensor, - dropout: bool, - ) -> torch.FloatTensor: - emb = self.entity_embeddings.get_in_canonical_shape(indices=idx) - lit = self.numeric_literals.get_in_canonical_shape(indices=idx) - if dropout: - emb = self.inp_drop(emb) - re, im = split_complex(emb) - re, im = [torch.cat([x, lit], dim=-1) for x in (re, im)] - re, im = [ - trans(x.view(-1, x.shape[-1])).view(*(x.shape[:-1]), self.embedding_dim // 2) - for x, trans in ( - (re, self.real_non_lin_transf), - (im, self.img_non_lin_transf), - ) - ] - x = torch.cat([re, im], dim=-1) - if dropout: - x = self.inp_drop(x) - return x - - def forward( - self, - h_indices: Optional[torch.LongTensor], - r_indices: Optional[torch.LongTensor], - t_indices: Optional[torch.LongTensor], - ) -> torch.FloatTensor: # noqa: D102 - h = self._get_entity_representations(idx=h_indices, dropout=True) - r = self.inp_drop(self.relation_embeddings.get_in_canonical_shape(indices=r_indices)) - t = self._get_entity_representations(idx=t_indices, dropout=False) - return self.interaction_function(h=h, r=r, t=t) diff --git a/src/pykeen/models/multimodal/distmult_literal.py b/src/pykeen/models/multimodal/distmult_literal.py index c350774c1a..76dccb93f8 100644 --- a/src/pykeen/models/multimodal/distmult_literal.py +++ b/src/pykeen/models/multimodal/distmult_literal.py @@ -2,21 +2,29 @@ """Implementation of the DistMultLiteral model.""" -from typing import Any, ClassVar, Mapping, Optional +from typing import Any, ClassVar, Mapping, Optional, TYPE_CHECKING -import torch import torch.nn as nn -from ..base import MultimodalModel -from ..unimodal.distmult import DistMult +from .base import LiteralInteraction, LiteralModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification +from ...nn.modules import DistMultInteraction from ...triples import TriplesNumericLiteralsFactory from ...typing import DeviceHint +if TYPE_CHECKING: + from ...typing import Representation # noqa -class DistMultLiteral(DistMult, MultimodalModel): +__all__ = [ + 'DistMultLiteral', +] + + +# TODO: Check entire build of the model +# TODO: There are no tests +class DistMultLiteral(LiteralModel): """An implementation of DistMultLiteral from [agustinus2018]_.""" #: The default strategy for optimizing the model's hyper-parameters @@ -35,44 +43,27 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, + predict_with_sigmoid: bool = False, ) -> None: super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=LiteralInteraction( + base=DistMultInteraction(), + combination=nn.Sequential( + nn.Linear(embedding_dim + triples_factory.numeric_literals.shape[1], embedding_dim), + nn.Dropout(input_dropout), + ), + ), + entity_specification=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.xavier_normal_, + ), + relation_specification=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.xavier_normal_, + ), loss=loss, + predict_with_sigmoid=predict_with_sigmoid, preferred_device=preferred_device, random_seed=random_seed, ) - - # Literal - # num_ent x num_lit - self.numeric_literals = Embedding( - num_embeddings=triples_factory.num_entities, - embedding_dim=triples_factory.numeric_literals.shape[-1], - initializer=lambda x: triples_factory.numeric_literals, - ) - # Number of columns corresponds to number of literals - self.num_of_literals = self.numeric_literals.embedding_dim - self.linear_transformation = nn.Linear(self.embedding_dim + self.num_of_literals, self.embedding_dim) - self.inp_drop = torch.nn.Dropout(input_dropout) - - def _get_entity_representations( - self, - idx: torch.LongTensor, - ) -> torch.FloatTensor: - emb = self.entity_embeddings.get_in_canonical_shape(indices=idx) - lit = self.numeric_literals.get_in_canonical_shape(indices=idx) - x = self.linear_transformation(torch.cat([emb, lit], dim=-1)) - return self.inp_drop(x) - - def forward( - self, - h_indices: Optional[torch.LongTensor], - r_indices: Optional[torch.LongTensor], - t_indices: Optional[torch.LongTensor], - ) -> torch.FloatTensor: # noqa: D102 - # TODO: this is very similar to ComplExLiteral, except a few dropout differences - h = self._get_entity_representations(idx=h_indices) - r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) - t = self._get_entity_representations(idx=t_indices) - return self.interaction_function(h=h, r=r, t=t) diff --git a/src/pykeen/models/unimodal/complex.py b/src/pykeen/models/unimodal/complex.py index 4cc6be0533..42f3c826f4 100644 --- a/src/pykeen/models/unimodal/complex.py +++ b/src/pykeen/models/unimodal/complex.py @@ -7,20 +7,21 @@ import torch import torch.nn as nn -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss, SoftplusLoss +from ...nn import EmbeddingSpecification +from ...nn.modules import ComplExInteraction from ...regularizers import LpRegularizer, Regularizer from ...triples import TriplesFactory from ...typing import DeviceHint -from ...utils import split_complex __all__ = [ 'ComplEx', ] -class ComplEx(EntityRelationEmbeddingModel): +class ComplEx(ERModel): r"""An implementation of ComplEx [trouillon2016]_. ComplEx is an extension of :class:`pykeen.models.DistMult` that uses complex valued representations for the @@ -75,98 +76,48 @@ def __init__( regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - entity_initializer=nn.init.normal_, - relation_initializer=nn.init.normal_, + embedding_specification: Optional[EmbeddingSpecification] = None, + relation_embedding_specification: Optional[EmbeddingSpecification] = None, ) -> None: """Initialize ComplEx. - :param triples_factory: + :param triples_factory: TriplesFactory The triple factory connected to the model. :param embedding_dim: The embedding dimensionality of the entity embeddings. :param loss: - The loss to use. Defaults to SoftplusLoss. + The loss to use. Defaults to :data:`loss_default`. :param regularizer: - The regularizer to use. + The regularizer to use. Defaults to :data:`regularizer_default`. :param preferred_device: The default device where to model is located. :param random_seed: An optional random seed to set before the initialization of weights. """ + if regularizer is None: + regularizer = self._instantiate_default_regularizer() + # initialize with entity and relation embeddings with standard normal distribution, cf. + # https://github.com/ttrouill/complex/blob/dc4eb93408d9a5288c986695b58488ac80b1cc17/efe/models.py#L481-L487 + if embedding_specification is None: + embedding_specification = EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.normal_, + regularizer=regularizer, + dtype=torch.complex64, + ) + if relation_embedding_specification is None: + relation_embedding_specification = EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=nn.init.normal_, + regularizer=regularizer, + dtype=torch.complex64, + ) super().__init__( triples_factory=triples_factory, - embedding_dim=2 * embedding_dim, # complex embeddings + interaction=ComplExInteraction(), + entity_representations=embedding_specification, + relation_representations=relation_embedding_specification, loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - # initialize with entity and relation embeddings with standard normal distribution, cf. - # https://github.com/ttrouill/complex/blob/dc4eb93408d9a5288c986695b58488ac80b1cc17/efe/models.py#L481-L487 - entity_initializer=entity_initializer, - relation_initializer=relation_initializer, ) - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function of ComplEx for given embeddings. - - The embeddings have to be in a broadcastable shape. - - :param h: - Head embeddings. - :param r: - Relation embeddings. - :param t: - Tail embeddings. - - :return: shape: (...) - The scores. - """ - # split into real and imaginary part - (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] - - # ComplEx space bilinear product - # *: Elementwise multiplication - return sum( - (hh * rr * tt).sum(dim=-1) - for hh, rr, tt in [ - (h_re, r_re, t_re), - (h_re, r_im, t_im), - (h_im, r_re, t_im), - (h_im, r_im, t_re), - ] - ) - - def forward( - self, - h_indices: Optional[torch.LongTensor], - r_indices: Optional[torch.LongTensor], - t_indices: Optional[torch.LongTensor], - ) -> torch.FloatTensor: - """Unified score function.""" - # get embeddings - h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) - r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) - t = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) - - # Regularization - self.regularize_if_necessary(h, r, t) - - # Compute scores - return self.interaction_function(h=h, r=r, t=t) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None) - - def score_r(self, ht_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=ht_batch[:, 0], r_indices=None, t_indices=ht_batch[:, 1]) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1]) diff --git a/src/pykeen/models/unimodal/conv_e.py b/src/pykeen/models/unimodal/conv_e.py index a66a2493d1..9a08a9effa 100644 --- a/src/pykeen/models/unimodal/conv_e.py +++ b/src/pykeen/models/unimodal/conv_e.py @@ -3,23 +3,19 @@ """Implementation of ConvE.""" import logging -import math -import sys -from typing import Any, ClassVar, Mapping, Optional, Tuple, Type +from typing import Any, ClassVar, Mapping, Optional, Type import torch from torch import nn -from torch.nn import functional as F # noqa: N812 -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE from ...losses import BCEAfterSigmoidLoss, Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification from ...nn.init import xavier_normal_ -from ...regularizers import Regularizer +from ...nn.modules import ConvEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint -from ...utils import is_cudnn_error __all__ = [ 'ConvE', @@ -28,56 +24,7 @@ logger = logging.getLogger(__name__) -def _calculate_missing_shape_information( - embedding_dim: int, - input_channels: Optional[int] = None, - width: Optional[int] = None, - height: Optional[int] = None, -) -> Tuple[int, int, int]: - """ - Automatically calculates missing dimensions for ConvE. - - :param embedding_dim: - :param input_channels: - :param width: - :param height: - - :return: (input_channels, width, height), such that - `embedding_dim = input_channels * width * height` - :raises: - If no factorization could be found. - """ - # Store initial input for error message - original = (input_channels, width, height) - - # All are None - if all(factor is None for factor in [input_channels, width, height]): - input_channels = 1 - result_sqrt = math.floor(math.sqrt(embedding_dim)) - height = max(factor for factor in range(1, result_sqrt + 1) if embedding_dim % factor == 0) - width = embedding_dim // height - - # input_channels is None, and any of height or width is None -> set input_channels=1 - if input_channels is None and any(remaining is None for remaining in [width, height]): - input_channels = 1 - - # input channels is not None, and one of height or width is None - assert len([factor for factor in [input_channels, width, height] if factor is None]) <= 1 - if width is None: - width = embedding_dim // (height * input_channels) - if height is None: - height = embedding_dim // (width * input_channels) - if input_channels is None: - input_channels = embedding_dim // (width * height) - assert not any(factor is None for factor in [input_channels, width, height]) - - if input_channels * width * height != embedding_dim: - raise ValueError(f'Could not resolve {original} to a valid factorization of {embedding_dim}.') - - return input_channels, width, height - - -class ConvE(EntityRelationEmbeddingModel): +class ConvE(ERModel): r"""An implementation of ConvE from [dettmers2018]_. ConvE is a CNN-based approach. For each triple $(h,r,t)$, the input to ConvE is a matrix @@ -159,12 +106,6 @@ class ConvE(EntityRelationEmbeddingModel): #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = {} - #: If batch normalization is enabled, this is: num_features – C from an expected input of size (N,C,L) - bn0: Optional[torch.nn.BatchNorm2d] - #: If batch normalization is enabled, this is: num_features – C from an expected input of size (N,C,H,W) - bn1: Optional[torch.nn.BatchNorm2d] - bn2: Optional[torch.nn.BatchNorm2d] - def __init__( self, triples_factory: TriplesFactory, @@ -181,7 +122,6 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, apply_batch_normalization: bool = True, ) -> None: """Initialize the model.""" @@ -192,239 +132,51 @@ def __init__( 'This can be done by defining the TriplesFactory class with the _create_inverse_triples_ parameter set ' 'to true.', ) - super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=ConvEInteraction( + input_channels=input_channels, + output_channels=output_channels, + embedding_height=embedding_height, + embedding_width=embedding_width, + kernel_height=kernel_height, + kernel_width=kernel_width, + input_dropout=input_dropout, + output_dropout=output_dropout, + feature_map_dropout=feature_map_dropout, + embedding_dim=embedding_dim, + apply_batch_normalization=apply_batch_normalization, + ), + entity_representations=[ + EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_normal_, + ), + # ConvE uses one bias for each entity + EmbeddingSpecification( + embedding_dim=1, + initializer=nn.init.zeros_, + ), + ], + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_normal_, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_normal_, - relation_initializer=xavier_normal_, - ) - - # ConvE uses one bias for each entity - self.bias_term = Embedding.init_with_device( - num_embeddings=triples_factory.num_entities, - embedding_dim=1, - device=self.device, - initializer=nn.init.zeros_, ) - # Automatic calculation of remaining dimensions - logger.info(f'Resolving {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') - if embedding_dim is None: - embedding_dim = input_channels * embedding_width * embedding_height - - # Parameter need to fulfil: - # input_channels * embedding_height * embedding_width = embedding_dim - input_channels, embedding_width, embedding_height = _calculate_missing_shape_information( - embedding_dim=embedding_dim, - input_channels=input_channels, - width=embedding_width, - height=embedding_height, - ) - logger.info(f'Resolved to {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') - self.embedding_height = embedding_height - self.embedding_width = embedding_width - self.input_channels = input_channels - - if self.input_channels * self.embedding_height * self.embedding_width != self.embedding_dim: - raise ValueError( - f'Product of input channels ({self.input_channels}), height ({self.embedding_height}), and width ' - f'({self.embedding_width}) does not equal target embedding dimension ({self.embedding_dim})', - ) - - self.inp_drop = nn.Dropout(input_dropout) - self.hidden_drop = nn.Dropout(output_dropout) - self.feature_map_drop = nn.Dropout2d(feature_map_dropout) - - self.conv1 = torch.nn.Conv2d( - in_channels=self.input_channels, - out_channels=output_channels, - kernel_size=(kernel_height, kernel_width), - stride=1, - padding=0, - bias=True, - ) - - self.apply_batch_normalization = apply_batch_normalization - if self.apply_batch_normalization: - self.bn0 = nn.BatchNorm2d(self.input_channels) - self.bn1 = nn.BatchNorm2d(output_channels) - self.bn2 = nn.BatchNorm1d(self.embedding_dim) - else: - self.bn0 = None - self.bn1 = None - self.bn2 = None - num_in_features = ( - output_channels - * (2 * self.embedding_height - kernel_height + 1) - * (self.embedding_width - kernel_width + 1) - ) - self.fc = nn.Linear(num_in_features, self.embedding_dim) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - - self.bias_term.reset_parameters() - - # weights - for module in [ - self.conv1, - self.bn0, - self.bn1, - self.bn2, - self.fc, - ]: - if module is None: - continue - module.reset_parameters() - - def _convolve_entity_relation(self, h: torch.LongTensor, r: torch.LongTensor) -> torch.FloatTensor: - """Perform the main calculations of the ConvE model.""" - batch_size = h.shape[0] - - # batch_size, num_input_channels, 2*height, width - x = torch.cat([h, r], dim=2) - - try: - # batch_size, num_input_channels, 2*height, width - if self.apply_batch_normalization: - x = self.bn0(x) - - # batch_size, num_input_channels, 2*height, width - x = self.inp_drop(x) - # (N,C_out,H_out,W_out) - x = self.conv1(x) - - if self.apply_batch_normalization: - x = self.bn1(x) - x = F.relu(x) - x = self.feature_map_drop(x) - # batch_size, num_output_channels * (2 * height - kernel_height + 1) * (width - kernel_width + 1) - x = x.view(batch_size, -1) - x = self.fc(x) - x = self.hidden_drop(x) - - if self.apply_batch_normalization: - x = self.bn2(x) - x = F.relu(x) - except RuntimeError as e: - if not is_cudnn_error(e): - raise e - logger.warning( - '\nThis code crash might have been caused by a CUDA bug, see ' - 'https://github.com/allenai/allennlp/issues/2888, ' - 'which causes the code to crash during evaluation mode.\n' - 'To avoid this error, the batch size has to be reduced.\n' - f'The original error message: \n{e.args[0]}', - ) - sys.exit(1) - - return x - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hrt_batch[:, 0]).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - x = self._convolve_entity_relation(h, r) - - # For efficient calculation, each of the convolved [h, r] rows has only to be multiplied with one t row - x = (x.view(-1, self.embedding_dim) * t).sum(dim=1, keepdim=True) - - """ - In ConvE the bias term add the end is added for each tail item. In the sLCWA training approach we only have - one tail item for each head and relation. Accordingly the relevant bias for each tail item and triple has to be - looked up. - """ - x = x + self.bias_term(indices=hrt_batch[:, 2]) - # The application of the sigmoid during training is automatically handled by the default loss. - - return x - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hr_batch[:, 0]).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - r = self.relation_embeddings(indices=hr_batch[:, 1]).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - t = self.entity_embeddings(indices=None).transpose(1, 0) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - x = self._convolve_entity_relation(h, r) - - x = x @ t - x = x + self.bias_term(indices=None).t() - # The application of the sigmoid during training is automatically handled by the default loss. - - return x - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - rt_batch_size = rt_batch.shape[0] - h = self.entity_embeddings(indices=None) - r = self.relation_embeddings(indices=rt_batch[:, 0]).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - t = self.entity_embeddings(indices=rt_batch[:, 1]) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - ''' - Every head has to be convolved with every relation in the rt_batch. Hence we repeat the - relation _num_entities_ times and the head _rt_batch_size_ times. - ''' - r = r.repeat(h.shape[0], 1, 1, 1) - # Code to repeat each item successively instead of the entire tensor - h = h.unsqueeze(1).repeat(1, rt_batch_size, 1).view( - -1, - self.input_channels, - self.embedding_height, - self.embedding_width, - ) - - x = self._convolve_entity_relation(h, r) - - ''' - For efficient computation, each convolved [h, r] pair has only to be multiplied with the corresponding t - embedding found in the rt_batch with [r, t] pairs. - ''' - x = (x.view(self.num_entities, rt_batch_size, self.embedding_dim) * t[None, :, :]).sum(2).transpose(1, 0) - - """ - In ConvE the bias term at the end is added for each tail item. In the score_h function, each row holds - the same tail for many different heads, meaning that these items have to be looked up for each tail of each row - and only then can be added correctly. - """ - x = x + self.bias_term(indices=rt_batch[:, 1]) - # The application of the sigmoid during training is automatically handled by the default loss. - - return x + def forward( + self, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, + ) -> torch.FloatTensor: # noqa: D102 + h = self.entity_representations[0].get_in_canonical_shape(dim="h", indices=h_indices) + r = self.relation_representations[0].get_in_canonical_shape(dim="r", indices=r_indices) + t = self.entity_representations[0].get_in_canonical_shape(dim="t", indices=t_indices) + t_bias = self.entity_representations[1].get_in_canonical_shape(dim="t", indices=t_indices) + return self.interaction.score(h=h, r=r, t=(t, t_bias), slice_size=slice_size, slice_dim=slice_dim) diff --git a/src/pykeen/models/unimodal/conv_kb.py b/src/pykeen/models/unimodal/conv_kb.py index a01a226f0e..605ee00a21 100644 --- a/src/pykeen/models/unimodal/conv_kb.py +++ b/src/pykeen/models/unimodal/conv_kb.py @@ -5,13 +5,11 @@ import logging from typing import Any, ClassVar, Mapping, Optional, Type -import torch -import torch.autograd -from torch import nn - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification +from ...nn.modules import ConvKBInteraction from ...regularizers import LpRegularizer, Regularizer from ...triples import TriplesFactory from ...typing import DeviceHint @@ -23,7 +21,7 @@ logger = logging.getLogger(__name__) -class ConvKB(EntityRelationEmbeddingModel): +class ConvKB(ERModel): r"""An implementation of ConvKB from [nguyen2018]_. ConvKB uses a convolutional neural network (CNN) whose feature maps capture global interactions of the input. @@ -78,10 +76,10 @@ def __init__( hidden_dropout_rate: float = 0., embedding_dim: int = 200, loss: Optional[Loss] = None, + regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, num_filters: int = 400, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: """Initialize the model. @@ -89,57 +87,28 @@ def __init__( """ super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=ConvKBInteraction( + hidden_dropout_rate=hidden_dropout_rate, + embedding_dim=embedding_dim, + num_filters=num_filters, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, ) - - self.num_filters = num_filters - - # The interaction model - self.conv = nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(1, 3), bias=True) - self.relu = nn.ReLU() - self.hidden_dropout = nn.Dropout(p=hidden_dropout_rate) - self.linear = nn.Linear(embedding_dim * num_filters, 1, bias=True) - - def _reset_parameters_(self): # noqa: D102 - # embeddings - logger.warning('To be consistent with the paper, initialize entity and relation embeddings from TransE.') - super()._reset_parameters_() - - # Use Xavier initialization for weight; bias to zero - nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) - nn.init.zeros_(self.linear.bias) - - # Initialize all filters to [0.1, 0.1, -0.1], - # c.f. https://github.com/daiquocnguyen/ConvKB/blob/master/model.py#L34-L36 - nn.init.constant_(self.conv.weight[..., :2], 0.1) - nn.init.constant_(self.conv.weight[..., 2], -0.1) - nn.init.zeros_(self.conv.bias) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Output layer regularization + if regularizer is None: + regularizer = self._instantiate_default_regularizer() # In the code base only the weights of the output layer are used for regularization # c.f. https://github.com/daiquocnguyen/ConvKB/blob/73a22bfa672f690e217b5c18536647c7cf5667f1/model.py#L60-L66 - self.regularize_if_necessary(self.linear.weight, self.linear.bias) - - # Stack to convolution input - conv_inp = torch.stack([h, r, t], dim=-1).view(-1, 1, self.embedding_dim, 3) - - # Convolution - conv_out = self.conv(conv_inp).view(-1, self.embedding_dim * self.num_filters) - hidden = self.relu(conv_out) - - # Apply dropout, cf. https://github.com/daiquocnguyen/ConvKB/blob/master/model.py#L54-L56 - hidden = self.hidden_dropout(hidden) - - # Linear layer for final scores - scores = self.linear(hidden) - - return scores + if regularizer is not None: + self.append_weight_regularizer( + parameter=self.interaction.parameters(), + regularizer=regularizer, + ) + logger.warning('To be consistent with the paper, initialize entity and relation embeddings from TransE.') diff --git a/src/pykeen/models/unimodal/distmult.py b/src/pykeen/models/unimodal/distmult.py index ae7d8e2834..5ca4faa38c 100644 --- a/src/pykeen/models/unimodal/distmult.py +++ b/src/pykeen/models/unimodal/distmult.py @@ -4,14 +4,14 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import torch -import torch.autograd from torch import nn from torch.nn import functional -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification +from ...nn.modules import DistMultInteraction from ...regularizers import LpRegularizer, Regularizer from ...triples import TriplesFactory from ...typing import DeviceHint @@ -22,7 +22,7 @@ ] -class DistMult(EntityRelationEmbeddingModel): +class DistMult(ERModel): r"""An implementation of DistMult from [yang2014]_. This model simplifies RESCAL by restricting matrices representing relations as diagonal matrices. @@ -73,97 +73,38 @@ def __init__( triples_factory: TriplesFactory, embedding_dim: int = 50, loss: Optional[Loss] = None, + regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize DistMult. :param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$. """ + if regularizer is None: + regularizer = self._instantiate_default_regularizer() super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=DistMultInteraction(), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + # xavier uniform, cf. + # https://github.com/thunlp/OpenKE/blob/adeed2c0d2bef939807ed4f69c1ea4db35fd149b/models/DistMult.py#L16-L17 + initializer=nn.init.xavier_uniform_, + # Constrain entity embeddings to unit length + constrainer=functional.normalize, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + # relations are initialized to unit length (but not constraint) + initializer=compose( + nn.init.xavier_uniform_, + functional.normalize, + ), + # Only relation embeddings are regularized + regularizer=regularizer, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - # xavier uniform, cf. - # https://github.com/thunlp/OpenKE/blob/adeed2c0d2bef939807ed4f69c1ea4db35fd149b/models/DistMult.py#L16-L17 - entity_initializer=nn.init.xavier_uniform_, - # Constrain entity embeddings to unit length - entity_constrainer=functional.normalize, - # relations are initialized to unit length (but not constraint) - relation_initializer=compose( - nn.init.xavier_uniform_, - functional.normalize, - ), ) - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function for given embeddings. - - The embeddings have to be in a broadcastable shape. - - WARNING: Does not ensure forward constraints. - - :param h: shape: (..., e) - Head embeddings. - :param r: shape: (..., e) - Relation embeddings. - :param t: shape: (..., e) - Tail embeddings. - - :return: shape: (...) - The scores. - """ - # Bilinear product - # *: Elementwise multiplication - return torch.sum(h * r * t, dim=-1) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(hrt_batch[:, 0]) - r = self.relation_embeddings(hrt_batch[:, 1]) - t = self.entity_embeddings(hrt_batch[:, 2]) - - # Compute score - scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) - - # Only regularize relation embeddings - self.regularize_if_necessary(r) - - return scores - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.embedding_dim) - r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, 1, self.embedding_dim) - t = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim) - - # Rank against all entities - scores = self.interaction_function(h=h, r=r, t=t) - - # Only regularize relation embeddings - self.regularize_if_necessary(r) - - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim) - r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, 1, self.embedding_dim) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.embedding_dim) - - # Rank against all entities - scores = self.interaction_function(h=h, r=r, t=t) - - # Only regularize relation embeddings - self.regularize_if_necessary(r) - - return scores diff --git a/src/pykeen/models/unimodal/ermlp.py b/src/pykeen/models/unimodal/ermlp.py index 418d8ef2e7..91a499cb79 100644 --- a/src/pykeen/models/unimodal/ermlp.py +++ b/src/pykeen/models/unimodal/ermlp.py @@ -4,14 +4,11 @@ from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd -from torch import nn - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...regularizers import Regularizer +from ...nn import EmbeddingSpecification +from ...nn.modules import ERMLPInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -20,7 +17,7 @@ ] -class ERMLP(EntityRelationEmbeddingModel): +class ERMLP(ERModel): r"""An implementation of ERMLP from [dong2014]_. ERMLP is a multi-layer perceptron based approach that uses a single hidden layer and represents entities and @@ -50,107 +47,24 @@ def __init__( preferred_device: DeviceHint = None, random_seed: Optional[int] = None, hidden_dim: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: - """Initialize the model.""" + """Initialize ERMLP.""" + if hidden_dim is None: + hidden_dim = embedding_dim + super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=ERMLPInteraction( + embedding_dim=embedding_dim, + hidden_dim=hidden_dim, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - ) - - if hidden_dim is None: - hidden_dim = embedding_dim - self.hidden_dim = hidden_dim - """The multi-layer perceptron consisting of an input layer with 3 * self.embedding_dim neurons, a hidden layer - with self.embedding_dim neurons and output layer with one neuron. - The input is represented by the concatenation embeddings of the heads, relations and tail embeddings. - """ - self.linear1 = nn.Linear(3 * self.embedding_dim, self.hidden_dim) - self.linear2 = nn.Linear(self.hidden_dim, 1) - self.mlp = nn.Sequential( - self.linear1, - nn.ReLU(), - self.linear2, ) - - def _reset_parameters_(self): # noqa: D102 - # The authors do not specify which initialization was used. Hence, we use the pytorch default. - super()._reset_parameters_() - - # weight initialization - nn.init.zeros_(self.linear1.bias) - nn.init.xavier_uniform_(self.linear1.weight) - nn.init.zeros_(self.linear2.bias) - nn.init.xavier_uniform_(self.linear2.weight, gain=nn.init.calculate_gain('relu')) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - # Concatenate them - x_s = torch.cat([h, r, t], dim=-1) - - # Compute scores - return self.mlp(x_s) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]) - r = self.relation_embeddings(indices=hr_batch[:, 1]) - t = self.entity_embeddings(indices=None) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - # First layer can be unrolled - layers = self.mlp.children() - first_layer = next(layers) - w = first_layer.weight - i = 2 * self.embedding_dim - w_hr = w[None, :, :i] @ torch.cat([h, r], dim=-1).unsqueeze(-1) - w_t = w[None, :, i:] @ t.unsqueeze(-1) - b = first_layer.bias - scores = (b[None, None, :] + w_hr[:, None, :, 0]) + w_t[None, :, :, 0] - - # Send scores through rest of the network - scores = scores.view(-1, self.hidden_dim) - for remaining_layer in layers: - scores = remaining_layer(scores) - scores = scores.view(-1, self.num_entities) - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None) - r = self.relation_embeddings(indices=rt_batch[:, 0]) - t = self.entity_embeddings(indices=rt_batch[:, 1]) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - # First layer can be unrolled - layers = self.mlp.children() - first_layer = next(layers) - w = first_layer.weight - i = self.embedding_dim - w_h = w[None, :, :i] @ h.unsqueeze(-1) - w_rt = w[None, :, i:] @ torch.cat([r, t], dim=-1).unsqueeze(-1) - b = first_layer.bias - scores = (b[None, None, :] + w_rt[:, None, :, 0]) + w_h[None, :, :, 0] - - # Send scores through rest of the network - scores = scores.view(-1, self.hidden_dim) - for remaining_layer in layers: - scores = remaining_layer(scores) - scores = scores.view(-1, self.num_entities) - - return scores diff --git a/src/pykeen/models/unimodal/ermlpe.py b/src/pykeen/models/unimodal/ermlpe.py index bbbc39c99e..39d05080be 100644 --- a/src/pykeen/models/unimodal/ermlpe.py +++ b/src/pykeen/models/unimodal/ermlpe.py @@ -4,13 +4,11 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import torch -from torch import nn - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import BCEAfterSigmoidLoss, Loss -from ...regularizers import Regularizer +from ...nn import EmbeddingSpecification +from ...nn.modules import ERMLPEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -19,7 +17,7 @@ ] -class ERMLPE(EntityRelationEmbeddingModel): +class ERMLPE(ERModel): r"""An extension of ERMLP proposed by [sharifzadeh2019]_. This model uses a neural network-based approach similar to ERMLP and with slight modifications. @@ -64,118 +62,22 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=ERMLPEInteraction( + hidden_dim=hidden_dim, + input_dropout=input_dropout, + hidden_dropout=hidden_dropout, + embedding_dim=embedding_dim, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - ) - self.hidden_dim = hidden_dim - self.input_dropout = input_dropout - - self.linear1 = nn.Linear(2 * self.embedding_dim, self.hidden_dim) - self.linear2 = nn.Linear(self.hidden_dim, self.embedding_dim) - self.input_dropout = nn.Dropout(self.input_dropout) - self.bn1 = nn.BatchNorm1d(self.hidden_dim) - self.bn2 = nn.BatchNorm1d(self.embedding_dim) - self.mlp = nn.Sequential( - self.linear1, - nn.Dropout(hidden_dropout), - self.bn1, - nn.ReLU(), - self.linear2, - nn.Dropout(hidden_dropout), - self.bn2, - nn.ReLU(), ) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - for module in [ - self.linear1, - self.linear2, - self.bn1, - self.bn2, - ]: - module.reset_parameters() - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.embedding_dim) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - # Concatenate them - x_s = torch.cat([h, r], dim=-1) - x_s = self.input_dropout(x_s) - - # Predict t embedding - x_t = self.mlp(x_s) - - # compare with all t's - # For efficient calculation, each of the calculated [h, r] rows has only to be multiplied with one t row - x = (x_t.view(-1, self.embedding_dim) * t).sum(dim=1, keepdim=True) - # The application of the sigmoid during training is automatically handled by the default loss. - - return x - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, self.embedding_dim) - r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim) - t = self.entity_embeddings(indices=None).transpose(1, 0) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - # Concatenate them - x_s = torch.cat([h, r], dim=-1) - x_s = self.input_dropout(x_s) - - # Predict t embedding - x_t = self.mlp(x_s) - - x = x_t @ t - # The application of the sigmoid during training is automatically handled by the default loss. - - return x - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=None) - r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - rt_batch_size = t.shape[0] - - # Extend each rt_batch of "r" with shape [rt_batch_size, dim] to [rt_batch_size, dim * num_entities] - r = torch.repeat_interleave(r, self.num_entities, dim=0) - # Extend each h with shape [num_entities, dim] to [rt_batch_size * num_entities, dim] - # h = torch.repeat_interleave(h, rt_batch_size, dim=0) - h = h.repeat(rt_batch_size, 1) - - # Extend t - t = t.repeat_interleave(self.num_entities, dim=0) - - # Concatenate them - x_s = torch.cat([h, r], dim=-1) - x_s = self.input_dropout(x_s) - - # Predict t embedding - x_t = self.mlp(x_s) - - # For efficient calculation, each of the calculated [h, r] rows has only to be multiplied with one t row - x = (x_t.view(-1, self.embedding_dim) * t).sum(dim=1, keepdim=True) - # The results have to be realigned with the expected output of the score_h function - x = x.view(rt_batch_size, self.num_entities) - # The application of the sigmoid during training is automatically handled by the default loss. - - return x diff --git a/src/pykeen/models/unimodal/hole.py b/src/pykeen/models/unimodal/hole.py index a48346000e..a80dbb4587 100644 --- a/src/pykeen/models/unimodal/hole.py +++ b/src/pykeen/models/unimodal/hole.py @@ -4,14 +4,12 @@ from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification from ...nn.init import xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import HolEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import clamp_norm @@ -21,7 +19,7 @@ ] -class HolE(EntityRelationEmbeddingModel): +class HolE(ERModel): r"""An implementation of HolE [nickel2016]_. Holographic embeddings (HolE) make use of the circular correlation operator to compute interactions between @@ -61,93 +59,23 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: """Initialize the model.""" super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=HolEInteraction(), + # Initialization, cf. https://github.com/mnick/scikit-kge/blob/master/skge/param.py#L18-L27 + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - # Initialisation, cf. https://github.com/mnick/scikit-kge/blob/master/skge/param.py#L18-L27 - entity_initializer=xavier_uniform_, - relation_initializer=xavier_uniform_, - entity_constrainer=clamp_norm, - entity_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), ) - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function for given embeddings. - - The embeddings have to be in a broadcastable shape. - - :param h: shape: (batch_size, num_entities, d) - Head embeddings. - :param r: shape: (batch_size, num_entities, d) - Relation embeddings. - :param t: shape: (batch_size, num_entities, d) - Tail embeddings. - - :return: shape: (batch_size, num_entities) - The scores. - """ - # Circular correlation of entity embeddings - a_fft = torch.rfft(h, signal_ndim=1, onesided=True) - b_fft = torch.rfft(t, signal_ndim=1, onesided=True) - - # complex conjugate, a_fft.shape = (batch_size, num_entities, d', 2) - a_fft[:, :, :, 1] *= -1 - - # Hadamard product in frequency domain - p_fft = a_fft * b_fft - - # inverse real FFT, shape: (batch_size, num_entities, d) - composite = torch.irfft(p_fft, signal_ndim=1, onesided=True, signal_sizes=(h.shape[-1],)) - - # inner product with relation embedding - scores = torch.sum(r * composite, dim=-1, keepdim=False) - - return scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hrt_batch[:, 0]).unsqueeze(dim=1) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).unsqueeze(dim=1) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) - - return scores - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hr_batch[:, 0]).unsqueeze(dim=1) - r = self.relation_embeddings(indices=hr_batch[:, 1]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=None).unsqueeze(dim=0) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - scores = self.interaction_function(h=h, r=r, t=t) - - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=None).unsqueeze(dim=0) - r = self.relation_embeddings(indices=rt_batch[:, 0]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=rt_batch[:, 1]).unsqueeze(dim=1) - - # Embedding Regularization - self.regularize_if_necessary(h, r, t) - - scores = self.interaction_function(h=h, r=r, t=t) - - return scores diff --git a/src/pykeen/models/unimodal/kg2e.py b/src/pykeen/models/unimodal/kg2e.py index 50db0dea6a..f001b3621b 100644 --- a/src/pykeen/models/unimodal/kg2e.py +++ b/src/pykeen/models/unimodal/kg2e.py @@ -2,17 +2,16 @@ """Implementation of KG2E.""" -import math from typing import Any, ClassVar, Mapping, Optional import torch import torch.autograd -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding -from ...regularizers import Regularizer +from ...nn import EmbeddingSpecification +from ...nn.modules import KG2EInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import clamp_norm @@ -21,10 +20,8 @@ 'KG2E', ] -_LOG_2_PI = math.log(2. * math.pi) - -class KG2E(EntityRelationEmbeddingModel): +class KG2E(ERModel): r"""An implementation of KG2E from [he2015]_. KG2E aims to explicitly model (un)certainties in entities and relations (e.g. influenced by the number of triples @@ -62,13 +59,12 @@ def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 50, - loss: Optional[Loss] = None, - preferred_device: DeviceHint = None, - random_seed: Optional[int] = None, dist_similarity: Optional[str] = None, c_min: float = 0.05, c_max: float = 5., - regularizer: Optional[Regularizer] = None, + loss: Optional[Loss] = None, + preferred_device: DeviceHint = None, + random_seed: Optional[int] = None, ) -> None: r"""Initialize KG2E. @@ -77,203 +73,31 @@ def __init__( :param c_min: :param c_max: """ + # Both, entities and relations, are represented as d-dimensional Normal distributions with diagonal covariance + # matrix + representation_spec = [ + # mean of Normal distribution + EmbeddingSpecification( + embedding_dim=embedding_dim, + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ), + # diagonal covariance of Normal distribution + # Ensure positive definite covariances matrices and appropriate size by clamping + EmbeddingSpecification( + embedding_dim=embedding_dim, + constrainer=torch.clamp, + constrainer_kwargs=dict(min=c_min, max=c_max), + ), + ] super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=KG2EInteraction( + similarity=dist_similarity, + ), + entity_representations=representation_spec, + relation_representations=representation_spec, loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_constrainer=clamp_norm, - entity_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), - relation_constrainer=clamp_norm, - relation_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), - ) - - # Similarity function used for distributions - if dist_similarity is None or dist_similarity.upper() == 'KL': - self.similarity = self.kullback_leibler_similarity - elif dist_similarity.upper() == 'EL': - self.similarity = self.expected_likelihood - else: - raise ValueError(f'Unknown distribution similarity: "{dist_similarity}".') - - # element-wise covariance bounds - self.c_min = c_min - self.c_max = c_max - - # Additional covariance embeddings - self.entity_covariances = Embedding.init_with_device( - num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - device=self.device, - # Ensure positive definite covariances matrices and appropriate size by clamping - constrainer=torch.clamp, - constrainer_kwargs=dict(min=self.c_min, max=self.c_max), - ) - self.relation_covariances = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim, - device=self.device, - # Ensure positive definite covariances matrices and appropriate size by clamping - constrainer=torch.clamp, - constrainer_kwargs=dict(min=self.c_min, max=self.c_max), ) - - def _reset_parameters_(self): # noqa: D102 - # Constraints are applied through post_parameter_update - super()._reset_parameters_() - for emb in [ - self.entity_covariances, - self.relation_covariances, - ]: - emb.reset_parameters() - - def post_parameter_update(self) -> None: # noqa: D102 - super().post_parameter_update() - for cov in ( - self.entity_covariances, - self.relation_covariances, - ): - cov.post_parameter_update() - - def _score( - self, - h_indices: Optional[torch.LongTensor] = None, - r_indices: Optional[torch.LongTensor] = None, - t_indices: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - """ - Compute scores for NTN. - - :param h_indices: shape: (batch_size,) - :param r_indices: shape: (batch_size,) - :param t_indices: shape: (batch_size,) - - :return: shape: (batch_size, num_entities) - """ - # Get embeddings - mu_h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) - mu_r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) - mu_t = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) - - sigma_h = self.entity_covariances.get_in_canonical_shape(indices=h_indices) - sigma_r = self.relation_covariances.get_in_canonical_shape(indices=r_indices) - sigma_t = self.entity_covariances.get_in_canonical_shape(indices=t_indices) - - # Compute entity distribution - mu_e = mu_h - mu_t - sigma_e = sigma_h + sigma_t - return self.similarity(mu_e=mu_e, mu_r=mu_r, sigma_e=sigma_e, sigma_r=sigma_r) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1]) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1]) - - @staticmethod - def expected_likelihood( - mu_e: torch.FloatTensor, - mu_r: torch.FloatTensor, - sigma_e: torch.FloatTensor, - sigma_r: torch.FloatTensor, - epsilon: float = 1.0e-10, - ) -> torch.FloatTensor: - r"""Compute the similarity based on expected likelihood. - - .. math:: - - D((\mu_e, \Sigma_e), (\mu_r, \Sigma_r))) - = \frac{1}{2} \left( - (\mu_e - \mu_r)^T(\Sigma_e + \Sigma_r)^{-1}(\mu_e - \mu_r) - + \log \det (\Sigma_e + \Sigma_r) + d \log (2 \pi) - \right) - = \frac{1}{2} \left( - \mu^T\Sigma^{-1}\mu - + \log \det \Sigma + d \log (2 \pi) - \right) - - :param mu_e: torch.Tensor, shape: (s_1, ..., s_k, d) - The mean of the first Gaussian. - :param mu_r: torch.Tensor, shape: (s_1, ..., s_k, d) - The mean of the second Gaussian. - :param sigma_e: torch.Tensor, shape: (s_1, ..., s_k, d) - The diagonal covariance matrix of the first Gaussian. - :param sigma_r: torch.Tensor, shape: (s_1, ..., s_k, d) - The diagonal covariance matrix of the second Gaussian. - :param epsilon: float (default=1.0) - Small constant used to avoid numerical issues when dividing. - - :return: torch.Tensor, shape: (s_1, ..., s_k) - The similarity. - """ - d = sigma_e.shape[-1] - sigma = sigma_r + sigma_e - mu = mu_e - mu_r - - #: a = \mu^T\Sigma^{-1}\mu - safe_sigma = torch.clamp_min(sigma, min=epsilon) - sigma_inv = torch.reciprocal(safe_sigma) - a = torch.sum(sigma_inv * mu ** 2, dim=-1) - - #: b = \log \det \Sigma - b = safe_sigma.log().sum(dim=-1) - return a + b + d * _LOG_2_PI - - @staticmethod - def kullback_leibler_similarity( - mu_e: torch.FloatTensor, - mu_r: torch.FloatTensor, - sigma_e: torch.FloatTensor, - sigma_r: torch.FloatTensor, - epsilon: float = 1.0e-10, - ) -> torch.FloatTensor: - r"""Compute the similarity based on KL divergence. - - This is done between two Gaussian distributions given by mean mu_* and diagonal covariance matrix sigma_*. - - .. math:: - - D((\mu_e, \Sigma_e), (\mu_r, \Sigma_r))) - = \frac{1}{2} \left( - tr(\Sigma_r^{-1}\Sigma_e) - + (\mu_r - \mu_e)^T\Sigma_r^{-1}(\mu_r - \mu_e) - - \log \frac{det(\Sigma_e)}{det(\Sigma_r)} - k_e - \right) - - Note: The sign of the function has been flipped as opposed to the description in the paper, as the - Kullback Leibler divergence is large if the distributions are dissimilar. - - :param mu_e: torch.Tensor, shape: (s_1, ..., s_k, d) - The mean of the first Gaussian. - :param mu_r: torch.Tensor, shape: (s_1, ..., s_k, d) - The mean of the second Gaussian. - :param sigma_e: torch.Tensor, shape: (s_1, ..., s_k, d) - The diagonal covariance matrix of the first Gaussian. - :param sigma_r: torch.Tensor, shape: (s_1, ..., s_k, d) - The diagonal covariance matrix of the second Gaussian. - :param epsilon: float (default=1.0) - Small constant used to avoid numerical issues when dividing. - - :return: torch.Tensor, shape: (s_1, ..., s_k) - The similarity. - """ - d = mu_e.shape[-1] - safe_sigma_r = torch.clamp_min(sigma_r, min=epsilon) - sigma_r_inv = torch.reciprocal(safe_sigma_r) - - #: a = tr(\Sigma_r^{-1}\Sigma_e) - a = torch.sum(sigma_e * sigma_r_inv, dim=-1) - - #: b = (\mu_r - \mu_e)^T\Sigma_r^{-1}(\mu_r - \mu_e) - mu = mu_r - mu_e - b = torch.sum(sigma_r_inv * mu ** 2, dim=-1) - - #: c = \log \frac{det(\Sigma_e)}{det(\Sigma_r)} - # = sum log (sigma_e)_i - sum log (sigma_r)_i - c = sigma_e.clamp_min(min=epsilon).log().sum(dim=-1) - safe_sigma_r.log().sum(dim=-1) - return -0.5 * (a + b - c - d) diff --git a/src/pykeen/models/unimodal/ntn.py b/src/pykeen/models/unimodal/ntn.py index 2511046c27..fb91085814 100644 --- a/src/pykeen/models/unimodal/ntn.py +++ b/src/pykeen/models/unimodal/ntn.py @@ -4,13 +4,13 @@ from typing import Any, ClassVar, Mapping, Optional -import torch from torch import nn -from ..base import EntityEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...regularizers import Regularizer +from ...nn import EmbeddingSpecification +from ...nn.modules import NTNInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -19,7 +19,7 @@ ] -class NTN(EntityEmbeddingModel): +class NTN(ERModel): r"""An implementation of NTN from [socher2013]_. NTN uses a bilinear tensor layer instead of a standard linear neural network layer: @@ -61,7 +61,6 @@ def __init__( preferred_device: DeviceHint = None, random_seed: Optional[int] = None, non_linearity: Optional[nn.Module] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize NTN. @@ -72,173 +71,25 @@ def __init__( """ super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=NTNInteraction( + non_linearity=non_linearity, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + relation_representations=[ + # w: (k, d, d) + EmbeddingSpecification(shape=(num_slices, embedding_dim, embedding_dim)), + # vh: (k, d) + EmbeddingSpecification(shape=(num_slices, embedding_dim)), + # vt: (k, d) + EmbeddingSpecification(shape=(num_slices, embedding_dim)), + # b: (k,) + EmbeddingSpecification(shape=(num_slices,)), + # u: (k,) + EmbeddingSpecification(shape=(num_slices,)), + ], loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, ) - self.num_slices = num_slices - - self.w = nn.Parameter(data=torch.empty( - triples_factory.num_relations, - num_slices, - embedding_dim, - embedding_dim, - device=self.device, - ), requires_grad=True) - self.vh = nn.Parameter(data=torch.empty( - triples_factory.num_relations, - num_slices, - embedding_dim, - device=self.device, - ), requires_grad=True) - self.vt = nn.Parameter(data=torch.empty( - triples_factory.num_relations, - num_slices, - embedding_dim, - device=self.device, - ), requires_grad=True) - self.b = nn.Parameter(data=torch.empty( - triples_factory.num_relations, - num_slices, - device=self.device, - ), requires_grad=True) - self.u = nn.Parameter(data=torch.empty( - triples_factory.num_relations, - num_slices, - device=self.device, - ), requires_grad=True) - if non_linearity is None: - non_linearity = nn.Tanh() - self.non_linearity = non_linearity - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - nn.init.normal_(self.w) - nn.init.normal_(self.vh) - nn.init.normal_(self.vt) - nn.init.normal_(self.b) - nn.init.normal_(self.u) - - def _score( - self, - h_indices: Optional[torch.LongTensor] = None, - r_indices: Optional[torch.LongTensor] = None, - t_indices: Optional[torch.LongTensor] = None, - slice_size: int = None, - ) -> torch.FloatTensor: - """ - Compute scores for NTN. - - :param h_indices: shape: (batch_size,) - :param r_indices: shape: (batch_size,) - :param t_indices: shape: (batch_size,) - - :return: shape: (batch_size, num_entities) - """ - assert r_indices is not None - - #: shape: (batch_size, num_entities, d) - h_all = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) - t_all = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) - - if slice_size is None: - return self._interaction_function(h=h_all, t=t_all, r_indices=r_indices) - - if h_all.shape[1] > t_all.shape[1]: - h_was_split = True - split_tensor = torch.split(h_all, slice_size, dim=1) - constant_tensor = t_all - else: - h_was_split = False - split_tensor = torch.split(t_all, slice_size, dim=1) - constant_tensor = h_all - - scores_arr = [] - for split in split_tensor: - if h_was_split: - h = split - t = constant_tensor - else: - h = constant_tensor - t = split - score = self._interaction_function(h=h, t=t, r_indices=r_indices) - scores_arr.append(score) - - return torch.cat(scores_arr, dim=1) - - def _interaction_function( - self, - h: torch.FloatTensor, - t: torch.FloatTensor, - r_indices: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - #: Prepare h: (b, e, d) -> (b, e, 1, 1, d) - h_for_w = h.unsqueeze(dim=-2).unsqueeze(dim=-2) - - #: Prepare t: (b, e, d) -> (b, e, 1, d, 1) - t_for_w = t.unsqueeze(dim=-2).unsqueeze(dim=-1) - - #: Prepare w: (R, k, d, d) -> (b, k, d, d) -> (b, 1, k, d, d) - w_r = self.w.index_select(dim=0, index=r_indices).unsqueeze(dim=1) - - # h.T @ W @ t, shape: (b, e, k, 1, 1) - hwt = (h_for_w @ w_r @ t_for_w) - - #: reduce (b, e, k, 1, 1) -> (b, e, k) - hwt = hwt.squeeze(dim=-1).squeeze(dim=-1) - - #: Prepare vh: (R, k, d) -> (b, k, d) -> (b, 1, k, d) - vh_r = self.vh.index_select(dim=0, index=r_indices).unsqueeze(dim=1) - - #: Prepare h: (b, e, d) -> (b, e, d, 1) - h_for_v = h.unsqueeze(dim=-1) - - # V_h @ h, shape: (b, e, k, 1) - vhh = vh_r @ h_for_v - - #: reduce (b, e, k, 1) -> (b, e, k) - vhh = vhh.squeeze(dim=-1) - - #: Prepare vt: (R, k, d) -> (b, k, d) -> (b, 1, k, d) - vt_r = self.vt.index_select(dim=0, index=r_indices).unsqueeze(dim=1) - - #: Prepare t: (b, e, d) -> (b, e, d, 1) - t_for_v = t.unsqueeze(dim=-1) - - # V_t @ t, shape: (b, e, k, 1) - vtt = vt_r @ t_for_v - - #: reduce (b, e, k, 1) -> (b, e, k) - vtt = vtt.squeeze(dim=-1) - - #: Prepare b: (R, k) -> (b, k) -> (b, 1, k) - b = self.b.index_select(dim=0, index=r_indices).unsqueeze(dim=1) - - # a = f(h.T @ W @ t + Vh @ h + Vt @ t + b), shape: (b, e, k) - pre_act = hwt + vhh + vtt + b - act = self.non_linearity(pre_act) - - # prepare u: (R, k) -> (b, k) -> (b, 1, k, 1) - u = self.u.index_select(dim=0, index=r_indices).unsqueeze(dim=1).unsqueeze(dim=-1) - - # prepare act: (b, e, k) -> (b, e, 1, k) - act = act.unsqueeze(dim=-2) - - # compute score, shape: (b, e, 1, 1) - score = act @ u - - # reduce - score = score.squeeze(dim=-1).squeeze(dim=-1) - - return score - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]) - - def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], slice_size=slice_size) - - def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102 - return self._score(r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1], slice_size=slice_size) diff --git a/src/pykeen/models/unimodal/proj_e.py b/src/pykeen/models/unimodal/proj_e.py index 734997d4b8..c33881a6a6 100644 --- a/src/pykeen/models/unimodal/proj_e.py +++ b/src/pykeen/models/unimodal/proj_e.py @@ -4,16 +4,14 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import numpy -import torch -import torch.autograd from torch import nn -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import BCEWithLogitsLoss, Loss +from ...nn import EmbeddingSpecification from ...nn.init import xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import ProjEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -22,7 +20,7 @@ ] -class ProjE(EntityRelationEmbeddingModel): +class ProjE(ERModel): r"""An implementation of ProjE from [shi2017]_. ProjE is a neural network-based approach with a *combination* and a *projection* layer. The interaction model @@ -54,88 +52,36 @@ class ProjE(EntityRelationEmbeddingModel): #: The default loss function class loss_default: ClassVar[Type[Loss]] = BCEWithLogitsLoss #: The default parameters for the default loss function class - loss_default_kwargs = dict(reduction='mean') + loss_default_kwargs: ClassVar[Mapping[str, Any]] = dict(reduction='mean') def __init__( self, triples_factory: TriplesFactory, + # ProjE parameters embedding_dim: int = 50, + inner_non_linearity: Optional[nn.Module] = None, + # Loss loss: Optional[Loss] = None, + # Model parameters preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - inner_non_linearity: Optional[nn.Module] = None, - regularizer: Optional[Regularizer] = None, ) -> None: + """Initialize :class:`ERModel` using :class:`ProjEInteraction`.""" super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=ProjEInteraction( + embedding_dim=embedding_dim, + inner_non_linearity=inner_non_linearity, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_uniform_, - relation_initializer=xavier_uniform_, - ) - - # Global entity projection - self.d_e = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) - - # Global relation projection - self.d_r = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) - - # Global combination bias - self.b_c = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) - - # Global combination bias - self.b_p = nn.Parameter(torch.empty(1, device=self.device), requires_grad=True) - - if inner_non_linearity is None: - inner_non_linearity = nn.Tanh() - self.inner_non_linearity = inner_non_linearity - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - bound = numpy.sqrt(6) / self.embedding_dim - nn.init.uniform_(self.d_e, a=-bound, b=bound) - nn.init.uniform_(self.d_r, a=-bound, b=bound) - nn.init.uniform_(self.b_c, a=-bound, b=bound) - nn.init.uniform_(self.b_p, a=-bound, b=bound) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Compute score - hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :]) - scores = torch.sum(hidden * t, dim=-1, keepdim=True) + self.b_p - - return scores - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]) - r = self.relation_embeddings(indices=hr_batch[:, 1]) - t = self.entity_embeddings(indices=None) - - # Rank against all entities - hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :]) - scores = torch.sum(hidden[:, None, :] * t[None, :, :], dim=-1) + self.b_p - - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None) - r = self.relation_embeddings(indices=rt_batch[:, 0]) - t = self.entity_embeddings(indices=rt_batch[:, 1]) - - # Rank against all entities - hidden = self.inner_non_linearity( - self.d_e[None, None, :] * h[None, :, :] - + (self.d_r[None, None, :] * r[:, None, :] + self.b_c[None, None, :]), ) - scores = torch.sum(hidden * t[:, None, :], dim=-1) + self.b_p - - return scores diff --git a/src/pykeen/models/unimodal/rescal.py b/src/pykeen/models/unimodal/rescal.py index e62c1fe9aa..aba19774b2 100644 --- a/src/pykeen/models/unimodal/rescal.py +++ b/src/pykeen/models/unimodal/rescal.py @@ -4,11 +4,11 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import torch - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification +from ...nn.modules import RESCALInteraction from ...regularizers import LpRegularizer, Regularizer from ...triples import TriplesFactory from ...typing import DeviceHint @@ -18,7 +18,7 @@ ] -class RESCAL(EntityRelationEmbeddingModel): +class RESCAL(ERModel): r"""An implementation of RESCAL from [nickel2011]_. This model represents relations as matrices and models interactions between latent features. @@ -56,9 +56,9 @@ def __init__( triples_factory: TriplesFactory, embedding_dim: int = 50, loss: Optional[Loss] = None, + regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize RESCAL. @@ -68,57 +68,20 @@ def __init__( - OpenKE `implementation of RESCAL `_ """ + if regularizer is None: + regularizer = self._instantiate_default_regularizer() super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, - relation_dim=embedding_dim ** 2, # d x d matrices + interaction=RESCALInteraction(), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + regularizer=regularizer, + ), + relation_representations=EmbeddingSpecification( + shape=(embedding_dim, embedding_dim), + regularizer=regularizer, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, ) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - # shape: (b, d) - h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, 1, self.embedding_dim) - # shape: (b, d, d) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim) - # shape: (b, d) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.embedding_dim, 1) - - # Compute scores - scores = h @ r @ t - - # Regularization - self.regularize_if_necessary(h, r, t) - - return scores[:, :, 0] - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.embedding_dim) - r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim) - t = self.entity_embeddings(indices=None).transpose(0, 1).view(1, self.embedding_dim, self.num_entities) - - # Compute scores - scores = h @ r @ t - - # Regularization - self.regularize_if_necessary(h, r, t) - - return scores[:, 0, :] - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - """Forward pass using left side (head) prediction.""" - # Get embeddings - h = self.entity_embeddings(indices=None).view(1, self.num_entities, self.embedding_dim) - r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.embedding_dim) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim, 1) - - # Compute scores - scores = h @ r @ t - - # Regularization - self.regularize_if_necessary(h, r, t) - - return scores[:, :, 0] diff --git a/src/pykeen/models/unimodal/rgcn.py b/src/pykeen/models/unimodal/rgcn.py index 455463fe10..b8436e92e2 100644 --- a/src/pykeen/models/unimodal/rgcn.py +++ b/src/pykeen/models/unimodal/rgcn.py @@ -2,20 +2,21 @@ """Implementation of the R-GCN model.""" -import logging -from os import path from typing import Any, Callable, ClassVar, Mapping, Optional, Type import torch from torch import nn -from torch.nn import functional from . import ComplEx, DistMult, ERMLP -from .. import EntityEmbeddingModel -from ..base import Model +from ..base import ERModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE from ...losses import Loss -from ...nn import Embedding, RepresentationModule +from ...nn import EmbeddingSpecification, Interaction +from ...nn.modules import DistMultInteraction +from ...nn.representation import ( + RGCNRepresentations, inverse_indegree_edge_weights, inverse_outdegree_edge_weights, + symmetric_edge_weights, +) from ...triples import TriplesFactory from ...typing import DeviceHint @@ -23,409 +24,8 @@ 'RGCN', ] -logger = logging.getLogger(name=path.basename(__file__)) - -def _get_neighborhood( - start_nodes: torch.LongTensor, - sources: torch.LongTensor, - targets: torch.LongTensor, - k: int, - num_nodes: int, - undirected: bool = False, -) -> torch.BoolTensor: - # Construct node neighbourhood mask - node_mask = torch.zeros(num_nodes, device=start_nodes.device, dtype=torch.bool) - - # Set nodes in batch to true - node_mask[start_nodes] = True - - # Compute k-neighbourhood - for _ in range(k): - # if the target node needs an embeddings, so does the source node - node_mask[sources] |= node_mask[targets] - - if undirected: - node_mask[targets] |= node_mask[sources] - - # Create edge mask - edge_mask = node_mask[targets] - - if undirected: - edge_mask |= node_mask[sources] - - return edge_mask - - -# pylint: disable=unused-argument -def inverse_indegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: - """Normalize messages by inverse in-degree. - - :param source: shape: (num_edges,) - The source indices. - :param target: shape: (num_edges,) - The target indices. - - :return: shape: (num_edges,) - The edge weights. - """ - # Calculate in-degree, i.e. number of incoming edges - uniq, inv, cnt = torch.unique(target, return_counts=True, return_inverse=True) - return cnt[inv].float().reciprocal() - - -# pylint: disable=unused-argument -def inverse_outdegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: - """Normalize messages by inverse out-degree. - - :param source: shape: (num_edges,) - The source indices. - :param target: shape: (num_edges,) - The target indices. - - :return: shape: (num_edges,) - The edge weights. - """ - # Calculate in-degree, i.e. number of incoming edges - uniq, inv, cnt = torch.unique(source, return_counts=True, return_inverse=True) - return cnt[inv].float().reciprocal() - - -def symmetric_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: - """Normalize messages by product of inverse sqrt of in-degree and out-degree. - - :param source: shape: (num_edges,) - The source indices. - :param target: shape: (num_edges,) - The target indices. - - :return: shape: (num_edges,) - The edge weights. - """ - return ( - inverse_indegree_edge_weights(source=source, target=target) - * inverse_outdegree_edge_weights(source=source, target=target) - ).sqrt() - - -class RGCNRepresentations(RepresentationModule): - """Representations enriched by R-GCN.""" - - def __init__( - self, - triples_factory: TriplesFactory, - embedding_dim: int = 500, - num_bases_or_blocks: int = 5, - num_layers: int = 2, - use_bias: bool = True, - use_batch_norm: bool = False, - activation_cls: Optional[Type[nn.Module]] = None, - activation_kwargs: Optional[Mapping[str, Any]] = None, - sparse_messages_slcwa: bool = True, - edge_dropout: float = 0.4, - self_loop_dropout: float = 0.2, - edge_weighting: Callable[ - [torch.LongTensor, torch.LongTensor], - torch.FloatTensor, - ] = inverse_indegree_edge_weights, - decomposition: str = 'basis', - buffer_messages: bool = True, - base_representations: Optional[RepresentationModule] = None, - ): - super().__init__() - - self.triples_factory = triples_factory - - # normalize representations - if base_representations is None: - base_representations = Embedding( - num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - # https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/affine_transform.py#L24-L28 - initializer=nn.init.xavier_uniform_, - ) - self.base_embeddings = base_representations - self.embedding_dim = embedding_dim - - # check decomposition - self.decomposition = decomposition - if self.decomposition == 'basis': - if num_bases_or_blocks is None: - logging.info('Using a heuristic to determine the number of bases.') - num_bases_or_blocks = triples_factory.num_relations // 2 + 1 - if num_bases_or_blocks > triples_factory.num_relations: - raise ValueError('The number of bases should not exceed the number of relations.') - elif self.decomposition == 'block': - if num_bases_or_blocks is None: - logging.info('Using a heuristic to determine the number of blocks.') - num_bases_or_blocks = 2 - if embedding_dim % num_bases_or_blocks != 0: - raise ValueError( - 'With block decomposition, the embedding dimension has to be divisible by the number of' - f' blocks, but {embedding_dim} % {num_bases_or_blocks} != 0.', - ) - else: - raise ValueError(f'Unknown decomposition: "{decomposition}". Please use either "basis" or "block".') - - self.num_bases = num_bases_or_blocks - self.edge_weighting = edge_weighting - self.edge_dropout = edge_dropout - if self_loop_dropout is None: - self_loop_dropout = edge_dropout - self.self_loop_dropout = self_loop_dropout - self.use_batch_norm = use_batch_norm - if activation_cls is None: - activation_cls = nn.ReLU - self.activation_cls = activation_cls - self.activation_kwargs = activation_kwargs - if use_batch_norm: - if use_bias: - logger.warning('Disabling bias because batch normalization was used.') - use_bias = False - self.use_bias = use_bias - self.num_layers = num_layers - self.sparse_messages_slcwa = sparse_messages_slcwa - - # Save graph using buffers, such that the tensors are moved together with the model - h, r, t = self.triples_factory.mapped_triples.t() - self.register_buffer('sources', h) - self.register_buffer('targets', t) - self.register_buffer('edge_types', r) - - self.activations = nn.ModuleList([ - self.activation_cls(**(self.activation_kwargs or {})) for _ in range(self.num_layers) - ]) - - # Weights - self.bases = nn.ParameterList() - if self.decomposition == 'basis': - self.att = nn.ParameterList() - for _ in range(self.num_layers): - self.bases.append(nn.Parameter( - data=torch.empty( - self.num_bases, - self.embedding_dim, - self.embedding_dim, - ), - requires_grad=True, - )) - self.att.append(nn.Parameter( - data=torch.empty( - self.triples_factory.num_relations + 1, - self.num_bases, - ), - requires_grad=True, - )) - elif self.decomposition == 'block': - block_size = self.embedding_dim // self.num_bases - for _ in range(self.num_layers): - self.bases.append(nn.Parameter( - data=torch.empty( - self.triples_factory.num_relations + 1, - self.num_bases, - block_size, - block_size, - ), - requires_grad=True, - )) - - self.att = None - else: - raise NotImplementedError - if self.use_bias: - self.biases = nn.ParameterList([ - nn.Parameter(torch.empty(self.embedding_dim), requires_grad=True) - for _ in range(self.num_layers) - ]) - else: - self.biases = None - if self.use_batch_norm: - self.batch_norms = nn.ModuleList([ - nn.BatchNorm1d(num_features=self.embedding_dim) - for _ in range(self.num_layers) - ]) - else: - self.batch_norms = None - - # buffering of messages - self.buffer_messages = buffer_messages - self.enriched_embeddings = None - - def _get_relation_weights(self, i_layer: int, r: int) -> torch.FloatTensor: - if self.decomposition == 'block': - # allocate weight - w = torch.zeros(self.embedding_dim, self.embedding_dim, device=self.bases[i_layer].device) - - # Get blocks - this_layer_blocks = self.bases[i_layer] - - # self.bases[i_layer].shape (num_relations, num_blocks, embedding_dim/num_blocks, embedding_dim/num_blocks) - # note: embedding_dim is guaranteed to be divisible by num_bases in the constructor - block_size = self.embedding_dim // self.num_bases - for b, start in enumerate(range(0, self.embedding_dim, block_size)): - stop = start + block_size - w[start:stop, start:stop] = this_layer_blocks[r, b, :, :] - - elif self.decomposition == 'basis': - # The current basis weights, shape: (num_bases) - att = self.att[i_layer][r, :] - # the current bases, shape: (num_bases, embedding_dim, embedding_dim) - b = self.bases[i_layer] - # compute the current relation weights, shape: (embedding_dim, embedding_dim) - w = torch.sum(att[:, None, None] * b, dim=0) - - else: - raise AssertionError(f'Unknown decomposition: {self.decomposition}') - - return w - - def forward( - self, - indices: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - # use buffered messages if applicable - if indices is None and self.enriched_embeddings is not None: - return self.enriched_embeddings - - # Bind fields - # shape: (num_entities, embedding_dim) - x = self.base_embeddings(indices=None) - sources = self.sources - targets = self.targets - edge_types = self.edge_types - - # Edge dropout: drop the same edges on all layers (only in training mode) - if self.training and self.edge_dropout is not None: - # Get random dropout mask - edge_keep_mask = torch.rand(self.sources.shape[0], device=x.device) > self.edge_dropout - - # Apply to edges - sources = sources[edge_keep_mask] - targets = targets[edge_keep_mask] - edge_types = edge_types[edge_keep_mask] - - # Different dropout for self-loops (only in training mode) - if self.training and self.self_loop_dropout is not None: - node_keep_mask = torch.rand(self.triples_factory.num_entities, device=x.device) > self.self_loop_dropout - else: - node_keep_mask = None - - for i in range(self.num_layers): - # Initialize embeddings in the next layer for all nodes - new_x = torch.zeros_like(x) - - # TODO: Can we vectorize this loop? - for r in range(self.triples_factory.num_relations): - # Choose the edges which are of the specific relation - mask = (edge_types == r) - - # No edges available? Skip rest of inner loop - if not mask.any(): - continue - - # Get source and target node indices - sources_r = sources[mask] - targets_r = targets[mask] - - # send messages in both directions - sources_r, targets_r = torch.cat([sources_r, targets_r]), torch.cat([targets_r, sources_r]) - - # Select source node embeddings - x_s = x[sources_r] - - # get relation weights - w = self._get_relation_weights(i_layer=i, r=r) - - # Compute message (b x d) * (d x d) = (b x d) - m_r = x_s @ w - - # Normalize messages by relation-specific in-degree - if self.edge_weighting is not None: - m_r *= self.edge_weighting(source=sources_r, target=targets_r).unsqueeze(dim=-1) - - # Aggregate messages in target - new_x.index_add_(dim=0, index=targets_r, source=m_r) - - # Self-loop - self_w = self._get_relation_weights(i_layer=i, r=self.triples_factory.num_relations) - if node_keep_mask is None: - new_x += x @ self_w - else: - new_x[node_keep_mask] += x[node_keep_mask] @ self_w - - # Apply bias, if requested - if self.use_bias: - bias = self.biases[i] - new_x += bias - - # Apply batch normalization, if requested - if self.use_batch_norm: - batch_norm = self.batch_norms[i] - new_x = batch_norm(new_x) - - # Apply non-linearity - if self.activations is not None: - activation = self.activations[i] - new_x = activation(new_x) - - x = new_x - - if indices is None and self.buffer_messages: - self.enriched_embeddings = x - if indices is not None: - x = x[indices] - - return x - - def post_parameter_update(self) -> None: # noqa: D102 - super().post_parameter_update() - - # invalidate enriched embeddings - self.enriched_embeddings = None - - def reset_parameters(self): - self.base_embeddings.reset_parameters() - - gain = nn.init.calculate_gain(nonlinearity=self.activation_cls.__name__.lower()) - if self.decomposition == 'basis': - for base in self.bases: - nn.init.xavier_normal_(base, gain=gain) - for att in self.att: - # Random convex-combination of bases for initialization (guarantees that initial weight matrices are - # initialized properly) - # We have one additional relation for self-loops - nn.init.uniform_(att) - functional.normalize(att.data, p=1, dim=1, out=att.data) - elif self.decomposition == 'block': - for base in self.bases: - block_size = base.shape[-1] - # Xavier Glorot initialization of each block - std = torch.sqrt(torch.as_tensor(2.)) * gain / (2 * block_size) - nn.init.normal_(base, std=std) - - # Reset biases - if self.biases is not None: - for bias in self.biases: - nn.init.zeros_(bias) - - # Reset batch norm parameters - if self.batch_norms is not None: - for bn in self.batch_norms: - bn.reset_parameters() - - # Reset activation parameters, if any - for act in self.activations: - if hasattr(act, 'reset_parameters'): - act.reset_parameters() - - -class Decoder(nn.Module): - # TODO: Replace this by interaction function, once https://github.com/pykeen/pykeen/pull/107 is merged. - def forward(self, h, r, t): - return (h * r * t).sum(dim=-1) - - -class RGCN(Model): +class RGCN(ERModel): """An implementation of R-GCN from [schlichtkrull2018]_. This model uses graph convolutions with relation-specific weights. @@ -438,31 +38,6 @@ class RGCN(Model): `_ """ - #: Interaction model used as decoder - base_model: EntityEmbeddingModel - - #: The blocks of the relation-specific weight matrices - #: shape: (num_relations, num_blocks, embedding_dim//num_blocks, embedding_dim//num_blocks) - blocks: Optional[nn.ParameterList] - - #: The base weight matrices to generate relation-specific weights - #: shape: (num_bases, embedding_dim, embedding_dim) - bases: Optional[nn.ParameterList] - - #: The relation-specific weights for each base - #: shape: (num_relations, num_bases) - att: Optional[nn.ParameterList] - - #: The biases for each layer (if used) - #: shape of each element: (embedding_dim,) - biases: Optional[nn.ParameterList] - - #: Batch normalization for each layer (if used) - batch_norms: Optional[nn.ModuleList] - - #: Activations for each layer (if used) - activations: Optional[nn.ModuleList] - #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=dict(type=int, low=16, high=1024, q=16), @@ -486,6 +61,7 @@ class RGCN(Model): def __init__( self, triples_factory: TriplesFactory, + interaction: Optional[Interaction] = None, embedding_dim: int = 500, loss: Optional[Loss] = None, predict_with_sigmoid: bool = False, @@ -509,14 +85,10 @@ def __init__( ): if triples_factory.create_inverse_triples: raise ValueError('R-GCN handles edges in an undirected manner.') - super().__init__( - triples_factory=triples_factory, - loss=loss, - predict_with_sigmoid=predict_with_sigmoid, - preferred_device=preferred_device, - random_seed=random_seed, - ) - self.entity_representations = RGCNRepresentations( + if interaction is None: + interaction = DistMultInteraction() + + entity_representations = RGCNRepresentations( triples_factory=triples_factory, embedding_dim=embedding_dim, num_bases_or_blocks=num_bases_or_blocks, @@ -533,25 +105,15 @@ def __init__( buffer_messages=buffer_messages, base_representations=None, ) - self.relation_embeddings = Embedding( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim, + super().__init__( + triples_factory=triples_factory, + loss=loss, + predict_with_sigmoid=predict_with_sigmoid, + preferred_device=preferred_device, + random_seed=random_seed, + interaction=interaction, + entity_representations=entity_representations, + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), ) - # TODO: Dummy - self.decoder = Decoder() - - def post_parameter_update(self) -> None: # noqa: D102 - super().post_parameter_update() - self.entity_representations.post_parameter_update() - self.relation_embeddings.post_parameter_update() - - def _reset_parameters_(self): - self.entity_representations.reset_parameters() - self.relation_embeddings.reset_parameters() - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Enrich embeddings - h = self.entity_representations(indices=hrt_batch[:, 0]) - t = self.entity_representations(indices=hrt_batch[:, 2]) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - return self.decoder(h, r, t).unsqueeze(dim=-1) diff --git a/src/pykeen/models/unimodal/rotate.py b/src/pykeen/models/unimodal/rotate.py index 812add0fa1..cb02f68386 100644 --- a/src/pykeen/models/unimodal/rotate.py +++ b/src/pykeen/models/unimodal/rotate.py @@ -5,46 +5,22 @@ from typing import Any, ClassVar, Mapping, Optional import torch -import torch.autograd -from torch.nn import functional -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...losses import Loss +from ...nn import EmbeddingSpecification from ...nn.init import init_phases, xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import RotatEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint +from ...utils import complex_normalize __all__ = [ 'RotatE', ] -def complex_normalize(x: torch.Tensor) -> torch.Tensor: - r"""Normalize the length of relation vectors, if the forward constraint has not been applied yet. - - The `modulus of complex number `_ is given as: - - .. math:: - - |a + ib| = \sqrt{a^2 + b^2} - - $l_2$ norm of complex vector $x \in \mathbb{C}^d$: - - .. math:: - \|x\|^2 = \sum_{i=1}^d |x_i|^2 - = \sum_{i=1}^d \left(\operatorname{Re}(x_i)^2 + \operatorname{Im}(x_i)^2\right) - = \left(\sum_{i=1}^d \operatorname{Re}(x_i)^2) + (\sum_{i=1}^d \operatorname{Im}(x_i)^2\right) - = \|\operatorname{Re}(x)\|^2 + \|\operatorname{Im}(x)\|^2 - = \| [\operatorname{Re}(x); \operatorname{Im}(x)] \|^2 - """ - y = x.data.view(x.shape[0], -1, 2) - y = functional.normalize(y, p=2, dim=-1) - x.data = y.view(*x.shape) - return x - - -class RotatE(EntityRelationEmbeddingModel): +class RotatE(ERModel): r"""An implementation of RotatE from [sun2019]_. RotatE models relations as rotations from head to tail entities in complex space: @@ -81,112 +57,22 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: super().__init__( triples_factory=triples_factory, - embedding_dim=2 * embedding_dim, + interaction=RotatEInteraction(), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + dtype=torch.complex64, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=init_phases, + constrainer=complex_normalize, + dtype=torch.complex64, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_uniform_, - relation_initializer=init_phases, - relation_constrainer=complex_normalize, - ) - self.real_embedding_dim = embedding_dim - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function of ComplEx for given embeddings. - - The embeddings have to be in a broadcastable shape. - - WARNING: No forward constraints are applied. - - :param h: shape: (..., e, 2) - Head embeddings. Last dimension corresponds to (real, imag). - :param r: shape: (..., e, 2) - Relation embeddings. Last dimension corresponds to (real, imag). - :param t: shape: (..., e, 2) - Tail embeddings. Last dimension corresponds to (real, imag). - - :return: shape: (...) - The scores. - """ - # Decompose into real and imaginary part - h_re = h[..., 0] - h_im = h[..., 1] - r_re = r[..., 0] - r_im = r[..., 1] - - # Rotate (=Hadamard product in complex space). - rot_h = torch.stack( - [ - h_re * r_re - h_im * r_im, - h_re * r_im + h_im * r_re, - ], - dim=-1, ) - # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed - diff = rot_h - t - scores = -torch.norm(diff.view(diff.shape[:-2] + (-1,)), dim=-1) - - return scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.real_embedding_dim, 2) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.real_embedding_dim, 2) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.real_embedding_dim, 2) - - # Compute scores - scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) - - # Embedding Regularization - self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) - - return scores - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) - r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) - - # Rank against all entities - t = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) - - # Compute scores - scores = self.interaction_function(h=h, r=r, t=t) - - # Embedding Regularization - self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) - - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) - - # r expresses a rotation in complex plane. - # The inverse rotation is expressed by the complex conjugate of r. - # The score is computed as the distance of the relation-rotated head to the tail. - # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. - # |h * r - t| = |h - conj(r) * t| - r_inv = torch.stack([r[:, :, :, 0], -r[:, :, :, 1]], dim=-1) - - # Rank against all entities - h = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) - - # Compute scores - scores = self.interaction_function(h=t, r=r_inv, t=h) - - # Embedding Regularization - self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) - - return scores diff --git a/src/pykeen/models/unimodal/simple.py b/src/pykeen/models/unimodal/simple.py index baf1f12090..56feb1313d 100644 --- a/src/pykeen/models/unimodal/simple.py +++ b/src/pykeen/models/unimodal/simple.py @@ -4,12 +4,13 @@ from typing import Any, ClassVar, Mapping, Optional, Tuple, Type, Union -import torch.autograd +import torch -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss, SoftplusLoss -from ...nn import Embedding +from ...nn import EmbeddingSpecification +from ...nn.modules import SimplEInteraction from ...regularizers import PowerSumRegularizer, Regularizer from ...triples import TriplesFactory from ...typing import DeviceHint @@ -19,7 +20,7 @@ ] -class SimplE(EntityRelationEmbeddingModel): +class SimplE(ERModel): r"""An implementation of SimplE [kazemi2018]_. SimplE is an extension of canonical polyadic (CP), an early tensor factorization approach in which each entity @@ -68,81 +69,59 @@ def __init__( triples_factory: TriplesFactory, embedding_dim: int = 200, loss: Optional[Loss] = None, + regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, clamp_score: Optional[Union[float, Tuple[float, float]]] = None, ) -> None: + if regularizer is None: + regularizer = self._instantiate_default_regularizer() super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=SimplEInteraction(clamp_score=clamp_score), + entity_representations=[ + EmbeddingSpecification( + embedding_dim=embedding_dim, + regularizer=regularizer, + ), + EmbeddingSpecification( + embedding_dim=embedding_dim, + regularizer=regularizer, + ), + ], + relation_representations=[ + EmbeddingSpecification( + embedding_dim=embedding_dim, + regularizer=regularizer, + ), + EmbeddingSpecification( + embedding_dim=embedding_dim, + regularizer=regularizer, + ), + ], loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - ) - - # extra embeddings - self.tail_entity_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - device=self.device, ) - self.inverse_relation_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim, - device=self.device, - ) - - if isinstance(clamp_score, float): - clamp_score = (-clamp_score, clamp_score) - self.clamp = clamp_score - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - for emb in [ - self.tail_entity_embeddings, - self.inverse_relation_embeddings, - ]: - emb.reset_parameters() - - def _score( + def forward( self, h_indices: Optional[torch.LongTensor], r_indices: Optional[torch.LongTensor], t_indices: Optional[torch.LongTensor], + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, ) -> torch.FloatTensor: # noqa: D102 - # forward model - h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) - r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) - t = self.tail_entity_embeddings.get_in_canonical_shape(indices=t_indices) - scores = (h * r * t).sum(dim=-1) - - # Regularization - self.regularize_if_necessary(h, r, t) - - # backward model - h = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) - r = self.inverse_relation_embeddings.get_in_canonical_shape(indices=r_indices) - t = self.tail_entity_embeddings.get_in_canonical_shape(indices=h_indices) - scores = 0.5 * (scores + (h * r * t).sum(dim=-1)) - - # Regularization - self.regularize_if_necessary(h, r, t) - - # Note: In the code in their repository, the score is clamped to [-20, 20]. - # That is not mentioned in the paper, so it is omitted here. - if self.clamp is not None: - min_, max_ = self.clamp - scores = scores.clamp(min=min_, max=max_) - - return scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1]) + h, r, t = zip(*( + ( + h_source.get_in_canonical_shape(dim="h", indices=h_indices), + r_source.get_in_canonical_shape(dim="r", indices=r_indices), + t_source.get_in_canonical_shape(dim="t", indices=t_indices), + ) + for h_source, r_source, t_source in ( + (self.entity_representations[0], self.relation_representations[0], self.entity_representations[1]), + (self.entity_representations[1], self.relation_representations[1], self.entity_representations[0]), + ) + )) + scores = self.interaction.score(h=h, r=r, t=t, slice_size=slice_size, slice_dim=slice_dim) + return self._repeat_scores_if_necessary(scores, h_indices, r_indices, t_indices) diff --git a/src/pykeen/models/unimodal/structured_embedding.py b/src/pykeen/models/unimodal/structured_embedding.py index f0dcb533a8..878a7b20bb 100644 --- a/src/pykeen/models/unimodal/structured_embedding.py +++ b/src/pykeen/models/unimodal/structured_embedding.py @@ -6,17 +6,15 @@ from typing import Any, ClassVar, Mapping, Optional import numpy as np -import torch -import torch.autograd from torch import nn from torch.nn import functional -from ..base import EntityEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification from ...nn.init import xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import StructuredEmbeddingInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import compose @@ -26,7 +24,7 @@ ] -class StructuredEmbedding(EntityEmbeddingModel): +class StructuredEmbedding(ERModel): r"""An implementation of the Structured Embedding (SE) published by [bordes2011]_. SE applies role- and relation-specific projection matrices @@ -56,119 +54,41 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize SE. :param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$. :param scoring_fct_norm: The $l_p$ norm. Usually 1 for SE. """ - super().__init__( - triples_factory=triples_factory, - embedding_dim=embedding_dim, - loss=loss, - preferred_device=preferred_device, - random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_uniform_, - entity_constrainer=functional.normalize, - ) - - self.scoring_fct_norm = scoring_fct_norm - # Embeddings - init_bound = 6 / np.sqrt(self.embedding_dim) + init_bound = 6 / np.sqrt(embedding_dim) # Initialise relation embeddings to unit length - initializer = compose( + relation_initializer = compose( functools.partial(nn.init.uniform_, a=-init_bound, b=+init_bound), functional.normalize, ) - self.left_relation_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim ** 2, - device=self.device, - initializer=initializer, - ) - self.right_relation_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim ** 2, - device=self.device, - initializer=initializer, + super().__init__( + triples_factory=triples_factory, + interaction=StructuredEmbeddingInteraction( + p=scoring_fct_norm, + power_norm=False, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + constrainer=functional.normalize, + ), + relation_representations=[ + EmbeddingSpecification( + shape=(embedding_dim, embedding_dim), + initializer=relation_initializer, + ), + EmbeddingSpecification( + shape=(embedding_dim, embedding_dim), + initializer=relation_initializer, + ), + ], + loss=loss, + preferred_device=preferred_device, + random_seed=random_seed, ) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - self.left_relation_embeddings.reset_parameters() - self.right_relation_embeddings.reset_parameters() - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.embedding_dim, 1) - rel_h = self.left_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim) - rel_t = self.right_relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.embedding_dim, 1) - - # Project entities - proj_h = rel_h @ h - proj_t = rel_t @ t - - scores = -torch.norm(proj_h - proj_t, dim=1, p=self.scoring_fct_norm) - return scores - - def score_t(self, hr_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, self.embedding_dim, 1) - rel_h = self.left_relation_embeddings(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.embedding_dim) - rel_t = self.right_relation_embeddings(indices=hr_batch[:, 1]) - rel_t = rel_t.view(-1, 1, self.embedding_dim, self.embedding_dim) - t_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1) - - if slice_size is not None: - proj_t_arr = [] - # Project entities - proj_h = rel_h @ h - - for t in torch.split(t_all, slice_size, dim=1): - # Project entities - proj_t = rel_t @ t - proj_t_arr.append(proj_t) - - proj_t = torch.cat(proj_t_arr, dim=1) - - else: - # Project entities - proj_h = rel_h @ h - proj_t = rel_t @ t_all - - scores = -torch.norm(proj_h[:, None, :, 0] - proj_t[:, :, :, 0], dim=-1, p=self.scoring_fct_norm) - - return scores - - def score_h(self, rt_batch: torch.LongTensor, slice_size: int = None) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h_all = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim, 1) - rel_h = self.left_relation_embeddings(indices=rt_batch[:, 0]) - rel_h = rel_h.view(-1, 1, self.embedding_dim, self.embedding_dim) - rel_t = self.right_relation_embeddings(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.embedding_dim) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, self.embedding_dim, 1) - - if slice_size is not None: - proj_h_arr = [] - - # Project entities - proj_t = rel_t @ t - - for h in torch.split(h_all, slice_size, dim=1): - # Project entities - proj_h = rel_h @ h - proj_h_arr.append(proj_h) - - proj_h = torch.cat(proj_h_arr, dim=1) - else: - # Project entities - proj_h = rel_h @ h_all - proj_t = rel_t @ t - - scores = -torch.norm(proj_h[:, :, :, 0] - proj_t[:, None, :, 0], dim=-1, p=self.scoring_fct_norm) - - return scores diff --git a/src/pykeen/models/unimodal/trans_d.py b/src/pykeen/models/unimodal/trans_d.py index a4a6ff7e17..4ae8d264bc 100644 --- a/src/pykeen/models/unimodal/trans_d.py +++ b/src/pykeen/models/unimodal/trans_d.py @@ -4,15 +4,12 @@ from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification from ...nn.init import xavier_normal_ -from ...regularizers import Regularizer +from ...nn.modules import TransDInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import clamp_norm @@ -22,54 +19,7 @@ ] -def _project_entity( - e: torch.FloatTensor, - e_p: torch.FloatTensor, - r: torch.FloatTensor, - r_p: torch.FloatTensor, -) -> torch.FloatTensor: - r"""Project entity relation-specific. - - .. math:: - - e_{\bot} = M_{re} e - = (r_p e_p^T + I^{d_r \times d_e}) e - = r_p e_p^T e + I^{d_r \times d_e} e - = r_p (e_p^T e) + e' - - and additionally enforces - - .. math:: - - \|e_{\bot}\|_2 \leq 1 - - :param e: shape: (batch_size, num_entities, d_e) - The entity embedding. - :param e_p: shape: (batch_size, num_entities, d_e) - The entity projection. - :param r: shape: (batch_size, num_entities, d_r) - The relation embedding. - :param r_p: shape: (batch_size, num_entities, d_r) - The relation projection. - - :return: shape: (batch_size, num_entities, d_r) - - """ - # The dimensions affected by e' - change_dim = min(e.shape[-1], r.shape[-1]) - - # Project entities - # r_p (e_p.T e) + e' - e_bot = r_p * torch.sum(e_p * e, dim=-1, keepdim=True) - e_bot[:, :, :change_dim] += e[:, :, :change_dim] - - # Enforce constraints - e_bot = clamp_norm(e_bot, p=2, dim=-1, maxnorm=1) - - return e_bot - - -class TransD(EntityRelationEmbeddingModel): +class TransD(ERModel): r"""An implementation of TransD from [ji2015]_. TransD is an extension of :class:`pykeen.models.TransR` that, like TransR, considers entities and relations @@ -115,113 +65,29 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, - relation_dim=relation_dim, + interaction=TransDInteraction(p=2, power_norm=True), + entity_representations=[ + EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_normal_, + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ) + for _ in range(2) + ], + relation_representations=[ + EmbeddingSpecification( + embedding_dim=relation_dim, + initializer=xavier_normal_, + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ) + for _ in range(2) + ], loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_normal_, - relation_initializer=xavier_normal_, - entity_constrainer=clamp_norm, - entity_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), - relation_constrainer=clamp_norm, - relation_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), ) - - self.entity_projections = Embedding.init_with_device( - num_embeddings=triples_factory.num_entities, - embedding_dim=embedding_dim, - device=self.device, - initializer=xavier_normal_, - ) - self.relation_projections = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=relation_dim, - device=self.device, - initializer=xavier_normal_, - ) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - self.entity_projections.reset_parameters() - self.relation_projections.reset_parameters() - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - h_p: torch.FloatTensor, - r: torch.FloatTensor, - r_p: torch.FloatTensor, - t: torch.FloatTensor, - t_p: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function for given embeddings. - - The embeddings have to be in a broadcastable shape. - - :param h: shape: (batch_size, num_entities, d_e) - Head embeddings. - :param h_p: shape: (batch_size, num_entities, d_e) - Head projections. - :param r: shape: (batch_size, num_entities, d_r) - Relation embeddings. - :param r_p: shape: (batch_size, num_entities, d_r) - Relation projections. - :param t: shape: (batch_size, num_entities, d_e) - Tail embeddings. - :param t_p: shape: (batch_size, num_entities, d_e) - Tail projections. - - :return: shape: (batch_size, num_entities) - The scores. - """ - # Project entities - h_bot = _project_entity(e=h, e_p=h_p, r=r, r_p=r_p) - t_bot = _project_entity(e=t, e_p=t_p, r=r, r_p=r_p) - - # score = -||h_bot + r - t_bot||_2^2 - return -torch.norm(h_bot + r - t_bot, dim=-1, p=2) ** 2 - - def _score( - self, - h_indices: Optional[torch.LongTensor] = None, - r_indices: Optional[torch.LongTensor] = None, - t_indices: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - """ - Evaluate the interaction function. - - :param h_indices: shape: (batch_size,) - The indices of head entities. If None, score against all. - :param r_indices: shape: (batch_size,) - The indices of relations. If None, score against all. - :param t_indices: shape: (batch_size,) - The indices of tail entities. If None, score against all. - - :return: The scores, shape: (batch_size, num_entities) - """ - # Head - h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) - h_p = self.entity_projections.get_in_canonical_shape(indices=h_indices) - - r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) - r_p = self.relation_projections.get_in_canonical_shape(indices=r_indices) - - t = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) - t_p = self.entity_projections.get_in_canonical_shape(indices=t_indices) - - return self.interaction_function(h=h, h_p=h_p, r=r, r_p=r_p, t=t, t_p=t_p) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._score(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1]) diff --git a/src/pykeen/models/unimodal/trans_e.py b/src/pykeen/models/unimodal/trans_e.py index 724930ae77..3b9c545709 100644 --- a/src/pykeen/models/unimodal/trans_e.py +++ b/src/pykeen/models/unimodal/trans_e.py @@ -4,15 +4,14 @@ from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd from torch.nn import functional -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification from ...nn.init import xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import TransEInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import compose @@ -22,7 +21,7 @@ ] -class TransE(EntityRelationEmbeddingModel): +class TransE(ERModel): r"""TransE models relations as a translation from head to tail entities in :math:`\textbf{e}` [bordes2013]_. .. math:: @@ -56,7 +55,6 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize TransE. @@ -69,43 +67,20 @@ def __init__( """ super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=TransEInteraction(p=scoring_fct_norm, power_norm=False), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + constrainer=functional.normalize, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=compose( + xavier_uniform_, + functional.normalize, + ), + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_uniform_, - relation_initializer=compose( - xavier_uniform_, - functional.normalize, - ), - entity_constrainer=functional.normalize, ) - self.scoring_fct_norm = scoring_fct_norm - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # TODO: Use torch.dist - return -torch.norm(h + r - t, dim=-1, p=self.scoring_fct_norm, keepdim=True) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]) - r = self.relation_embeddings(indices=hr_batch[:, 1]) - t = self.entity_embeddings(indices=None) - - # TODO: Use torch.cdist - return -torch.norm(h[:, None, :] + r[:, None, :] - t[None, :, :], dim=-1, p=self.scoring_fct_norm) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None) - r = self.relation_embeddings(indices=rt_batch[:, 0]) - t = self.entity_embeddings(indices=rt_batch[:, 1]) - - # TODO: Use torch.cdist - return -torch.norm(h[None, :, :] + r[:, None, :] - t[:, None, :], dim=-1, p=self.scoring_fct_norm) diff --git a/src/pykeen/models/unimodal/trans_h.py b/src/pykeen/models/unimodal/trans_h.py index 1a2e1e8b40..f2cf2163e8 100644 --- a/src/pykeen/models/unimodal/trans_h.py +++ b/src/pykeen/models/unimodal/trans_h.py @@ -4,23 +4,24 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import torch from torch.nn import functional -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification +from ...nn.modules import TransHInteraction from ...regularizers import Regularizer, TransHRegularizer from ...triples import TriplesFactory from ...typing import DeviceHint +from ...utils import pop_only __all__ = [ 'TransH', ] -class TransH(EntityRelationEmbeddingModel): +class TransH(ERModel): r"""An implementation of TransH [wang2014]_. This model extends :class:`pykeen.models.TransE` by applying the translation from head to tail entity in a @@ -81,89 +82,34 @@ def __init__( """ super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, + interaction=TransHInteraction( + p=scoring_fct_norm, + power_norm=False, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + relation_representations=[ + EmbeddingSpecification( + embedding_dim=embedding_dim, + ), + EmbeddingSpecification( + embedding_dim=embedding_dim, + # Normalise the normal vectors by their l2 norms + constrainer=functional.normalize, + ), + ], loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, predict_with_sigmoid=predict_with_sigmoid, ) - - self.scoring_fct_norm = scoring_fct_norm - - # embeddings - self.normal_vector_embeddings = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=embedding_dim, - device=self.device, - # Normalise the normal vectors by their l2 norms - constrainer=functional.normalize, - ) - - def post_parameter_update(self) -> None: # noqa: D102 - super().post_parameter_update() - self.normal_vector_embeddings.post_parameter_update() - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - self.normal_vector_embeddings.reset_parameters() - # TODO: Add initialization - - def regularize_if_necessary(self) -> None: - """Update the regularizer's term given some tensors, if regularization is requested.""" - # As described in [wang2014], all entities and relations are used to compute the regularization term - # which enforces the defined soft constraints. - super().regularize_if_necessary( - self.entity_embeddings(indices=None), - self.normal_vector_embeddings(indices=None), # FIXME - self.relation_embeddings(indices=None), - ) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - d_r = self.relation_embeddings(indices=hrt_batch[:, 1]) - w_r = self.normal_vector_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - - # Project to hyperplane - ph = h - torch.sum(w_r * h, dim=-1, keepdim=True) * w_r - pt = t - torch.sum(w_r * t, dim=-1, keepdim=True) * w_r - - # Regularization term - self.regularize_if_necessary() - - return -torch.norm(ph + d_r - pt, p=2, dim=-1, keepdim=True) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]) - d_r = self.relation_embeddings(indices=hr_batch[:, 1]) - w_r = self.normal_vector_embeddings(indices=hr_batch[:, 1]) - t = self.entity_embeddings(indices=None) - - # Project to hyperplane - ph = h - torch.sum(w_r * h, dim=-1, keepdim=True) * w_r - pt = t[None, :, :] - torch.sum(w_r[:, None, :] * t[None, :, :], dim=-1, keepdim=True) * w_r[:, None, :] - - # Regularization term - self.regularize_if_necessary() - - return -torch.norm(ph[:, None, :] + d_r[:, None, :] - pt, p=2, dim=-1) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None) - rel_id = rt_batch[:, 0] - d_r = self.relation_embeddings(indices=rel_id) - w_r = self.normal_vector_embeddings(indices=rel_id) - t = self.entity_embeddings(indices=rt_batch[:, 1]) - - # Project to hyperplane - ph = h[None, :, :] - torch.sum(w_r[:, None, :] * h[None, :, :], dim=-1, keepdim=True) * w_r[:, None, :] - pt = t - torch.sum(w_r * t, dim=-1, keepdim=True) * w_r - - # Regularization term - self.regularize_if_necessary() - - return -torch.norm(ph + d_r[:, None, :] - pt[:, None, :], p=2, dim=-1) + if regularizer is None: + # Note that the TransH regularizer has a different interface + self.regularizer = self._instantiate_default_regularizer( + entity_embeddings=pop_only(self.entity_representations[0].parameters()), + relation_embeddings=pop_only(self.relation_representations[0].parameters()), + normal_vector_embeddings=pop_only(self.relation_representations[1].parameters()), + ) + else: + self.regularizer = regularizer diff --git a/src/pykeen/models/unimodal/trans_r.py b/src/pykeen/models/unimodal/trans_r.py index c47a2de3c6..88c1b93c09 100644 --- a/src/pykeen/models/unimodal/trans_r.py +++ b/src/pykeen/models/unimodal/trans_r.py @@ -2,20 +2,17 @@ """Implementation of TransR.""" -from functools import partial +import logging from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd -import torch.nn.init from torch.nn import functional -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss -from ...nn import Embedding +from ...nn import EmbeddingSpecification from ...nn.init import xavier_uniform_ -from ...regularizers import Regularizer +from ...nn.modules import TransRInteraction from ...triples import TriplesFactory from ...typing import DeviceHint from ...utils import clamp_norm, compose @@ -24,18 +21,10 @@ 'TransR', ] +logger = logging.getLogger(__name__) -def _projection_initializer( - x: torch.FloatTensor, - num_relations: int, - embedding_dim: int, - relation_dim: int, -) -> torch.FloatTensor: - """Initialize by Glorot.""" - return torch.nn.init.xavier_uniform_(x.view(num_relations, embedding_dim, relation_dim)).view(x.shape) - -class TransR(EntityRelationEmbeddingModel): +class TransR(ERModel): r"""An implementation of TransR from [lin2015]_. TransR is an extension of :class:`pykeen.models.TransH` that explicitly considers entities and relations as @@ -82,104 +71,39 @@ def __init__( loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: """Initialize the model.""" super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, - relation_dim=relation_dim, + interaction=TransRInteraction( + p=scoring_fct_norm, + ), + # Entity embeddings + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_uniform_, + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ), + relation_representations=[ + # Relation embeddings + EmbeddingSpecification( + embedding_dim=relation_dim, + initializer=compose( + xavier_uniform_, + functional.normalize, + ), + constrainer=clamp_norm, # type: ignore + constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), + ), + # Relation projections + EmbeddingSpecification( + shape=(embedding_dim, relation_dim), + initializer=xavier_uniform_, + ), + ], loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_uniform_, - entity_constrainer=clamp_norm, - entity_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), - relation_initializer=compose( - xavier_uniform_, - functional.normalize, - ), - relation_constrainer=clamp_norm, - relation_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), ) - self.scoring_fct_norm = scoring_fct_norm - - # TODO: Initialize from TransE - - # embeddings - self.relation_projections = Embedding.init_with_device( - num_embeddings=triples_factory.num_relations, - embedding_dim=relation_dim * embedding_dim, - device=self.device, - initializer=partial( - _projection_initializer, - num_relations=self.num_relations, - embedding_dim=self.embedding_dim, - relation_dim=self.relation_dim, - ), - ) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - self.relation_projections.reset_parameters() - - @staticmethod - def interaction_function( - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - m_r: torch.FloatTensor, - ) -> torch.FloatTensor: - """Evaluate the interaction function for given embeddings. - - The embeddings have to be in a broadcastable shape. - - :param h: shape: (batch_size, num_entities, d_e) - Head embeddings. - :param r: shape: (batch_size, num_entities, d_r) - Relation embeddings. - :param t: shape: (batch_size, num_entities, d_e) - Tail embeddings. - :param m_r: shape: (batch_size, num_entities, d_e, d_r) - The relation specific linear transformations. - - :return: shape: (batch_size, num_entities) - The scores. - """ - # project to relation specific subspace, shape: (b, e, d_r) - h_bot = h @ m_r - t_bot = t @ m_r - # ensure constraints - h_bot = clamp_norm(h_bot, p=2, dim=-1, maxnorm=1.) - t_bot = clamp_norm(t_bot, p=2, dim=-1, maxnorm=1.) - - # evaluate score function, shape: (b, e) - return -torch.norm(h_bot + r - t_bot, dim=-1) ** 2 - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]).unsqueeze(dim=1) - r = self.relation_embeddings(indices=hrt_batch[:, 1]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).unsqueeze(dim=1) - m_r = self.relation_projections(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.relation_dim) - - return self.interaction_function(h=h, r=r, t=t, m_r=m_r).view(-1, 1) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]).unsqueeze(dim=1) - r = self.relation_embeddings(indices=hr_batch[:, 1]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=None).unsqueeze(dim=0) - m_r = self.relation_projections(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.relation_dim) - - return self.interaction_function(h=h, r=r, t=t, m_r=m_r) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None).unsqueeze(dim=0) - r = self.relation_embeddings(indices=rt_batch[:, 0]).unsqueeze(dim=1) - t = self.entity_embeddings(indices=rt_batch[:, 1]).unsqueeze(dim=1) - m_r = self.relation_projections(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.relation_dim) - - return self.interaction_function(h=h, r=r, t=t, m_r=m_r) + logger.warning("Initialize from TransE") diff --git a/src/pykeen/models/unimodal/tucker.py b/src/pykeen/models/unimodal/tucker.py index f2177e0525..3bdd2bbcf3 100644 --- a/src/pykeen/models/unimodal/tucker.py +++ b/src/pykeen/models/unimodal/tucker.py @@ -4,15 +4,12 @@ from typing import Any, ClassVar, Mapping, Optional, Type -import torch -import torch.autograd -from torch import nn - -from ..base import EntityRelationEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import BCEAfterSigmoidLoss, Loss +from ...nn import EmbeddingSpecification from ...nn.init import xavier_normal_ -from ...regularizers import Regularizer +from ...nn.modules import TuckerInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -21,18 +18,7 @@ ] -def _apply_bn_to_tensor( - batch_norm: nn.BatchNorm1d, - tensor: torch.FloatTensor, -) -> torch.FloatTensor: - shape = tensor.shape - tensor = tensor.view(-1, shape[-1]) - tensor = batch_norm(tensor) - tensor = tensor.view(*shape) - return tensor - - -class TuckER(EntityRelationEmbeddingModel): +class TuckER(ERModel): r"""An implementation of TuckEr from [balazevic2019]_. TuckER is a linear model that is based on the tensor factorization method Tucker in which a three-mode tensor @@ -88,7 +74,6 @@ def __init__( dropout_0: float = 0.3, dropout_1: float = 0.4, dropout_2: float = 0.5, - regularizer: Optional[Regularizer] = None, apply_batch_normalization: bool = True, ) -> None: """Initialize the model. @@ -100,119 +85,26 @@ def __init__( where h,r,t are the head, relation, and tail embedding, W is the core tensor, x_i denotes the tensor product along the i-th mode, BN denotes batch normalization, and DO dropout. """ + relation_dim = relation_dim or embedding_dim super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, - relation_dim=relation_dim, + interaction=TuckerInteraction( + embedding_dim=embedding_dim, + relation_dim=relation_dim, + head_dropout=dropout_0, + relation_dropout=dropout_1, + head_relation_dropout=dropout_2, + apply_batch_normalization=apply_batch_normalization, + ), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_normal_, + ), + relation_representations=EmbeddingSpecification( + embedding_dim=relation_dim, + initializer=xavier_normal_, + ), loss=loss, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_normal_, - relation_initializer=xavier_normal_, - ) - - # Core tensor - # Note: we use a different dimension permutation as in the official implementation to match the paper. - self.core_tensor = nn.Parameter( - torch.empty(self.embedding_dim, self.relation_dim, self.embedding_dim, device=self.device), - requires_grad=True, ) - - # Dropout - self.input_dropout = nn.Dropout(dropout_0) - self.hidden_dropout_1 = nn.Dropout(dropout_1) - self.hidden_dropout_2 = nn.Dropout(dropout_2) - - self.apply_batch_normalization = apply_batch_normalization - - if self.apply_batch_normalization: - self.bn_0 = nn.BatchNorm1d(self.embedding_dim) - self.bn_1 = nn.BatchNorm1d(self.embedding_dim) - - def _reset_parameters_(self): # noqa: D102 - super()._reset_parameters_() - # Initialize core tensor, cf. https://github.com/ibalazevic/TuckER/blob/master/model.py#L12 - nn.init.uniform_(self.core_tensor, -1., 1.) - - def _scoring_function( - self, - h: torch.FloatTensor, - r: torch.FloatTensor, - t: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - Evaluate the scoring function. - - Compute scoring function W x_1 h x_2 r x_3 t as in the official implementation, i.e. as - - DO(BN(DO(BN(h)) x_1 DO(W x_2 r))) x_3 t - - where BN denotes BatchNorm and DO denotes Dropout - - :param h: shape: (batch_size, 1, embedding_dim) or (1, num_entities, embedding_dim) - :param r: shape: (batch_size, relation_dim) - :param t: shape: (1, num_entities, embedding_dim) or (batch_size, 1, embedding_dim) - :return: shape: (batch_size, num_entities) or (batch_size, 1) - """ - # Abbreviation - w = self.core_tensor - d_e = self.embedding_dim - d_r = self.relation_dim - - # Compute h_n = DO(BN(h)) - if self.apply_batch_normalization: - h = _apply_bn_to_tensor(batch_norm=self.bn_0, tensor=h) - - h = self.input_dropout(h) - - # Compute wr = DO(W x_2 r) - w = w.view(1, d_e, d_r, d_e) - r = r.view(-1, 1, 1, d_r) - wr = r @ w - wr = self.hidden_dropout_1(wr) - - # compute whr = DO(BN(h_n x_1 wr)) - wr = wr.view(-1, d_e, d_e) - whr = (h @ wr) - if self.apply_batch_normalization: - whr = _apply_bn_to_tensor(batch_norm=self.bn_1, tensor=whr) - whr = self.hidden_dropout_2(whr) - - # Compute whr x_3 t - scores = torch.sum(whr * t, dim=-1) - - return scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hrt_batch[:, 0]).unsqueeze(1) - r = self.relation_embeddings(indices=hrt_batch[:, 1]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]).unsqueeze(1) - - # Compute scores - scores = self._scoring_function(h=h, r=r, t=t) - - return scores - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=hr_batch[:, 0]).unsqueeze(1) - r = self.relation_embeddings(indices=hr_batch[:, 1]) - t = self.entity_embeddings(indices=None).unsqueeze(0) - - # Compute scores - scores = self._scoring_function(h=h, r=r, t=t) - - return scores - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(indices=None).unsqueeze(0) - r = self.relation_embeddings(indices=rt_batch[:, 0]) - t = self.entity_embeddings(indices=rt_batch[:, 1]).unsqueeze(1) - - # Compute scores - scores = self._scoring_function(h=h, r=r, t=t) - - return scores diff --git a/src/pykeen/models/unimodal/unstructured_model.py b/src/pykeen/models/unimodal/unstructured_model.py index ef22d15a3e..cc39532f02 100644 --- a/src/pykeen/models/unimodal/unstructured_model.py +++ b/src/pykeen/models/unimodal/unstructured_model.py @@ -4,14 +4,12 @@ from typing import Any, ClassVar, Mapping, Optional -import torch -import torch.autograd - -from ..base import EntityEmbeddingModel +from ..base import ERModel from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE from ...losses import Loss +from ...nn import EmbeddingSpecification from ...nn.init import xavier_normal_ -from ...regularizers import Regularizer +from ...nn.modules import UnstructuredModelInteraction from ...triples import TriplesFactory from ...typing import DeviceHint @@ -20,7 +18,7 @@ ] -class UnstructuredModel(EntityEmbeddingModel): +class UnstructuredModel(ERModel): r"""An implementation of the Unstructured Model (UM) published by [bordes2014]_. UM computes the distance between head and tail entities then applies the $l_p$ norm. @@ -50,37 +48,25 @@ def __init__( embedding_dim: int = 50, scoring_fct_norm: int = 1, loss: Optional[Loss] = None, + predict_with_sigmoid: bool = False, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, - regularizer: Optional[Regularizer] = None, ) -> None: r"""Initialize UM. :param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$. :param scoring_fct_norm: The $l_p$ norm. Usually 1 for UM. """ + self.embedding_dim = embedding_dim super().__init__( triples_factory=triples_factory, - embedding_dim=embedding_dim, loss=loss, + predict_with_sigmoid=predict_with_sigmoid, preferred_device=preferred_device, random_seed=random_seed, - regularizer=regularizer, - entity_initializer=xavier_normal_, + interaction=UnstructuredModelInteraction(p=scoring_fct_norm), + entity_representations=EmbeddingSpecification( + embedding_dim=embedding_dim, + initializer=xavier_normal_, + ), ) - self.scoring_fct_norm = scoring_fct_norm - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hrt_batch[:, 0]) - t = self.entity_embeddings(indices=hrt_batch[:, 2]) - return -torch.norm(h - t, dim=-1, p=self.scoring_fct_norm, keepdim=True) ** 2 - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.embedding_dim) - t = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim) - return -torch.norm(h - t, dim=-1, p=self.scoring_fct_norm) ** 2 - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - h = self.entity_embeddings(indices=None).view(1, -1, self.embedding_dim) - t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.embedding_dim) - return -torch.norm(h - t, dim=-1, p=self.scoring_fct_norm) ** 2 diff --git a/src/pykeen/nn/__init__.py b/src/pykeen/nn/__init__.py index 39930ddcb0..dada0a07ca 100644 --- a/src/pykeen/nn/__init__.py +++ b/src/pykeen/nn/__init__.py @@ -2,11 +2,16 @@ """PyKEEN internal "nn" module.""" -from . import init -from .emb import Embedding, RepresentationModule +from . import functional, init +from .modules import Interaction +from .representation import Embedding, EmbeddingSpecification, LiteralRepresentations, RepresentationModule __all__ = [ 'Embedding', + 'EmbeddingSpecification', + 'LiteralRepresentations', 'RepresentationModule', + 'Interaction', 'init', + 'functional', ] diff --git a/src/pykeen/nn/compute_kernel.py b/src/pykeen/nn/compute_kernel.py new file mode 100644 index 0000000000..f56fe99c12 --- /dev/null +++ b/src/pykeen/nn/compute_kernel.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- + +"""Compute kernels for common sub-tasks.""" + +import numpy +import torch + +from pykeen.utils import extended_einsum, split_complex, tensor_product, view_complex + + +def _batched_dot_manual( + a: torch.FloatTensor, + b: torch.FloatTensor, +) -> torch.FloatTensor: + return (a * b).sum(dim=-1) + + +def _batched_dot_matmul( + a: torch.FloatTensor, + b: torch.FloatTensor, +) -> torch.FloatTensor: + return (a.unsqueeze(dim=-2) @ b.unsqueeze(dim=-1)).view(a.shape[:-1]) + + +def _batched_dot_einsum( + a: torch.FloatTensor, + b: torch.FloatTensor, +) -> torch.FloatTensor: + return torch.einsum("...i,...i->...", a, b) + + +def batched_dot( + a: torch.FloatTensor, + b: torch.FloatTensor, +) -> torch.FloatTensor: + """Compute "element-wise" dot-product between batched vectors.""" + return _batched_dot_manual(a, b) + + +def _complex_broadcast_optimized( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Manually split into real/imag, and used optimized broadcasted combination.""" + (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] + return sum( + factor * tensor_product(hh, rr, tt).sum(dim=-1) + for factor, hh, rr, tt in [ + (+1, h_re, r_re, t_re), + (+1, h_re, r_im, t_im), + (+1, h_im, r_re, t_im), + (-1, h_im, r_im, t_re), + ] + ) + + +def _complex_direct( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Manually split into real/imag, and directly evaluate interaction.""" + (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] + return ( + (h_re * r_re * t_re).sum(dim=-1) + + (h_re * r_im * t_im).sum(dim=-1) + + (h_im * r_re * t_im).sum(dim=-1) + - (h_im * r_im * t_re).sum(dim=-1) + ) + + +def _complex_einsum( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Use einsum.""" + x = h.new_zeros(2, 2, 2) + x[0, 0, 0] = 1 + x[0, 1, 1] = 1 + x[1, 0, 1] = 1 + x[1, 1, 0] = -1 + return extended_einsum( + "ijk,bhdi,brdj,btdk->bhrt", + x, + h.view(*h.shape[:-1], -1, 2), + r.view(*r.shape[:-1], -1, 2), + t.view(*t.shape[:-1], -1, 2), + ) + + +def _complex_native_complex( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Use torch built-ins for computation with complex numbers.""" + h, r, t = [view_complex(x=x) for x in (h, r, t)] + return torch.real(tensor_product(h, r, torch.conj(t)).sum(dim=-1)) + + +def _complex_native_complex_select( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Use torch built-ins for computation with complex numbers and select whether to combine hr or ht first.""" + h, r, t = [view_complex(x=x) for x in (h, r, t)] + hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)]) + rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)]) + t = torch.conj(t) + if hr_cost < rt_cost: + h = h * r + else: + t = r * t + return torch.real((h * t).sum(dim=-1)) + + +def _complex_select( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Decide based on result shape whether to combine hr or ht first.""" + hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)]) + rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)]) + (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] + if hr_cost < rt_cost: + h_re, h_im = (h_re * r_re - h_im * r_im), (h_re * r_im + h_im * r_re) + else: + t_re, t_im = (t_re * r_re - t_im * r_im), (t_re * r_im + t_im * r_re) + return h_re @ t_re.transpose(-2, -1) - h_im @ t_im.transpose(-2, -1) + + +def _complex_to_stacked(h, r, t): + (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (r, t)] + h = torch.cat([h, h], dim=-1) # re im re im + r = torch.cat([r_re, r_re, r_im, r_im], dim=-1) # re re im im + t = torch.cat([t_re, t_im, t_im, t_re], dim=-1) # re im im re + return h, r, t + + +def _complex_stacked( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Stack vectors.""" + h, r, t = _complex_to_stacked(h, r, t) + return (h * r * t).sum(dim=-1) + + +def _complex_stacked_select( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Stack vectors and select order.""" + h, r, t = _complex_to_stacked(h, r, t) + hr_cost = numpy.prod([max(hs, rs) for hs, rs in zip(h.shape, r.shape)]) + rt_cost = numpy.prod([max(ts, rs) for ts, rs in zip(t.shape, r.shape)]) + if hr_cost < rt_cost: + # h = h_re, -h_im + h = h * r + else: + t = r * t + return h @ t.transpose(-2, -1) diff --git a/src/pykeen/nn/functional.py b/src/pykeen/nn/functional.py new file mode 100644 index 0000000000..c7d7eb5871 --- /dev/null +++ b/src/pykeen/nn/functional.py @@ -0,0 +1,943 @@ +# -*- coding: utf-8 -*- + +""" +Functional forms of interaction methods. + +The functional forms always assume the general form of the interaction function, where head, relation and tail +representations are provided in shape (batch_size, num_heads, 1, 1, ``*``), (batch_size, 1, num_relations, 1, ``*``), +and (batch_size, 1, 1, num_tails, ``*``), and return a score tensor of shape +(batch_size, num_heads, num_relations, num_tails). +""" + +import dataclasses +from typing import Optional, Tuple, Union + +import numpy +import torch +import torch.fft +from torch import nn + +from .compute_kernel import _complex_native_complex +from .sim import KG2E_SIMILARITIES +from ..typing import GaussianDistribution +from ..utils import ( + broadcast_cat, clamp_norm, estimate_cost_of_sequence, extended_einsum, is_cudnn_error, negative_norm, + negative_norm_of_sum, project_entity, tensor_product, tensor_sum, view_complex, +) + +__all__ = [ + "complex_interaction", + "conve_interaction", + "convkb_interaction", + "distmult_interaction", + "ermlp_interaction", + "ermlpe_interaction", + "hole_interaction", + "kg2e_interaction", + "ntn_interaction", + "proje_interaction", + "rescal_interaction", + "rotate_interaction", + "simple_interaction", + "structured_embedding_interaction", + "transd_interaction", + "transe_interaction", + "transh_interaction", + "transr_interaction", + "tucker_interaction", + "unstructured_model_interaction", +] + + +@dataclasses.dataclass +class SizeInformation: + """Size information of generic score function.""" + + #: The batch size of the head representations. + bh: int + + #: The number of head representations per batch + nh: int + + #: The batch size of the relation representations. + br: int + + #: The number of relation representations per batch + nr: int + + #: The batch size of the tail representations. + bt: int + + #: The number of tail representations per batch + nt: int + + @property + def same(self) -> bool: + """Whether all representations have the same shape.""" + return ( + self.bh == self.br + and self.bh == self.bt + and self.nh == self.nr + and self.nh == self.nt + ) + + +def _extract_size_information( + h: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, +) -> SizeInformation: + """Extract size information from tensors.""" + bh, nh = h.shape[:2] + br, nr = r.shape[:2] + bt, nt = t.shape[:2] + return SizeInformation(bh=bh, nh=nh, br=br, nr=nr, bt=bt, nt=nt) + + +def _extract_sizes( + h: torch.Tensor, + r: torch.Tensor, + t: torch.Tensor, +) -> Tuple[int, int, int, int, int]: + """Extract size dimensions from head/relation/tail representations.""" + num_heads, num_relations, num_tails = [xx.shape[i] for i, xx in enumerate((h, r, t), start=1)] + d_e = h.shape[-1] + d_r = r.shape[-1] + return num_heads, num_relations, num_tails, d_e, d_r + + +def _apply_optional_bn_to_tensor( + x: torch.FloatTensor, + output_dropout: nn.Dropout, + batch_norm: Optional[nn.BatchNorm1d] = None, +) -> torch.FloatTensor: + """Apply optional batch normalization and dropout layer. Supports multiple batch dimensions.""" + if batch_norm is not None: + shape = x.shape + x = x.reshape(-1, shape[-1]) + x = batch_norm(x) + x = x.view(*shape) + return output_dropout(x) + + +def _add_cuda_warning(func): + def wrapped(*args, **kwargs): + try: + return func(*args, **kwargs) + except RuntimeError as e: + if not is_cudnn_error(e): + raise e + raise RuntimeError( + '\nThis code crash might have been caused by a CUDA bug, see ' + 'https://github.com/allenai/allennlp/issues/2888, ' + 'which causes the code to crash during evaluation mode.\n' + 'To avoid this error, the batch size has to be reduced.', + ) from e + + return wrapped + + +def complex_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + r""" + Evaluate the ComplEx interaction function. + + .. math :: + Re(\langle h, r, conj(t) \rangle) + + :param h: shape: (batch_size, num_heads, 1, 1, `2*dim`) + The complex head representations. + :param r: shape: (batch_size, 1, num_relations, 1, 2*dim) + The complex relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, 2*dim) + The complex tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return _complex_native_complex(h, r, t) + + +@_add_cuda_warning +def conve_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + t_bias: torch.FloatTensor, + input_channels: int, + embedding_height: int, + embedding_width: int, + hr2d: nn.Module, + hr1d: nn.Module, +) -> torch.FloatTensor: + """Evaluate the ConvE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param t_bias: shape: (batch_size, 1, 1, num_tails, 1) + The tail entity bias. + :param input_channels: + The number of input channels. + :param embedding_height: + The height of the reshaped embedding. + :param embedding_width: + The width of the reshaped embedding. + :param hr2d: + The first module, transforming the 2D stacked head-relation "image". + :param hr1d: + The second module, transforming the 1D flattened output of the 2D module. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # repeat if necessary, and concat head and relation, batch_size', num_input_channels, 2*height, width + # with batch_size' = batch_size * num_heads * num_relations + x = broadcast_cat( + h.view(*h.shape[:-1], input_channels, embedding_height, embedding_width), + r.view(*r.shape[:-1], input_channels, embedding_height, embedding_width), + dim=-2, + ).view(-1, input_channels, 2 * embedding_height, embedding_width) + + # batch_size', num_input_channels, 2*height, width + x = hr2d(x) + + # batch_size', num_output_channels * (2 * height - kernel_height + 1) * (width - kernel_width + 1) + x = x.view(-1, numpy.prod(x.shape[-3:])) + x = hr1d(x) + + # reshape: (batch_size', embedding_dim) -> (b, h, r, 1, d) + x = x.view(-1, h.shape[1], r.shape[2], 1, h.shape[-1]) + + # For efficient calculation, each of the convolved [h, r] rows has only to be multiplied with one t row + # output_shape: (batch_size, num_heads, num_relations, num_tails) + t = t.transpose(-1, -2) + x = (x @ t).squeeze(dim=-2) + + # add bias term + return x + t_bias.squeeze(dim=-1) + + +def convkb_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + conv: nn.Conv2d, + activation: nn.Module, + hidden_dropout: nn.Dropout, + linear: nn.Linear, +) -> torch.FloatTensor: + r"""Evaluate the ConvKB interaction function. + + .. math:: + W_L drop(act(W_C \ast ([h; r; t]) + b_C)) + b_L + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param conv: + The 3x1 convolution. + :param activation: + The activation function. + :param hidden_dropout: + The dropout layer applied to the hidden activations. + :param linear: + The final linear layer. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # decompose convolution for faster computation in 1-n case + num_filters = conv.weight.shape[0] + assert conv.weight.shape == (num_filters, 1, 1, 3) + + # compute conv(stack(h, r, t)) + # prepare input shapes for broadcasting + # (b, h, r, t, 1, d) + h = h.unsqueeze(dim=-2) + r = r.unsqueeze(dim=-2) + t = t.unsqueeze(dim=-2) + + # conv.weight.shape = (C_out, C_in, kernel_size[0], kernel_size[1]) + # here, kernel_size = (1, 3), C_in = 1, C_out = num_filters + # -> conv_head, conv_rel, conv_tail shapes: (num_filters,) + # reshape to (1, 1, 1, 1, f, 1) + conv_head, conv_rel, conv_tail, conv_bias = [ + c.view(1, 1, 1, 1, num_filters, 1) + for c in list(conv.weight[:, 0, 0, :].t()) + [conv.bias] + ] + + # convolve -> output.shape: (*, embedding_dim, num_filters) + h = conv_head @ h + r = conv_rel @ r + t = conv_tail @ t + + x = tensor_sum(conv_bias, h, r, t) + x = activation(x) + + # Apply dropout, cf. https://github.com/daiquocnguyen/ConvKB/blob/master/model.py#L54-L56 + x = hidden_dropout(x) + + # Linear layer for final scores; use flattened representations, shape: (b, h, r, t, d * f) + x = x.view(*x.shape[:-2], -1) + x = linear(x) + return x.squeeze(dim=-1) + + +def distmult_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """ + Evaluate the DistMult interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return tensor_product(h, r, t).sum(dim=-1) + + +def ermlp_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + hidden: nn.Linear, + activation: nn.Module, + final: nn.Linear, +) -> torch.FloatTensor: + r""" + Evaluate the ER-MLP interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param hidden: + The first linear layer. + :param activation: + The activation function of the hidden layer. + :param final: + The second linear layer. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + sizes = _extract_size_information(h, r, t) + + # same shape + if sizes.same: + return final(activation( + hidden(torch.cat([h, r, t], dim=-1).view(-1, 3 * h.shape[-1]))), + ).view(sizes.bh, sizes.nh, sizes.nr, sizes.nt) + + hidden_dim = hidden.weight.shape[0] + # split, shape: (embedding_dim, hidden_dim) + head_to_hidden, rel_to_hidden, tail_to_hidden = hidden.weight.t().split(h.shape[-1]) + bias = hidden.bias.view(1, 1, 1, 1, -1) + h = h @ head_to_hidden.view(1, 1, 1, h.shape[-1], hidden_dim) + r = r @ rel_to_hidden.view(1, 1, 1, r.shape[-1], hidden_dim) + t = t @ tail_to_hidden.view(1, 1, 1, t.shape[-1], hidden_dim) + return final(activation(tensor_sum(bias, h, r, t))).squeeze(dim=-1) + + +def ermlpe_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + mlp: nn.Module, +) -> torch.FloatTensor: + r""" + Evaluate the ER-MLPE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param mlp: + The MLP. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # repeat if necessary, and concat head and relation, (batch_size, num_heads, num_relations, 1, 2 * embedding_dim) + x = broadcast_cat(h, r, dim=-1) + + # Predict t embedding, shape: (b, h, r, 1, d) + shape = x.shape + x = mlp(x.view(-1, shape[-1])).view(*shape[:-1], -1) + + # transpose t, (b, 1, 1, d, t) + t = t.transpose(-2, -1) + + # dot product, (b, h, r, 1, t) + return (x @ t).squeeze(dim=-2) + + +def hole_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: # noqa: D102 + """ + Evaluate the HolE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # Circular correlation of entity embeddings + a_fft = torch.fft.rfft(h, dim=-1) + b_fft = torch.fft.rfft(t, dim=-1) + + # complex conjugate + a_fft = torch.conj(a_fft) + + # Hadamard product in frequency domain + p_fft = a_fft * b_fft + + # inverse real FFT, shape: (b, h, 1, t, d) + composite = torch.fft.irfft(p_fft, n=h.shape[-1], dim=-1) + + # transpose composite: (b, h, 1, d, t) + composite = composite.transpose(-2, -1) + + # inner product with relation embedding + return (r @ composite).squeeze(dim=-2) + + +def kg2e_interaction( + h_mean: torch.FloatTensor, + h_var: torch.FloatTensor, + r_mean: torch.FloatTensor, + r_var: torch.FloatTensor, + t_mean: torch.FloatTensor, + t_var: torch.FloatTensor, + similarity: str = "KL", + exact: bool = True, +) -> torch.FloatTensor: + """ + Evaluate the KG2E interaction function. + + :param h_mean: shape: (batch_size, num_heads, 1, 1, d) + The head entity distribution mean. + :param h_var: shape: (batch_size, num_heads, 1, 1, d) + The head entity distribution variance. + :param r_mean: shape: (batch_size, 1, num_relations, 1, d) + The relation distribution mean. + :param r_var: shape: (batch_size, 1, num_relations, 1, d) + The relation distribution variance. + :param t_mean: shape: (batch_size, 1, 1, num_tails, d) + The tail entity distribution mean. + :param t_var: shape: (batch_size, 1, 1, num_tails, d) + The tail entity distribution variance. + :param similarity: + The similarity measures for gaussian distributions. From {"KL", "EL"}. + :param exact: + Whether to leave out constants to accelerate similarity computation. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return KG2E_SIMILARITIES[similarity]( + h=GaussianDistribution(mean=h_mean, diagonal_covariance=h_var), + r=GaussianDistribution(mean=r_mean, diagonal_covariance=r_var), + t=GaussianDistribution(mean=t_mean, diagonal_covariance=t_var), + exact=exact, + ) + + +def ntn_interaction( + h: torch.FloatTensor, + t: torch.FloatTensor, + w: torch.FloatTensor, + vh: torch.FloatTensor, + vt: torch.FloatTensor, + b: torch.FloatTensor, + u: torch.FloatTensor, + activation: nn.Module, +) -> torch.FloatTensor: + r""" + Evaluate the NTN interaction function. + + .. math:: + + f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r' t + b_r) + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param w: shape: (batch_size, 1, num_relations, 1, k, dim, dim) + The relation specific transformation matrix W_r. + :param vh: shape: (batch_size, 1, num_relations, 1, k, dim) + The head transformation matrix V_h. + :param vt: shape: (batch_size, 1, num_relations, 1, k, dim) + The tail transformation matrix V_h. + :param b: shape: (batch_size, 1, num_relations, 1, k) + The relation specific offset b_r. + :param u: shape: (batch_size, 1, num_relations, 1, k) + The relation specific final linear transformation b_r. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param activation: + The activation function. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + x = activation(tensor_sum( + extended_einsum("bhrtd,bhrtkde,bhrte->bhrtk", h, w, t), + (vh @ h.unsqueeze(dim=-1)).squeeze(dim=-1), + (vt @ t.unsqueeze(dim=-1)).squeeze(dim=-1), + b, + )) + u = u.transpose(-2, -1) + return (x @ u).squeeze(dim=-1) + + +def proje_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + d_e: torch.FloatTensor, + d_r: torch.FloatTensor, + b_c: torch.FloatTensor, + b_p: torch.FloatTensor, + activation: nn.Module, +) -> torch.FloatTensor: + r""" + Evaluate the ProjE interaction function. + + .. math:: + + f(h, r, t) = g(t z(D_e h + D_r r + b_c) + b_p) + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param d_e: shape: (dim,) + Global entity projection. + :param d_r: shape: (dim,) + Global relation projection. + :param b_c: shape: (dim,) + Global combination bias. + :param b_p: shape: (1,) + Final score bias + :param activation: + The activation function. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + num_heads, num_relations, num_tails, dim, _ = _extract_sizes(h, r, t) + # global projections + h = h * d_e.view(1, 1, 1, 1, dim) + r = r * d_r.view(1, 1, 1, 1, dim) + # combination, shape: (b, h, r, 1, d) + x = tensor_sum(h, r, b_c) + x = activation(x) # shape: (b, h, r, 1, d) + # dot product with t, shape: (b, h, r, t) + t = t.transpose(-2, -1) # shape: (b, 1, 1, d, t) + return (x @ t).squeeze(dim=-2) + b_p + + +def rescal_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """ + Evaluate the RESCAL interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return extended_einsum("bhrtd,bhrtde,bhrte->bhrt", h, r, t) + + +def rotate_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, +) -> torch.FloatTensor: + """Evaluate the interaction function of RotatE for given embeddings. + + :param h: shape: (batch_size, num_heads, 1, 1, 2*dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, 2*dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, 2*dim) + The tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # r expresses a rotation in complex plane. + h, r, t = [view_complex(x) for x in (h, r, t)] + if estimate_cost_of_sequence(h.shape, r.shape) < estimate_cost_of_sequence(r.shape, t.shape): + # rotate head by relation (=Hadamard product in complex space) + h = h * r + else: + # rotate tail by inverse of relation + # The inverse rotation is expressed by the complex conjugate of r. + # The score is computed as the distance of the relation-rotated head to the tail. + # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. + # |h * r - t| = |h - conj(r) * t| + t = t * torch.conj(r) + + # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed + return negative_norm(h - t, p=2, power_norm=False) + + +def simple_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + h_inv: torch.FloatTensor, + r_inv: torch.FloatTensor, + t_inv: torch.FloatTensor, + clamp: Optional[Tuple[float, float]] = None, +) -> torch.FloatTensor: + """ + Evaluate the SimplE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param h_inv: shape: (batch_size, num_heads, 1, 1, dim) + The inverse head representations. + :param r_inv: shape: (batch_size, 1, num_relations, 1, dim, dim) + The relation representations. + :param t_inv: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param clamp: + Clamp the scores to the given range. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + scores = 0.5 * (distmult_interaction(h=h, r=r, t=t) + distmult_interaction(h=h_inv, r=r_inv, t=t_inv)) + # Note: In the code in their repository, the score is clamped to [-20, 20]. + # That is not mentioned in the paper, so it is made optional here. + if clamp: + min_, max_ = clamp + scores = scores.clamp(min=min_, max=max_) + return scores + + +def structured_embedding_interaction( + h: torch.FloatTensor, + r_h: torch.FloatTensor, + r_t: torch.FloatTensor, + t: torch.FloatTensor, + p: int, + power_norm: bool = False, +) -> torch.FloatTensor: + r""" + Evaluate the Structured Embedding interaction function. + + .. math :: + f(h, r, t) = -\|R_h h - R_t t\| + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r_h: shape: (batch_size, 1, num_relations, 1, rel_dim, dim) + The relation-specific head projection. + :param r_t: shape: (batch_size, 1, num_relations, 1, rel_dim, dim) + The relation-specific tail projection. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param p: + The p for the norm. cf. torch.norm. + :param power_norm: + Whether to return the powered norm. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return negative_norm( + (r_h @ h.unsqueeze(dim=-1) - r_t @ t.unsqueeze(dim=-1)).squeeze(dim=-1), + p=p, + power_norm=power_norm, + ) + + +def transd_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + h_p: torch.FloatTensor, + r_p: torch.FloatTensor, + t_p: torch.FloatTensor, + p: int, + power_norm: bool = False, +) -> torch.FloatTensor: + """ + Evaluate the TransD interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, d_e) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, d_r) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, d_e) + The tail representations. + :param h_p: shape: (batch_size, num_heads, 1, 1, d_e) + The head projections. + :param r_p: shape: (batch_size, 1, num_relations, 1, d_r) + The relation projections. + :param t_p: shape: (batch_size, 1, 1, num_tails, d_e) + The tail projections. + :param p: + The parameter p for selecting the norm. + :param power_norm: + Whether to return the powered norm instead. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # Project entities + h_bot = project_entity( + e=h, + e_p=h_p, + r_p=r_p, + ) + t_bot = project_entity( + e=t, + e_p=t_p, + r_p=r_p, + ) + return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm) + + +def transe_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + p: Union[int, str] = 2, + power_norm: bool = False, +) -> torch.FloatTensor: + """ + Evaluate the TransE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, dim) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param p: + The p for the norm. + :param power_norm: + Whether to return the powered norm. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return negative_norm_of_sum(h, r, -t, p=p, power_norm=power_norm) + + +def transh_interaction( + h: torch.FloatTensor, + w_r: torch.FloatTensor, + d_r: torch.FloatTensor, + t: torch.FloatTensor, + p: int, + power_norm: bool = False, +) -> torch.FloatTensor: + """ + Evaluate the DistMult interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param w_r: shape: (batch_size, 1, num_relations, 1, dim) + The relation normal vector representations. + :param d_r: shape: (batch_size, 1, num_relations, 1, dim) + The relation difference vector representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param p: + The p for the norm. cf. torch.norm. + :param power_norm: + Whether to return $|x-y|_p^p$. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return negative_norm_of_sum( + # h projection to hyperplane + h, + -(h * w_r).sum(dim=-1, keepdims=True) * w_r, + # r + d_r, + # -t projection to hyperplane + -t, + (t * w_r).sum(dim=-1, keepdims=True) * w_r, + p=p, + power_norm=power_norm, + ) + + +def transr_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + m_r: torch.FloatTensor, + p: int, + power_norm: bool = True, +) -> torch.FloatTensor: + """Evaluate the interaction function for given embeddings. + + :param h: shape: (batch_size, num_heads, 1, 1, d_e) + Head embeddings. + :param r: shape: (batch_size, 1, num_relations, 1, d_r) + Relation embeddings. + :param m_r: shape: (batch_size, 1, num_relations, 1, d_e, d_r) + The relation specific linear transformations. + :param t: shape: (batch_size, 1, 1, num_tails, d_e) + Tail embeddings. + :param p: + The parameter p for selecting the norm. + :param power_norm: + Whether to return the powered norm instead. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + # project to relation specific subspace and ensure constraints + h_bot = clamp_norm((h.unsqueeze(dim=-2) @ m_r), p=2, dim=-1, maxnorm=1.).squeeze(dim=-2) + t_bot = clamp_norm((t.unsqueeze(dim=-2) @ m_r), p=2, dim=-1, maxnorm=1.).squeeze(dim=-2) + return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm) + + +def tucker_interaction( + h: torch.FloatTensor, + r: torch.FloatTensor, + t: torch.FloatTensor, + core_tensor: torch.FloatTensor, + do_h: nn.Dropout, + do_r: nn.Dropout, + do_hr: nn.Dropout, + bn_h: Optional[nn.BatchNorm1d], + bn_hr: Optional[nn.BatchNorm1d], +) -> torch.FloatTensor: + r""" + Evaluate the TuckEr interaction function. + + Compute scoring function W x_1 h x_2 r x_3 t as in the official implementation, i.e. as + + .. math :: + + DO_{hr}(BN_{hr}(DO_h(BN_h(h)) x_1 DO_r(W x_2 r))) x_3 t + + where BN denotes BatchNorm and DO denotes Dropout + + :param h: shape: (batch_size, num_heads, 1, 1, d_e) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, d_r) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, d_e) + The tail representations. + :param core_tensor: shape: (d_e, d_r, d_e) + The core tensor. + :param do_h: + The dropout layer for the head representations. + :param do_r: + The first hidden dropout. + :param do_hr: + The second hidden dropout. + :param bn_h: + The first batch normalization layer. + :param bn_hr: + The second batch normalization layer. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return extended_einsum( + # x_3 contraction + "bhrtk,bhrtk->bhrt", + _apply_optional_bn_to_tensor( + x=extended_einsum( + # x_1 contraction + "bhrtik,bhrti->bhrtk", + _apply_optional_bn_to_tensor( + x=extended_einsum( + # x_2 contraction + "ijk,bhrtj->bhrtik", + core_tensor, + r, + ), + output_dropout=do_r, + ), + _apply_optional_bn_to_tensor( + x=h, + batch_norm=bn_h, + output_dropout=do_h, + )), + batch_norm=bn_hr, + output_dropout=do_hr, + ), + t, + ) + + +def unstructured_model_interaction( + h: torch.FloatTensor, + t: torch.FloatTensor, + p: int, + power_norm: bool = True, +) -> torch.FloatTensor: + """ + Evaluate the SimplE interaction function. + + :param h: shape: (batch_size, num_heads, 1, 1, dim) + The head representations. + :param t: shape: (batch_size, 1, 1, num_tails, dim) + The tail representations. + :param p: + The parameter p for selecting the norm. + :param power_norm: + Whether to return the powered norm instead. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return negative_norm(h - t, p=p, power_norm=power_norm) diff --git a/src/pykeen/nn/init.py b/src/pykeen/nn/init.py index 973b2b7c85..b01af783d6 100644 --- a/src/pykeen/nn/init.py +++ b/src/pykeen/nn/init.py @@ -5,6 +5,7 @@ import math import numpy as np +import torch import torch.nn import torch.nn.init diff --git a/src/pykeen/nn/modules.py b/src/pykeen/nn/modules.py new file mode 100644 index 0000000000..91eef5a97a --- /dev/null +++ b/src/pykeen/nn/modules.py @@ -0,0 +1,1040 @@ +# -*- coding: utf-8 -*- + +"""Stateful interaction functions.""" + +import logging +import math +from abc import ABC +from typing import ( + Any, Callable, Generic, Mapping, MutableMapping, Optional, Sequence, TYPE_CHECKING, Tuple, Type, + Union, +) + +import torch +from torch import FloatTensor, nn + +from . import functional as pkf +from .representation import CANONICAL_DIMENSIONS, convert_to_canonical_shape +from ..typing import HeadRepresentation, RelationRepresentation, TailRepresentation +from ..utils import ensure_tuple, upgrade_to_sequence + +if TYPE_CHECKING: + from ..typing import Representation # noqa + +__all__ = [ + # Base Classes + 'Interaction', + 'TranslationalInteraction', + # Concrete Classes + 'ComplExInteraction', + 'ConvEInteraction', + 'ConvKBInteraction', + 'DistMultInteraction', + 'ERMLPInteraction', + 'ERMLPEInteraction', + 'HolEInteraction', + 'KG2EInteraction', + 'NTNInteraction', + 'ProjEInteraction', + 'RESCALInteraction', + 'RotatEInteraction', + 'SimplEInteraction', + 'StructuredEmbeddingInteraction', + 'TransDInteraction', + 'TransEInteraction', + 'TransHInteraction', + 'TransRInteraction', + 'TuckerInteraction', + 'UnstructuredModelInteraction', +] + +logger = logging.getLogger(__name__) + + +def _get_prefix(slice_size, slice_dim, d) -> str: + if slice_size is None or slice_dim != d: + return 'b' + else: + return 'n' + + +def _get_batches(z, slice_size): + for batch in zip(*(hh.split(slice_size, dim=1) for hh in ensure_tuple(z)[0])): + if len(batch) == 1: + batch = batch[0] + yield batch + + +class Interaction(nn.Module, Generic[HeadRepresentation, RelationRepresentation, TailRepresentation], ABC): + """Base class for interaction functions.""" + + #: The symbolic shapes for entity representations + entity_shape: Sequence[str] = ("d",) + + #: The symbolic shapes for entity representations for tail entities, if different. This is ony relevant for ConvE. + tail_entity_shape: Optional[Sequence[str]] = None + + #: The symbolic shapes for relation representations + relation_shape: Sequence[str] = ("d",) + + #: The functional interaction form + func: Callable[..., torch.FloatTensor] + + @classmethod + def from_func(cls, f) -> 'Interaction': + """Create an instance of a stateless interaction class.""" + return cls.cls_from_func(f)() + + @classmethod + def cls_from_func(cls, f) -> Type['Interaction']: + """Create a stateless interaction class.""" + + class StatelessInteraction(cls): # type: ignore + func = f + + return StatelessInteraction + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: + """Conversion utility to prepare the h/r/t representations for the functional form.""" + assert all(torch.is_tensor(x) for x in (h, r, t)) + return dict(h=h, r=r, t=t) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: + """Conversion utility to prepare the state to be passed to the functional form.""" + return dict() + + def _prepare_for_functional( + self, + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> Mapping[str, torch.FloatTensor]: + """Conversion utility to prepare the arguments for the functional form.""" + kwargs = self._prepare_hrt_for_functional(h=h, r=r, t=t) + kwargs.update(self._prepare_state_for_functional()) + return kwargs + + def forward( + self, + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> torch.FloatTensor: + """Compute broadcasted triple scores given broadcasted representations for head, relation and tails. + + :param h: shape: (batch_size, num_heads, 1, 1, ``*``) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, ``*``) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, ``*``) + The tail representations. + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return self.__class__.func(**self._prepare_for_functional(h=h, r=r, t=t)) + + def score( + self, + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, + ) -> torch.FloatTensor: + """ + Compute broadcasted triple scores with optional slicing. + + .. note :: + At most one of the slice sizes may be not None. + + :param h: shape: (batch_size, num_heads, `1, 1, `*``) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, ``*``) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, ``*``) + The tail representations. + :param slice_size: + The slice size. + :param slice_dim: + The dimension along which to slice. From {"h", "r", "t"} + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return self._forward_slicing_wrapper(h=h, r=r, t=t, slice_size=slice_size, slice_dim=slice_dim) + + def _score( + self, + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + slice_size: Optional[int] = None, + slice_dim: str = None, + ) -> torch.FloatTensor: + """ + Compute scores for the score_* methods outside of models. + + TODO: merge this with the Model utilities? + + :param h: shape: (b, h, *) + :param r: shape: (b, r, *) + :param t: shape: (b, t, *) + :param slice_size: ... + :param slice_dim: ... + :return: shape: (b, h, r, t) + """ + args = [] + for key, x in zip("hrt", (h, r, t)): + value = [] + for xx in upgrade_to_sequence(x): # type: torch.FloatTensor + # bring to (b, n, *) + xx = xx.unsqueeze(dim=1 if key != slice_dim else 0) + # bring to (b, h, r, t, *) + xx = convert_to_canonical_shape( + x=xx, + dim=key, + num=xx.shape[1], + batch_size=xx.shape[0], + suffix_shape=xx.shape[2:], + ) + value.append(xx) + # unpack singleton + if len(value) == 1: + value = value[0] + args.append(value) + h, r, t = args + return self._forward_slicing_wrapper(h=h, r=r, t=t, slice_dim=slice_dim, slice_size=slice_size) + + def _forward_slicing_wrapper( + self, + h: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], + r: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], + t: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], + slice_size: Optional[int], + slice_dim: Optional[str], + ) -> torch.FloatTensor: + """ + Compute broadcasted triple scores with optional slicing for representations in canonical shape. + + .. note :: + Depending on the interaction function, there may be more than one representation for h/r/t. In that case, + a tuple of at least two tensors is passed. + + :param h: shape: (batch_size, num_heads, 1, 1, ``*``) + The head representations. + :param r: shape: (batch_size, 1, num_relations, 1, ``*``) + The relation representations. + :param t: shape: (batch_size, 1, 1, num_tails, ``*``) + The tail representations. + :param slice_size: + The slice size. + :param slice_dim: + The dimension along which to slice. From {"h", "r", "t"} + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + + :raises ValueError: + If slice_dim is invalid. + """ + if slice_size is None: + scores = self(h=h, r=r, t=t) + elif slice_dim == "h": + scores = torch.cat([ + self(h=h_batch, r=r, t=t) + for h_batch in _get_batches(h, slice_size) + ], dim=CANONICAL_DIMENSIONS[slice_dim]) + elif slice_dim == "r": + scores = torch.cat([ + self(h=h, r=r_batch, t=t) + for r_batch in _get_batches(r, slice_size) + ], dim=CANONICAL_DIMENSIONS[slice_dim]) + elif slice_dim == "t": + scores = torch.cat([ + self(h=h, r=r, t=t_batch) + for t_batch in _get_batches(t, slice_size) + ], dim=CANONICAL_DIMENSIONS[slice_dim]) + else: + raise ValueError(f'Invalid slice_dim: {slice_dim}') + return scores + + def score_hrt( + self, + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> torch.FloatTensor: + """ + Score a batch of triples.. + + :param h: shape: (batch_size, d_e) + The head representations. + :param r: shape: (batch_size, d_r) + The relation representations. + :param t: shape: (batch_size, d_e) + The tail representations. + + :return: shape: (batch_size, 1) + The scores. + """ + return self._score(h=h, r=r, t=t)[:, 0, 0, 0, None] + + def score_h( + self, + all_entities: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + slice_size: Optional[int] = None, + ) -> torch.FloatTensor: + """ + Score all head entities. + + :param all_entities: shape: (num_entities, d_e) + The head representations. + :param r: shape: (batch_size, d_r) + The relation representations. + :param t: shape: (batch_size, d_e) + The tail representations. + :param slice_size: + The slice size. + + :return: shape: (batch_size, num_entities) + The scores. + """ + return self._score(h=all_entities, r=r, t=t, slice_dim="h", slice_size=slice_size)[:, :, 0, 0] + + def score_r( + self, + h: HeadRepresentation, + all_relations: RelationRepresentation, + t: TailRepresentation, + slice_size: Optional[int] = None, + ) -> torch.FloatTensor: + """ + Score all relations. + + :param h: shape: (batch_size, d_e) + The head representations. + :param all_relations: shape: (num_relations, d_r) + The relation representations. + :param t: shape: (batch_size, d_e) + The tail representations. + :param slice_size: + The slice size. + + :return: shape: (batch_size, num_entities) + The scores. + """ + return self._score(h=h, r=all_relations, t=t, slice_dim="r", slice_size=slice_size)[:, 0, :, 0] + + def score_t( + self, + h: HeadRepresentation, + r: RelationRepresentation, + all_entities: TailRepresentation, + slice_size: Optional[int] = None, + ) -> torch.FloatTensor: + """ + Score all tail entities. + + :param h: shape: (batch_size, d_e) + The head representations. + :param r: shape: (batch_size, d_r) + The relation representations. + :param all_entities: shape: (num_entities, d_e) + The tail representations. + :param slice_size: + The slice size. + + :return: shape: (batch_size, num_entities) + The scores. + """ + return self._score(h=h, r=r, t=all_entities, slice_dim="t", slice_size=slice_size)[:, 0, 0, :] + + def reset_parameters(self): + """Reset parameters the interaction function may have.""" + for mod in self.modules(): + if mod is self: + continue + if hasattr(mod, 'reset_parameters'): + mod.reset_parameters() + + +class TranslationalInteraction(Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation], ABC): + """The translational interaction function shared by the TransE, TransR, TransH, and other Trans models.""" + + def __init__(self, p: int, power_norm: bool = False): + """Initialize the translational interaction function. + + :param p: + The norm used with :func:`torch.norm`. Typically is 1 or 2. + :param power_norm: + Whether to use the p-th power of the L_p norm. It has the advantage of being differentiable around 0, + and numerically more stable. + """ + super().__init__() + self.p = p + self.power_norm = power_norm + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict(p=self.p, power_norm=self.power_norm) + + +class TransEInteraction(TranslationalInteraction[FloatTensor, FloatTensor, FloatTensor]): + """The TransE interaction function.""" + + func = pkf.transe_interaction + + +class ComplExInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of ComplEx.""" + + func = pkf.complex_interaction + + +def _calculate_missing_shape_information( + embedding_dim: int, + input_channels: Optional[int] = None, + width: Optional[int] = None, + height: Optional[int] = None, +) -> Tuple[int, int, int]: + """Automatically calculates missing dimensions for ConvE. + + :param embedding_dim: + The embedding dimension. + :param input_channels: + The number of input channels for the convolution. + :param width: + The width of the embedding "image". + :param height: + The height of the embedding "image". + + :return: (input_channels, width, height), such that + `embedding_dim = input_channels * width * height` + + :raises ValueError: + If no factorization could be found. + """ + # Store initial input for error message + original = (input_channels, width, height) + + # All are None -> try and make closest to square + if input_channels is None and width is None and height is None: + input_channels = 1 + result_sqrt = math.floor(math.sqrt(embedding_dim)) + height = max(factor for factor in range(1, result_sqrt + 1) if embedding_dim % factor == 0) + width = embedding_dim // height + # Only input channels is None + elif input_channels is None and width is not None and height is not None: + input_channels = embedding_dim // (width * height) + # Only width is None + elif input_channels is not None and width is None and height is not None: + width = embedding_dim // (height * input_channels) + # Only height is none + elif height is None and width is not None and input_channels is not None: + height = embedding_dim // (width * input_channels) + # Width and input_channels are None -> set input_channels to 1 and calculage height + elif input_channels is None and height is None and width is not None: + input_channels = 1 + height = embedding_dim // width + # Width and input channels are None -> set input channels to 1 and calculate width + elif input_channels is None and height is not None and width is None: + input_channels = 1 + width = embedding_dim // height + + if input_channels * width * height != embedding_dim: # type: ignore + raise ValueError(f'Could not resolve {original} to a valid factorization of {embedding_dim}.') + + return input_channels, width, height # type: ignore + + +class ConvEInteraction(Interaction[torch.FloatTensor, torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]]): + """ConvE interaction function.""" + + tail_entity_shape = ("d", "k") # with k=1 + + #: The head-relation encoder operating on 2D "images" + hr2d: nn.Module + + #: The head-relation encoder operating on the 1D flattened version + hr1d: nn.Module + + #: The interaction function + func = pkf.conve_interaction + + def __init__( + self, + input_channels: Optional[int] = None, + output_channels: int = 32, + embedding_height: Optional[int] = None, + embedding_width: Optional[int] = None, + kernel_height: int = 3, + kernel_width: int = 3, + input_dropout: float = 0.2, + output_dropout: float = 0.3, + feature_map_dropout: float = 0.2, + embedding_dim: int = 200, + apply_batch_normalization: bool = True, + ): + super().__init__() + + # Automatic calculation of remaining dimensions + logger.info(f'Resolving {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') + if embedding_dim is None: + embedding_dim = input_channels * embedding_width * embedding_height + + # Parameter need to fulfil: + # input_channels * embedding_height * embedding_width = embedding_dim + input_channels, embedding_width, embedding_height = _calculate_missing_shape_information( + embedding_dim=embedding_dim, + input_channels=input_channels, + width=embedding_width, + height=embedding_height, + ) + logger.info(f'Resolved to {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') + + if input_channels * embedding_height * embedding_width != embedding_dim: + raise ValueError( + f'Product of input channels ({input_channels}), height ({embedding_height}), and width ' + f'({embedding_width}) does not equal target embedding dimension ({embedding_dim})', + ) + + # encoders + # 1: 2D encoder: BN?, DO, Conv, BN?, Act, DO + hr2d_layers = [ + nn.BatchNorm2d(input_channels) if apply_batch_normalization else None, + nn.Dropout(input_dropout), + nn.Conv2d( + in_channels=input_channels, + out_channels=output_channels, + kernel_size=(kernel_height, kernel_width), + stride=1, + padding=0, + bias=True, + ), + nn.BatchNorm2d(output_channels) if apply_batch_normalization else None, + nn.ReLU(), + nn.Dropout2d(feature_map_dropout), + ] + self.hr2d = nn.Sequential(*(layer for layer in hr2d_layers if layer is not None)) + + # 2: 1D encoder: FC, DO, BN?, Act + num_in_features = ( + output_channels + * (2 * embedding_height - kernel_height + 1) + * (embedding_width - kernel_width + 1) + ) + hr1d_layers = [ + nn.Linear(num_in_features, embedding_dim), + nn.Dropout(output_dropout), + nn.BatchNorm1d(embedding_dim) if apply_batch_normalization else None, + nn.ReLU(), + ] + self.hr1d = nn.Sequential(*(layer for layer in hr1d_layers if layer is not None)) + + # store reshaping dimensions + self.embedding_height = embedding_height + self.embedding_width = embedding_width + self.input_channels = input_channels + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h, r=r, t=t[0], t_bias=t[1]) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict( + input_channels=self.input_channels, + embedding_height=self.embedding_height, + embedding_width=self.embedding_width, + hr2d=self.hr2d, + hr1d=self.hr1d, + ) + + +class ConvKBInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of ConvKB. + + .. seealso:: :func:`pykeen.nn.functional.convkb_interaction`` + """ + + func = pkf.convkb_interaction + + def __init__( + self, + hidden_dropout_rate: float = 0., + embedding_dim: int = 200, + num_filters: int = 400, + ): + super().__init__() + self.embedding_dim = embedding_dim + self.num_filters = num_filters + + # The interaction model + self.conv = nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(1, 3), bias=True) + self.activation = nn.ReLU() + self.hidden_dropout = nn.Dropout(p=hidden_dropout_rate) + self.linear = nn.Linear(embedding_dim * num_filters, 1, bias=True) + + def reset_parameters(self): # noqa: D102 + # Use Xavier initialization for weight; bias to zero + nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) + nn.init.zeros_(self.linear.bias) + + # Initialize all filters to [0.1, 0.1, -0.1], + # c.f. https://github.com/daiquocnguyen/ConvKB/blob/master/model.py#L34-L36 + nn.init.constant_(self.conv.weight[..., :2], 0.1) + nn.init.constant_(self.conv.weight[..., 2], -0.1) + nn.init.zeros_(self.conv.bias) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict( + conv=self.conv, + activation=self.activation, + hidden_dropout=self.hidden_dropout, + linear=self.linear, + ) + + +class DistMultInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """A module wrapping the DistMult interaction function at :func:`pykeen.nn.functional.distmult_interaction`.""" + + func = pkf.distmult_interaction + + +class ERMLPInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """A module wrapping the ER-MLP interaction function from :func:`pykeen.nn.functional.ermlp_interaction`. + + .. math :: + f(h, r, t) = W_2 ReLU(W_1 cat(h, r, t) + b_1) + b_2 + """ + + func = pkf.ermlp_interaction + + def __init__( + self, + embedding_dim: int, + hidden_dim: int, + ): + """Initialize the interaction function. + + :param embedding_dim: + The embedding vector dimension. + :param hidden_dim: + The hidden dimension of the MLP. + """ + super().__init__() + """The multi-layer perceptron consisting of an input layer with 3 * self.embedding_dim neurons, a hidden layer + with self.embedding_dim neurons and output layer with one neuron. + The input is represented by the concatenation embeddings of the heads, relations and tail embeddings. + """ + self.hidden = nn.Linear(in_features=3 * embedding_dim, out_features=hidden_dim, bias=True) + self.activation = nn.ReLU() + self.hidden_to_score = nn.Linear(in_features=hidden_dim, out_features=1, bias=True) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict( + hidden=self.hidden, + activation=self.activation, + final=self.hidden_to_score, + ) + + def reset_parameters(self): # noqa: D102 + # Initialize biases with zero + nn.init.zeros_(self.hidden.bias) + nn.init.zeros_(self.hidden_to_score.bias) + # In the original formulation, + nn.init.xavier_uniform_(self.hidden.weight) + nn.init.xavier_uniform_( + self.hidden_to_score.weight, + gain=nn.init.calculate_gain(self.activation.__class__.__name__.lower()), + ) + + +class ERMLPEInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of ER-MLP (E).""" + + func = pkf.ermlpe_interaction + + def __init__( + self, + hidden_dim: int = 300, + input_dropout: float = 0.2, + hidden_dropout: float = 0.3, + embedding_dim: int = 200, + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Dropout(input_dropout), + nn.Linear(2 * embedding_dim, hidden_dim), + nn.Dropout(hidden_dropout), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, embedding_dim), + nn.Dropout(hidden_dropout), + nn.BatchNorm1d(embedding_dim), + nn.ReLU(), + ) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict(mlp=self.mlp) + + +class TransRInteraction( + TranslationalInteraction[ + torch.FloatTensor, + Tuple[torch.FloatTensor, torch.FloatTensor], + torch.FloatTensor, + ], +): + """The TransR interaction function.""" + + relation_shape = ("e", "de") + func = pkf.transr_interaction + + def __init__(self, p: int, power_norm: bool = True): + super().__init__(p=p, power_norm=power_norm) + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h, r=r[0], t=t, m_r=r[1]) + + +class RotatEInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of RotatE.""" + + func = pkf.rotate_interaction + + +class HolEInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function for HolE.""" + + func = pkf.hole_interaction + + +class ProjEInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function for ProjE.""" + + func = pkf.proje_interaction + + def __init__( + self, + embedding_dim: int = 50, + inner_non_linearity: Optional[nn.Module] = None, + ): + super().__init__() + + # Global entity projection + self.d_e = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) + + # Global relation projection + self.d_r = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) + + # Global combination bias + self.b_c = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) + + # Global combination bias + self.b_p = nn.Parameter(torch.empty(1), requires_grad=True) + + if inner_non_linearity is None: + inner_non_linearity = nn.Tanh() + self.inner_non_linearity = inner_non_linearity + + def reset_parameters(self): # noqa: D102 + embedding_dim = self.d_e.shape[0] + bound = math.sqrt(6) / embedding_dim + for p in self.parameters(): + nn.init.uniform_(p, a=-bound, b=bound) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: + return dict(d_e=self.d_e, d_r=self.d_r, b_c=self.b_c, b_p=self.b_p, activation=self.inner_non_linearity) + + +class RESCALInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of RESCAL.""" + + relation_shape = ("dd",) + func = pkf.rescal_interaction + + +class StructuredEmbeddingInteraction( + TranslationalInteraction[ + torch.FloatTensor, + Tuple[torch.FloatTensor, torch.FloatTensor], + torch.FloatTensor, + ], +): + """Interaction function of Structured Embedding.""" + + relation_shape = ("dd", "dd") + func = pkf.structured_embedding_interaction + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h, t=t, r_h=r[0], r_t=r[1]) + + +class TuckerInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]): + """Interaction function of Tucker.""" + + func = pkf.tucker_interaction + + def __init__( + self, + embedding_dim: int = 200, + relation_dim: Optional[int] = None, + head_dropout: float = 0.3, + relation_dropout: float = 0.4, + head_relation_dropout: float = 0.5, + apply_batch_normalization: bool = True, + ): + """Initialize the Tucker interaction function. + + :param embedding_dim: + The entity embedding dimension. + :param relation_dim: + The relation embedding dimension. + :param head_dropout: + The dropout rate applied to the head representations. + :param relation_dropout: + The dropout rate applied to the relation representations. + :param head_relation_dropout: + The dropout rate applied to the combined head and relation representations. + :param apply_batch_normalization: + Whether to use batch normalization on head representations and the combination of head and relation. + """ + super().__init__() + + if relation_dim is None: + relation_dim = embedding_dim + + # Core tensor + # Note: we use a different dimension permutation as in the official implementation to match the paper. + self.core_tensor = nn.Parameter( + torch.empty(embedding_dim, relation_dim, embedding_dim), + requires_grad=True, + ) + + # Dropout + self.head_dropout = nn.Dropout(head_dropout) + self.relation_dropout = nn.Dropout(relation_dropout) + self.head_relation_dropout = nn.Dropout(head_relation_dropout) + + if apply_batch_normalization: + self.head_batch_norm = nn.BatchNorm1d(embedding_dim) + self.head_relation_batch_norm = nn.BatchNorm1d(embedding_dim) + else: + self.head_batch_norm = self.head_relation_batch_norm = None + + def reset_parameters(self): # noqa:D102 + # Initialize core tensor, cf. https://github.com/ibalazevic/TuckER/blob/master/model.py#L12 + nn.init.uniform_(self.core_tensor, -1., 1.) + # batch norm gets reset automatically, since it defines reset_parameters + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: + return dict( + core_tensor=self.core_tensor, + do_h=self.head_dropout, + do_r=self.relation_dropout, + do_hr=self.head_relation_dropout, + bn_h=self.head_batch_norm, + bn_hr=self.head_relation_batch_norm, + ) + + +class UnstructuredModelInteraction( + TranslationalInteraction[torch.FloatTensor, None, torch.FloatTensor], +): + """Interaction function of UnstructuredModel.""" + + # shapes + relation_shape: Sequence[str] = tuple() + + func = pkf.unstructured_model_interaction + + def __init__(self, p: int, power_norm: bool = True): + super().__init__(p=p, power_norm=power_norm) + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h, t=t) + + +class TransDInteraction( + TranslationalInteraction[ + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + ], +): + """Interaction function of TransD.""" + + entity_shape = ("d", "d") + relation_shape = ("e", "e") + func = pkf.transd_interaction + + def __init__(self, p: int = 2, power_norm: bool = True): + super().__init__(p=p, power_norm=power_norm) + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + h, h_p = h + r, r_p = r + t, t_p = t + return dict(h=h, r=r, t=t, h_p=h_p, r_p=r_p, t_p=t_p) + + +class NTNInteraction( + Interaction[ + torch.FloatTensor, + Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor], + torch.FloatTensor, + ], +): + """The interaction function of NTN.""" + + relation_shape = ("kdd", "kd", "kd", "k", "k") + func = pkf.ntn_interaction + + def __init__( + self, + non_linearity: Optional[nn.Module] = None, + ): + super().__init__() + if non_linearity is None: + non_linearity = nn.Tanh() + self.non_linearity = non_linearity + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + w, vh, vt, b, u = r + return dict(h=h, t=t, w=w, b=b, u=u, vh=vh, vt=vt) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict(activation=self.non_linearity) + + +class KG2EInteraction( + Interaction[ + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + ], +): + """Interaction function of KG2E.""" + + entity_shape = ("d", "d") + relation_shape = ("d", "d") + similarity: str + exact: bool + func = pkf.kg2e_interaction + + def __init__( + self, + similarity: Optional[str] = None, + exact: bool = True, + ): + super().__init__() + if similarity is None: + similarity = 'KL' + self.similarity = similarity + self.exact = exact + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: + h_mean, h_var = h + r_mean, r_var = r + t_mean, t_var = t + return dict( + h_mean=h_mean, + h_var=h_var, + r_mean=r_mean, + r_var=r_var, + t_mean=t_mean, + t_var=t_var, + ) + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: + return dict( + similarity=self.similarity, + exact=self.exact, + ) + + +class TransHInteraction(TranslationalInteraction[FloatTensor, Tuple[FloatTensor, FloatTensor], FloatTensor]): + """Interaction function of TransH.""" + + relation_shape = ("d", "d") + func = pkf.transh_interaction + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h, w_r=r[0], d_r=r[1], t=t) + + +class SimplEInteraction( + Interaction[ + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + Tuple[torch.FloatTensor, torch.FloatTensor], + ], +): + """Interaction function of SimplE.""" + + func = pkf.simple_interaction + entity_shape = ("d", "d") + relation_shape = ("d", "d") + + def __init__(self, clamp_score: Union[None, float, Tuple[float, float]] = None): + super().__init__() + if isinstance(clamp_score, float): + clamp_score = (-clamp_score, clamp_score) + self.clamp_score = clamp_score + + def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 + return dict(clamp=self.clamp_score) + + @staticmethod + def _prepare_hrt_for_functional( + h: HeadRepresentation, + r: RelationRepresentation, + t: TailRepresentation, + ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 + return dict(h=h[0], h_inv=h[1], r=r[0], r_inv=r[1], t=t[0], t_inv=t[1]) diff --git a/src/pykeen/nn/representation.py b/src/pykeen/nn/representation.py new file mode 100644 index 0000000000..efec14888d --- /dev/null +++ b/src/pykeen/nn/representation.py @@ -0,0 +1,765 @@ +# -*- coding: utf-8 -*- + +"""Embedding modules.""" + +import dataclasses +import functools +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union + +import numpy +import torch +import torch.nn +from torch import nn +from torch.nn import functional + +from ..regularizers import Regularizer +from ..triples import TriplesFactory +from ..typing import Constrainer, Initializer, Normalizer +from ..utils import upgrade_to_sequence + +__all__ = [ + 'RepresentationModule', + 'Embedding', + 'EmbeddingSpecification', + 'LiteralRepresentations', + 'RGCNRepresentations', +] + +logger = logging.getLogger(__name__) + +CANONICAL_DIMENSIONS = dict(h=1, r=2, t=3) + + +def _normalize_dim(dim: Union[int, str]) -> int: + """Normalize the dimension selection.""" + if isinstance(dim, int): + return dim + return CANONICAL_DIMENSIONS[dim.lower()[0]] + + +def get_expected_canonical_shape( + indices: Union[None, int, Tuple[int, int], torch.LongTensor], + dim: Union[str, int], + suffix_shape: Union[int, Sequence[int]], + num: Optional[int] = None, +) -> Tuple[int, ...]: + """ + Calculate the expected canonical shape for the given parameters. + + :param indices: + The indices, their shape, or None, if no indices are to be provided. + :param dim: + The dimension, either symbolic, or numeric. + :param suffix_shape: + The suffix-shape. + :param num: + The number of representations, if indices_shape is None, i.e. 1-n scoring. + + :return: (batch_size, num_heads, num_relations, num_tails, ``*``). + The expected shape, a tuple of at least 5 positive integers. + """ + if isinstance(suffix_shape, int): + exp_shape = [1, 1, 1, 1, suffix_shape] + else: + exp_shape = [1, 1, 1, 1, *suffix_shape] + dim = _normalize_dim(dim=dim) + if indices is None: # 1-n scoring + exp_shape[dim] = num # type: ignore + else: # batch dimension + if isinstance(indices, torch.Tensor): + indices = indices.shape + exp_shape[0] = indices[0] + if len(indices) > 1: # multi-target batching + exp_shape[dim] = indices[1] + return tuple(exp_shape) + + +def convert_to_canonical_shape( + x: torch.FloatTensor, + dim: Union[int, str], + num: Optional[int] = None, + batch_size: int = 1, + suffix_shape: Union[int, Sequence[int]] = -1, +) -> torch.FloatTensor: + """ + Convert a tensor to canonical shape. + + :param x: + The tensor in compatible shape. + :param dim: + The "num" dimension. + :param batch_size: + The batch size. + :param num: + The number. + :param suffix_shape: + The suffix shape. + + :return: shape: (batch_size, num_heads, num_relations, num_tails, ``*``) + A tensor in canonical shape. + """ + if num is None: + num = x.shape[0] + suffix_shape = upgrade_to_sequence(suffix_shape) + shape = [batch_size, 1, 1, 1] + dim = _normalize_dim(dim=dim) + shape[dim] = num + return x.view(*shape, *suffix_shape) + + +class RepresentationModule(nn.Module, ABC): + """A base class for obtaining representations for entities/relations.""" + + #: The shape of a single representation + shape: Sequence[int] + + #: The maximum admissible ID (excl.) + max_id: int + + def __init__(self, shape: Iterable[int], max_id: int): + super().__init__() + self.shape = tuple(shape) + self.max_id = max_id + + @abstractmethod + def forward(self, indices: Optional[torch.LongTensor] = None) -> torch.FloatTensor: + """Get representations for indices. + + :param indices: shape: (m,) + The indices, or None. If None, return all representations. + + :return: shape: (m, d) + The representations. + """ + raise NotImplementedError + + def get_in_canonical_shape( + self, + dim: Union[int, str], + indices: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """Get representations in canonical shape. + + The canonical shape is given as + + (batch_size, d_1, d_2, d_3, ``*``) + + fulfilling the following properties: + + Let i = dim. If indices is None, the return shape is (1, d_1, d_2, d_3) with d_i = num_representations, + d_i = 1 else. If indices is not None, then batch_size = indices.shape[0], and d_i = 1 if + indices.ndimension() = 1 else d_i = indices.shape[1] + + The canonical shape is given by (batch_size, 1, ``*``) if indices is not None, where batch_size=len(indices), + or (1, num, ``*``) if indices is None with num equal to the total number of embeddings. + + + :param dim: + The dimension along which to expand for indices = None, or indices.ndimension() == 2. + :param indices: + The indices. Either None, in which care all embeddings are returned, or a 1 or 2 dimensional index tensor. + + :return: shape: (batch_size, d1, d2, d3, ``*self.shape``) + """ + r_shape: Tuple[int, ...] + if indices is None: + x = self(indices=indices) + r_shape = (1, self.max_id) + else: + flat_indices = indices.view(-1) + x = self(indices=flat_indices) + if indices.ndimension() > 1: + x = x.view(*indices.shape, -1) + r_shape = tuple(indices.shape) + if len(r_shape) < 2: + r_shape = r_shape + (1,) + return convert_to_canonical_shape(x=x, dim=dim, num=r_shape[1], batch_size=r_shape[0], suffix_shape=self.shape) + + def reset_parameters(self) -> None: + """Reset the module's parameters.""" + + def post_parameter_update(self): + """Apply constraints which should not be included in gradients.""" + + +@dataclasses.dataclass +class EmbeddingSpecification: + """An embedding specification.""" + + embedding_dim: Optional[int] = None + shape: Optional[Sequence[int]] = None + + dtype: Optional[torch.dtype] = None + + initializer: Optional[Initializer] = None + initializer_kwargs: Optional[Mapping[str, Any]] = None + + normalizer: Optional[Normalizer] = None + normalizer_kwargs: Optional[Mapping[str, Any]] = None + + constrainer: Optional[Constrainer] = None + constrainer_kwargs: Optional[Mapping[str, Any]] = None + + regularizer: Optional[Regularizer] = None + + def make( + self, + num_embeddings: int, + ) -> 'Embedding': + """Create an embedding with this specification.""" + return Embedding( + num_embeddings=num_embeddings, + embedding_dim=self.embedding_dim, + shape=self.shape, + dtype=self.dtype, + initializer=self.initializer, + initializer_kwargs=self.initializer_kwargs, + normalizer=self.normalizer, + normalizer_kwargs=self.normalizer_kwargs, + constrainer=self.constrainer, + constrainer_kwargs=self.constrainer_kwargs, + regularizer=self.regularizer, + ) + + +class Embedding(RepresentationModule): + """Trainable embeddings. + + This class provides the same interface as :class:`torch.nn.Embedding` and + can be used throughout PyKEEN as a more fully featured drop-in replacement. + """ + + normalizer: Optional[Normalizer] + constrainer: Optional[Constrainer] + regularizer: Optional[Regularizer] + + def __init__( + self, + num_embeddings: int, + embedding_dim: Optional[int] = None, + shape: Union[None, int, Sequence[int]] = None, + initializer: Optional[Initializer] = None, + initializer_kwargs: Optional[Mapping[str, Any]] = None, + normalizer: Optional[Normalizer] = None, + normalizer_kwargs: Optional[Mapping[str, Any]] = None, + constrainer: Optional[Constrainer] = None, + constrainer_kwargs: Optional[Mapping[str, Any]] = None, + regularizer: Optional[Regularizer] = None, + dtype: Optional[torch.dtype] = None, + ): + """Instantiate an embedding with extended functionality. + + :param num_embeddings: >0 + The number of embeddings. + :param embedding_dim: >0 + The embedding dimensionality. + :param shape: + The embedding shape. If given, shape supersedes embedding_dim, with setting embedding_dim = prod(shape). + :param initializer: + An optional initializer, which takes an uninitialized (num_embeddings, embedding_dim) tensor as input, + and returns an initialized tensor of same shape and dtype (which may be the same, i.e. the + initialization may be in-place) + :param initializer_kwargs: + Additional keyword arguments passed to the initializer + :param normalizer: + A normalization function, which is applied in every forward pass. + :param normalizer_kwargs: + Additional keyword arguments passed to the normalizer + :param constrainer: + A function which is applied to the weights after each parameter update, without tracking gradients. + It may be used to enforce model constraints outside of gradient-based training. The function does not need + to be in-place, but the weight tensor is modified in-place. + :param constrainer_kwargs: + Additional keyword arguments passed to the constrainer + + :raises ValueError: + If neither shape nor embedding_dim are given. + """ + if shape is None and embedding_dim is None: + raise ValueError('Missing both, shape and embedding_dim') + elif shape is not None and embedding_dim is not None: + raise ValueError('Provided both, shape and embedding_dim') + elif shape is None and embedding_dim is not None: + shape = (embedding_dim,) + elif isinstance(shape, int) and embedding_dim is None: + embedding_dim = shape + shape = (shape,) + elif isinstance(shape, Sequence) and embedding_dim is None: + shape = tuple(shape) + embedding_dim = int(numpy.prod(shape)) + else: + raise TypeError(f'Invalid type for shape: ({type(shape)}) {shape}') + + assert isinstance(shape, tuple) + assert isinstance(embedding_dim, int) + + if dtype is None: + dtype = torch.get_default_dtype() + + # work-around until full complex support + # TODO: verify that this is our understanding of complex! + if dtype.is_complex: + shape = shape[:-1] + (2 * shape[-1],) + embedding_dim = embedding_dim * 2 + super().__init__(shape=shape, max_id=num_embeddings) + + if initializer is None: + initializer = nn.init.normal_ + + if initializer_kwargs: + initializer = functools.partial(initializer, **initializer_kwargs) + self.initializer = initializer + + if constrainer is not None and constrainer_kwargs: + constrainer = functools.partial(constrainer, **constrainer_kwargs) + self.constrainer = constrainer + + # TODO: Move regularizer and normalizer to RepresentationModule? + if normalizer is not None and normalizer_kwargs: + normalizer = functools.partial(normalizer, **normalizer_kwargs) + self.normalizer = normalizer + + self.regularizer = regularizer + + self._embeddings = torch.nn.Embedding( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + ) + + @classmethod + def from_specification( + cls, + num_embeddings: int, + specification: Optional[EmbeddingSpecification] = None, + ) -> 'Embedding': + """Create an embedding based on a specification. + + :param num_embeddings: >0 + The number of embeddings. + :param specification: + The specification. + :return: + An embedding object. + """ + if specification is None: + specification = EmbeddingSpecification() + return specification.make( + num_embeddings=num_embeddings, + ) + + @property + def num_embeddings(self) -> int: # noqa: D401 + """The total number of representations (i.e. the maximum ID).""" + return self.max_id + + @property + def embedding_dim(self) -> int: # noqa: D401 + """The representation dimension.""" + return self._embeddings.embedding_dim + + def reset_parameters(self) -> None: # noqa: D102 + # initialize weights in-place + self._embeddings.weight.data = self.initializer( + self._embeddings.weight.data.view(self.num_embeddings, *self.shape), + ).view(self.num_embeddings, self.embedding_dim) + + def post_parameter_update(self): # noqa: D102 + # apply constraints in-place + if self.constrainer is not None: + self._embeddings.weight.data = self.constrainer(self._embeddings.weight.data) + + def forward( + self, + indices: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: # noqa: D102 + if indices is None: + x = self._embeddings.weight + else: + x = self._embeddings(indices) + x = x.view(x.shape[0], *self.shape) + if self.normalizer is not None: + x = self.normalizer(x) + if self.regularizer is not None: + self.regularizer.update(x) + return x + + +class LiteralRepresentations(Embedding): + """Literal representations.""" + + def __init__( + self, + numeric_literals: torch.FloatTensor, + ): + num_embeddings, embedding_dim = numeric_literals.shape + super().__init__( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + initializer=lambda x: numeric_literals, # initialize with the literals + ) + # freeze + self._embeddings.requires_grad_(False) + + +def inverse_indegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: + """Normalize messages by inverse in-degree. + + :param source: shape: (num_edges,) + The source indices. + :param target: shape: (num_edges,) + The target indices. + + :return: shape: (num_edges,) + The edge weights. + """ + # Calculate in-degree, i.e. number of incoming edges + uniq, inv, cnt = torch.unique(target, return_counts=True, return_inverse=True) + return cnt[inv].float().reciprocal() + + +def inverse_outdegree_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: + """Normalize messages by inverse out-degree. + + :param source: shape: (num_edges,) + The source indices. + :param target: shape: (num_edges,) + The target indices. + + :return: shape: (num_edges,) + The edge weights. + """ + # Calculate in-degree, i.e. number of incoming edges + uniq, inv, cnt = torch.unique(source, return_counts=True, return_inverse=True) + return cnt[inv].float().reciprocal() + + +def symmetric_edge_weights(source: torch.LongTensor, target: torch.LongTensor) -> torch.FloatTensor: + """Normalize messages by product of inverse sqrt of in-degree and out-degree. + + :param source: shape: (num_edges,) + The source indices. + :param target: shape: (num_edges,) + The target indices. + + :return: shape: (num_edges,) + The edge weights. + """ + return ( + inverse_indegree_edge_weights(source=source, target=target) + * inverse_outdegree_edge_weights(source=source, target=target) + ).sqrt() + + +class RGCNRepresentations(RepresentationModule): + """Representations enriched by R-GCN.""" + + def __init__( + self, + triples_factory: TriplesFactory, + embedding_dim: int = 500, + num_bases_or_blocks: int = 5, + num_layers: int = 2, + use_bias: bool = True, + use_batch_norm: bool = False, + activation_cls: Optional[Type[nn.Module]] = None, + activation_kwargs: Optional[Mapping[str, Any]] = None, + sparse_messages_slcwa: bool = True, + edge_dropout: float = 0.4, + self_loop_dropout: float = 0.2, + edge_weighting: Callable[ + [torch.LongTensor, torch.LongTensor], + torch.FloatTensor, + ] = inverse_indegree_edge_weights, + decomposition: str = 'basis', + buffer_messages: bool = True, + base_representations: Optional[RepresentationModule] = None, + ): + super().__init__(shape=(embedding_dim,), max_id=triples_factory.num_entities) + + self.triples_factory = triples_factory + + # normalize representations + if base_representations is None: + base_representations = Embedding( + num_embeddings=triples_factory.num_entities, + embedding_dim=embedding_dim, + # https://github.com/MichSchli/RelationPrediction/blob/c77b094fe5c17685ed138dae9ae49b304e0d8d89/code/encoders/affine_transform.py#L24-L28 + initializer=nn.init.xavier_uniform_, + ) + self.base_embeddings = base_representations + self.embedding_dim = embedding_dim + + # check decomposition + self.decomposition = decomposition + if self.decomposition == 'basis': + if num_bases_or_blocks is None: + logging.info('Using a heuristic to determine the number of bases.') + num_bases_or_blocks = triples_factory.num_relations // 2 + 1 + if num_bases_or_blocks > triples_factory.num_relations: + raise ValueError('The number of bases should not exceed the number of relations.') + elif self.decomposition == 'block': + if num_bases_or_blocks is None: + logging.info('Using a heuristic to determine the number of blocks.') + num_bases_or_blocks = 2 + if embedding_dim % num_bases_or_blocks != 0: + raise ValueError( + 'With block decomposition, the embedding dimension has to be divisible by the number of' + f' blocks, but {embedding_dim} % {num_bases_or_blocks} != 0.', + ) + else: + raise ValueError(f'Unknown decomposition: "{decomposition}". Please use either "basis" or "block".') + + self.num_bases = num_bases_or_blocks + self.edge_weighting = edge_weighting + self.edge_dropout = edge_dropout + if self_loop_dropout is None: + self_loop_dropout = edge_dropout + self.self_loop_dropout = self_loop_dropout + self.use_batch_norm = use_batch_norm + if activation_cls is None: + activation_cls = nn.ReLU + self.activation_cls = activation_cls + self.activation_kwargs = activation_kwargs + if use_batch_norm: + if use_bias: + logger.warning('Disabling bias because batch normalization was used.') + use_bias = False + self.use_bias = use_bias + self.num_layers = num_layers + self.sparse_messages_slcwa = sparse_messages_slcwa + + # Save graph using buffers, such that the tensors are moved together with the model + h, r, t = self.triples_factory.mapped_triples.t() + self.register_buffer('sources', h) + self.register_buffer('targets', t) + self.register_buffer('edge_types', r) + + self.activations = nn.ModuleList([ + self.activation_cls(**(self.activation_kwargs or {})) for _ in range(self.num_layers) + ]) + + # Weights + self.bases = nn.ParameterList() + if self.decomposition == 'basis': + self.att = nn.ParameterList() + for _ in range(self.num_layers): + self.bases.append(nn.Parameter( + data=torch.empty( + self.num_bases, + self.embedding_dim, + self.embedding_dim, + ), + requires_grad=True, + )) + self.att.append(nn.Parameter( + data=torch.empty( + self.triples_factory.num_relations + 1, + self.num_bases, + ), + requires_grad=True, + )) + elif self.decomposition == 'block': + block_size = self.embedding_dim // self.num_bases + for _ in range(self.num_layers): + self.bases.append(nn.Parameter( + data=torch.empty( + self.triples_factory.num_relations + 1, + self.num_bases, + block_size, + block_size, + ), + requires_grad=True, + )) + + self.att = None + else: + raise NotImplementedError + if self.use_bias: + self.biases = nn.ParameterList([ + nn.Parameter(torch.empty(self.embedding_dim), requires_grad=True) + for _ in range(self.num_layers) + ]) + else: + self.biases = None + if self.use_batch_norm: + self.batch_norms = nn.ModuleList([ + nn.BatchNorm1d(num_features=self.embedding_dim) + for _ in range(self.num_layers) + ]) + else: + self.batch_norms = None + + # buffering of messages + self.buffer_messages = buffer_messages + self.enriched_embeddings = None + + def _get_relation_weights(self, i_layer: int, r: int) -> torch.FloatTensor: + if self.decomposition == 'block': + # allocate weight + w = torch.zeros(self.embedding_dim, self.embedding_dim, device=self.bases[i_layer].device) + + # Get blocks + this_layer_blocks = self.bases[i_layer] + + # self.bases[i_layer].shape (num_relations, num_blocks, embedding_dim/num_blocks, embedding_dim/num_blocks) + # note: embedding_dim is guaranteed to be divisible by num_bases in the constructor + block_size = self.embedding_dim // self.num_bases + for b, start in enumerate(range(0, self.embedding_dim, block_size)): + stop = start + block_size + w[start:stop, start:stop] = this_layer_blocks[r, b, :, :] + + elif self.decomposition == 'basis': + # The current basis weights, shape: (num_bases) + att = self.att[i_layer][r, :] + # the current bases, shape: (num_bases, embedding_dim, embedding_dim) + b = self.bases[i_layer] + # compute the current relation weights, shape: (embedding_dim, embedding_dim) + w = torch.sum(att[:, None, None] * b, dim=0) + + else: + raise AssertionError(f'Unknown decomposition: {self.decomposition}') + + return w + + def forward( + self, + indices: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: # noqa:D102 + # use buffered messages if applicable + if indices is None and self.enriched_embeddings is not None: + return self.enriched_embeddings + if indices is not None and indices.ndimension() > 1: + raise RuntimeError("indices must be None, or 1-dimensional.") + + # Bind fields + # shape: (num_entities, embedding_dim) + x = self.base_embeddings(indices=None) + sources = self.sources + targets = self.targets + edge_types = self.edge_types + + # Edge dropout: drop the same edges on all layers (only in training mode) + if self.training and self.edge_dropout is not None: + # Get random dropout mask + edge_keep_mask = torch.rand(self.sources.shape[0], device=x.device) > self.edge_dropout + + # Apply to edges + sources = sources[edge_keep_mask] + targets = targets[edge_keep_mask] + edge_types = edge_types[edge_keep_mask] + + # Different dropout for self-loops (only in training mode) + if self.training and self.self_loop_dropout is not None: + node_keep_mask = torch.rand(self.triples_factory.num_entities, device=x.device) > self.self_loop_dropout + else: + node_keep_mask = None + + for i in range(self.num_layers): + # Initialize embeddings in the next layer for all nodes + new_x = torch.zeros_like(x) + + # TODO: Can we vectorize this loop? + for r in range(self.triples_factory.num_relations): + # Choose the edges which are of the specific relation + mask = (edge_types == r) + + # No edges available? Skip rest of inner loop + if not mask.any(): + continue + + # Get source and target node indices + sources_r = sources[mask] + targets_r = targets[mask] + + # send messages in both directions + sources_r, targets_r = torch.cat([sources_r, targets_r]), torch.cat([targets_r, sources_r]) + + # Select source node embeddings + x_s = x[sources_r] + + # get relation weights + w = self._get_relation_weights(i_layer=i, r=r) + + # Compute message (b x d) * (d x d) = (b x d) + m_r = x_s @ w + + # Normalize messages by relation-specific in-degree + if self.edge_weighting is not None: + m_r *= self.edge_weighting(sources_r, targets_r).unsqueeze(dim=-1) + + # Aggregate messages in target + new_x.index_add_(dim=0, index=targets_r, source=m_r) + + # Self-loop + self_w = self._get_relation_weights(i_layer=i, r=self.triples_factory.num_relations) + if node_keep_mask is None: + new_x += x @ self_w + else: + new_x[node_keep_mask] += x[node_keep_mask] @ self_w + + # Apply bias, if requested + if self.use_bias: + bias = self.biases[i] + new_x += bias + + # Apply batch normalization, if requested + if self.use_batch_norm: + batch_norm = self.batch_norms[i] + new_x = batch_norm(new_x) + + # Apply non-linearity + if self.activations is not None: + activation = self.activations[i] + new_x = activation(new_x) + + x = new_x + + if indices is None and self.buffer_messages: + self.enriched_embeddings = x + if indices is not None: + x = x[indices] + + return x + + def post_parameter_update(self) -> None: # noqa: D102 + super().post_parameter_update() + + # invalidate enriched embeddings + self.enriched_embeddings = None + + def reset_parameters(self): # noqa:D102 + self.base_embeddings.reset_parameters() + + gain = nn.init.calculate_gain(nonlinearity=self.activation_cls.__name__.lower()) + if self.decomposition == 'basis': + for base in self.bases: + nn.init.xavier_normal_(base, gain=gain) + for att in self.att: + # Random convex-combination of bases for initialization (guarantees that initial weight matrices are + # initialized properly) + # We have one additional relation for self-loops + nn.init.uniform_(att) + functional.normalize(att.data, p=1, dim=1, out=att.data) + elif self.decomposition == 'block': + for base in self.bases: + block_size = base.shape[-1] + # Xavier Glorot initialization of each block + std = torch.sqrt(torch.as_tensor(2.)) * gain / (2 * block_size) + nn.init.normal_(base, std=std) + + # Reset biases + if self.biases is not None: + for bias in self.biases: + nn.init.zeros_(bias) + + # Reset batch norm parameters + if self.batch_norms is not None: + for bn in self.batch_norms: + bn.reset_parameters() + + # Reset activation parameters, if any + for act in self.activations: + if hasattr(act, 'reset_parameters'): + act.reset_parameters() diff --git a/src/pykeen/nn/sim.py b/src/pykeen/nn/sim.py new file mode 100644 index 0000000000..949d759283 --- /dev/null +++ b/src/pykeen/nn/sim.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +"""Similarity functions.""" + +import itertools +import math + +import torch + +from .compute_kernel import batched_dot +from ..typing import GaussianDistribution +from ..utils import calculate_broadcasted_elementwise_result_shape, tensor_sum + +__all__ = [ + 'expected_likelihood', + 'kullback_leibler_similarity', + 'KG2E_SIMILARITIES', +] + + +def expected_likelihood( + h: GaussianDistribution, + r: GaussianDistribution, + t: GaussianDistribution, + epsilon: float = 1.0e-10, + exact: bool = True, +) -> torch.FloatTensor: + r"""Compute the similarity based on expected likelihood. + + .. math:: + + D((\mu_e, \Sigma_e), (\mu_r, \Sigma_r))) + = \frac{1}{2} \left( + (\mu_e - \mu_r)^T(\Sigma_e + \Sigma_r)^{-1}(\mu_e - \mu_r) + + \log \det (\Sigma_e + \Sigma_r) + d \log (2 \pi) + \right) + = \frac{1}{2} \left( + \mu^T\Sigma^{-1}\mu + + \log \det \Sigma + d \log (2 \pi) + \right) + + with :math:`\mu_e = \mu_h - \mu_t` and :math:`\Sigma_e = \Sigma_h + \Sigma_t`. + + :param h: shape: (batch_size, num_heads, 1, 1, d) + The head entity Gaussian distribution. + :param r: shape: (batch_size, 1, num_relations, 1, d) + The relation Gaussian distribution. + :param t: shape: (batch_size, 1, 1, num_tails, d) + The tail entity Gaussian distribution. + :param epsilon: float (default=1.0) + Small constant used to avoid numerical issues when dividing. + :param exact: + Whether to return the exact similarity, or leave out constant offsets. + + :return: torch.Tensor, shape: (batch_size, num_heads, num_relations, num_tails) + The similarity. + """ + # subtract, shape: (batch_size, num_heads, num_relations, num_tails, dim) + var = tensor_sum(*(d.diagonal_covariance for d in (h, r, t))) + mean = tensor_sum(h.mean, -t.mean, -r.mean) + + #: a = \mu^T\Sigma^{-1}\mu + safe_sigma = torch.clamp_min(var, min=epsilon) + sim = batched_dot( + a=safe_sigma.reciprocal(), + b=(mean ** 2), + ) + + #: b = \log \det \Sigma + sim = sim + safe_sigma.log().sum(dim=-1) + if exact: + sim = sim + sim.shape[-1] * math.log(2. * math.pi) + return sim + + +def kullback_leibler_similarity( + h: GaussianDistribution, + r: GaussianDistribution, + t: GaussianDistribution, + epsilon: float = 1.0e-10, + exact: bool = True, +) -> torch.FloatTensor: + r"""Compute the negative KL divergence. + + This is done between two Gaussian distributions given by mean `mu_*` and diagonal covariance matrix `sigma_*`. + + .. math:: + + D((\mu_0, \Sigma_0), (\mu_1, \Sigma_1)) = 0.5 * ( + tr(\Sigma_1^-1 \Sigma_0) + + (\mu_1 - \mu_0) * \Sigma_1^-1 (\mu_1 - \mu_0) + - k + + ln (det(\Sigma_1) / det(\Sigma_0)) + ) + + with :math:`\mu_e = \mu_h - \mu_t` and :math:`\Sigma_e = \Sigma_h + \Sigma_t`. + + .. note :: + This methods assumes diagonal covariance matrices :math:`\Sigma`. + + .. seealso :: + https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence + + :param h: shape: (batch_size, num_heads, 1, 1, d) + The head entity Gaussian distribution. + :param r: shape: (batch_size, 1, num_relations, 1, d) + The relation Gaussian distribution. + :param t: shape: (batch_size, 1, 1, num_tails, d) + The tail entity Gaussian distribution. + :param epsilon: float (default=1.0) + Small constant used to avoid numerical issues when dividing. + :param exact: + Whether to return the exact similarity, or leave out constant offsets. + + :return: torch.Tensor, shape: (s_1, ..., s_k) + The similarity. + """ + assert all((d.diagonal_covariance > 0).all() for d in (h, r, t)) + return -_vectorized_kl_divergence( + h=h, + r=r, + t=t, + epsilon=epsilon, + exact=exact, + ) + + +def _vectorized_kl_divergence( + h: GaussianDistribution, + r: GaussianDistribution, + t: GaussianDistribution, + epsilon: float = 1.0e-10, + exact: bool = True, +) -> torch.FloatTensor: + r""" + Vectorized implementation of KL-divergence. + + Computes the divergence between :math:`\mathcal{N}(\mu_e, \Sigma_e)` and :math:`\mathcal{N}(\mu_r, \Sigma_r)` + given by + + .. math :: + \mu_e = \mu_h - \mu_t + + \Sigma_e = \Sigma_h + \Sigma_t + + where all covariance matrices are diagonal. Hence we can simplify + + .. math :: + D(\mathcal{N}(\mu_e, \Sigma_e), \mathcal{N}(\mu_r, \Sigma_r)) + = + 0.5 * ( + \trace(\Sigma_r^-1 \Sigma_e) + + (\mu_r - \mu_e) * \Sigma_r^-1 (\mu_r - \mu_e) + - k + + \ln (\det(\Sigma_r) / \det(\Sigma_e)) + ) + = 0.5 * ( + \sum_i \Sigma_e[i] / Sigma_r[i] + + \sum_i \mu[i]^2 / \Sigma_r[i] + + \sum_i \ln Sigma_r[i] + - \sum_i \ln Sigma_e[i] + - k + ) + + where :math:`\mu = \mu_r - \mu_e = \mu_r - \mu_h + \mu_t` + + :param h: shape: (batch_size, num_heads, 1, 1, d) + The head entity Gaussian distribution. + :param r: shape: (batch_size, 1, num_relations, 1, d) + The relation Gaussian distribution. + :param t: shape: (batch_size, 1, 1, num_tails, d) + The tail entity Gaussian distribution. + :param epsilon: float (default=1.0) + Small constant used to avoid numerical issues when dividing. + :param exact: + Whether to return the exact similarity, or leave out constant offsets. + + :return: torch.Tensor, shape: (s_1, ..., s_k) + The KL-divergence. + """ + e_var = (h.diagonal_covariance + t.diagonal_covariance) + r_var_safe = r.diagonal_covariance.clamp_min(min=epsilon) + terms = [] + # 1. Component + # \sum_i \Sigma_e[i] / Sigma_r[i] + r_var_safe_reciprocal = r_var_safe.reciprocal() + terms.append(batched_dot(e_var, r_var_safe_reciprocal)) + # 2. Component + # (mu_1 - mu_0) * Sigma_1^-1 (mu_1 - mu_0) + # with mu = (mu_1 - mu_0) + # = mu * Sigma_1^-1 mu + # since Sigma_1 is diagonal + # = mu**2 / sigma_1 + mu = tensor_sum(r.mean, -h.mean, t.mean) + terms.append(batched_dot(mu.pow(2), r_var_safe_reciprocal)) + # 3. Component + if exact: + terms.append(-torch.as_tensor(data=[h.mean.shape[-1]], device=mu.device)) + # 4. Component + # ln (det(\Sigma_1) / det(\Sigma_0)) + # = ln det Sigma_1 - ln det Sigma_0 + # since Sigma is diagonal, we have det Sigma = prod Sigma[ii] + # = ln prod Sigma_1[ii] - ln prod Sigma_0[ii] + # = sum ln Sigma_1[ii] - sum ln Sigma_0[ii] + e_var_safe = e_var.clamp_min(min=epsilon) + terms.extend(( + r_var_safe.log().sum(dim=-1), + -e_var_safe.log().sum(dim=-1), + )) + result = tensor_sum(*terms) + if exact: + result = 0.5 * result + return result + + +def _torch_kl_similarity( + h: GaussianDistribution, + r: GaussianDistribution, + t: GaussianDistribution, +) -> torch.FloatTensor: + """ + Compute KL similarity using torch.distributions. + + .. note :: + Do not use this method in production code. + """ + e_mean = h.mean - t.mean + e_var = h.diagonal_covariance + t.diagonal_covariance + + # allocate result + batch_size, num_heads, num_relations, num_tails = calculate_broadcasted_elementwise_result_shape( + e_mean.shape, + r.mean.shape, + )[:-1] + result = h.mean.new_empty(batch_size, num_heads, num_relations, num_tails) + for bi, hi, ri, ti in itertools.product( + range(batch_size), + range(num_heads), + range(num_relations), + range(num_tails), + ): + # prepare distributions + e_loc = e_mean[bi, hi, 0, ti, :] + r_loc = r.mean[bi, 0, ri, 0, :] + e_cov = torch.diag(e_var[bi, hi, 0, ti, :]) + r_cov = torch.diag(r.diagonal_covariance[bi, 0, ri, 0, :]) + p = torch.distributions.MultivariateNormal( + loc=e_loc, + covariance_matrix=e_cov, + ) + q = torch.distributions.MultivariateNormal( + loc=r_loc, + covariance_matrix=r_cov, + ) + result[bi, hi, ri, ti] = torch.distributions.kl_divergence(p=p, q=q).view(-1) + return -result + + +KG2E_SIMILARITIES = { + 'KL': kullback_leibler_similarity, + 'EL': expected_likelihood, +} diff --git a/src/pykeen/pipeline.py b/src/pykeen/pipeline.py index 5255eb89d4..720ad381fe 100644 --- a/src/pykeen/pipeline.py +++ b/src/pykeen/pipeline.py @@ -905,7 +905,7 @@ def pipeline( # noqa: C901 del model_kwargs['regularizer'] regularizer_cls: Type[Regularizer] = get_regularizer_cls(regularizer) model_kwargs['regularizer'] = regularizer_cls( - device=device, + # device=device, **(regularizer_kwargs or {}), ) diff --git a/src/pykeen/regularizers.py b/src/pykeen/regularizers.py index c70f8f9c02..af41e7b51f 100644 --- a/src/pykeen/regularizers.py +++ b/src/pykeen/regularizers.py @@ -3,7 +3,7 @@ """Regularization in PyKEEN.""" from abc import ABC, abstractmethod -from typing import Any, ClassVar, Collection, Iterable, Mapping, Optional, Type, Union +from typing import Any, ClassVar, Collection, Iterable, List, Mapping, Optional, Sequence, Type, Union import torch from torch import nn @@ -18,6 +18,7 @@ 'CombinedRegularizer', 'PowerSumRegularizer', 'TransHRegularizer', + 'collect_regularization_terms', 'get_regularizer_cls', ] @@ -31,7 +32,10 @@ class Regularizer(nn.Module, ABC): weight: torch.FloatTensor #: The current regularization term (a scalar) - regularization_term: torch.FloatTensor + regularization_term: Union[torch.FloatTensor, float] + + #: Has this regularizer been updated? + updated: bool #: Should the regularization only be applied once? This was used for ConvKB and defaults to False. apply_only_once: bool @@ -39,50 +43,54 @@ class Regularizer(nn.Module, ABC): #: The default strategy for optimizing the regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] + #: weights which should be regularized + tracked_parameters: List[nn.Parameter] + def __init__( self, - device: torch.device, weight: float = 1.0, apply_only_once: bool = False, + parameters: Optional[Iterable[nn.Parameter]] = None, ): super().__init__() - self.device = device - self.register_buffer(name='weight', tensor=torch.as_tensor(weight, device=self.device)) + self.register_buffer(name='weight', tensor=torch.as_tensor(weight)) self.apply_only_once = apply_only_once - self.reset() + self.tracked_parameters = list(parameters) if parameters else [] + self._clear() - def to(self, *args, **kwargs) -> 'Regularizer': # noqa: D102 - super().to(*args, **kwargs) - self.device = torch._C._nn._parse_to(*args, **kwargs)[0] - self.reset() - return self + def _clear(self): + self.regularization_term = 0. + self.updated = False + + def add_parameter(self, parameter: nn.Parameter) -> None: + """Add a parameter for regularization.""" + self.tracked_parameters.append(parameter) @classmethod def get_normalized_name(cls) -> str: """Get the normalized name of the regularizer class.""" return normalize_string(cls.__name__, suffix=_REGULARIZER_SUFFIX) - def reset(self) -> None: - """Reset the regularization term to zero.""" - self.regularization_term = torch.zeros(1, dtype=torch.float, device=self.device) - self.updated = False - @abstractmethod def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: """Compute the regularization term for one tensor.""" raise NotImplementedError - def update(self, *tensors: torch.FloatTensor) -> None: + def update(self, *tensors: torch.FloatTensor) -> bool: """Update the regularization term based on passed tensors.""" - if self.apply_only_once and self.updated: - return - self.regularization_term = self.regularization_term + sum(self.forward(x=x) for x in tensors) + if not self.training or not torch.is_grad_enabled() or (self.apply_only_once and self.updated): + return False + self.regularization_term = self.regularization_term + sum(self(x) for x in tensors) self.updated = True + return True - @property - def term(self) -> torch.FloatTensor: - """Return the weighted regularization term.""" - return self.regularization_term * self.weight + def pop_regularization_term(self) -> torch.FloatTensor: + """Return the weighted regularization term, and clear it afterwards.""" + if len(self.tracked_parameters) > 0: + self.update(*self.tracked_parameters) + term = self.regularization_term + self._clear() + return self.weight * term class NoRegularizer(Regularizer): @@ -91,16 +99,14 @@ class NoRegularizer(Regularizer): Used to simplify code. """ - #: The default strategy for optimizing the regularizer's hyper-parameters - hpo_default: ClassVar[Mapping[str, Any]] = {} + # TODO(cthoyt): Deprecated, but used in notebooks / README - def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102 - # no need to compute anything - pass + #: The default strategy for optimizing the no-op regularizer's hyper-parameters + hpo_default: ClassVar[Mapping[str, Any]] = {} def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 # always return zero - return torch.zeros(1, dtype=x.dtype, device=x.device) + return x.new_zeros(1) class LpRegularizer(Regularizer): @@ -113,21 +119,21 @@ class LpRegularizer(Regularizer): #: This allows dimensionality-independent weight tuning. normalize: bool - #: The default strategy for optimizing the regularizer's hyper-parameters + #: The default strategy for optimizing the LP regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, - device: torch.device, weight: float = 1.0, dim: Optional[int] = -1, normalize: bool = False, p: float = 2., apply_only_once: bool = False, + parameters: Optional[Sequence[nn.Parameter]] = None, ): - super().__init__(device=device, weight=weight, apply_only_once=apply_only_once) + super().__init__(weight=weight, apply_only_once=apply_only_once, parameters=parameters) self.dim = dim self.normalize = normalize self.p = p @@ -153,21 +159,21 @@ class PowerSumRegularizer(Regularizer): Has some nice properties, cf. e.g. https://github.com/pytorch/pytorch/issues/28119. """ - #: The default strategy for optimizing the regularizer's hyper-parameters + #: The default strategy for optimizing the power sum regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, - device: torch.device, weight: float = 1.0, dim: Optional[int] = -1, normalize: bool = False, p: float = 2., apply_only_once: bool = False, + parameters: Optional[Sequence[nn.Parameter]] = None, ): - super().__init__(device=device, weight=weight, apply_only_once=apply_only_once) + super().__init__(weight=weight, apply_only_once=apply_only_once, parameters=parameters) self.dim = dim self.normalize = normalize self.p = p @@ -183,41 +189,40 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 class TransHRegularizer(Regularizer): """A regularizer for the soft constraints in TransH.""" - #: The default strategy for optimizing the regularizer's hyper-parameters + #: The default strategy for optimizing the TransH regularizer's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( weight=dict(type=float, low=0.01, high=1.0, scale='log'), ) def __init__( self, - device: torch.device, + entity_embeddings: nn.Parameter, + normal_vector_embeddings: nn.Parameter, + relation_embeddings: nn.Parameter, weight: float = 0.05, epsilon: float = 1e-5, ): # The regularization in TransH enforces the defined soft constraints that should computed only for every batch. # Therefore, apply_only_once is always set to True. - super().__init__(device=device, weight=weight, apply_only_once=True) + super().__init__(weight=weight, apply_only_once=True, parameters=[]) + self.normal_vector_embeddings = normal_vector_embeddings + self.relation_embeddings = relation_embeddings + self.entity_embeddings = entity_embeddings self.epsilon = epsilon def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 raise NotImplementedError('TransH regularizer is order-sensitive!') - def update(self, *tensors: torch.FloatTensor) -> None: # noqa: D102 - if len(tensors) != 3: - raise KeyError('Expects exactly three tensors') - if self.apply_only_once and self.updated: - return - entity_embeddings, normal_vector_embeddings, relation_embeddings = tensors + def pop_regularization_term(self, *tensors: torch.FloatTensor) -> torch.FloatTensor: # noqa: D102 # Entity soft constraint - self.regularization_term += torch.sum(functional.relu(torch.norm(entity_embeddings, dim=-1) ** 2 - 1.0)) + self.regularization_term += torch.sum(functional.relu(torch.norm(self.entity_embeddings, dim=-1) ** 2 - 1.0)) # Orthogonality soft constraint - d_r_n = functional.normalize(relation_embeddings, dim=-1) + d_r_n = functional.normalize(self.relation_embeddings, dim=-1) self.regularization_term += torch.sum( - functional.relu(torch.sum((normal_vector_embeddings * d_r_n) ** 2, dim=-1) - self.epsilon), + functional.relu(torch.sum((self.normal_vector_embeddings * d_r_n) ** 2, dim=-1) - self.epsilon), ) - - self.updated = True + return super().pop_regularization_term() class CombinedRegularizer(Regularizer): @@ -229,18 +234,17 @@ class CombinedRegularizer(Regularizer): def __init__( self, regularizers: Iterable[Regularizer], - device: torch.device, total_weight: float = 1.0, apply_only_once: bool = False, + parameters: Optional[Sequence[nn.Parameter]] = None, ): - super().__init__(weight=total_weight, device=device, apply_only_once=apply_only_once) + super().__init__(weight=total_weight, apply_only_once=apply_only_once, parameters=parameters) self.regularizers = nn.ModuleList(regularizers) for r in self.regularizers: - if isinstance(r, NoRegularizer): + if r is None or isinstance(r, NoRegularizer): raise TypeError('Can not combine a no-op regularizer') - self.register_buffer(name='normalization_factor', tensor=torch.as_tensor( - sum(r.weight for r in self.regularizers), device=device, - ).reciprocal()) + normalization_factor = torch.as_tensor(sum(r.weight for r in self.regularizers)).reciprocal() + self.register_buffer(name='normalization_factor', tensor=normalization_factor) @property def normalize(self): # noqa: D102 @@ -274,3 +278,12 @@ def get_regularizer_cls(query: Union[None, str, Type[Regularizer]]) -> Type[Regu default=NoRegularizer, suffix=_REGULARIZER_SUFFIX, ) + + +def collect_regularization_terms(main_module: nn.Module) -> Union[float, torch.FloatTensor]: + """Recursively collect regularization terms from attached regularizers, and clear their accumulator.""" + return sum( + module.pop_regularization_term() + for module in main_module.modules() + if isinstance(module, Regularizer) + ) diff --git a/src/pykeen/sampling/__init__.py b/src/pykeen/sampling/__init__.py index b1b43a3802..a80d4116b9 100644 --- a/src/pykeen/sampling/__init__.py +++ b/src/pykeen/sampling/__init__.py @@ -54,7 +54,7 @@ def get_negative_sampler_cls(query: Union[None, str, Type[NegativeSampler]]) -> """Get the negative sampler class.""" return get_cls( query, - base=NegativeSampler, # type: ignore + base=NegativeSampler, lookup_dict=negative_samplers, default=BasicNegativeSampler, suffix=_NEGATIVE_SAMPLER_SUFFIX, diff --git a/src/pykeen/sampling/negative_sampler.py b/src/pykeen/sampling/negative_sampler.py index 5411517513..86f873ef95 100644 --- a/src/pykeen/sampling/negative_sampler.py +++ b/src/pykeen/sampling/negative_sampler.py @@ -82,11 +82,9 @@ def filter_negative_triples(self, negative_batch: torch.LongTensor) -> Tuple[tor try: # Check which heads of the mapped triples are also in the negative triples - head_filter = ( - self.mapped_triples[:, 0:1].view(1, -1) == negative_batch[:, 0:1] # type: ignore - ).max(axis=0)[0] + head_filter = (self.mapped_triples[:, 0:1].view(1, -1) == negative_batch[:, 0:1]).max(axis=0)[0] # Reduce the search space by only using possible matches that at least contain the head we look for - sub_mapped_triples = self.mapped_triples[head_filter] # type: ignore + sub_mapped_triples = self.mapped_triples[head_filter] # Check in this subspace which relations of the mapped triples are also in the negative triples relation_filter = (sub_mapped_triples[:, 1:2].view(1, -1) == negative_batch[:, 1:2]).max(axis=0)[0] # Reduce the search space by only using possible matches that at least contain head and relation we look for diff --git a/src/pykeen/testing/__init__.py b/src/pykeen/testing/__init__.py new file mode 100644 index 0000000000..0ac80fed9a --- /dev/null +++ b/src/pykeen/testing/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +"""Code for testing of KGEMs and PyKEEN.""" diff --git a/src/pykeen/testing/base.py b/src/pykeen/testing/base.py new file mode 100644 index 0000000000..8c037a2e61 --- /dev/null +++ b/src/pykeen/testing/base.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +"""Base classes for simplified testing.""" + +import unittest +from typing import Any, Collection, Generic, Mapping, MutableMapping, Optional, Type, TypeVar + +from ..utils import get_subclasses, set_random_seed + +__all__ = [ + 'GenericTestCase', + 'TestsTestCase', +] + +T = TypeVar("T") + + +class GenericTestCase(Generic[T]): + """Generic tests.""" + + cls: Type[T] + kwargs: Optional[Mapping[str, Any]] = None + instance: T + + def setUp(self) -> None: + """Set up the generic testing method.""" + # fix seeds for reproducibility + set_random_seed(seed=42) + kwargs = self.kwargs or {} + kwargs = self._pre_instantiation_hook(kwargs=dict(kwargs)) + self.instance = self.cls(**kwargs) + self.post_instantiation_hook() + + def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]: + """Perform actions before instantiation, potentially modyfing kwargs.""" + return kwargs + + def post_instantiation_hook(self) -> None: + """Perform actions after instantiation.""" + + +class TestsTestCase(Generic[T], unittest.TestCase): + """A generic test for tests.""" + + base_cls: Type[T] + base_test: Type[GenericTestCase[T]] + skip_cls: Collection[T] = tuple() + + def test_testing(self): + """Check that there is a test for all subclasses.""" + to_test = set(get_subclasses(self.base_cls)).difference(self.skip_cls) + tested = (test_cls.cls for test_cls in get_subclasses(self.base_test) if hasattr(test_cls, "cls")) + not_tested = to_test.difference(tested) + assert not not_tested, not_tested diff --git a/src/pykeen/testing/mocks.py b/src/pykeen/testing/mocks.py new file mode 100644 index 0000000000..7a7f670d6f --- /dev/null +++ b/src/pykeen/testing/mocks.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +"""Mocks for testing PyKEEN.""" + +from typing import Optional, Sequence + +import numpy +import torch +from torch import nn + +from pykeen.models.base import Model +from pykeen.nn import RepresentationModule +from pykeen.triples import TriplesFactory + +__all__ = [ + 'MockModel', + 'MockRepresentations', +] + + +class MockModel(Model): + """A mock model returning fake scores.""" + + def __init__(self, triples_factory: TriplesFactory): + super().__init__( + triples_factory=triples_factory, + ) + + def forward( + self, + h_indices: Optional[torch.LongTensor], + r_indices: Optional[torch.LongTensor], + t_indices: Optional[torch.LongTensor], + slice_size: Optional[int] = None, + slice_dim: Optional[str] = None, + ) -> torch.FloatTensor: # noqa: D102 + # (batch_size, num_heads, num_relations, num_tails) + scores = torch.zeros(1, 1, 1, 1, requires_grad=True) # for requires_grad + # reproducible scores + for i, (ind, num) in enumerate(( + (h_indices, self.num_entities), + (r_indices, self.num_relations), + (t_indices, self.num_entities), + )): + shape = [1, 1, 1, 1] + if ind is None: + shape[i + 1] = num + delta = torch.arange(num) + else: + shape[0] = len(ind) + delta = ind + scores = scores + delta.float().view(*shape) + return scores + + +class MockRepresentations(RepresentationModule): + """A custom representation module with minimal implementation.""" + + def __init__(self, num_entities: int, shape: Sequence[int]): + super().__init__(shape=shape, max_id=num_entities) + self.x = nn.Parameter(torch.rand(int(numpy.prod(self.shape)))) + + def forward(self, indices: Optional[torch.LongTensor] = None) -> torch.FloatTensor: # noqa: D102 + n = self.max_id if indices is None else indices.shape[0] + return self.x.unsqueeze(dim=0).repeat(n, 1) diff --git a/src/pykeen/training/lcwa.py b/src/pykeen/training/lcwa.py index 14cdda9b66..addd1faafa 100644 --- a/src/pykeen/training/lcwa.py +++ b/src/pykeen/training/lcwa.py @@ -43,19 +43,13 @@ def _process_batch( # Send batch to device batch_pairs = batch_pairs[start:stop].to(device=self.device) + predictions = self.model.score_t(hr_batch=batch_pairs, slice_size=slice_size) batch_labels_full = batch_labels_full[start:stop].to(device=self.device) - - if slice_size is None: - predictions = self.model.score_t(hr_batch=batch_pairs) - else: - predictions = self.model.score_t(hr_batch=batch_pairs, slice_size=slice_size) - - loss = self._loss_helper( + return self._loss_helper( predictions, batch_labels_full, label_smoothing, ) - return loss def _label_loss_helper( self, @@ -121,7 +115,6 @@ def _slice_size_search( sub_batch_size: int, supports_sub_batching: bool, ) -> int: # noqa: D102 - self._check_slicing_availability(supports_sub_batching) reached_max = False evaluated_once = False logger.info("Trying slicing now.") @@ -166,19 +159,3 @@ def _slice_size_search( evaluated_once = True return slice_size - - def _check_slicing_availability(self, supports_sub_batching: bool): - if self.model.can_slice_t: - return - elif supports_sub_batching: - report = ( - "This model supports sub-batching, but it also requires slicing," - " which is not implemented for this model yet." - ) - else: - report = ( - "This model doesn't support sub-batching and slicing is not" - " implemented for this model yet." - ) - logger.warning(report) - raise MemoryError("The current model can't be trained on this hardware with these parameters.") diff --git a/src/pykeen/training/training_loop.py b/src/pykeen/training/training_loop.py index 51bc74e359..9213b95ca0 100644 --- a/src/pykeen/training/training_loop.py +++ b/src/pykeen/training/training_loop.py @@ -231,7 +231,7 @@ def train( """ # Create training instances # During size probing the training instances should not show the tqdm progress bar - self.training_instances = self._create_instances(use_tqdm=not only_size_probing) + self.training_instances = self._create_instances(use_tqdm=use_tqdm and not only_size_probing) # In some cases, e.g. using Optuna for HPO, the cuda cache from a previous run is not cleared torch.cuda.empty_cache() @@ -348,6 +348,9 @@ def _train( # noqa: C901 :param slice_size: >0 The divisor for the scoring function when using slicing. This is only possible for LCWA training loops in general and only for models that have the slicing capability implemented. + :param automatic_memory_optimization: bool + Whether to automatically optimize the sub-batch size during training and batch size during evaluation with + regards to the hardware at hand. :param label_smoothing: (0 <= label_smoothing < 1) If larger than zero, use label smoothing. :param sampler: (None or 'schlichtkrull') @@ -626,9 +629,6 @@ def _forward_pass(self, batch, start, stop, current_batch_size, label_smoothing, loss.backward() current_epoch_loss = loss.item() - # reset the regularizer to free the computational graph - self.model.regularizer.reset() - return current_epoch_loss @staticmethod @@ -855,8 +855,6 @@ def to_embeddingdb(self, session=None, use_tqdm: bool = False): return self.model.to_embeddingdb(session=session, use_tqdm=use_tqdm) def _free_graph_and_cache(self): - # The regularizer has to be reset to free the computational graph - self.model.regularizer.reset() # The cache of the previous run has to be freed to allow accurate memory availability estimates torch.cuda.empty_cache() diff --git a/src/pykeen/triples/triples_factory.py b/src/pykeen/triples/triples_factory.py index 67dd68975b..dbb1a66682 100644 --- a/src/pykeen/triples/triples_factory.py +++ b/src/pykeen/triples/triples_factory.py @@ -416,7 +416,7 @@ def get_inverse_relation_id(self, relation: Union[str, int]) -> int: """Get the inverse relation identifier for the given relation.""" if not self.create_inverse_triples: raise ValueError('Can not get inverse triple, they have not been created.') - relation = next(iter(self.relations_to_ids(relations=[relation]))) # type: ignore + relation = next(iter(self.relations_to_ids(relations=[relation]))) # type:ignore return self._get_inverse_relation_id(relation) @staticmethod @@ -645,7 +645,7 @@ def _word_cloud(self, *, ids: torch.LongTensor, id_to_label: Mapping[int, str], def tensor_to_df( self, tensor: torch.LongTensor, - **kwargs: Union[torch.Tensor, np.ndarray, Sequence], + **kwargs: Union[torch.Tensor, np.ndarray, Sequence], # FIXME fix the type annotation ) -> pd.DataFrame: """Take a tensor of triples and make a pandas dataframe with labels. diff --git a/src/pykeen/triples/triples_numeric_literals_factory.py b/src/pykeen/triples/triples_numeric_literals_factory.py index 3fe72673b5..ec003c47b8 100644 --- a/src/pykeen/triples/triples_numeric_literals_factory.py +++ b/src/pykeen/triples/triples_numeric_literals_factory.py @@ -47,6 +47,9 @@ def create_matrix_of_literals( class TriplesNumericLiteralsFactory(TriplesFactory): """Create multi-modal instances given the path to triples.""" + numeric_literals: np.ndarray + literals_to_id: Dict[str, int] + def __init__( self, *, diff --git a/src/pykeen/typing.py b/src/pykeen/typing.py index 56ae0b9df9..e834dfacb1 100644 --- a/src/pykeen/typing.py +++ b/src/pykeen/typing.py @@ -2,7 +2,7 @@ """Type hints for PyKEEN.""" -from typing import Callable, Mapping, TypeVar, Union +from typing import Callable, Mapping, NamedTuple, Sequence, TypeVar, Union import numpy as np import torch @@ -15,9 +15,13 @@ 'Initializer', 'Normalizer', 'Constrainer', - 'InteractionFunction', 'DeviceHint', 'TorchRandomHint', + 'HeadRepresentation', + 'RelationRepresentation', + 'Representation', + 'TailRepresentation', + 'GaussianDistribution', ] LabeledTriples = np.ndarray @@ -27,10 +31,23 @@ # comment: TypeVar expects none, or at least two super-classes TensorType = TypeVar("TensorType", torch.Tensor, torch.FloatTensor) -InteractionFunction = Callable[[TensorType, TensorType, TensorType], TensorType] Initializer = Callable[[TensorType], TensorType] Normalizer = Callable[[TensorType], TensorType] Constrainer = Callable[[TensorType], TensorType] DeviceHint = Union[None, str, torch.device] TorchRandomHint = Union[None, int, torch.Generator] + +Representation = torch.FloatTensor +# TODO upgrade to use bound=... +# HeadRepresentation = TypeVar("HeadRepresentation", bound=Union[Representation, Sequence[Representation]]) +HeadRepresentation = TypeVar("HeadRepresentation", Representation, Sequence[Representation]) # type: ignore +RelationRepresentation = TypeVar("RelationRepresentation", Representation, Sequence[Representation]) # type: ignore +TailRepresentation = TypeVar("TailRepresentation", Representation, Sequence[Representation]) # type: ignore + + +class GaussianDistribution(NamedTuple): + """A gaussian distribution with diagonal covariance matrix.""" + + mean: torch.FloatTensor + diagonal_covariance: torch.FloatTensor diff --git a/src/pykeen/utils.py b/src/pykeen/utils.py index ddae07fdb9..950ae92031 100644 --- a/src/pykeen/utils.py +++ b/src/pykeen/utils.py @@ -3,14 +3,17 @@ """Utilities for PyKEEN.""" import ftplib +import functools +import itertools import json import logging +import operator import random from abc import ABC, abstractmethod from io import BytesIO from pathlib import Path from typing import ( - Any, Callable, Collection, Dict, Generic, Iterable, List, Mapping, Optional, Tuple, Type, TypeVar, + Any, Callable, Collection, Dict, Generic, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, TypeVar, Union, ) @@ -19,23 +22,37 @@ import torch import torch.nn import torch.nn.modules.batchnorm +from torch.nn import functional from .constants import PYKEEN_BENCHMARKS from .typing import DeviceHint, TorchRandomHint from .version import get_git_hash __all__ = [ + 'broadcast_cat', 'compose', + 'check_shapes', 'clamp_norm', + 'combine_complex', 'compact_mapping', 'ensure_torch_random_state', 'format_relative_comparison', - 'imag_part', + 'complex_normalize', + 'fix_dataclass_init_docs', + 'flatten_dictionary', + 'get_cls', + 'get_subclasses', + 'get_until_first_blank', 'invert_mapping', + 'is_cudnn_error', 'is_cuda_oom_error', + 'project_entity', + 'negative_norm_of_sum', + 'normalize_string', + 'normalized_lookup', 'random_non_negative_int', - 'real_part', 'resolve_device', + 'set_random_seed', 'split_complex', 'split_list_in_batches_iter', 'torch_is_in_1d', @@ -45,6 +62,10 @@ 'get_until_first_blank', 'flatten_dictionary', 'set_random_seed', + 'tensor_sum', + 'tensor_product', + 'upgrade_to_sequence', + 'view_complex', 'NoRandomSeedNecessary', 'Result', 'fix_dataclass_init_docs', @@ -309,20 +330,18 @@ def split_complex( return x[..., :dim], x[..., dim:] -def real_part( - x: torch.FloatTensor, -) -> torch.FloatTensor: - """Get the real part from a complex tensor.""" - dim = x.shape[-1] // 2 - return x[..., :dim] +def view_complex(x: torch.FloatTensor) -> torch.Tensor: + """Convert a PyKEEN complex tensor representation into a torch one.""" + real, imag = split_complex(x=x) + return torch.complex(real=real, imag=imag) -def imag_part( - x: torch.FloatTensor, +def combine_complex( + x_re: torch.FloatTensor, + x_im: torch.FloatTensor, ) -> torch.FloatTensor: - """Get the imaginary part from a complex tensor.""" - dim = x.shape[-1] // 2 - return x[..., dim:] + """Combine a complex tensor from real and imaginary part.""" + return torch.cat([x_re, x_im], dim=-1) def fix_dataclass_init_docs(cls: Type) -> Type: @@ -461,6 +480,373 @@ def format_relative_comparison( return f"{part}/{total} ({part / total:2.2%})" +def check_shapes( + *x: Tuple[Union[torch.Tensor, Tuple[int, ...]], str], + raise_on_errors: bool = True, +) -> bool: + """ + Verify that a sequence of tensors are of matching shapes. + + :param x: + A tuple (tensor, shape), where tensor is a tensor, and shape is a string, where each character corresponds to + a (named) dimension. If the shapes of different tensors share a character, the corresponding dimensions are + expected to be of equal size. + :param raise_on_errors: + Whether to raise an exception in case of a mismatch. + + :return: + Whether the shapes matched. + + :raises ValueError: + If the shapes mismatch and raise_on_error is True. + """ + dims: Dict[str, Tuple[int, ...]] = dict() + errors = [] + for actual_shape, shape in x: + if isinstance(actual_shape, torch.Tensor): + actual_shape = actual_shape.shape + if len(actual_shape) != len(shape): + errors.append(f"Invalid number of dimensions: {actual_shape} vs. {shape}") + continue + for dim, name in zip(actual_shape, shape): + exp_dim = dims.get(name) + if exp_dim is not None and exp_dim != dim: + errors.append(f"{name}: {dim} vs. {exp_dim}") + dims[name] = dim + if raise_on_errors and errors: + raise ValueError("Shape verification failed:\n" + '\n'.join(errors)) + return len(errors) == 0 + + +def broadcast_cat( + x: torch.FloatTensor, + y: torch.FloatTensor, + dim: int, +) -> torch.FloatTensor: + """ + Concatenate with broadcasting. + + :param x: + The first tensor. + :param y: + The second tensor. + :param dim: + The concat dimension. + + :return: + """ + if x.ndimension() != y.ndimension(): + raise ValueError + if dim < 0: + dim = x.ndimension() + dim + x_rep, y_rep = [], [] + for d, (xd, yd) in enumerate(zip(x.shape, y.shape)): + xr = yr = 1 + if d != dim and xd != yd: + if xd == 1: + xr = yd + elif yd == 1: + yr = xd + else: + raise ValueError + x_rep.append(xr) + y_rep.append(yr) + return torch.cat([x.repeat(*x_rep), y.repeat(*y_rep)], dim=dim) + + +def get_subclasses(cls: Type[X]) -> Iterable[Type[X]]: + """ + Get all subclasses. + + Credit to: https://stackoverflow.com/a/33607093. + + """ + for subclass in cls.__subclasses__(): + yield from get_subclasses(subclass) + yield subclass + + +def complex_normalize(x: torch.Tensor) -> torch.Tensor: + r"""Normalize the length of relation vectors, if the forward constraint has not been applied yet. + + The `modulus of complex number `_ is given as: + + .. math:: + + |a + ib| = \sqrt{a^2 + b^2} + + $l_2$ norm of complex vector $x \in \mathbb{C}^d$: + + .. math:: + \|x\|^2 = \sum_{i=1}^d |x_i|^2 + = \sum_{i=1}^d \left(\operatorname{Re}(x_i)^2 + \operatorname{Im}(x_i)^2\right) + = \left(\sum_{i=1}^d \operatorname{Re}(x_i)^2) + (\sum_{i=1}^d \operatorname{Im}(x_i)^2\right) + = \|\operatorname{Re}(x)\|^2 + \|\operatorname{Im}(x)\|^2 + = \| [\operatorname{Re}(x); \operatorname{Im}(x)] \|^2 + """ + y = x.data.view(x.shape[0], -1, 2) + y = functional.normalize(y, p=2, dim=-1) + x.data = y.view(*x.shape) + return x + + +def calculate_broadcasted_elementwise_result_shape( + first: Tuple[int, ...], + second: Tuple[int, ...], +) -> Tuple[int, ...]: + """Determine the return shape of a broadcasted elementwise operation.""" + return tuple(max(a, b) for a, b in zip(first, second)) + + +def estimate_cost_of_sequence( + shape: Tuple[int, ...], + *other_shapes: Tuple[int, ...], +) -> int: + """Cost of a sequence of broadcasted element-wise operations of tensors, given their shapes.""" + return sum(map( + np.prod, + itertools.islice( + itertools.accumulate( + (shape,) + other_shapes, + calculate_broadcasted_elementwise_result_shape, + ), + 1, + None, + ), + )) + + +@functools.lru_cache(maxsize=32) +def _get_optimal_sequence( + *sorted_shapes: Tuple[int, ...], +) -> Tuple[int, Tuple[int, ...]]: + """Find the optimal sequence in which to combine tensors element-wise based on the shapes. + + The shapes should be sorted to enable efficient caching. + :param sorted_shapes: + The shapes of the tensors to combine. + :return: + The optimal execution order (as indices), and the cost. + """ + return min( + (estimate_cost_of_sequence(*(sorted_shapes[i] for i in p)), p) + for p in itertools.permutations(list(range(len(sorted_shapes)))) + ) + + +@functools.lru_cache(maxsize=64) +def get_optimal_sequence(*shapes: Tuple[int, ...]) -> Tuple[int, Tuple[int, ...]]: + """Find the optimal sequence in which to combine tensors elementwise based on the shapes. + + :param shapes: + The shapes of the tensors to combine. + :return: + The optimal execution order (as indices), and the cost. + """ + # create sorted list of shapes to allow utilization of lru cache (optimal execution order does not depend on the + # input sorting, as the order is determined by re-ordering the sequence anyway) + arg_sort = sorted(range(len(shapes)), key=shapes.__getitem__) + + # Determine optimal order and cost + cost, optimal_order = _get_optimal_sequence(*(shapes[new_index] for new_index in arg_sort)) + + # translate back to original order + optimal_order = tuple(arg_sort[i] for i in optimal_order) + + return cost, optimal_order + + +def _reorder( + tensors: Tuple[torch.FloatTensor, ...], +) -> Tuple[torch.FloatTensor, ...]: + """Re-order tensors for broadcasted element-wise combination of tensors. + + The optimal execution plan gets cached so that the optimization is only performed once for a fixed set of shapes. + + :param tensors: + The tensors, in broadcastable shape. + + :return: + The re-ordered tensors in optimal processing order. + """ + if len(tensors) < 3: + return tensors + # determine optimal processing order + shapes = tuple(tuple(t.shape) for t in tensors) + if len(set(s[0] for s in shapes)) < 2: + # heuristic + return tensors + order = get_optimal_sequence(*shapes)[1] + return tuple(tensors[i] for i in order) + + +def tensor_sum(*x: torch.FloatTensor) -> torch.FloatTensor: + """Compute elementwise sum of tensors in brodcastable shape.""" + return sum(_reorder(tensors=x)) + + +def tensor_product(*x: torch.FloatTensor) -> torch.FloatTensor: + """Compute elementwise product of tensors in broadcastable shape.""" + head, *rest = _reorder(tensors=x) + return functools.reduce(operator.mul, rest, head) + + +def negative_norm_of_sum( + *x: torch.FloatTensor, + p: Union[str, int] = 2, + power_norm: bool = False, +) -> torch.FloatTensor: + """Evaluate negative norm of a sum of vectors on already broadcasted representations. + + :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim) + The representations. + :param p: + The p for the norm. cf. torch.norm. + :param power_norm: + Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119 + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + return negative_norm(tensor_sum(*x), p=p, power_norm=power_norm) + + +def negative_norm( + x: torch.FloatTensor, + p: Union[str, int] = 2, + power_norm: bool = False, +) -> torch.FloatTensor: + """Evaluate negative norm of a vector. + + :param x: shape: (batch_size, num_heads, num_relations, num_tails, dim) + The vectors. + :param p: + The p for the norm. cf. torch.norm. + :param power_norm: + Whether to return $|x-y|_p^p$, cf. https://github.com/pytorch/pytorch/issues/28119 + + :return: shape: (batch_size, num_heads, num_relations, num_tails) + The scores. + """ + if power_norm: + assert not isinstance(p, str) + return -(x.abs() ** p).sum(dim=-1) + + if torch.is_complex(x): + assert not isinstance(p, str) + # workaround for complex numbers: manually compute norm + return -(x.abs() ** p).sum(dim=-1) ** (1 / p) + + return -x.norm(p=p, dim=-1) + + +def extended_einsum( + eq: str, + *tensors, +) -> torch.FloatTensor: + """Drop dimensions of size 1 to allow broadcasting.""" + # TODO: check if einsum is still very slow. + lhs, rhs = eq.split("->") + mod_ops, mod_t = [], [] + for op, t in zip(lhs.split(","), tensors): + mod_op = "" + if len(op) != len(t.shape): + raise ValueError(f'Shapes not equal: op={op} and t.shape={t.shape}') + # TODO: t_shape = list(t.shape); del t_shape[i]; t.view(*shape) -> only one reshape operation + for i, c in reversed(list(enumerate(op))): + if t.shape[i] == 1: + t = t.squeeze(dim=i) + else: + mod_op = c + mod_op + mod_ops.append(mod_op) + mod_t.append(t) + m_lhs = ",".join(mod_ops) + r_keep_dims = set("".join(mod_ops)) + m_rhs = "".join(c for c in rhs if c in r_keep_dims) + m_eq = f"{m_lhs}->{m_rhs}" + mod_r = torch.einsum(m_eq, *mod_t) + # unsqueeze + for i, c in enumerate(rhs): + if c not in r_keep_dims: + mod_r = mod_r.unsqueeze(dim=i) + return mod_r + + +def project_entity( + e: torch.FloatTensor, + e_p: torch.FloatTensor, + r_p: torch.FloatTensor, +) -> torch.FloatTensor: + r"""Project entity relation-specific. + + .. math:: + + e_{\bot} = M_{re} e + = (r_p e_p^T + I^{d_r \times d_e}) e + = r_p e_p^T e + I^{d_r \times d_e} e + = r_p (e_p^T e) + e' + + and additionally enforces + + .. math:: + + \|e_{\bot}\|_2 \leq 1 + + :param e: shape: (..., d_e) + The entity embedding. + :param e_p: shape: (..., d_e) + The entity projection. + :param r_p: shape: (..., d_r) + The relation projection. + + :return: shape: (..., d_r) + + """ + # The dimensions affected by e' + change_dim = min(e.shape[-1], r_p.shape[-1]) + + # Project entities + # r_p (e_p.T e) + e' + e_bot = r_p * torch.sum(e_p * e, dim=-1, keepdim=True) + e_bot[..., :change_dim] += e[..., :change_dim] + + # Enforce constraints + e_bot = clamp_norm(e_bot, p=2, dim=-1, maxnorm=1) + + return e_bot + + +def pop_only(elements: Iterable[X]) -> X: + """Unpack a one element list, or raise an error.""" + elements = tuple(elements) + if len(elements) == 0: + raise ValueError('Empty sequence given') + if len(elements) > 1: + raise ValueError(f'More than one element: {elements}') + return elements[0] + + +def strip_dim(*x, num: int = 4): + """Strip the first dimensions.""" + return [xx.view(xx.shape[num:]) for xx in x] + + +def upgrade_to_sequence(x: Union[X, Sequence[X]]) -> Sequence[X]: + """Ensure that the input is a sequence.""" + return x if isinstance(x, Sequence) else (x,) + + +def ensure_tuple(*x: Union[X, Sequence[X]]) -> Sequence[Sequence[X]]: + return tuple(upgrade_to_sequence(xx) for xx in x) + + +def unpack_singletons(*xs: Tuple[X]) -> Sequence[Union[X, Tuple[X]]]: + return [ + x[0] if len(x) == 1 else x + for x in xs + ] + + def get_batchnorm_modules(module: torch.nn.Module) -> List[torch.nn.Module]: """Return all submodules which are batch normalization layers.""" return [ diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index 44e4202c64..df28f71507 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -14,11 +14,11 @@ from pykeen.evaluation import Evaluator, MetricResults, RankBasedEvaluator, RankBasedMetricResults from pykeen.evaluation.rank_based_evaluator import RANK_TYPES, SIDES from pykeen.models import TransE -from pykeen.models.base import EntityRelationEmbeddingModel, Model +from pykeen.models.base import Model from pykeen.stoppers.early_stopping import EarlyStopper, is_improvement +from pykeen.testing.mocks import MockModel from pykeen.trackers import MLFlowResultTracker from pykeen.training import SLCWATrainingLoop -from pykeen.triples import TriplesFactory from pykeen.typing import MappedTriples @@ -113,34 +113,6 @@ def __repr__(self): # noqa: D105 return f'{self.__class__.__name__}(losses={self.losses})' -class MockModel(EntityRelationEmbeddingModel): - """A mock model returning fake scores.""" - - def __init__(self, triples_factory: TriplesFactory): - super().__init__(triples_factory=triples_factory) - num_entities = self.num_entities - self.scores = torch.arange(num_entities, dtype=torch.float) - - def _generate_fake_scores(self, batch: torch.LongTensor) -> torch.FloatTensor: - """Generate fake scores s[b, i] = i of size (batch_size, num_entities).""" - batch_size = batch.shape[0] - batch_scores = self.scores.view(1, -1).repeat(batch_size, 1) - assert batch_scores.shape == (batch_size, self.num_entities) - return batch_scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=hrt_batch) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=hr_batch) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=rt_batch) - - def reset_parameters_(self) -> Model: # noqa: D102 - pass # Not needed for unittest - - class LogCallWrapper: """An object which wraps functions and checks whether they have been called.""" @@ -178,8 +150,8 @@ class TestEarlyStopping(unittest.TestCase): def setUp(self): """Prepare for testing the early stopper.""" - # Set automatic_memory_optimization to false for tests self.mock_evaluator = MockEvaluator(self.mock_losses, automatic_memory_optimization=False) + # Set automatic_memory_optimization to false for tests nations = Nations() self.model = MockModel(triples_factory=nations.training) self.stopper = EarlyStopper( diff --git a/tests/test_evaluators.py b/tests/test_evaluators.py index ed4a19c0c2..3175c0a050 100644 --- a/tests/test_evaluators.py +++ b/tests/test_evaluators.py @@ -15,7 +15,8 @@ from pykeen.evaluation.rank_based_evaluator import RANK_TYPES, SIDES, compute_rank_from_scores from pykeen.evaluation.sklearn import SklearnEvaluator, SklearnMetricResults from pykeen.models import TransE -from pykeen.models.base import EntityRelationEmbeddingModel, Model +from pykeen.models.base import Model +from pykeen.testing.mocks import MockModel from pykeen.triples import TriplesFactory from pykeen.typing import MappedTriples @@ -428,34 +429,6 @@ def __repr__(self): # noqa: D105 return f'{self.__class__.__name__}(losses={self.losses})' -class MockModel(EntityRelationEmbeddingModel): - """A dummy model returning fake scores.""" - - def __init__(self, triples_factory: TriplesFactory): - super().__init__(triples_factory=triples_factory) - num_entities = self.num_entities - self.scores = torch.arange(num_entities, dtype=torch.float) - - def _generate_fake_scores(self, batch: torch.LongTensor) -> torch.FloatTensor: - """Generate fake scores s[b, i] = i of size (batch_size, num_entities).""" - batch_size = batch.shape[0] - batch_scores = self.scores.view(1, -1).repeat(batch_size, 1) - assert batch_scores.shape == (batch_size, self.num_entities) - return batch_scores - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=hrt_batch) - - def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=hr_batch) - - def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - return self._generate_fake_scores(batch=rt_batch) - - def reset_parameters_(self) -> Model: # noqa: D102 - pass # Not needed for unittest - - class TestEvaluationStructure(unittest.TestCase): """Tests for testing the correct structure of the evaluation procedure.""" diff --git a/tests/test_handling_of_cuda_exceptions.py b/tests/test_handling_of_cuda_exceptions.py deleted file mode 100644 index b75021eef0..0000000000 --- a/tests/test_handling_of_cuda_exceptions.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Test that CUDA exceptions are processed properly.""" - -import unittest - -from pykeen.utils import _CUDA_OOM_ERROR, _CUDNN_ERROR, is_cuda_oom_error, is_cudnn_error - - -class TestCudaExceptionsHandling(unittest.TestCase): - """Test handling of CUDA exceptions.""" - - not_cuda_error = Exception("Something else.") - - def test_is_cuda_oom_error(self): - """Test handling of a CUDA out of memory exception.""" - error = RuntimeError(_CUDA_OOM_ERROR) - self.assertTrue(is_cuda_oom_error(runtime_error=error)) - self.assertFalse(is_cudnn_error(runtime_error=error)) - - self.assertFalse(is_cuda_oom_error(runtime_error=self.not_cuda_error)) - - def test_is_cudnn_error(self): - """Test handling of a cuDNN error.""" - error = RuntimeError(_CUDNN_ERROR) - self.assertTrue(is_cudnn_error(runtime_error=error)) - self.assertFalse(is_cuda_oom_error(runtime_error=error)) - - error = Exception("Something else.") - self.assertFalse(is_cudnn_error(runtime_error=self.not_cuda_error)) diff --git a/tests/test_interactions.py b/tests/test_interactions.py new file mode 100644 index 0000000000..1e9e294b97 --- /dev/null +++ b/tests/test_interactions.py @@ -0,0 +1,618 @@ +# -*- coding: utf-8 -*- + +"""Tests for interaction functions.""" + +import logging +import unittest +from abc import abstractmethod +from typing import Collection, Sequence, Tuple, Union +from unittest.case import SkipTest + +import numpy +import torch + +import pykeen.nn.modules +import pykeen.utils +from pykeen.models.multimodal.base import LiteralInteraction +from pykeen.nn.functional import distmult_interaction +from pykeen.nn.modules import Interaction, TranslationalInteraction +from pykeen.testing import base as ptb +from pykeen.typing import Representation +from pykeen.utils import clamp_norm, project_entity, strip_dim, view_complex + +logger = logging.getLogger(__name__) + + +class InteractionTestCase(ptb.GenericTestCase[pykeen.nn.modules.Interaction]): + """Generic test for interaction functions.""" + + dim: int = 2 + batch_size: int = 3 + num_relations: int = 5 + num_entities: int = 7 + + shape_kwargs = dict() + + def post_instantiation_hook(self) -> None: + """Initialize parameters.""" + self.instance.reset_parameters() + + def _get_hrt( + self, + *shapes: Tuple[int, ...], + ) -> Tuple[Union[Representation, Sequence[Representation]], ...]: + self.shape_kwargs.setdefault("d", self.dim) + result = tuple( + tuple( + torch.rand(*prefix_shape, *(self.shape_kwargs[dim] for dim in weight_shape), requires_grad=True) + for weight_shape in weight_shapes + ) + for prefix_shape, weight_shapes in zip( + shapes, + [self.cls.entity_shape, self.cls.relation_shape, self.cls.entity_shape], + ) + ) + return tuple(pykeen.utils.unpack_singletons(*result)) + + def _check_scores(self, scores: torch.FloatTensor, exp_shape: Tuple[int, ...]): + """Check shape, dtype and gradients of scores.""" + assert torch.is_tensor(scores) + assert scores.dtype == torch.float32 + assert scores.ndimension() == len(exp_shape) + assert scores.shape == exp_shape + assert scores.requires_grad + self._additional_score_checks(scores) + + def _additional_score_checks(self, scores): + """Additional checks for scores.""" + + def test_score_hrt(self): + """Test score_hrt.""" + h, r, t = self._get_hrt( + (self.batch_size,), + (self.batch_size,), + (self.batch_size,), + ) + scores = self.instance.score_hrt(h=h, r=r, t=t) + self._check_scores(scores=scores, exp_shape=(self.batch_size, 1)) + + def test_score_h(self): + """Test score_h.""" + h, r, t = self._get_hrt( + (self.num_entities,), + (self.batch_size,), + (self.batch_size,), + ) + scores = self.instance.score_h(all_entities=h, r=r, t=t) + self._check_scores(scores=scores, exp_shape=(self.batch_size, self.num_entities)) + + def test_score_h_slicing(self): + """Test score_h with slicing.""" + #: The equivalence for models with batch norm only holds in evaluation mode + self.instance.eval() + h, r, t = self._get_hrt( + (self.num_entities,), + (self.batch_size,), + (self.batch_size,), + ) + scores = self.instance.score_h(all_entities=h, r=r, t=t, slice_size=self.num_entities // 2 + 1) + scores_no_slice = self.instance.score_h(all_entities=h, r=r, t=t, slice_size=None) + self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice) + + def test_score_r(self): + """Test score_r.""" + h, r, t = self._get_hrt( + (self.batch_size,), + (self.num_relations,), + (self.batch_size,), + ) + scores = self.instance.score_r(h=h, all_relations=r, t=t) + if len(self.cls.relation_shape) == 0: + exp_shape = (self.batch_size, 1) + else: + exp_shape = (self.batch_size, self.num_relations) + self._check_scores(scores=scores, exp_shape=exp_shape) + + def test_score_r_slicing(self): + """Test score_r with slicing.""" + if len(self.cls.relation_shape) == 0: + raise SkipTest("No use in slicing relations for models without relation information.") + #: The equivalence for models with batch norm only holds in evaluation mode + self.instance.eval() + h, r, t = self._get_hrt( + (self.batch_size,), + (self.num_relations,), + (self.batch_size,), + ) + scores = self.instance.score_r(h=h, all_relations=r, t=t, slice_size=self.num_relations // 2 + 1) + scores_no_slice = self.instance.score_r(h=h, all_relations=r, t=t, slice_size=None) + self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice) + + def test_score_t(self): + """Test score_t.""" + h, r, t = self._get_hrt( + (self.batch_size,), + (self.batch_size,), + (self.num_entities,), + ) + scores = self.instance.score_t(h=h, r=r, all_entities=t) + self._check_scores(scores=scores, exp_shape=(self.batch_size, self.num_entities)) + + def test_score_t_slicing(self): + """Test score_t with slicing.""" + #: The equivalence for models with batch norm only holds in evaluation mode + self.instance.eval() + h, r, t = self._get_hrt( + (self.batch_size,), + (self.batch_size,), + (self.num_entities,), + ) + scores = self.instance.score_t(h=h, r=r, all_entities=t, slice_size=self.num_entities // 2 + 1) + scores_no_slice = self.instance.score_t(h=h, r=r, all_entities=t, slice_size=None) + self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice) + + def _check_close_scores(self, scores, scores_no_slice): + self.assertTrue(torch.isfinite(scores).all(), msg=f'Normal scores had nan:\n\t{scores}') + self.assertTrue(torch.isfinite(scores_no_slice).all(), msg=f'Slice scores had nan\n\t{scores}') + self.assertTrue(torch.allclose(scores, scores_no_slice), msg=f'Differences: {scores - scores_no_slice}') + + def _get_test_shapes(self) -> Collection[Tuple[ + Tuple[int, int, int, int], + Tuple[int, int, int, int], + Tuple[int, int, int, int], + ]]: + """Return a set of test shapes for (h, r, t).""" + return ( + ( # single score + (1, 1, 1, 1), + (1, 1, 1, 1), + (1, 1, 1, 1), + ), + ( # score_r with multi-t + (self.batch_size, 1, 1, 1), + (1, 1, self.num_relations, 1), + (self.batch_size, 1, 1, self.num_entities // 2 + 1), + ), + ( # score_r with multi-t and broadcasted head + (1, 1, 1, 1), + (1, 1, self.num_relations, 1), + (self.batch_size, 1, 1, self.num_entities), + ), + ( # full cwa + (1, self.num_entities, 1, 1), + (1, 1, self.num_relations, 1), + (1, 1, 1, self.num_entities), + ), + ) + + def _get_output_shape( + self, + hs: Tuple[int, int, int, int], + rs: Tuple[int, int, int, int], + ts: Tuple[int, int, int, int], + ) -> Tuple[int, int, int, int]: + result = [max(ds) for ds in zip(hs, rs, ts)] + if len(self.instance.entity_shape) == 0: + result[1] = result[3] = 1 + if len(self.instance.relation_shape) == 0: + result[2] = 1 + return tuple(result) + + def test_forward(self): + """Test forward.""" + for hs, rs, ts in self._get_test_shapes(): + try: + h, r, t = self._get_hrt(hs, rs, ts) + scores = self.instance(h=h, r=r, t=t) + expected_shape = self._get_output_shape(hs, rs, ts) + self._check_scores(scores=scores, exp_shape=expected_shape) + except ValueError as error: + # check whether the error originates from batch norm for single element batches + small_batch_size = any(s[0] == 1 for s in (hs, rs, ts)) + has_batch_norm = any( + isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)) + for m in self.instance.modules() + ) + if small_batch_size and has_batch_norm: + logger.warning( + f"Skipping test for shapes {hs}, {rs}, {ts} because too small batch size for batch norm", + ) + continue + raise error + + def test_forward_consistency_with_functional(self): + """Test forward's consistency with functional.""" + # set in eval mode (otherwise there are non-deterministic factors like Dropout + self.instance.eval() + for hs, rs, ts in self._get_test_shapes(): + h, r, t = self._get_hrt(hs, rs, ts) + scores = self.instance(h=h, r=r, t=t) + kwargs = self.instance._prepare_for_functional(h=h, r=r, t=t) + scores_f = self.cls.func(**kwargs) + assert torch.allclose(scores, scores_f) + + def test_scores(self): + """Test individual scores.""" + # set in eval mode (otherwise there are non-deterministic factors like Dropout + self.instance.eval() + for _ in range(10): + # test multiple different initializations + self.instance.reset_parameters() + h, r, t = self._get_hrt((1, 1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1)) + kwargs = self.instance._prepare_for_functional(h=h, r=r, t=t) + + # calculate by functional + scores_f = self.cls.func(**kwargs).view(-1) + + # calculate manually + scores_f_manual = self._exp_score(**kwargs).view(-1) + assert torch.allclose(scores_f_manual, scores_f), f'Diff: {scores_f_manual - scores_f}' + + @abstractmethod + def _exp_score(self, **kwargs) -> torch.FloatTensor: + """Compute the expected score for a single-score batch.""" + raise NotImplementedError(f"{self.cls.__name__}({sorted(kwargs.keys())})") + + +class ComplExTests(InteractionTestCase, unittest.TestCase): + """Tests for ComplEx interaction function.""" + + cls = pykeen.nn.modules.ComplExInteraction + + def _exp_score(self, h, r, t) -> torch.FloatTensor: # noqa: D102 + h, r, t = [view_complex(x) for x in (h, r, t)] + return (h * r * torch.conj(t)).sum().real + + +class ConvETests(InteractionTestCase, unittest.TestCase): + """Tests for ConvE interaction function.""" + + cls = pykeen.nn.modules.ConvEInteraction + kwargs = dict( + embedding_height=1, + embedding_width=2, + kernel_height=2, + kernel_width=1, + embedding_dim=InteractionTestCase.dim, + ) + + def _get_hrt( + self, + *shapes: Tuple[int, ...], + **kwargs, + ) -> Tuple[Union[Representation, Sequence[Representation]], ...]: # noqa: D102 + h, r, t = super()._get_hrt(*shapes, **kwargs) + t_bias = torch.rand_like(t[..., 0, None]) + return h, r, (t, t_bias) + + def _exp_score( + self, embedding_height, embedding_width, h, hr1d, hr2d, input_channels, r, t, t_bias, + ) -> torch.FloatTensor: + x = torch.cat([ + h.view(1, input_channels, embedding_height, embedding_width), + r.view(1, input_channels, embedding_height, embedding_width), + ], dim=2) + x = hr2d(x) + x = x.view(-1, numpy.prod(x.shape[-3:])) + x = hr1d(x) + return (x.view(1, -1) * t.view(1, -1)).sum() + t_bias + + +class ConvKBTests(InteractionTestCase, unittest.TestCase): + """Tests for ConvKB interaction function.""" + + cls = pykeen.nn.modules.ConvKBInteraction + kwargs = dict( + embedding_dim=InteractionTestCase.dim, + num_filters=2 * InteractionTestCase.dim - 1, + ) + + def _exp_score(self, h, r, t, conv, activation, hidden_dropout, linear) -> torch.FloatTensor: # noqa: D102 + # W_L drop(act(W_C \ast ([h; r; t]) + b_C)) + b_L + # prepare conv input (N, C, H, W) + x = torch.stack([x.view(-1) for x in (h, r, t)], dim=1).view(1, 1, -1, 3) + x = conv(x) + x = hidden_dropout(activation(x)) + return linear(x.view(1, -1)) + + +class DistMultTests(InteractionTestCase, unittest.TestCase): + """Tests for DistMult interaction function.""" + + cls = pykeen.nn.modules.DistMultInteraction + + def _exp_score(self, h, r, t) -> torch.FloatTensor: + return (h * r * t).sum(dim=-1) + + +class ERMLPTests(InteractionTestCase, unittest.TestCase): + """Tests for ERMLP interaction function.""" + + cls = pykeen.nn.modules.ERMLPInteraction + kwargs = dict( + embedding_dim=InteractionTestCase.dim, + hidden_dim=2 * InteractionTestCase.dim - 1, + ) + + def _exp_score(self, h, r, t, hidden, activation, final) -> torch.FloatTensor: + x = torch.cat([x.view(-1) for x in (h, r, t)]) + return final(activation(hidden(x))) + + +class ERMLPETests(InteractionTestCase, unittest.TestCase): + """Tests for ERMLP-E interaction function.""" + + cls = pykeen.nn.modules.ERMLPEInteraction + kwargs = dict( + embedding_dim=InteractionTestCase.dim, + hidden_dim=2 * InteractionTestCase.dim - 1, + ) + + def _exp_score(self, h, r, t, mlp) -> torch.FloatTensor: # noqa: D102 + x = torch.cat([x.view(1, -1) for x in (h, r)], dim=-1) + return mlp(x).view(1, -1) @ t.view(-1, 1) + + +class HolETests(InteractionTestCase, unittest.TestCase): + """Tests for HolE interaction function.""" + + cls = pykeen.nn.modules.HolEInteraction + + def _exp_score(self, h, r, t) -> torch.FloatTensor: # noqa: D102 + h, t = [torch.fft.rfft(x.view(1, -1), dim=-1) for x in (h, t)] + h = torch.conj(h) + c = torch.fft.irfft(h * t, n=h.shape[-1], dim=-1) + return (c * r).sum() + + +class NTNTests(InteractionTestCase, unittest.TestCase): + """Tests for NTN interaction function.""" + + cls = pykeen.nn.modules.NTNInteraction + + num_slices: int = 11 + shape_kwargs = dict( + k=11, + ) + + def _exp_score(self, h, t, w, vt, vh, b, u, activation) -> torch.FloatTensor: + # f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r t + b_r) + # shapes: w: (k, dim, dim), vh/vt: (k, dim), b/u: (k,), h/t: (dim,) + # remove batch/num dimension + h, t, w, vt, vh, b, u = strip_dim(h, t, w, vt, vh, b, u) + score = 0. + for i in range(u.shape[-1]): + first_part = h.view(1, self.dim) @ w[i] @ t.view(self.dim, 1) + second_part = (vh[i] * h.view(-1)).sum() + third_part = (vt[i] * t.view(-1)).sum() + score = score + u[i] * activation(first_part + second_part + third_part + b[i]) + return score + + +class ProjETests(InteractionTestCase, unittest.TestCase): + """Tests for ProjE interaction function.""" + + cls = pykeen.nn.modules.ProjEInteraction + kwargs = dict( + embedding_dim=InteractionTestCase.dim, + ) + + def _exp_score(self, h, r, t, d_e, d_r, b_c, b_p, activation) -> torch.FloatTensor: + # f(h, r, t) = g(t z(D_e h + D_r r + b_c) + b_p) + h, r, t = strip_dim(h, r, t) + return (t * activation((d_e * h) + (d_r * r) + b_c)).sum() + b_p + + +class RESCALTests(InteractionTestCase, unittest.TestCase): + """Tests for RESCAL interaction function.""" + + cls = pykeen.nn.modules.RESCALInteraction + + def _exp_score(self, h, r, t) -> torch.FloatTensor: + # f(h, r, t) = h @ r @ t + h, r, t = strip_dim(h, r, t) + return h.view(1, -1) @ r @ t.view(-1, 1) + + +class KG2ETests(InteractionTestCase, unittest.TestCase): + """Tests for KG2E interaction function.""" + + cls = pykeen.nn.modules.KG2EInteraction + + def _exp_score(self, exact, h_mean, h_var, r_mean, r_var, similarity, t_mean, t_var): + assert similarity == "KL" + h_mean, h_var, r_mean, r_var, t_mean, t_var = strip_dim(h_mean, h_var, r_mean, r_var, t_mean, t_var) + e_mean, e_var = h_mean - t_mean, h_var + t_var + p = torch.distributions.MultivariateNormal(loc=e_mean, covariance_matrix=torch.diag(e_var)) + q = torch.distributions.MultivariateNormal(loc=r_mean, covariance_matrix=torch.diag(r_var)) + return -torch.distributions.kl.kl_divergence(p, q) + + +class TuckerTests(InteractionTestCase, unittest.TestCase): + """Tests for Tucker interaction function.""" + + cls = pykeen.nn.modules.TuckerInteraction + kwargs = dict( + embedding_dim=InteractionTestCase.dim, + ) + + def _exp_score(self, bn_h, bn_hr, core_tensor, do_h, do_r, do_hr, h, r, t) -> torch.FloatTensor: + # DO_{hr}(BN_{hr}(DO_h(BN_h(h)) x_1 DO_r(W x_2 r))) x_3 t + h, r, t = strip_dim(h, r, t) + a = do_r((core_tensor * r[None, :, None]).sum(dim=1, keepdims=True)) # shape: (embedding_dim, 1, embedding_dim) + b = do_h(bn_h(h.view(1, -1))).view(-1) # shape: (embedding_dim) + c = (b[:, None, None] * a).sum(dim=0, keepdims=True) # shape: (1, 1, embedding_dim) + d = do_hr(bn_hr((c.view(1, -1)))).view(1, 1, -1) # shape: (1, 1, 1, embedding_dim) + return (d * t[None, None, :]).sum() + + +class RotatETests(InteractionTestCase, unittest.TestCase): + """Tests for RotatE interaction function.""" + + cls = pykeen.nn.modules.RotatEInteraction + + def _get_hrt(self, *shapes): # noqa: D102 + # normalize length of r + h, r, t = super()._get_hrt(*shapes) + rc = view_complex(r) + rl = (rc.abs() ** 2).sum(dim=-1).sqrt() + r = r / rl.unsqueeze(dim=-1) + return h, r, t + + def _exp_score(self, h, r, t) -> torch.FloatTensor: # noqa: D102 + h, r, t = strip_dim(*(view_complex(x) for x in (h, r, t))) + # check for unit length + assert torch.allclose((r.abs() ** 2).sum(dim=-1).sqrt(), torch.ones(1)) + d = h * r - t + return -(d.abs() ** 2).sum(dim=-1).sqrt() + + +class TranslationalInteractionTests(InteractionTestCase): + """Common tests for translational interaction.""" + + kwargs = dict( + p=2, + ) + + def _additional_score_checks(self, scores): + assert (scores <= 0).all() + + +class TransDTests(TranslationalInteractionTests, unittest.TestCase): + """Tests for TransD interaction function.""" + + cls = pykeen.nn.modules.TransDInteraction + shape_kwargs = dict( + e=3, + ) + + def test_manual_small_relation_dim(self): + """Manually test the value of the interaction function.""" + # entity embeddings + h = t = torch.as_tensor(data=[2., 2.], dtype=torch.float).view(1, 2) + h_p = t_p = torch.as_tensor(data=[3., 3.], dtype=torch.float).view(1, 2) + + # relation embeddings + r = torch.as_tensor(data=[4.], dtype=torch.float).view(1, 1) + r_p = torch.as_tensor(data=[5.], dtype=torch.float).view(1, 1) + + # Compute Scores + scores = self.instance.score_hrt(h=(h, h_p), r=(r, r_p), t=(t, t_p)) + first_score = scores[0].item() + self.assertAlmostEqual(first_score, -16, delta=0.01) + + def test_manual_big_relation_dim(self): + """Manually test the value of the interaction function.""" + # entity embeddings + h = t = torch.as_tensor(data=[2., 2.], dtype=torch.float).view(1, 2) + h_p = t_p = torch.as_tensor(data=[3., 3.], dtype=torch.float).view(1, 2) + + # relation embeddings + r = torch.as_tensor(data=[3., 3., 3.], dtype=torch.float).view(1, 3) + r_p = torch.as_tensor(data=[4., 4., 4.], dtype=torch.float).view(1, 3) + + # Compute Scores + scores = self.instance.score_hrt(h=(h, h_p), r=(r, r_p), t=(t, t_p)) + self.assertAlmostEqual(scores.item(), -27, delta=0.01) + + def _exp_score(self, h, r, t, h_p, r_p, t_p, p, power_norm) -> torch.FloatTensor: # noqa: D102 + assert power_norm + h_bot = project_entity(e=h, e_p=h_p, r_p=r_p) + t_bot = project_entity(e=t, e_p=t_p, r_p=r_p) + return -((h_bot + r - t_bot) ** p).sum() + + +class TransETests(TranslationalInteractionTests, unittest.TestCase): + """Tests for TransE interaction function.""" + + cls = pykeen.nn.modules.TransEInteraction + + def _exp_score(self, h, r, t, p, power_norm) -> torch.FloatTensor: + assert not power_norm + return -(h + r - t).norm(p=p, dim=-1) + + +class TransHTests(TranslationalInteractionTests, unittest.TestCase): + """Tests for TransH interaction function.""" + + cls = pykeen.nn.modules.TransHInteraction + + def _exp_score(self, h, w_r, d_r, t, p, power_norm) -> torch.FloatTensor: # noqa: D102 + assert not power_norm + h, w_r, d_r, t = strip_dim(h, w_r, d_r, t) + h, t = [x - (x * w_r).sum() * w_r for x in (h, t)] + return -(h + d_r - t).norm(p=p) + + +class TransRTests(TranslationalInteractionTests, unittest.TestCase): + """Tests for TransR interaction function.""" + + cls = pykeen.nn.modules.TransRInteraction + shape_kwargs = dict( + e=3, + ) + + def test_manual(self): + """Manually test the value of the interaction function.""" + # Compute Scores + h = torch.as_tensor(data=[2, 2], dtype=torch.float32).view(1, 2) + r = torch.as_tensor(data=[4, 4], dtype=torch.float32).view(1, 2) + m_r = torch.as_tensor(data=[5, 5, 6, 6], dtype=torch.float32).view(1, 2, 2) + t = torch.as_tensor(data=[2, 2], dtype=torch.float32).view(1, 2) + scores = self.instance.score_hrt(h=h, r=(r, m_r), t=t) + first_score = scores[0].item() + self.assertAlmostEqual(first_score, -32, delta=1.0e-04) + + def _exp_score(self, h, r, m_r, t, p, power_norm) -> torch.FloatTensor: + assert power_norm + h, r, m_r, t = strip_dim(h, r, m_r, t) + h_bot, t_bot = [clamp_norm(x.unsqueeze(dim=0) @ m_r, p=2, dim=-1, maxnorm=1.) for x in (h, t)] + return -((h_bot + r - t_bot) ** p).sum() + + +class SETests(TranslationalInteractionTests, unittest.TestCase): + """Tests for SE interaction function.""" + + cls = pykeen.nn.modules.StructuredEmbeddingInteraction + + def _exp_score(self, h, t, r_h, r_t, p, power_norm) -> torch.FloatTensor: + assert not power_norm + # -\|R_h h - R_t t\| + h, t, r_h, r_t = strip_dim(h, t, r_h, r_t) + h = r_h @ h.unsqueeze(dim=-1) + t = r_t @ t.unsqueeze(dim=-1) + return -(h - t).norm(p) + + +class UMTests(TranslationalInteractionTests, unittest.TestCase): + """Tests for UM interaction function.""" + + cls = pykeen.nn.modules.UnstructuredModelInteraction + + def _exp_score(self, h, t, p, power_norm) -> torch.FloatTensor: + assert power_norm + # -\|h - t\| + h, t = strip_dim(h, t) + return -(h - t).pow(p).sum() + + +class SimplEInteractionTests(InteractionTestCase, unittest.TestCase): + """Tests for SimplE interaction function.""" + + cls = pykeen.nn.modules.SimplEInteraction + + def _exp_score(self, h, r, t, h_inv, r_inv, t_inv, clamp) -> torch.FloatTensor: + h, r, t, h_inv, r_inv, t_inv = strip_dim(h, r, t, h_inv, r_inv, t_inv) + assert clamp is None + return 0.5 * distmult_interaction(h, r, t) + 0.5 * distmult_interaction(h_inv, r_inv, t_inv) + + +class InteractionTestsTestCase(ptb.TestsTestCase[Interaction]): + """Test for tests for all interaction functions.""" + + base_cls = Interaction + base_test = InteractionTestCase + skip_cls = { + TranslationalInteraction, + LiteralInteraction, + } diff --git a/tests/test_model_mode.py b/tests/test_model_mode.py index ad18335f7f..51478feaee 100644 --- a/tests/test_model_mode.py +++ b/tests/test_model_mode.py @@ -3,14 +3,13 @@ """Test that models are set in the right mode when they're training.""" import unittest -from dataclasses import dataclass +from unittest.mock import MagicMock import torch -from torch import nn from pykeen.datasets import Nations -from pykeen.models import TransE -from pykeen.models.base import EntityRelationEmbeddingModel, Model +from pykeen.models import Model, TransE +from pykeen.testing.mocks import MockModel from pykeen.triples import TriplesFactory from pykeen.utils import resolve_device @@ -81,9 +80,9 @@ class TestBaseModelScoringFunctions(unittest.TestCase): def setUp(self): """Prepare for testing the scoring functions.""" self.generator = torch.random.manual_seed(seed=42) - self.triples_factory = MinimalTriplesFactory + self.triples_factory = MagicMock(num_relations=2, num_entities=2) self.device = resolve_device() - self.model = SimpleInteractionModel(triples_factory=self.triples_factory).to(self.device) + self.model = MockModel(triples_factory=self.triples_factory).to(self.device) def test_alignment_of_score_t_fall_back(self) -> None: """Test if ``BaseModule.score_t`` aligns with ``BaseModule.score_hrt``.""" @@ -106,8 +105,8 @@ def test_alignment_of_score_t_fall_back(self) -> None: device=self.device, ) scores_t_function = self.model.score_t(hr_batch=hr_batch).flatten() - scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch) - assert all(scores_t_function == scores_hrt_function) + scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch).flatten() + assert (scores_t_function == scores_hrt_function).all() def test_alignment_of_score_h_fall_back(self) -> None: """Test if ``BaseModule.score_h`` aligns with ``BaseModule.score_hrt``.""" @@ -130,8 +129,8 @@ def test_alignment_of_score_h_fall_back(self) -> None: device=self.device, ) scores_h_function = self.model.score_h(rt_batch=rt_batch).flatten() - scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch) - assert all(scores_h_function == scores_hrt_function) + scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch).flatten() + assert (scores_h_function == scores_hrt_function).all() def test_alignment_of_score_r_fall_back(self) -> None: """Test if ``BaseModule.score_r`` aligns with ``BaseModule.score_hrt``.""" @@ -154,41 +153,5 @@ def test_alignment_of_score_r_fall_back(self) -> None: device=self.device, ) scores_r_function = self.model.score_r(ht_batch=ht_batch).flatten() - scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch) - assert all(scores_r_function == scores_hrt_function) - - -class SimpleInteractionModel(EntityRelationEmbeddingModel): - """A model with a simple interaction function for testing the base model.""" - - def __init__(self, triples_factory: TriplesFactory): - super().__init__(triples_factory=triples_factory) - self.entity_embeddings = nn.Embedding(self.num_entities, self.embedding_dim) - self.relation_embeddings = nn.Embedding(self.num_relations, self.embedding_dim) - - def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 - # Get embeddings - h = self.entity_embeddings(hrt_batch[:, 0]) - r = self.relation_embeddings(hrt_batch[:, 1]) - t = self.entity_embeddings(hrt_batch[:, 2]) - - return torch.sum(h + r + t, dim=1) - - def reset_parameters_(self) -> Model: # noqa: D102 - pass # Not needed for unittest - - -@dataclass -class MinimalTriplesFactory: - """A triples factory with minial attributes to allow the model to initiate.""" - - relation_to_id = { - "0": 0, - "1": 1, - } - entity_to_id = { - "0": 0, - "1": 1, - } - num_entities = 2 - num_relations = 2 + scores_hrt_function = self.model.score_hrt(hrt_batch=hrt_batch).flatten() + assert (scores_r_function == scores_hrt_function).all() diff --git a/tests/test_models.py b/tests/test_models.py index a71074a282..907f5809f1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -8,7 +8,7 @@ import traceback import unittest from typing import Any, ClassVar, Mapping, Optional, Type -from unittest.mock import patch +from unittest.mock import MagicMock, patch import numpy import pytest @@ -24,63 +24,77 @@ from pykeen.datasets.nations import NATIONS_TEST_PATH, NATIONS_TRAIN_PATH, Nations from pykeen.models import _MODELS from pykeen.models.base import ( - EntityEmbeddingModel, - EntityRelationEmbeddingModel, + ERModel, Model, - MultimodalModel, _extend_batch, get_novelty_mask, ) from pykeen.models.cli import build_cli_from_cls -from pykeen.models.unimodal.rgcn import ( - inverse_indegree_edge_weights, - inverse_outdegree_edge_weights, +from pykeen.models.multimodal.base import LiteralModel +from pykeen.nn.representation import ( + RGCNRepresentations, inverse_indegree_edge_weights, inverse_outdegree_edge_weights, symmetric_edge_weights, ) -from pykeen.models.unimodal.trans_d import _project_entity -from pykeen.nn import Embedding, RepresentationModule +from pykeen.regularizers import LpRegularizer, collect_regularization_terms +from pykeen.testing.mocks import MockRepresentations from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop, TrainingLoop from pykeen.triples import TriplesFactory -from pykeen.utils import all_in_bounds, clamp_norm, set_random_seed +from pykeen.utils import all_in_bounds, set_random_seed SKIP_MODULES = { Model.__name__, + ERModel.__name__, + LiteralModel.__name__, 'DummyModel', - MultimodalModel.__name__, - EntityEmbeddingModel.__name__, - EntityRelationEmbeddingModel.__name__, 'MockModel', 'models', 'get_model_cls', - 'SimpleInteractionModel', } -for cls in MultimodalModel.__subclasses__(): - SKIP_MODULES.add(cls.__name__) +SKIP_MODULES.update({ + cls.__name__ + for cls in LiteralModel.__subclasses__() +}) _EPSILON = 1.0e-07 -class _CustomRepresentations(RepresentationModule): - """A custom representation module with minimal implementation.""" +class ERModelTests(unittest.TestCase): + """Test basic functionality of ERModel.""" - def __init__(self, num_entities: int, embedding_dim: int = 2): - super().__init__() - self.num_embeddings = num_entities - self.embedding_dim = embedding_dim - self.x = nn.Parameter(torch.rand(embedding_dim)) + def setUp(self) -> None: + """Set up the test instance.""" + self.model = ERModel( + triples_factory=MagicMock(), + interaction=MagicMock(), + ) + + def test_add_weight_regularizer_non_existing(self): + """Test that an assertion is raised for add_weight_regularizer for a non-existing weight.""" + with self.assertRaises(KeyError): + self.model.append_weight_regularizer( + parameter="this.weight.does.not.exist", + regularizer=..., + ) - def forward(self, indices: Optional[torch.LongTensor] = None) -> torch.FloatTensor: - n = self.num_embeddings if indices is None else indices.shape[0] - return self.x.unsqueeze(dim=0).repeat(n, 1) + def test_add_weight_regularizer(self): + """Test add_weight_regularizer.""" + # add weighted submodules + self.model.linear = nn.Linear(3, 2) + self.model.sub_model = nn.Sequential( + nn.Linear(2, 3), + nn.LeakyReLU(), + self.model.linear, + ) + + # try to add regularizer to existing weight + self.model.append_weight_regularizer( + parameter="linear.weight", + regularizer=LpRegularizer(), + ) - def get_in_canonical_shape( - self, - indices: Optional[torch.LongTensor] = None, - ) -> torch.FloatTensor: - x = self(indices=indices) - if indices is None: - return x.unsqueeze(dim=0) - return x.unsqueeze(dim=1) + # check it gets found by collect + term = collect_regularization_terms(self.model) + assert torch.is_tensor(term) class _ModelTestCase: @@ -179,6 +193,7 @@ def test_reset_parameters_(self): def _check_scores(self, batch, scores) -> None: """Check the scores produced by a forward function.""" + # TODO: Move score checks to Interaction tests? # check for finite values by default self.assertTrue(torch.all(torch.isfinite(scores)).item(), f'Some scores were not finite:\n{scores}') @@ -300,23 +315,25 @@ def test_save_load_model_state(self): **(self.model_kwargs or {}), ).to_device_() - def _equal_embeddings(a: RepresentationModule, b: RepresentationModule) -> bool: - """Test whether two embeddings are equal.""" - return (a(indices=None) == b(indices=None)).all() + def _equal_weights(a: nn.Module, b: nn.Module) -> bool: + """Test whether two modules are equal.""" + a_state = a.state_dict() + b_state = b.state_dict() + if a_state.keys() != b_state.keys(): + return False + for key, original_value in a_state.items(): + if not torch.allclose(original_value, b_state[key]): + return False + return True - if isinstance(original_model, EntityEmbeddingModel): - assert not _equal_embeddings(original_model.entity_embeddings, loaded_model.entity_embeddings) - if isinstance(original_model, EntityRelationEmbeddingModel): - assert not _equal_embeddings(original_model.relation_embeddings, loaded_model.relation_embeddings) + assert not _equal_weights(original_model, loaded_model) with tempfile.TemporaryDirectory() as tmpdirname: file_path = os.path.join(tmpdirname, 'test.pt') original_model.save_state(path=file_path) loaded_model.load_state(path=file_path) - if isinstance(original_model, EntityEmbeddingModel): - assert _equal_embeddings(original_model.entity_embeddings, loaded_model.entity_embeddings) - if isinstance(original_model, EntityRelationEmbeddingModel): - assert _equal_embeddings(original_model.relation_embeddings, loaded_model.relation_embeddings) + + assert _equal_weights(original_model, loaded_model) @property def cli_extras(self): @@ -400,17 +417,6 @@ def test_has_hpo_defaults(self): else: self.assertIsInstance(d, dict) - def test_post_parameter_update_regularizer(self): - """Test whether post_parameter_update resets the regularization term.""" - # set regularizer term - self.model.regularizer.regularization_term = None - - # call post_parameter_update - self.model.post_parameter_update() - - # assert that the regularization term has been reset - assert self.model.regularizer.regularization_term == torch.zeros(1, dtype=torch.float, device=self.model.device) - def test_post_parameter_update(self): """Test whether post_parameter_update correctly enforces model constraints.""" # do one optimization step @@ -442,7 +448,7 @@ def test_score_h_with_score_hrt_equality(self) -> None: scores_h = self.model.score_h(batch) scores_hrt = super(self.model.__class__, self.model).score_h(batch) except NotImplementedError: - self.fail(msg='Score_h not yet implemented') + self.fail(msg=f'{self.model.__class__.__name__}.score_h() has not yet been implemented') except RuntimeError as e: if str(e) == 'fft: ATen not compiled with MKL support': self.skipTest(str(e)) @@ -463,7 +469,7 @@ def test_score_r_with_score_hrt_equality(self) -> None: scores_r = self.model.score_r(batch) scores_hrt = super(self.model.__class__, self.model).score_r(batch) except NotImplementedError: - self.fail(msg='Score_h not yet implemented') + self.fail(msg=f'{self.model.__class__.__name__}.score_r() has not yet been implemented') except RuntimeError as e: if str(e) == 'fft: ATen not compiled with MKL support': self.skipTest(str(e)) @@ -484,7 +490,7 @@ def test_score_t_with_score_hrt_equality(self) -> None: scores_t = self.model.score_t(batch) scores_hrt = super(self.model.__class__, self.model).score_t(batch) except NotImplementedError: - self.fail(msg='Score_h not yet implemented') + self.fail(msg=f'{self.model.__class__.__name__}.score_t() has not yet been implemented') except RuntimeError as e: if str(e) == 'fft: ATen not compiled with MKL support': self.skipTest(str(e)) @@ -496,44 +502,42 @@ def test_score_t_with_score_hrt_equality(self) -> None: def test_reset_parameters_constructor_call(self): """Tests whether reset_parameters is called in the constructor.""" with patch.object(self.model_cls, 'reset_parameters_', return_value=None) as mock_method: - try: - self.model_cls( - triples_factory=self.factory, - embedding_dim=self.embedding_dim, - **(self.model_kwargs or {}), - ) - except TypeError as error: - assert error.args == ("'NoneType' object is not callable",) + self.model_cls( + triples_factory=self.factory, + embedding_dim=self.embedding_dim, + **(self.model_kwargs or {}), + ) mock_method.assert_called_once() def test_custom_representations(self): """Tests whether we can provide custom representations.""" - if isinstance(self.model, EntityEmbeddingModel): - old_embeddings = self.model.entity_embeddings - self.model.entity_embeddings = _CustomRepresentations( + if not isinstance(self.model, ERModel): + self.skipTest(f'Not testing custom representations for model: {self.model.__class__.__name__}') + + old_entity_reps = self.model.entity_representations + self.model.entity_representations = nn.ModuleList([ + MockRepresentations( num_entities=self.factory.num_entities, - embedding_dim=old_embeddings.embedding_dim, + shape=er.base_embeddings.shape if isinstance(er, RGCNRepresentations) else er.shape, ) - # call some functions - self.model.reset_parameters_() - self.test_score_hrt() - self.test_score_t() - # reset to old state - self.model.entity_embeddings = old_embeddings - elif isinstance(self.model, EntityRelationEmbeddingModel): - old_embeddings = self.model.relation_embeddings - self.model.relation_embeddings = _CustomRepresentations( - num_entities=self.factory.num_relations, - embedding_dim=old_embeddings.embedding_dim, + for er in old_entity_reps + ]) + old_relation_reps = self.model.relation_representations + self.model.relation_representations = nn.ModuleList([ + MockRepresentations( + num_entities=self.factory.num_entities, + shape=er.shape, ) - # call some functions - self.model.reset_parameters_() - self.test_score_hrt() - self.test_score_t() - # reset to old state - self.model.relation_embeddings = old_embeddings - else: - self.skipTest(f'Not testing custom representations for model: {self.model.__class__.__name__}') + for er in old_relation_reps + ]) + + # call some functions + self.model.reset_parameters_() + self.test_score_hrt() + self.test_score_t() + # reset to old state + self.model.entity_representations = old_entity_reps + self.model.relation_representations = old_relation_reps class _DistanceModelTestCase(_ModelTestCase): @@ -590,7 +594,7 @@ def _check_constraints(self): Entity embeddings have to have unit L2 norm. """ - entity_norms = self.model.entity_embeddings(indices=None).norm(p=2, dim=-1) + entity_norms = self.model.entity_representations[0](indices=None).norm(p=2, dim=-1) assert torch.allclose(entity_norms, torch.ones_like(entity_norms)) def _test_score_all_triples(self, k: Optional[int], batch_size: int = 16): @@ -672,7 +676,11 @@ def _check_constraints(self): Entity embeddings have to have at most unit L2 norm. """ - assert all_in_bounds(self.model.entity_embeddings(indices=None).norm(p=2, dim=-1), high=1., a_tol=_EPSILON) + assert all_in_bounds( + self.model.entity_representations[0](indices=None).norm(p=2, dim=-1), + high=1., + a_tol=_EPSILON, + ) class _TestKG2E(_ModelTestCase): @@ -686,10 +694,13 @@ def _check_constraints(self): * Entity and relation embeddings have to have at most unit L2 norm. * Covariances have to have values between c_min and c_max """ - for embedding in (self.model.entity_embeddings, self.model.relation_embeddings): + low = self.model.entity_representations[1].constrainer.keywords['min'] + high = self.model.entity_representations[1].constrainer.keywords['max'] + + for embedding in (self.model.entity_representations[0], self.model.relation_representations[0]): assert all_in_bounds(embedding(indices=None).norm(p=2, dim=-1), high=1., a_tol=_EPSILON) - for cov in (self.model.entity_covariances, self.model.relation_covariances): - assert all_in_bounds(cov(indices=None), low=self.model.c_min, high=self.model.c_max) + for cov in (self.model.entity_representations[1], self.model.relation_representations[1]): + assert all_in_bounds(cov(indices=None), low=low, high=high) class TestKG2EWithKL(_TestKG2E, unittest.TestCase): @@ -713,12 +724,6 @@ class _BaseNTNTest(_ModelTestCase, unittest.TestCase): model_cls = pykeen.models.NTN - def test_can_slice(self): - """Test that the slicing properties are calculated correctly.""" - self.assertTrue(self.model.can_slice_h) - self.assertFalse(self.model.can_slice_r) - self.assertTrue(self.model.can_slice_t) - class TestNTNLowMemory(_BaseNTNTest): """Test the NTN model with automatic memory optimization.""" @@ -767,7 +772,7 @@ def _check_constraints(self): Enriched embeddings have to be reset. """ - assert self.model.entity_representations.enriched_embeddings is None + assert self.model.entity_representations[0].enriched_embeddings is None class TestRGCNBasis(_TestRGCN, unittest.TestCase): @@ -806,7 +811,7 @@ def _check_constraints(self): """ relation_abs = ( self.model - .relation_embeddings(indices=None) + .relation_representations[0](indices=None) .view(self.factory.num_relations, -1, 2) .norm(p=2, dim=-1) ) @@ -829,7 +834,7 @@ def _check_constraints(self): Entity embeddings have to have unit L2 norm. """ - norms = self.model.entity_embeddings(indices=None).norm(p=2, dim=-1) + norms = self.model.entity_representations[0](indices=None).norm(p=2, dim=-1) assert torch.allclose(norms, torch.ones_like(norms)) @@ -862,151 +867,9 @@ def _check_constraints(self): Entity and relation embeddings have to have at most unit L2 norm. """ - for emb in (self.model.entity_embeddings, self.model.relation_embeddings): + for emb in (self.model.entity_representations[0], self.model.relation_representations[0]): assert all_in_bounds(emb(indices=None).norm(p=2, dim=-1), high=1., a_tol=_EPSILON) - def test_score_hrt_manual(self): - """Manually test interaction function of TransD.""" - # entity embeddings - weights = torch.as_tensor(data=[[2., 2.], [4., 4.]], dtype=torch.float) - entity_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - entity_embeddings._embeddings.weight.data.copy_(weights) - self.model.entity_embeddings = entity_embeddings - - projection_weights = torch.as_tensor(data=[[3., 3.], [2., 2.]], dtype=torch.float) - entity_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - entity_projection_embeddings._embeddings.weight.data.copy_(projection_weights) - self.model.entity_projections = entity_projection_embeddings - - # relation embeddings - relation_weights = torch.as_tensor(data=[[4.], [4.]], dtype=torch.float) - relation_embeddings = Embedding( - num_embeddings=2, - embedding_dim=1, - ) - relation_embeddings._embeddings.weight.data.copy_(relation_weights) - self.model.relation_embeddings = relation_embeddings - - relation_projection_weights = torch.as_tensor(data=[[5.], [3.]], dtype=torch.float) - relation_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=1, - ) - relation_projection_embeddings._embeddings.weight.data.copy_(relation_projection_weights) - self.model.relation_projections = relation_projection_embeddings - - # Compute Scores - batch = torch.as_tensor(data=[[0, 0, 0], [0, 0, 1]], dtype=torch.long) - scores = self.model.score_hrt(hrt_batch=batch) - self.assertEqual(scores.shape[0], 2) - self.assertEqual(scores.shape[1], 1) - first_score = scores[0].item() - self.assertAlmostEqual(first_score, -16, delta=0.01) - - # Use different dimension for relation embedding: relation_dim > entity_dim - # relation embeddings - relation_weights = torch.as_tensor(data=[[3., 3., 3.], [3., 3., 3.]], dtype=torch.float) - relation_embeddings = Embedding( - num_embeddings=2, - embedding_dim=3, - ) - relation_embeddings._embeddings.weight.data.copy_(relation_weights) - self.model.relation_embeddings = relation_embeddings - - relation_projection_weights = torch.as_tensor(data=[[4., 4., 4.], [4., 4., 4.]], dtype=torch.float) - relation_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=3, - ) - relation_projection_embeddings._embeddings.weight.data.copy_(relation_projection_weights) - self.model.relation_projections = relation_projection_embeddings - - # Compute Scores - batch = torch.as_tensor(data=[[0, 0, 0]], dtype=torch.long) - scores = self.model.score_hrt(hrt_batch=batch) - self.assertAlmostEqual(scores.item(), -27, delta=0.01) - - batch = torch.as_tensor(data=[[0, 0, 0], [0, 0, 0]], dtype=torch.long) - scores = self.model.score_hrt(hrt_batch=batch) - self.assertEqual(scores.shape[0], 2) - self.assertEqual(scores.shape[1], 1) - first_score = scores[0].item() - second_score = scores[1].item() - self.assertAlmostEqual(first_score, -27, delta=0.01) - self.assertAlmostEqual(second_score, -27, delta=0.01) - - # Use different dimension for relation embedding: relation_dim < entity_dim - # entity embeddings - weights = torch.as_tensor(data=[[1., 1., 1.], [1., 1., 1.]], dtype=torch.float) - entity_embeddings = Embedding( - num_embeddings=2, - embedding_dim=3, - ) - entity_embeddings._embeddings.weight.data.copy_(weights) - self.model.entity_embeddings = entity_embeddings - - projection_weights = torch.as_tensor(data=[[2., 2., 2.], [2., 2., 2.]], dtype=torch.float) - entity_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=3, - ) - entity_projection_embeddings._embeddings.weight.data.copy_(projection_weights) - self.model.entity_projections = entity_projection_embeddings - - # relation embeddings - relation_weights = torch.as_tensor(data=[[3., 3.], [3., 3.]], dtype=torch.float) - relation_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - relation_embeddings._embeddings.weight.data.copy_(relation_weights) - self.model.relation_embeddings = relation_embeddings - - relation_projection_weights = torch.as_tensor(data=[[4., 4.], [4., 4.]], dtype=torch.float) - relation_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - relation_projection_embeddings._embeddings.weight.data.copy_(relation_projection_weights) - self.model.relation_projections = relation_projection_embeddings - - # Compute Scores - batch = torch.as_tensor(data=[[0, 0, 0], [0, 0, 0]], dtype=torch.long) - scores = self.model.score_hrt(hrt_batch=batch) - self.assertEqual(scores.shape[0], 2) - self.assertEqual(scores.shape[1], 1) - first_score = scores[0].item() - second_score = scores[1].item() - self.assertAlmostEqual(first_score, -18, delta=0.01) - self.assertAlmostEqual(second_score, -18, delta=0.01) - - def test_project_entity(self): - """Test _project_entity.""" - # random entity embeddings & projections - e = torch.rand(1, self.model.num_entities, self.embedding_dim, generator=self.generator) - e = clamp_norm(e, maxnorm=1, p=2, dim=-1) - e_p = torch.rand(1, self.model.num_entities, self.embedding_dim, generator=self.generator) - - # random relation embeddings & projections - r = torch.rand(self.batch_size, 1, self.model.relation_dim, generator=self.generator) - r = clamp_norm(r, maxnorm=1, p=2, dim=-1) - r_p = torch.rand(self.batch_size, 1, self.model.relation_dim, generator=self.generator) - - # project - e_bot = _project_entity(e=e, e_p=e_p, r=r, r_p=r_p) - - # check shape: - assert e_bot.shape == (self.batch_size, self.model.num_entities, self.model.relation_dim) - - # check normalization - assert (torch.norm(e_bot, dim=-1, p=2) <= 1.0 + 1.0e-06).all() - class TestTransE(_DistanceModelTestCase, unittest.TestCase): """Test the TransE model.""" @@ -1018,7 +881,7 @@ def _check_constraints(self): Entity embeddings have to have unit L2 norm. """ - entity_norms = self.model.entity_embeddings(indices=None).norm(p=2, dim=-1) + entity_norms = self.model.entity_representations[0](indices=None).norm(p=2, dim=-1) assert torch.allclose(entity_norms, torch.ones_like(entity_norms)) @@ -1032,7 +895,7 @@ def _check_constraints(self): Entity embeddings have to have unit L2 norm. """ - entity_norms = self.model.normal_vector_embeddings(indices=None).norm(p=2, dim=-1) + entity_norms = self.model.relation_representations[1](indices=None).norm(p=2, dim=-1) assert torch.allclose(entity_norms, torch.ones_like(entity_norms)) @@ -1044,49 +907,12 @@ class TestTransR(_DistanceModelTestCase, unittest.TestCase): 'relation_dim': 4, } - def test_score_hrt_manual(self): - """Manually test interaction function of TransR.""" - # entity embeddings - weights = torch.as_tensor(data=[[2., 2.], [3., 3.]], dtype=torch.float) - entity_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - entity_embeddings._embeddings.weight.data.copy_(weights) - self.model.entity_embeddings = entity_embeddings - - # relation embeddings - relation_weights = torch.as_tensor(data=[[4., 4], [5., 5.]], dtype=torch.float) - relation_embeddings = Embedding( - num_embeddings=2, - embedding_dim=2, - ) - relation_embeddings._embeddings.weight.data.copy_(relation_weights) - self.model.relation_embeddings = relation_embeddings - - relation_projection_weights = torch.as_tensor(data=[[5., 5., 6., 6.], [7., 7., 8., 8.]], dtype=torch.float) - relation_projection_embeddings = Embedding( - num_embeddings=2, - embedding_dim=4, - ) - relation_projection_embeddings._embeddings.weight.data.copy_(relation_projection_weights) - self.model.relation_projections = relation_projection_embeddings - - # Compute Scores - batch = torch.as_tensor(data=[[0, 0, 0], [0, 0, 1]], dtype=torch.long) - scores = self.model.score_hrt(hrt_batch=batch) - self.assertEqual(scores.shape[0], 2) - self.assertEqual(scores.shape[1], 1) - first_score = scores[0].item() - # second_score = scores[1].item() - self.assertAlmostEqual(first_score, -32, delta=0.01) - def _check_constraints(self): """Check model constraints. Entity and relation embeddings have to have at most unit L2 norm. """ - for emb in (self.model.entity_embeddings, self.model.relation_embeddings): + for emb in (self.model.entity_representations[0], self.model.relation_representations[0]): assert all_in_bounds(emb(indices=None).norm(p=2, dim=-1), high=1., a_tol=1.0e-06) @@ -1113,7 +939,7 @@ class TestTesting(unittest.TestCase): def test_testing(self): """Check that there's a test for all models. - For now, this is excluding multimodel models. Not sure how to test those yet. + For now, this is excluding multimodal models. Not sure how to test those yet. """ model_names = { cls.__name__ @@ -1128,7 +954,7 @@ def test_testing(self): isinstance(value, type) and issubclass(value, _ModelTestCase) and not name.startswith('_') - and not issubclass(value.model_cls, MultimodalModel) + and not issubclass(value.model_cls, LiteralModel) ) } tested_model_names -= SKIP_MODULES @@ -1233,14 +1059,13 @@ def test_symmetric_edge_weights(self): self._test_message_weighting(weight_func=symmetric_edge_weights) -class TestRandom(unittest.TestCase): - """Extra tests.""" +class TestModelUtilities(unittest.TestCase): + """Extra tests for utility functions.""" def test_abstract(self): """Test that classes are checked as abstract properly.""" - self.assertTrue(EntityEmbeddingModel._is_base_model) - self.assertTrue(EntityRelationEmbeddingModel._is_base_model) - self.assertTrue(MultimodalModel._is_base_model) + self.assertTrue(ERModel._is_base_model) + self.assertTrue(LiteralModel._is_base_model) for model_cls in _MODELS: self.assertFalse( model_cls._is_base_model, diff --git a/tests/test_nn.py b/tests/test_nn.py new file mode 100644 index 0000000000..0ed640ff96 --- /dev/null +++ b/tests/test_nn.py @@ -0,0 +1,512 @@ +# -*- coding: utf-8 -*- + +"""Unittest for the :mod:`pykeen.nn` module.""" + +import itertools +import unittest +from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence +from unittest.mock import MagicMock, Mock + +import numpy +import pytest +import torch +from torch.nn import functional + +from pykeen.nn import Embedding, EmbeddingSpecification, LiteralRepresentations, RepresentationModule +from pykeen.nn.representation import ( + CANONICAL_DIMENSIONS, RGCNRepresentations, convert_to_canonical_shape, get_expected_canonical_shape, +) +from pykeen.nn.sim import _torch_kl_similarity, kullback_leibler_similarity +from pykeen.testing import base as ptb +from pykeen.testing.mocks import MockRepresentations +from pykeen.triples import TriplesFactory +from pykeen.typing import GaussianDistribution + + +class RepresentationModuleTestCase(ptb.GenericTestCase[RepresentationModule]): + """Tests for RepresentationModule.""" + + #: The batch size + batch_size: int = 3 + + #: The number of representations + num: int = 5 + + #: The expected shape of an individual representation + exp_shape: Sequence[int] = (5,) + + def post_instantiation_hook(self) -> None: # noqa: D102 + self.instance.reset_parameters() + + def test_max_id(self): + """Test the maximum ID.""" + assert self.instance.max_id == self.num + + def test_shape(self): + """Test the shape.""" + assert self.instance.shape == self.exp_shape + + def _test_forward(self, indices: Optional[torch.LongTensor]): + """Test the forward method.""" + x = self.instance(indices=indices) + assert torch.is_tensor(x) + assert x.dtype == torch.float32 + n = self.num if indices is None else indices.shape[0] + assert x.shape == tuple([n, *self.instance.shape]) + self._verify_content(x=x, indices=indices) + + def _verify_content(self, x, indices): + """Additional verification.""" + assert x.requires_grad + + def _valid_indices(self) -> Iterable[torch.LongTensor]: + return [ + torch.randint(self.num, size=(self.batch_size,)), + torch.randperm(self.num), + torch.randperm(self.num).repeat(2), + ] + + def _invalid_indices(self) -> Iterable[torch.LongTensor]: + return [ + torch.as_tensor([self.num], dtype=torch.long), # too high index + torch.randint(self.num, size=(2, 3)), # too many indices + ] + + def test_forward_without_indices(self): + """Test forward without providing indices.""" + self._test_forward(indices=None) + + def test_forward_with_indices(self): + """Test forward with providing indices.""" + for indices in self._valid_indices(): + self._test_forward(indices=indices) + + def test_forward_with_invalid_indices(self): + """Test whether passing invalid indices crashes.""" + for indices in self._invalid_indices(): + with pytest.raises((IndexError, RuntimeError)): + self._test_forward(indices=indices) + + def _test_in_canonical_shape(self, indices: Optional[torch.LongTensor]): + """Test get_in_canonical_shape with the given indices.""" + # test both, using the actual dimension, and its name + for dim in itertools.chain(CANONICAL_DIMENSIONS.keys(), CANONICAL_DIMENSIONS.values()): + # batch_size, d1, d2, d3, * + x = self.instance.get_in_canonical_shape(dim=dim, indices=indices) + + # data type + assert torch.is_tensor(x) + assert x.dtype == torch.float32 # todo: adjust? + assert x.ndimension() == 4 + len(self.exp_shape) + + # get expected shape + exp_shape = get_expected_canonical_shape( + indices=indices, + dim=dim, + suffix_shape=self.exp_shape, + num=self.num, + ) + assert x.shape == exp_shape + + def test_get_in_canonical_shape_without_indices(self): + """Test get_in_canonical_shape without indices, i.e. with 1-n scoring.""" + self._test_in_canonical_shape(indices=None) + + def test_get_in_canonical_shape_with_indices(self): + """Test get_in_canonical_shape with 1-dimensional indices.""" + for indices in self._valid_indices(): + self._test_in_canonical_shape(indices=indices) + + def test_get_in_canonical_shape_with_2d_indices(self): + """Test get_in_canonical_shape with 2-dimensional indices.""" + indices = torch.randint(self.num, size=(self.batch_size, 2)) + self._test_in_canonical_shape(indices=indices) + + +def _check_call( + self: unittest.TestCase, + call_count: int, + should_be_called: bool, + wrapped: MagicMock, + kwargs: Optional[Mapping[str, Any]], +) -> int: + """ + Check whether a wrapped method is called. + + :param self: + The test cas calling the check + :param call_count: + The previous call count. + :param should_be_called: + Whether it should be called. + :param wrapped: + The wrapped method. + :param kwargs: + The expected kwargs when called. + + :return: + The updated counter. + """ + if should_be_called: + call_count += 1 + + self.assertEqual(call_count, wrapped.call_count) + + # Lets check the tuple + self.assertIsInstance(wrapped.call_args, tuple) + + call_size = len(wrapped.call_args) + # Make sure tuple at least has positional arguments, could be 3 if kwargs available + self.assertLessEqual(2, call_size) + + if call_size == 2: + args_idx, kwargs_idx = 0, 1 + else: # call_size == 3: + args_idx, kwargs_idx = 1, 2 + + # called with one positional argument ... + self.assertEqual(1, len(wrapped.call_args[args_idx]), + msg=f'Args: {wrapped.call_args[args_idx]} Kwargs: {wrapped.call_args[kwargs_idx]}') + # .. and additional key-word based arguments. + self.assertEqual(len(kwargs or {}), len(wrapped.call_args[kwargs_idx])) + else: + self.assertEqual(call_count, wrapped.call_count) + return call_count + + +class EmbeddingTests(RepresentationModuleTestCase, unittest.TestCase): + """Tests for Embedding.""" + + cls = Embedding + kwargs = dict( + num_embeddings=RepresentationModuleTestCase.num, + shape=RepresentationModuleTestCase.exp_shape, + ) + + def test_constructor_errors(self): + """Test error cases for constructor call.""" + for embedding_dim, shape in ( + (None, None), # neither + (3, (5, 3)), # both + ): + with pytest.raises(ValueError): + Embedding( + num_embeddings=self.num, + embedding_dim=embedding_dim, + shape=shape, + ) + + def _test_func_with_kwargs( + self, + name: str, + func, + kwargs: Optional[Mapping[str, Any]] = None, + reset_parameters_call: bool = False, + forward_call: bool = False, + post_parameter_update_call: bool = False, + ): + """Test initializer usage.""" + # wrap to check calls + wrapped = MagicMock(side_effect=func) + + # instantiate embedding + embedding_kwargs = {name: wrapped} + if kwargs is not None: + embedding_kwargs[f"{name}_kwargs"] = kwargs + embedding = Embedding( + num_embeddings=self.num, + shape=self.exp_shape, + **embedding_kwargs, + ) + + # check that nothing gets called in constructor + wrapped.assert_not_called() + call_count = 0 + + # check call in reset_parameters + embedding.reset_parameters() + call_count = _check_call( + self, + call_count=call_count, + should_be_called=reset_parameters_call, + wrapped=wrapped, + kwargs=kwargs, + ) + + # check call in forward + embedding.forward(indices=None) + call_count = _check_call( + self, + call_count=call_count, + should_be_called=forward_call, + wrapped=wrapped, + kwargs=kwargs, + ) + + # check call in post_parameter_update + embedding.post_parameter_update() + _check_call( + self, + call_count=call_count, + should_be_called=post_parameter_update_call, + wrapped=wrapped, + kwargs=kwargs, + ) + + def test_initializer(self): + """Test initializer.""" + self._test_func_with_kwargs( + name="initializer", + func=torch.nn.init.normal_, + reset_parameters_call=True, + ) + + def test_initializer_with_kwargs(self): + """Test initializer with kwargs.""" + self._test_func_with_kwargs( + name="initializer", + func=torch.nn.init.normal_, + kwargs=dict(mean=3), + reset_parameters_call=True, + ) + + def test_normalizer(self): + """Test normalizer.""" + self._test_func_with_kwargs( + name="normalizer", + func=functional.normalize, + forward_call=True, + ) + + def test_normalizer_kwargs(self): + """Test normalizer with kwargs.""" + self._test_func_with_kwargs( + name="normalizer", + func=functional.normalize, + kwargs=dict(p=1), + forward_call=True, + ) + + def test_constrainer(self): + """Test constrainer.""" + self._test_func_with_kwargs( + name="constrainer", + func=functional.normalize, + post_parameter_update_call=True, + ) + + def test_constrainer_kwargs(self): + """Test constrainer with kwargs.""" + self._test_func_with_kwargs( + name="constrainer", + func=functional.normalize, + kwargs=dict(p=1), + post_parameter_update_call=True, + ) + + +class TensorEmbeddingTests(RepresentationModuleTestCase, unittest.TestCase): + """Tests for Embedding with 2-dimensional shape.""" + + cls = Embedding + exp_shape = (3, 7) + kwargs = dict( + num_embeddings=RepresentationModuleTestCase.num, + shape=(3, 7), + ) + + +class LiteralRepresentationsTests(EmbeddingTests, unittest.TestCase): + """Tests for literal embeddings.""" + + cls = LiteralRepresentations + + def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]: # noqa: D102 + # requires own kwargs + kwargs.clear() + self.numeric_literals = torch.rand(self.num, *self.exp_shape) + kwargs["numeric_literals"] = self.numeric_literals + return kwargs + + def _verify_content(self, x, indices): # noqa: D102 + exp_x = self.numeric_literals + if indices is not None: + exp_x = exp_x[indices] + self.assertTrue(torch.allclose(x, exp_x)) + + +class RGCNRepresentationTests(RepresentationModuleTestCase, unittest.TestCase): + """Test RGCN representations.""" + + cls = RGCNRepresentations + kwargs = dict( + num_bases_or_blocks=2, + embedding_dim=RepresentationModuleTestCase.exp_shape[0], + ) + num_relations: int = 7 + num_triples: int = 31 + num_bases: int = 2 + + def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]: # noqa: D102 + kwargs = super()._pre_instantiation_hook(kwargs=kwargs) + # TODO: use triple generation + # generate random triples + mapped_triples = numpy.stack([ + numpy.random.randint(max_id, size=(self.num_triples,)) + for max_id in (self.num, self.num_relations, self.num) + ], axis=-1) + entity_names = [f"e_{i}" for i in range(self.num)] + relation_names = [f"r_{i}" for i in range(self.num_relations)] + triples = numpy.stack([ + [names[i] for i in col.tolist()] + for col, names in zip( + mapped_triples.T, + (entity_names, relation_names, entity_names), + ) + ]) + kwargs["triples_factory"] = TriplesFactory.from_labeled_triples(triples=triples) + return kwargs + + +class RepresentationModuleTestsTestCase(ptb.TestsTestCase[RepresentationModule]): + """Test that there are tests for all representation modules.""" + + base_cls = RepresentationModule + base_test = RepresentationModuleTestCase + skip_cls = {MockRepresentations} + + +class EmbeddingSpecificationTests(unittest.TestCase): + """Tests for EmbeddingSpecification.""" + + #: The number of embeddings + num: int = 3 + + def test_make(self): + """Test make.""" + initializer = Mock() + normalizer = Mock() + constrainer = Mock() + regularizer = Mock() + for embedding_dim, shape in [ + (None, (3,)), + (None, (3, 5)), + (3, None), + ]: + spec = EmbeddingSpecification( + embedding_dim=embedding_dim, + shape=shape, + initializer=initializer, + normalizer=normalizer, + constrainer=constrainer, + regularizer=regularizer, + ) + emb = spec.make(num_embeddings=self.num) + + # check shape + self.assertEqual(emb.embedding_dim, (embedding_dim or int(numpy.prod(shape)))) + self.assertEqual(emb.shape, (shape or (embedding_dim,))) + self.assertEqual(emb.num_embeddings, self.num) + + # check attributes + self.assertIs(emb.initializer, initializer) + self.assertIs(emb.normalizer, normalizer) + self.assertIs(emb.constrainer, constrainer) + self.assertIs(emb.regularizer, regularizer) + + +class KullbackLeiblerTests(unittest.TestCase): + """Tests for the vectorized computation of KL divergences.""" + + batch_size: int = 2 + num_heads: int = 3 + num_relations: int = 5 + num_tails: int = 7 + d: int = 11 + + def setUp(self) -> None: # noqa: D102 + dims = dict(h=self.num_heads, r=self.num_relations, t=self.num_tails) + (self.h_mean, self.r_mean, self.t_mean), (self.h_var, self.r_var, self.t_var) = [ + [ + convert_to_canonical_shape( + x=torch.rand(self.batch_size, num, self.d), + dim=dim, + num=num, + batch_size=self.batch_size, + ) + for dim, num in dims.items() + ] + for _ in ("mean", "diagonal_covariance") + ] + # ensure positivity + self.h_var, self.r_var, self.t_var = [x.exp() for x in (self.h_var, self.r_var, self.t_var)] + + def _get(self, name: str): + if name == "h": + mean, var = self.h_mean, self.h_var + elif name == "r": + mean, var = self.r_mean, self.r_var + elif name == "t": + mean, var = self.t_mean, self.t_var + elif name == "e": + mean, var = self.h_mean - self.t_mean, self.h_var + self.t_var + else: + raise ValueError + return GaussianDistribution(mean=mean, diagonal_covariance=var) + + def _get_kl_similarity_torch(self): + # compute using pytorch + e_mean = self.h_mean - self.t_mean + e_var = self.h_var + self.t_var + r_mean, r_var = self.r_var, self.r_mean + self.assertTrue((e_var > 0).all()) + sim2 = torch.empty(self.batch_size, self.num_heads, self.num_relations, self.num_tails) + for bi, hi, ri, ti in itertools.product( + range(self.batch_size), + range(self.num_heads), + range(self.num_relations), + range(self.num_tails), + ): + # prepare distributions + e_loc = e_mean[bi, hi, 0, ti, :] + r_loc = r_mean[bi, 0, ri, 0, :] + e_cov = torch.diag(e_var[bi, hi, 0, ti, :]) + r_cov = torch.diag(r_var[bi, 0, ri, 0, :]) + p = torch.distributions.MultivariateNormal( + loc=e_loc, + covariance_matrix=e_cov, + ) + q = torch.distributions.MultivariateNormal( + loc=r_loc, + covariance_matrix=r_cov, + ) + sim2[bi, hi, ri, ti] = -torch.distributions.kl_divergence(p=p, q=q).view(-1) + return sim2 + + def test_against_torch_builtin(self): + """Compare value against torch.distributions.""" + # compute using pykeen + h, r, t = [self._get(name=name) for name in "hrt"] + sim = kullback_leibler_similarity(h=h, r=r, t=t, exact=True) + sim2 = _torch_kl_similarity(h=h, r=r, t=t) + self.assertTrue(torch.allclose(sim, sim2), msg=f'Difference: {(sim - sim2).abs()}') + + def test_self_similarity(self): + """Check value of similarity to self.""" + # e: (batch_size, num_heads, num_tails, d) + # https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Properties + # divergence = 0 => similarity = -divergence = 0 + # (h - t), r + r = self._get(name="r") + h = GaussianDistribution(mean=2 * r.mean, diagonal_covariance=0.5 * r.diagonal_covariance) + t = GaussianDistribution(mean=r.mean, diagonal_covariance=0.5 * r.diagonal_covariance) + sim = kullback_leibler_similarity(h=h, r=r, t=t, exact=True) + self.assertTrue(torch.allclose(sim, torch.zeros_like(sim)), msg=f'Sim: {sim}') + + def test_value_range(self): + """Check the value range.""" + # https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence#Properties + # divergence >= 0 => similarity = -divergence <= 0 + h, r, t = [self._get(name=name) for name in "hrt"] + sim = kullback_leibler_similarity(h=h, r=r, t=t, exact=True) + self.assertTrue((sim <= 0).all()) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 48e0dc7fee..da05d878a3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -7,9 +7,10 @@ import pandas as pd -import pykeen.regularizers +import pykeen from pykeen.datasets import Nations -from pykeen.models.base import Model +from pykeen.models import DistMult +from pykeen.nn import Embedding from pykeen.pipeline import PipelineResult, pipeline from pykeen.regularizers import NoRegularizer @@ -185,7 +186,7 @@ class TestAttributes(unittest.TestCase): def test_specify_regularizer(self): """Test a pipeline that uses a regularizer.""" for regularizer, cls in [ - (None, pykeen.regularizers.NoRegularizer), + (None, pykeen.regularizers.LpRegularizer), # if none, goes to default. ('no', pykeen.regularizers.NoRegularizer), (NoRegularizer, pykeen.regularizers.NoRegularizer), ('powersum', pykeen.regularizers.PowerSumRegularizer), @@ -193,11 +194,16 @@ def test_specify_regularizer(self): ]: with self.subTest(regularizer=regularizer): pipeline_result = pipeline( - model='TransE', + model='DistMult', dataset='Nations', regularizer=regularizer, - training_kwargs=dict(num_epochs=1), + training_kwargs=dict(num_epochs=1, use_tqdm=False), ) self.assertIsInstance(pipeline_result, PipelineResult) - self.assertIsInstance(pipeline_result.model, Model) - self.assertIsInstance(pipeline_result.model.regularizer, cls) + self.assertIsInstance(pipeline_result.model, DistMult) + self.assertEqual(1, len(pipeline_result.model.entity_representations)) + self.assertIsInstance(pipeline_result.model.entity_representations[0], Embedding) + self.assertIsNone(pipeline_result.model.entity_representations[0].regularizer) + self.assertEqual(1, len(pipeline_result.model.relation_representations)) + self.assertIsInstance(pipeline_result.model.relation_representations[0], Embedding) + self.assertIsInstance(pipeline_result.model.relation_representations[0].regularizer, cls) diff --git a/tests/test_regularizers.py b/tests/test_regularizers.py index cd562fedf2..e252821bfd 100644 --- a/tests/test_regularizers.py +++ b/tests/test_regularizers.py @@ -5,15 +5,18 @@ import logging import unittest from typing import Any, ClassVar, Dict, Optional, Type +from unittest.mock import MagicMock import torch +from torch import nn from torch.nn import functional from pykeen.datasets import Nations -from pykeen.models import ConvKB, RESCAL, TransH +from pykeen.models import ConvKB, ERModel, RESCAL +from pykeen.nn import EmbeddingSpecification from pykeen.regularizers import ( CombinedRegularizer, LpRegularizer, NoRegularizer, PowerSumRegularizer, Regularizer, - TransHRegularizer, + TransHRegularizer, collect_regularization_terms, ) from pykeen.triples import TriplesFactory from pykeen.typing import MappedTriples @@ -45,9 +48,8 @@ def setUp(self) -> None: self.triples_factory = Nations().training self.device = resolve_device() self.regularizer = self.regularizer_cls( - device=self.device, **(self.regularizer_kwargs or {}), - ) + ).to(self.device) self.positive_batch = self.triples_factory.mapped_triples[:self.batch_size, :].to(device=self.device) def test_model(self) -> None: @@ -55,27 +57,22 @@ def test_model(self) -> None: # Use RESCAL as it regularizes multiple tensors of different shape. model = RESCAL( triples_factory=self.triples_factory, - regularizer=self.regularizer, ).to(self.device) - # Check if regularizer is stored correctly. - self.assertEqual(model.regularizer, self.regularizer) + # check for regularizer + assert sum(1 for m in model.modules() if isinstance(m, Regularizer)) > 0 # Forward pass (should update regularizer) model.score_hrt(hrt_batch=self.positive_batch) - # Call post_parameter_update (should reset regularizer) - model.post_parameter_update() - - # Check if regularization term is reset - self.assertEqual(0., model.regularizer.term) - - def test_reset(self) -> None: - """Test method `reset`.""" - # Call method - self.regularizer.reset() + # check that regularization term is accessible + term = collect_regularization_terms(model) + assert torch.is_tensor(term) + assert term.requires_grad - self.assertEqual(0., self.regularizer.regularization_term) + # second time should be 0. + term = collect_regularization_terms(model) + assert term == 0. def test_update(self) -> None: """Test method `update`.""" @@ -84,17 +81,18 @@ def test_update(self) -> None: b = torch.rand(self.batch_size, 20, device=self.device, generator=self.generator) # Call update - self.regularizer.update(a, b) + assert self.regularizer.update(a, b) # check shape - self.assertEqual((1,), self.regularizer.term.shape) + assert 1 == self.regularizer.regularization_term.numel() # compute expected term exp_penalties = torch.stack([self._expected_penalty(x) for x in (a, b)]) expected_term = torch.sum(exp_penalties).view(1) * self.regularizer.weight assert expected_term.shape == (1,) - self.assertAlmostEqual(self.regularizer.term.item(), expected_term.item()) + observed = self.regularizer.pop_regularization_term() + self.assertAlmostEqual(observed.item(), expected_term.item()) def test_forward(self) -> None: """Test the regularizer's `forward` method.""" @@ -114,6 +112,22 @@ def test_forward(self) -> None: else: assert (expected_penalty == penalty).all() + def test_pop_regularization_term(self): + """Test pop_regularization_term.""" + regularization_term = self.regularizer.pop_regularization_term() + + # check type + assert isinstance(regularization_term, float) or torch.is_tensor(regularization_term) + + # float only if there is not real regularization term + if isinstance(regularization_term, float): + assert regularization_term == 0.0 + + # check that the regularizer has been clear + assert isinstance(self.regularizer.regularization_term, float) + assert self.regularizer.regularization_term == 0.0 + assert self.regularizer.updated is False + def _expected_penalty(self, x: torch.FloatTensor) -> torch.FloatTensor: """Compute expected penalty for given tensor.""" return None @@ -168,8 +182,8 @@ class CombinedRegularizerTest(_RegularizerTestCase, unittest.TestCase): regularizer_cls = CombinedRegularizer regularizer_kwargs = { 'regularizers': [ - LpRegularizer(weight=0.1, p=1, device=resolve_device()), - LpRegularizer(weight=0.7, p=2, device=resolve_device()), + LpRegularizer(weight=0.1, p=1), + LpRegularizer(weight=0.7, p=2), ], } @@ -202,58 +216,47 @@ class TransHRegularizerTest(unittest.TestCase): regularizer_kwargs: Dict num_entities: int num_relations: int - entities_weight: torch.Tensor - relations_weight: torch.Tensor - normal_vector_weight: torch.Tensor + entities_weight: nn.Parameter + relations_weight: nn.Parameter + normal_vector_weight: nn.Parameter def setUp(self) -> None: """Set up the test case.""" self.generator = torch.random.manual_seed(seed=42) self.device = resolve_device() - self.regularizer_kwargs = {'weight': .5, 'epsilon': 1e-5} - self.regularizer = TransHRegularizer( - device=self.device, - **(self.regularizer_kwargs or {}), - ) self.num_entities = 10 self.num_relations = 5 - self.entities_weight = torch.rand(self.num_entities, 10, device=self.device, generator=self.generator) - self.relations_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator) - self.normal_vector_weight = torch.rand(self.num_relations, 20, device=self.device, generator=self.generator) + self.entities_weight = self._rand_param(10) + self.relations_weight = self._rand_param(20) + self.normal_vector_weight = self._rand_param(20) + self.weight = .5 + self.epsilon = 1e-5 + self.regularizer_kwargs = dict() + self.regularizer = TransHRegularizer( + weight=self.weight, epsilon=self.epsilon, + entity_embeddings=self.entities_weight, + normal_vector_embeddings=self.normal_vector_weight, + relation_embeddings=self.relations_weight, + ) + + def _rand_param(self, n) -> nn.Parameter: + return nn.Parameter(torch.rand(self.num_entities, n, device=self.device, generator=self.generator)) def test_update(self): """Test update function of TransHRegularizer.""" - # Tests that exception will be thrown when more than or less than three tensors are passed - with self.assertRaises(KeyError) as context: - self.regularizer.update( - self.entities_weight, - self.normal_vector_weight, - self.relations_weight, - torch.rand(self.num_entities, 10, device=self.device, generator=self.generator), - ) - self.assertTrue('Expects exactly three tensors' in context.exception) - - self.regularizer.update( - self.entities_weight, - self.normal_vector_weight, - ) - self.assertTrue('Expects exactly three tensors' in context.exception) - # Test that regularization term is computed correctly - self.regularizer.update(self.entities_weight, self.normal_vector_weight, self.relations_weight) expected_term = self._expected_penalty() - weight = self.regularizer_kwargs.get('weight') - self.assertAlmostEqual(self.regularizer.term.item(), weight * expected_term.item()) + observed_term = self.regularizer.pop_regularization_term() + assert torch.allclose(observed_term, self.weight * expected_term) def _expected_penalty(self) -> torch.FloatTensor: # noqa: D102 # Entity soft constraint regularization_term = torch.sum(functional.relu(torch.norm(self.entities_weight, dim=-1) ** 2 - 1.0)) - epsilon = self.regularizer_kwargs.get('epsilon') # # Orthogonality soft constraint d_r_n = functional.normalize(self.relations_weight, dim=-1) regularization_term += torch.sum( - functional.relu(torch.sum((self.normal_vector_weight * d_r_n) ** 2, dim=-1) - epsilon), + functional.relu(torch.sum((self.normal_vector_weight * d_r_n) ** 2, dim=-1) - self.epsilon), ) return regularization_term @@ -275,23 +278,13 @@ def test_lp(self): self.assertIn('apply_only_once', ConvKB.regularizer_default_kwargs) self.assertTrue(ConvKB.regularizer_default_kwargs['apply_only_once']) regularizer = LpRegularizer( - device=self.device, **ConvKB.regularizer_default_kwargs, ) self._help_test_regularizer(regularizer) - def test_transh_regularizer(self): - """Test the TransH regularizer only updates once.""" - self.assertNotIn('apply_only_once', TransH.regularizer_default_kwargs) - regularizer = TransHRegularizer( - device=self.device, - **TransH.regularizer_default_kwargs, - ) - self._help_test_regularizer(regularizer) - def _help_test_regularizer(self, regularizer: Regularizer, n_tensors: int = 3): self.assertFalse(regularizer.updated) - self.assertEqual(0.0, regularizer.regularization_term.item()) + assert 0.0 == regularizer.regularization_term # After first update, should change the term first_tensors = [ @@ -312,6 +305,67 @@ def _help_test_regularizer(self, regularizer: Regularizer, n_tensors: int = 3): self.assertTrue(regularizer.updated) self.assertEqual(term, regularizer.regularization_term) - regularizer.reset() + regularizer.pop_regularization_term() self.assertFalse(regularizer.updated) - self.assertEqual(0.0, regularizer.regularization_term.item()) + assert 0.0 == regularizer.regularization_term + + +class TestRandom(unittest.TestCase): + """Test random regularization utilities.""" + + def test_collect_regularization_terms(self): + """Test whether collect_regularization_terms finds and resets all regularization terms.""" + regularizers = [ + LpRegularizer(), + PowerSumRegularizer(), + LpRegularizer(p=1, normalize=True, apply_only_once=True), + PowerSumRegularizer(normalize=True), + ] + model = ERModel( + triples_factory=MagicMock(num_entities=3, num_relations=2), + interaction=MagicMock(relation_shape=("d",), entity_shape=("d",)), + entity_representations=EmbeddingSpecification( + regularizer=regularizers[0], + embedding_dim=2, + ), + relation_representations=EmbeddingSpecification( + regularizer=regularizers[1], + embedding_dim=2, + ), + ) + + # add weighted modules + linear = nn.Linear(3, 2) + model.sub_module = nn.ModuleList([ + nn.Sequential( + linear, + nn.Linear(2, 3), + ), + nn.BatchNorm1d(2), + linear, # one module occuring twice + ]) + + # add weight regularizer + model.append_weight_regularizer( + parameter="sub_module.0.0.bias", + regularizer=regularizers[2], + ) + model.append_weight_regularizer( + parameter="entity_representations.0._embeddings.weight", + regularizer=regularizers[3], + ) + + # retrieve all regularization terms + collect_regularization_terms(model) + + # check that all terms are reset + found_regularizers = set() + for module in model.modules(): + if isinstance(module, Regularizer): + term = module.regularization_term + assert isinstance(term, float) + assert term == 0.0 + found_regularizers.add(id(module)) + + # check that all regularizers were found + self.assertEqual(found_regularizers, set(map(id, regularizers))) diff --git a/tests/training/test_utils.py b/tests/test_training_utils.py similarity index 100% rename from tests/training/test_utils.py rename to tests/test_training_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py index 91222860c4..e27773d455 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,16 +1,25 @@ # -*- coding: utf-8 -*- -"""Unittest for for global utilities.""" +"""Tests for the :mod:`pykeen.utils` module.""" +import functools import itertools +import operator +import random import string +import timeit import unittest +from typing import Iterable, Tuple +import numpy +import pytest import torch -from pykeen.nn import Embedding from pykeen.utils import ( - clamp_norm, compact_mapping, compose, flatten_dictionary, get_until_first_blank, torch_is_in_1d, + _CUDA_OOM_ERROR, _CUDNN_ERROR, calculate_broadcasted_elementwise_result_shape, clamp_norm, combine_complex, + compact_mapping, compose, estimate_cost_of_sequence, flatten_dictionary, get_optimal_sequence, + get_until_first_blank, is_cuda_oom_error, is_cudnn_error, project_entity, set_random_seed, split_complex, + tensor_product, tensor_sum, torch_is_in_1d, ) @@ -118,62 +127,8 @@ def test_regular(self): self.assertEqual("Broken line.", r) -class EmbeddingsInCanonicalShapeTests(unittest.TestCase): - """Test get_embedding_in_canonical_shape().""" - - #: The number of embeddings - num_embeddings: int = 3 - - #: The embedding dimension - embedding_dim: int = 2 - - def setUp(self) -> None: - """Initialize embedding.""" - self.embedding = Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim) - self.generator = torch.manual_seed(42) - self.embedding._embeddings.weight.data = torch.rand( - self.num_embeddings, - self.embedding_dim, - generator=self.generator, - ) - - def test_no_indices(self): - """Test getting all embeddings.""" - emb = self.embedding.get_in_canonical_shape(indices=None) - - # check shape - assert emb.shape == (1, self.num_embeddings, self.embedding_dim) - - # check values - exp = self.embedding(indices=None).view(1, self.num_embeddings, self.embedding_dim) - assert torch.allclose(emb, exp) - - def _test_with_indices(self, indices: torch.Tensor) -> None: - """Help tests with index.""" - emb = self.embedding.get_in_canonical_shape(indices=indices) - - # check shape - num_ind = indices.shape[0] - assert emb.shape == (num_ind, 1, self.embedding_dim) - - # check values - exp = torch.stack([self.embedding(i) for i in indices], dim=0).view(num_ind, 1, self.embedding_dim) - assert torch.allclose(emb, exp) - - def test_with_consecutive_indices(self): - """Test to retrieve all embeddings with consecutive indices.""" - indices = torch.arange(self.num_embeddings, dtype=torch.long) - self._test_with_indices(indices=indices) - - def test_with_indices_with_duplicates(self): - """Test to retrieve embeddings at random positions with duplicate indices.""" - indices = torch.randint( - self.num_embeddings, - size=(2 * self.num_embeddings,), - dtype=torch.long, - generator=self.generator, - ) - self._test_with_indices(indices=indices) +class TestUtils(unittest.TestCase): + """Tests for :mod:`pykeen.utils`.""" def test_compact_mapping(self): """Test ``compact_mapping()``.""" @@ -188,25 +143,24 @@ def test_compact_mapping(self): self.assertEqual(set(id_remapping.keys()), set(mapping.values())) self.assertEqual(set(id_remapping.values()), set(compacted_mapping.values())) + def test_clamp_norm(self): + """Test :func:`pykeen.utils.clamp_norm`.""" + max_norm = 1.0 + gen = torch.manual_seed(42) + eps = 1.0e-06 + for p in [1, 2, float('inf')]: + for _ in range(10): + x = torch.rand(10, 20, 30, generator=gen) + for dim in range(x.ndimension()): + x_c = clamp_norm(x, maxnorm=max_norm, p=p, dim=dim) -def test_clamp_norm(): - """Test clamp_norm() .""" - max_norm = 1.0 - gen = torch.manual_seed(42) - eps = 1.0e-06 - for p in [1, 2, float('inf')]: - for _ in range(10): - x = torch.rand(10, 20, 30, generator=gen) - for dim in range(x.ndimension()): - x_c = clamp_norm(x, maxnorm=max_norm, p=p, dim=dim) + # check maximum norm constraint + assert (x_c.norm(p=p, dim=dim) <= max_norm + eps).all() - # check maximum norm constraint - assert (x_c.norm(p=p, dim=dim) <= max_norm + eps).all() - - # unchanged values for small norms - norm = x.norm(p=p, dim=dim) - mask = torch.stack([(norm < max_norm)] * x.shape[dim], dim=dim) - assert (x_c[mask] == x[mask]).all() + # unchanged values for small norms + norm = x.norm(p=p, dim=dim) + mask = torch.stack([(norm < max_norm)] * x.shape[dim], dim=dim) + assert (x_c[mask] == x[mask]).all() def _get_torch_is_in_1d_result_naive( @@ -243,3 +197,187 @@ def test_torch_is_in_1d(): invert=invert, ) assert (result == expected_result).all() + + +def test_complex_utils(): + """Test complex tensor utilities.""" + re = torch.rand(20, 10) + im = torch.rand(20, 10) + x = combine_complex(x_re=re, x_im=im) + re2, im2 = split_complex(x) + assert (re2 == re).all() + assert (im2 == im).all() + + +def test_project_entity(): + """Test _project_entity.""" + batch_size = 2 + embedding_dim = 3 + relation_dim = 5 + num_entities = 7 + + # random entity embeddings & projections + e = torch.rand(1, num_entities, embedding_dim) + e = clamp_norm(e, maxnorm=1, p=2, dim=-1) + e_p = torch.rand(1, num_entities, embedding_dim) + + # random relation embeddings & projections + r_p = torch.rand(batch_size, 1, relation_dim) + + # project + e_bot = project_entity(e=e, e_p=e_p, r_p=r_p) + + # check shape: + assert e_bot.shape == (batch_size, num_entities, relation_dim) + + # check normalization + assert (torch.norm(e_bot, dim=-1, p=2) <= 1.0 + 1.0e-06).all() + + # check equivalence of re-formulation + # e_{\bot} = M_{re} e = (r_p e_p^T + I^{d_r \times d_e}) e + # = r_p (e_p^T e) + e' + m_re = r_p.unsqueeze(dim=-1) @ e_p.unsqueeze(dim=-2) + m_re = m_re + torch.eye(relation_dim, embedding_dim).view(1, 1, relation_dim, embedding_dim) + assert m_re.shape == (batch_size, num_entities, relation_dim, embedding_dim) + e_vanilla = (m_re @ e.unsqueeze(dim=-1)).squeeze(dim=-1) + e_vanilla = clamp_norm(e_vanilla, p=2, dim=-1, maxnorm=1) + assert torch.allclose(e_vanilla, e_bot) + + +class TestCudaExceptionsHandling(unittest.TestCase): + """Test handling of CUDA exceptions.""" + + not_cuda_error = RuntimeError("Something else.") + + def test_is_cuda_oom_error(self): + """Test handling of a CUDA out of memory exception.""" + error = RuntimeError(_CUDA_OOM_ERROR) + self.assertTrue(is_cuda_oom_error(runtime_error=error)) + self.assertFalse(is_cudnn_error(runtime_error=error)) + + self.assertFalse(is_cuda_oom_error(runtime_error=self.not_cuda_error)) + + def test_is_cudnn_error(self): + """Test handling of a cuDNN error.""" + error = RuntimeError(_CUDNN_ERROR) + self.assertTrue(is_cudnn_error(runtime_error=error)) + self.assertFalse(is_cuda_oom_error(runtime_error=error)) + + self.assertFalse(is_cudnn_error(runtime_error=self.not_cuda_error)) + + +def test_calculate_broadcasted_elementwise_result_shape(): + """Test calculate_broadcasted_elementwise_result_shape.""" + max_dim = 64 + for n_dim, _ in itertools.product(range(2, 5), range(10)): + a_shape = [1 for _ in range(n_dim)] + b_shape = [1 for _ in range(n_dim)] + for j in range(n_dim): + dim = 2 + random.randrange(max_dim) + mod = random.randrange(3) + if mod % 2 == 0: + a_shape[j] = dim + if mod > 0: + b_shape[j] = dim + a = torch.empty(*a_shape) + b = torch.empty(*b_shape) + shape = calculate_broadcasted_elementwise_result_shape(first=a.shape, second=b.shape) + c = a + b + exp_shape = c.shape + assert shape == exp_shape + + +def _generate_shapes( + n_dim: int = 5, + n_terms: int = 4, + iterations: int = 64, +) -> Iterable[Tuple[Tuple[int, ...], ...]]: + """Generate shapes.""" + max_shape = torch.randint(low=2, high=32, size=(128,)) + for _ in range(iterations): + # create broadcastable shapes + idx = torch.randperm(max_shape.shape[0])[:n_dim] + this_max_shape = max_shape[idx] + this_min_shape = torch.ones_like(this_max_shape) + shapes = [] + for _j in range(n_terms): + mask = this_min_shape + while not (1 < mask.sum() < n_dim): + mask = torch.as_tensor(torch.rand(size=(n_dim,)) < 0.3, dtype=max_shape.dtype) + this_array_shape = this_max_shape * mask + this_min_shape * (1 - mask) + shapes.append(tuple(this_array_shape.tolist())) + yield tuple(shapes) + + +@pytest.mark.slow +def test_estimate_cost_of_add_sequence(): + """Test ``estimate_cost_of_add_sequence()``.""" + set_random_seed(seed=42) + # create random array, estimate the costs of addition, and measure some execution times. + # then, compute correlation between the estimated cost, and the measured time. + data = [] + for shapes in _generate_shapes(): + arrays = [torch.empty(*shape) for shape in shapes] + cost = estimate_cost_of_sequence(*(a.shape for a in arrays)) + n_samples, time = timeit.Timer(stmt='sum(arrays)', globals=dict(arrays=arrays)).autorange() + consumption = time / n_samples + data.append((cost, consumption)) + a = numpy.asarray(data) + + # check for strong correlation between estimated costs and measured execution time + assert (numpy.corrcoef(x=a[:, 0], y=a[:, 1])[0, 1]) > 0.8 + + +@pytest.mark.slow +def test_get_optimal_sequence_caching(): + """Test caching of ``get_optimal_sequence()``.""" + for shapes in _generate_shapes(iterations=10): + # get optimal sequence + first_time = timeit.default_timer() + get_optimal_sequence(*shapes) + first_time = timeit.default_timer() - first_time + + # check caching + samples, second_time = timeit.Timer(stmt="get_optimal_sequence(*shapes)", globals=dict( + get_optimal_sequence=get_optimal_sequence, + shapes=shapes, + )).autorange() + second_time /= samples + + assert second_time < first_time + + +def test_get_optimal_sequence(): + """Test ``get_optimal_sequence()``.""" + for shapes in _generate_shapes(): + # get optimal sequence + opt_cost, opt_seq = get_optimal_sequence(*shapes) + + # check correct cost + exp_opt_cost = estimate_cost_of_sequence(*(shapes[i] for i in opt_seq)) + assert exp_opt_cost == opt_cost + + # check optimality + for perm in itertools.permutations(list(range(len(shapes)))): + cost = estimate_cost_of_sequence(*(shapes[i] for i in perm)) + assert cost >= opt_cost + + +def test_tensor_sum(): + """Test tensor_sum.""" + for shapes in _generate_shapes(): + tensors = [torch.rand(*shape) for shape in shapes] + result = tensor_sum(*tensors) + + # compare result to sequential addition + assert torch.allclose(result, sum(tensors)) + + +def test_tensor_product(): + """Test tensor_product.""" + for shapes in _generate_shapes(): + tensors = [torch.rand(*shape) for shape in shapes] + result = tensor_product(*tensors) + + # compare result to sequential addition + assert torch.allclose(result, functools.reduce(operator.mul, tensors[1:], tensors[0])) diff --git a/tox.ini b/tox.ini index 6f4e7793af..702e1dc71c 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ envlist = flake8 darglint pyroma + mypy # documentation linters/checkers readme doc8 @@ -35,6 +36,7 @@ whitelist_externals = /bin/cat /bin/cp /bin/mkdir + /bin/rm /usr/bin/cat /usr/bin/cp /usr/bin/mkdir @@ -112,18 +114,13 @@ skip_install = true commands = mypy --ignore-missing-imports \ src/pykeen/typing.py \ src/pykeen/utils.py \ - src/pykeen/version.py \ src/pykeen/nn \ src/pykeen/regularizers.py \ src/pykeen/losses.py \ - src/pykeen/optimizers.py \ - src/pykeen/sampling \ src/pykeen/trackers \ - src/pykeen/triples/generation.py \ - src/pykeen/triples/splitting.py \ - src/pykeen/triples/instances.py \ - src/pykeen/triples/utils.py \ + src/pykeen/models/base.py \ src/pykeen/triples/triples_factory.py + description = Run the mypy tool to check static typing on the project. [testenv:pyroma] @@ -157,6 +154,7 @@ extras = commands = mkdir -p {envtmpdir} cp -r source {envtmpdir}/source + rm -rf source/api sphinx-build -W -b html -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/html sphinx-build -W -b coverage -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/coverage cat {envtmpdir}/build/coverage/c.txt