Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💃 🥤 Extract interaction function from models #107

Closed
wants to merge 729 commits into from
Closed
Show file tree
Hide file tree
Changes from 208 commits
Commits
Show all changes
729 commits
Select commit Hold shift + click to select a range
3bb9f16
Update adding regularizers
cthoyt Nov 19, 2020
362e3fb
Update trans_h.py
cthoyt Nov 19, 2020
be130bf
Update name for clarity
cthoyt Nov 19, 2020
7dda103
Fix british
cthoyt Nov 19, 2020
e6d945b
Mark slow test
cthoyt Nov 19, 2020
3bb92c8
Update .travis.yml
cthoyt Nov 19, 2020
e0e123d
Update .travis.yml
cthoyt Nov 19, 2020
e5d33d8
Extend append_weight_regularizer input types
mberr Nov 19, 2020
ccfeb08
Re-introduce ConvKB regularizer comment, and use append_weight_regula…
mberr Nov 19, 2020
7d8eef3
Cleanup utils.__all__
mberr Nov 19, 2020
4ec4e2b
Add comment to test_scores
mberr Nov 19, 2020
2e55761
Merge branch 'master' into add_interaction_function_2
cthoyt Nov 19, 2020
71944da
Fix bug in upgrading
cthoyt Nov 19, 2020
859c085
Bandage on tricky typing
cthoyt Nov 19, 2020
47da63e
Make typing correct
cthoyt Nov 19, 2020
555f02d
Merge branch 'master' into add_interaction_function_2
cthoyt Nov 19, 2020
f1fbdeb
Fix typo
cthoyt Nov 19, 2020
af97e03
Fix unittest catching error
mberr Nov 20, 2020
46f99d0
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 20, 2020
32df6b6
Rename test case
mberr Nov 20, 2020
4c6fed4
Test multiple different initializations for test_scores
mberr Nov 20, 2020
9a8bc32
Fix typo
mberr Nov 20, 2020
80e059c
Add TODO
mberr Nov 20, 2020
302c8db
Remove wrong TODO
mberr Nov 20, 2020
7725749
Equivalentl rename dimensions for Tucker functional
mberr Nov 20, 2020
3c5be20
Fix exp_score for TuckerTests
mberr Nov 20, 2020
e43cefb
Fix some darglint issues for nn.emb docstrings
mberr Nov 20, 2020
9ff8104
Add missing parameter to docstring
mberr Nov 20, 2020
e2d2666
Improve docstring of _calculate_missing_shape_information
mberr Nov 20, 2020
0fdbd83
Add raises to docstring
mberr Nov 20, 2020
baca36e
Add raises to docstring
mberr Nov 20, 2020
26c270b
Fix docstring for _add_dim
mberr Nov 20, 2020
99090ef
asaf
cthoyt Nov 20, 2020
35b690d
Improve logging on tests that sometimes fail
cthoyt Nov 20, 2020
d6ba841
Add reset_parameters to InteractionTests
mberr Nov 20, 2020
e0c6f33
Use isfinite to also check for +/-inf
mberr Nov 20, 2020
c1b6194
Fix typo
mberr Nov 20, 2020
2e5549b
Add shape to RepresentationModule
mberr Nov 22, 2020
c33e1e5
Add max_id to RepresentationModule
mberr Nov 22, 2020
3a4af87
Simplify LCWA training loop
mberr Nov 23, 2020
28888b3
Fix simplification
mberr Nov 23, 2020
88d7c30
Let training instance creation respect use_tqdm
mberr Nov 23, 2020
02f8412
More explicit viewing
mberr Nov 23, 2020
4d974cb
Merge branch 'master' into add_interaction_function_2
cthoyt Nov 23, 2020
a4d3a9f
Add mypy testing
cthoyt Nov 23, 2020
5b042b4
Modify all interaction functions to use the broadcasted format
mberr Nov 23, 2020
d4047cf
Inline nop function
mberr Nov 23, 2020
71bc48a
Use matmul instead of einsum
mberr Nov 23, 2020
e68c959
rename file
mberr Nov 23, 2020
c55c499
Adjust get_in_canonical_shape
mberr Nov 23, 2020
da0790d
Fix _get_representations
mberr Nov 23, 2020
c439dc4
Allow forwarding embedding specifications to ERModel
mberr Nov 23, 2020
8743038
Update DistMult
mberr Nov 23, 2020
e346e52
Better integration of complex embeddings
mberr Nov 23, 2020
664bb5b
Update ConvKB
mberr Nov 23, 2020
1507c31
Update ERMLP - KG2E
mberr Nov 23, 2020
94b79cf
Update NTN - RotatE
mberr Nov 23, 2020
4912501
Fix SimplE
mberr Nov 23, 2020
2258aef
Update SE - UM
mberr Nov 23, 2020
c5081c3
Remove abstract base classes
mberr Nov 23, 2020
56001d4
Fix DistMult and broken imports
mberr Nov 23, 2020
19ef829
Fix broken annotations
mberr Nov 23, 2020
013adc0
Fix _prepare_representation_module_list
mberr Nov 23, 2020
e3ce1cf
Fix _prepare_representation_module_list - again
mberr Nov 23, 2020
661e932
Fix SE shape
mberr Nov 23, 2020
8314c75
Comment out unused code
cthoyt Nov 23, 2020
bde23b0
pass flake8
cthoyt Nov 23, 2020
9c58aeb
Fix ConvKB
mberr Nov 23, 2020
b1a3a92
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 23, 2020
40002d2
inline sizes for ConvE
mberr Nov 23, 2020
b7b9e71
Fix ConvE init
mberr Nov 23, 2020
6eda30a
Fix ConvE revert reshaping
mberr Nov 23, 2020
59e3bbf
Fix TransR functional form
mberr Nov 23, 2020
7b40cc8
Add generics to LiteralModel
cthoyt Nov 23, 2020
1192bbf
Style cleanup
cthoyt Nov 23, 2020
a3c015f
Update handling fo embedding dimension in literal models
cthoyt Nov 23, 2020
ad0fe00
Fix usage of chain_matmul
mberr Nov 23, 2020
8caba1e
Fix NTN einsum eq
mberr Nov 23, 2020
6d829f5
Fix TransD
mberr Nov 23, 2020
24a6bf7
Fix SimplEInteraction
mberr Nov 23, 2020
b9f83d4
Fix RGCNRepresentations
mberr Nov 23, 2020
4e0b156
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 23, 2020
d5a4595
Add verification to representations utility
mberr Nov 23, 2020
53c3045
Pass flake8 and mypy
cthoyt Nov 23, 2020
ab7deec
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 23, 2020
a744d14
hsh
cthoyt Nov 23, 2020
da4a259
Add shapes verification to passed representations
mberr Nov 23, 2020
862225e
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 23, 2020
10ef479
Fix superclass order for ERModel
mberr Nov 23, 2020
cbfab6a
Improve error messages
mberr Nov 23, 2020
9cb56b0
Extract canonical shape utility
mberr Nov 23, 2020
93f3bfa
Pass flake8 / mypy
cthoyt Nov 23, 2020
614759b
Simplify score_* for Interaction
mberr Nov 23, 2020
d2ecc7d
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 23, 2020
356df81
Skip shape tests with non-consistent head entity/tail entity numbers
mberr Nov 23, 2020
c51662b
fix TransR functional
mberr Nov 23, 2020
a262c30
Adjust InteractionModule and tests
mberr Nov 23, 2020
36e4204
Update strip dim utility
mberr Nov 23, 2020
95b41a3
Fix shapes for NTN
mberr Nov 23, 2020
ed03d53
Fix ProjE
mberr Nov 23, 2020
a1bd33e
Fix HolE
mberr Nov 23, 2020
f287ffd
Fix HolE
mberr Nov 23, 2020
d47c8f8
Move generic test classes
mberr Nov 23, 2020
8d7c2bb
Add shape information to SimplE module
mberr Nov 23, 2020
97866f0
Add SimplE test
mberr Nov 23, 2020
e828b17
Fix TransR relation projection matrices
mberr Nov 23, 2020
7ad01b3
Fix SE
mberr Nov 23, 2020
e07be94
For flake
mberr Nov 23, 2020
583a620
Do not use dictionary
mberr Nov 23, 2020
838b76d
Add annotation for xx
mberr Nov 23, 2020
a2ccc6e
Give mypy a type for xx
mberr Nov 23, 2020
69fe97f
Update minimum python version
cthoyt Nov 24, 2020
14b45ea
Update representation.py
cthoyt Nov 24, 2020
b0fe4f4
Extend Embedding tests
mberr Nov 24, 2020
6219d7a
Fix Embedding.forward for >1d embeddings
mberr Nov 24, 2020
dbe5c50
Add tests for LiteralRepresentations
mberr Nov 24, 2020
d7dcf90
Add utility for expected_canonical_shape
mberr Nov 24, 2020
2f861ab
Improve documentation of tests
mberr Nov 24, 2020
5fd64d8
Add test test for representations
mberr Nov 24, 2020
96c32c1
Add test for EmbeddingSpecification
mberr Nov 24, 2020
2649b95
Add test for constructor error raising
mberr Nov 24, 2020
c07c975
Add test for initializer
mberr Nov 24, 2020
3bf23c4
Add documentation
mberr Nov 24, 2020
4098723
Add trailing comma
mberr Nov 24, 2020
8d4a682
Add tests for normalizer
mberr Nov 24, 2020
5bac9f6
Add tests for constrainer; simplify code
mberr Nov 24, 2020
5d13369
Extract method
mberr Nov 24, 2020
2607fd8
Delete unused variable
mberr Nov 24, 2020
f32f748
Subclass LiteralRepresentationsTests from correct test super class
mberr Nov 24, 2020
8290fd1
Move RGCN representations to representation.py
mberr Nov 24, 2020
32a500b
Add tests for RGCN representations
mberr Nov 24, 2020
fcb3c92
Fix instantiation of RGCN test instance
mberr Nov 24, 2020
d93599f
Add exception to RGCN representations for indices in wrong shape
mberr Nov 24, 2020
fd181e8
Inline constants
mberr Nov 24, 2020
1a9e40c
Rename constant
mberr Nov 24, 2020
bc8bd66
Add upgrade_to_sequence to export of utils.py
mberr Nov 24, 2020
bc0c663
Use complex dtype for ComplEx interaction
mberr Nov 24, 2020
a15d328
use tensor product
mberr Nov 24, 2020
19c54c2
Extract common utility for elementwise tensor combinations
mberr Nov 24, 2020
856b841
Add elementwise combination optimization
mberr Nov 24, 2020
74870b5
Add tests for broadcasted combination
mberr Nov 24, 2020
a97c84b
Add mark.slow
mberr Nov 24, 2020
15b368c
Add two implementation variants of complex to functional
mberr Nov 24, 2020
6428ab5
Add another complex variant
mberr Nov 24, 2020
0662725
Add cost estimation to rotate forward
mberr Nov 24, 2020
f2cbd0a
Add draft of functional interaction function benchmark
mberr Nov 24, 2020
3f6a1e3
change benchmark script
mberr Nov 24, 2020
d924126
Fix variant creation
mberr Nov 24, 2020
bf819fa
Store results better
cthoyt Nov 24, 2020
ce69e86
Cleaner generation of stateless interaction class
cthoyt Nov 24, 2020
a3100a9
Update interaction benchmark
cthoyt Nov 24, 2020
86875ab
Add more options
mberr Nov 24, 2020
b8d7b2b
Add gpu support to kernel benchmark
mberr Nov 24, 2020
7290163
Pass flake8
cthoyt Nov 25, 2020
90b55fc
Add docs stubs
cthoyt Nov 25, 2020
3053217
Pass mypy
cthoyt Nov 25, 2020
d9f93bc
Merge branch 'master' into add_interaction_function_2
cthoyt Nov 25, 2020
1f5488d
Merge branch 'master' into add_interaction_function_2
cthoyt Nov 25, 2020
20281cf
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Nov 25, 2020
b9000a3
Delete unused variable
mberr Nov 25, 2020
ecf12a9
extract common computation for KG2E similarities
mberr Nov 25, 2020
b51adea
Pass h/t down to KG2E similarity functions
mberr Nov 25, 2020
051e24d
Add docstring
mberr Nov 25, 2020
1079860
Add tensor_sum / tensor_product to __all__
mberr Nov 25, 2020
2db70f6
Do not use tensor_sum when it is clear that all summands are in the s…
mberr Nov 25, 2020
6e5eecc
Do not use tensor_sum when it is clear that all summands are in the s…
mberr Nov 25, 2020
f704595
short-cut tensor_sum for less than 3 tensors
mberr Nov 25, 2020
88e02dc
Add benchmarking script for tensor_sum
mberr Nov 25, 2020
328fb59
Remove unused method
mberr Nov 25, 2020
3eeb541
Update utils benchmark with all shape combinations of tensorsum
mberr Nov 25, 2020
a8bfd73
Add click
mberr Nov 25, 2020
bcc4a34
fixes
mberr Nov 25, 2020
02e2293
Do not use tensor_sum for 2-element list
mberr Nov 26, 2020
02fb26a
Add caching test for get_optimal_sequence
mberr Nov 26, 2020
0d6c7df
Fix test name
mberr Nov 26, 2020
bd04e55
Add additional heuristic for score_hrt case
mberr Nov 26, 2020
041a8f0
do not benchmark two-element tensor_sums
mberr Nov 26, 2020
f6c07e7
Remove amo from model
cthoyt Nov 30, 2020
41e9dfb
More manual amo merge
cthoyt Nov 30, 2020
bc7b116
Cleanup
cthoyt Nov 30, 2020
73a28a6
Merge remote-tracking branch 'origin/master' into add_interaction_fun…
cthoyt Nov 30, 2020
8719441
Pass lint
cthoyt Nov 30, 2020
eea971e
Fix ER
mberr Nov 30, 2020
3791a35
Adapt unittests for KL div similarity
mberr Nov 30, 2020
46d8096
add test in test
mberr Nov 30, 2020
58128f8
reformat
mberr Nov 30, 2020
83363aa
reformat
mberr Nov 30, 2020
2023a88
Move debug implementation to sim.py
mberr Nov 30, 2020
95b655d
Fix bug in KL
mberr Nov 30, 2020
01af212
Remove unused variables
mberr Nov 30, 2020
aac4bf5
Fix flake8
mberr Nov 30, 2020
3851f04
Fix unittest
mberr Nov 30, 2020
08c68e6
Add skip_cls
mberr Nov 30, 2020
acda8e2
Fix unittest for RotatEInteraction
mberr Nov 30, 2020
79a2410
circumvent accumulate(..., initial=x) since it is unavailable in Pyth…
mberr Nov 30, 2020
a42ce99
More unittesting
cthoyt Nov 30, 2020
4b8b65a
Fix mocking in test
cthoyt Dec 1, 2020
33dbe66
Use autorange instead of fixed number of samples
mberr Dec 1, 2020
caba4df
Merge remote-tracking branch 'origin/master' into add_interaction_fun…
mberr Dec 1, 2020
c0d7a67
Merge remote-tracking branch 'origin/master' into add_interaction_fun…
mberr Dec 10, 2020
2c66dfd
Fix imports for triples_factory.py
mberr Dec 10, 2020
2371537
fix more imports
mberr Dec 10, 2020
d551b1e
Fix ComplEx constructor merge
mberr Dec 10, 2020
9a18c9d
More post-merge fixes
mberr Dec 10, 2020
e94bdc4
Extract literal interaction wrapper
mberr Dec 10, 2020
60ecdb6
fix complex
mberr Dec 10, 2020
2b206b2
Merge branch 'master' into add_interaction_function_2
cthoyt Dec 13, 2020
d282f98
Add work load generation for fast SLCWA
mberr Dec 15, 2020
c6a60fa
Add docstring
mberr Dec 15, 2020
9abaecc
Move utility method
mberr Dec 15, 2020
bc07b12
Cleanup interaction benchmark script
mberr Dec 15, 2020
4d21e79
add another complex variant
mberr Dec 15, 2020
8a61286
add fast flag
mberr Dec 15, 2020
5d5d056
add shuffle option for more reliable meta time estimates
mberr Dec 15, 2020
c3c697b
Also measure memory
mberr Dec 15, 2020
abe6b43
Remove viz
mberr Dec 15, 2020
db48aac
fix variant
mberr Dec 15, 2020
010bc7b
Use single progress bar
mberr Dec 15, 2020
c51c621
Extend docstring
mberr Dec 15, 2020
447e185
Add einsum variant
mberr Dec 15, 2020
bcf20d0
Fix device error in KL divergence
mberr Dec 15, 2020
e2863de
Do not use root logger
mberr Dec 15, 2020
03d8b2a
Update tox.ini
cthoyt Dec 15, 2020
a59fc7b
Pass flake8
cthoyt Dec 15, 2020
b3502e7
Update functional.py
cthoyt Dec 15, 2020
3a8c9a7
Fix tucker without explicit relation dim
mberr Dec 15, 2020
601de95
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Dec 15, 2020
fb2d3c8
Merge branch 'master' into add_interaction_function_2
cthoyt Dec 15, 2020
392480a
Pass mypy
cthoyt Dec 15, 2020
19a2ec1
Bring back configurable regularization
cthoyt Dec 16, 2020
44d156d
Update pipeline.py
cthoyt Dec 16, 2020
65f36be
Now I remember why I did this :laughing:
cthoyt Dec 16, 2020
1904672
add another complex variant
mberr Dec 16, 2020
7a63cb8
Merge remote-tracking branch 'origin/add_interaction_function_2' into…
mberr Dec 16, 2020
2bc1bec
Update classvars
cthoyt Dec 16, 2020
14d463e
More types
cthoyt Dec 16, 2020
8982553
More typing
cthoyt Dec 16, 2020
b404b84
Add size information data class
mberr Dec 17, 2020
3d941a9
use utility for ER-MLP
mberr Dec 17, 2020
ca234af
Update docstring
mberr Dec 17, 2020
7d137c7
Move variants to module
mberr Dec 17, 2020
c9b240e
extract common code
mberr Dec 17, 2020
3bf1ab9
Code cleanup
cthoyt Dec 17, 2020
e3ca9bf
Update test_models.py
cthoyt Dec 17, 2020
675eeac
More cleanup
cthoyt Dec 17, 2020
ff9a9db
Update losses.py
cthoyt Dec 17, 2020
ebdddbd
Update typing.py
cthoyt Dec 17, 2020
bf63db3
Update losses.py
cthoyt Dec 17, 2020
c988b49
Update tests.yml
cthoyt Dec 17, 2020
79b2187
Update test_utils.py
cthoyt Jan 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/reference/models.rst
Expand Up @@ -20,3 +20,15 @@ Extra Modules
-------------
.. automodule:: pykeen.nn
:members:

Interactors
-----------
Functional Interface
~~~~~~~~~~~~~~~~~~~~
.. automodule:: pykeen.nn.functional
:members:

Module Interface
~~~~~~~~~~~~~~~~
.. automodule:: pykeen.nn.modules
:members:
14 changes: 10 additions & 4 deletions src/pykeen/models/__init__.py
Expand Up @@ -6,10 +6,12 @@
score value is model-dependent, and usually it cannot be directly interpreted as a probability.
""" # noqa: D205, D400

import inspect
from typing import Mapping, Set, Type, Union

from .base import EntityEmbeddingModel, EntityRelationEmbeddingModel, Model, MultimodalModel # noqa:F401
from .base import ( # noqa:F401
EntityEmbeddingModel, EntityRelationEmbeddingModel, Model,
MultimodalModel, SimpleVectorEntityRelationEmbeddingModel,
)
from .multimodal import ComplExLiteral, DistMultLiteral
from .unimodal import (
ComplEx,
Expand Down Expand Up @@ -64,10 +66,14 @@
'get_model_cls',
]

_CONCRETE_BASES = {
SimpleVectorEntityRelationEmbeddingModel,
}


def _concrete_subclasses(cls):
def _concrete_subclasses(cls: Type[Model]):
for subcls in cls.__subclasses__():
if not inspect.isabstract(subcls):
if not subcls._is_abstract() and subcls not in _CONCRETE_BASES:
yield subcls
yield from _concrete_subclasses(subcls)

Expand Down
308 changes: 265 additions & 43 deletions src/pykeen/models/base.py

Large diffs are not rendered by default.

75 changes: 6 additions & 69 deletions src/pykeen/models/unimodal/complex.py
Expand Up @@ -4,22 +4,21 @@

from typing import Optional

import torch
import torch.nn as nn

from ..base import EntityRelationEmbeddingModel
from ..base import SimpleVectorEntityRelationEmbeddingModel
from ...losses import Loss, SoftplusLoss
from ...nn.modules import ComplExInteractionFunction
from ...regularizers import LpRegularizer, Regularizer
from ...triples import TriplesFactory
from ...typing import DeviceHint
from ...utils import split_complex

__all__ = [
'ComplEx',
]


class ComplEx(EntityRelationEmbeddingModel):
class ComplEx(SimpleVectorEntityRelationEmbeddingModel):
r"""An implementation of ComplEx [trouillon2016]_.

ComplEx is an extension of :class:`pykeen.models.DistMult` that uses complex valued representations for the
Expand Down Expand Up @@ -94,8 +93,11 @@ def __init__(
:param regularizer: BaseRegularizer
The regularizer to use.
"""
interaction_function = ComplExInteractionFunction()

super().__init__(
triples_factory=triples_factory,
interaction_function=interaction_function,
embedding_dim=2 * embedding_dim, # complex embeddings
automatic_memory_optimization=automatic_memory_optimization,
loss=loss,
Expand All @@ -107,68 +109,3 @@ def __init__(
entity_initializer=nn.init.normal_,
relation_initializer=nn.init.normal_,
)

@staticmethod
def interaction_function(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Evaluate the interaction function of ComplEx for given embeddings.

The embeddings have to be in a broadcastable shape.

:param h:
Head embeddings.
:param r:
Relation embeddings.
:param t:
Tail embeddings.

:return: shape: (...)
The scores.
"""
# split into real and imaginary part
(h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)]

# ComplEx space bilinear product
# *: Elementwise multiplication
return sum(
(hh * rr * tt).sum(dim=-1)
for hh, rr, tt in [
(h_re, r_re, t_re),
(h_re, r_im, t_im),
(h_im, r_re, t_im),
(h_im, r_im, t_re),
]
)

def forward(
self,
h_indices: Optional[torch.LongTensor],
r_indices: Optional[torch.LongTensor],
t_indices: Optional[torch.LongTensor],
) -> torch.FloatTensor:
"""Unified score function."""
# get embeddings
h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices)
r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices)
t = self.entity_embeddings.get_in_canonical_shape(indices=t_indices)

# Regularization
self.regularize_if_necessary(h, r, t)

# Compute scores
return self.interaction_function(h=h, r=r, t=t)

def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1)

def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None)

def score_r(self, ht_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self(h_indices=ht_batch[:, 0], r_indices=None, t_indices=ht_batch[:, 1])

def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
return self(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1])