Skip to content

Commit

Permalink
鉀忥笍 馃 Update docs for using learned embeddings (#474)
Browse files Browse the repository at this point in the history
* Fix typos in documentation

* Add missing type annotation

* Add forwards compatible properties

* Update extension tutorial to use forward-compatible properties

* Pass flake8

* Finish updating docs

Trigger CI
  • Loading branch information
cthoyt committed Jun 2, 2021
1 parent 7914625 commit eb9919d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 45 deletions.
25 changes: 13 additions & 12 deletions docs/source/extending/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,16 @@ The only implementation we have to provide is of the `score_hrt` member function
class ModifiedDistMult(EntityRelationEmbeddingModel):
def score_hrt(self, hrt_batch):
# Get embeddings
h = self.entity_embeddings(hrt_batch[:, 0])
r = self.relation_embeddings(hrt_batch[:, 1])
t = self.entity_embeddings(hrt_batch[:, 2])
h = self.entity_representations[0](hrt_batch[:, 0])
r = self.relation_representations[0](hrt_batch[:, 1])
t = self.entity_representations[0](hrt_batch[:, 2])
# evaluate interaction function
return h * r.sigmoid() * t
The ``entity_embeddings`` and ``relation_embeddings`` are available for all
:class:`pykeen.models.base.EntityRelationEmbeddingModel` and are instances of
:class:`torch.nn.Embedding`.
The ``entity_representations`` and ``relation_representations`` sequences are available for all
:class:`pykeen.models.base.EntityRelationEmbeddingModel` and are lists of length one containing
a single instances of a :class:`pykeen.nn.Embedding`. This may seem like a strange data structure, but
it prepares for the much more powerful usages covered by the new-style :class:`pykeen.models.ERModel`.

The ``hrt_batch`` is a long tensor representing the internal indices of the edges.
The above example shows a very common way of slicing it to get separate lists of
Expand Down Expand Up @@ -74,9 +75,9 @@ where the value is the loss class.
loss_default = NSSALoss
def score_hrt(self, hrt_batch):
h = self.entity_embeddings(hrt_batch[:, 0])
r = self.relation_embeddings(hrt_batch[:, 1])
t = self.entity_embeddings(hrt_batch[:, 2])
h = self.entity_representations[0](hrt_batch[:, 0])
r = self.relation_representations[0](hrt_batch[:, 1])
t = self.entity_representations[0](hrt_batch[:, 2])
return h * r.sigmoid() * t
Now, when using the pipeline, the :class:`pykeen.losses.NSSALoss`. loss is used by default
Expand Down Expand Up @@ -140,9 +141,9 @@ consider:
self.linear2 = torch.nn.Linear(self.hidden_dim, self.embedding_dim)
def score_hrt(self, hrt_batch):
h = self.entity_embeddings(hrt_batch[:, 0])
r = self.relation_embeddings(hrt_batch[:, 1])
t = self.entity_embeddings(hrt_batch[:, 2])
h = self.entity_representations[0](hrt_batch[:, 0])
r = self.relation_representations[0](hrt_batch[:, 1])
t = self.entity_representations[0](hrt_batch[:, 2])
# add more transformations
h = self.linear2(self.linear1(h))
Expand Down
73 changes: 46 additions & 27 deletions docs/source/tutorial/first_steps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,56 +28,75 @@ The embeddings learned for entities and relations are not only useful for link
prediction (see :ref:`making_predictions`), but also for other downstream machine
learning tasks like clustering, regression, and classification.

The embeddings themselves are typically stored in an instance of
:class:`pykeen.nn.emb.Embedding`, which wraps the :class:`torch.nn.Embedding`
and extends the more general :class:`pykeen.nn.emb.RepresentationModule` class.
All entity representations can be accessed from a model with the following:
Knowledge graph embedding models can potentially have multiple entity representations and
multiple relation representations, so they are respectively stored as sequences in the
``entity_representations`` and ``relation_representations`` attributes of each model.
While the exact contents of these sequences are model-dependent, the first element of
each is usually the "primary" representation for either the entities or relations.

Typically, the values in these sequences are instances of the :class:`pykeen.nn.emb.Embedding`.
This implements a similar, but more powerful, interface to the built-in :class:`torch.nn.Embedding`
class. However, the values in these sequences can more generally be instances of any subclasses of
:class:`pykeen.nn.emb.RepresentationModule`. This allows for more powerful encoders those in GNNs
such as :class:`pykeen.models.RGCN` to be implemented and used.

The entity representations and relation representations can be accessed like this:

.. code-block:: python
from typing import List
import pykeen.nn
from pykeen.pipeline import pipeline
result = pipeline(model='TransE', dataset='UMLS')
model = result.model
entity_embeddings: torch.FloatTensor = model.entity_embeddings()
entity_representation_modules: List['pykeen.nn.RepresentationModule'] = model.entity_representations
relation_representation_modules: List['pykeen.nn.RepresentationModule'] = model.relation_representations
or more explicitly:
Most models, like :class:`pykeen.models.TransE`, only have one representation for entities and one
for relations. This means that the ``entity_representations`` and ``relation_representations``
lists both have a length of 1. All of the entity embeddings can be accessed like:

.. code-block:: python
entity_embeddings: torch.FloatTensor = model.entity_embeddings(indices=None)
entity_embeddings: pykeen.nn.Embedding = entity_representation_modules[0]
relation_embeddings: pykeen.nn.Embedding = relation_representation_modules[0]
If you'd like to only look up certain embeddings, you can use the ``indices`` parameter
and pass a :class:`torch.LongTensor` with their corresponding indices.
Since all representations are subclasses of :class:`torch.nn.Module`, you need to call them like functions
to invoke the `forward()` and get the values.

Some models, like :class:`pykeen.models.TransD` have more than one embedding for entities.
Old-style models (e.g., ones inheriting from :class:`pykeen.models.EntityRelationEmbeddingModel`)
define one embedding as primary (e.g., :data:`pykeen.models.TransD.entity_embedding`) and others
are considered as secondary (e.g., :data:`pykeen.models.TransD.entity_projections`).
.. code-block:: python
New-style models (e.g., ones inheriting from :class:`pykeen.models.ERModel`) are
generalized to easier allow for multiple entity representations and
relation representations. These models have two lists of entity and relation
representations respectively. You can access them via
entity_embedding_tensor: torch.FloatTensor = entity_embeddings()
relation_embedding_tensor: torch.FloatTensor = relation_embeddings()
The `forward()` function of all :class:`pykeen.nn.emb.RepresentationModule` takes an ``indices`` parameter.
By default, it is ``None`` and returns all values. More explicitly, this looks like:

.. code-block:: python
entity_representation_modules: List['pykeen.nn.Embedding'] = model.entity_representations
relation_representation_modules: List['pykeen.nn.Embedding'] = model.relation_representations
entity_embedding_tensor: torch.FloatTensor = entity_embeddings(indices=None)
relation_embedding_tensor: torch.FloatTensor = relation_embeddings(indices=None)
If you want to obtain a single representation, you can index this list then call the function
to unwrap the embeddings, e.g.
If you'd like to only look up certain embeddings, you can use the ``indices`` parameter
and pass a :class:`torch.LongTensor` with their corresponding indices.

You might want to detach them from the GPU and convert to a :class:`numpy.ndarray` with

.. code-block:: python
first_entity_representation_module: 'pykeen.nn.Embedding' = entity_representations[0]
first_entity_representations: torch.FloatTensor = first_entity_representations_module()
entity_embedding_tensor = model.entity_representations[0](indices=None).detach().numpy()
.. warning::

and treat them as before. The ordering in this list corresponds to the
ordering of representations defined in the interaction function. Some
models may provide Pythonic properties that provide a vanity attribute
to the instance of the class for a specific entity or relation representation.
Some old-style models (e.g., ones inheriting from :class:`pykeen.models.EntityRelationEmbeddingModel`)
don't fully implement the ``entity_representations`` and ``relation_representations`` interface. This means
that they might have additional embeddings stored in attributes that aren't exposed through these sequences.
For example, :class:`pykeen.models.TransD` has a secondary entity embedding in
:data:`pykeen.models.TransD.entity_projections`.
Eventually, all models will be upgraded to new-style models and this won't be a problem.

Beyond the Pipeline
-------------------
Expand Down
36 changes: 30 additions & 6 deletions src/pykeen/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import pickle
import warnings
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Iterable, Mapping, Optional, Type, Union
from typing import Any, ClassVar, Iterable, Mapping, Optional, Sequence, Type, Union

import pandas as pd
import torch
from docdata import parse_docdata
from torch import nn

from ..losses import Loss, MarginRankingLoss
from ..nn.emb import Embedding, EmbeddingSpecification
from ..nn.emb import Embedding, EmbeddingSpecification, RepresentationModule
from ..regularizers import NoRegularizer, Regularizer
from ..triples import CoreTriplesFactory
from ..typing import DeviceHint, ScorePack
Expand Down Expand Up @@ -683,7 +683,7 @@ def _free_graph_and_cache(self):
class EntityEmbeddingModel(_OldAbstractModel, ABC, autoreset=False):
"""A base module for most KGE models that have one embedding for entities."""

entity_embedding: Embedding
entity_embeddings: Embedding

def __init__(
self,
Expand Down Expand Up @@ -718,6 +718,14 @@ def embedding_dim(self) -> int: # noqa:D401
"""The entity embedding dimension."""
return self.entity_embeddings.embedding_dim

@property
def entity_representations(self) -> Sequence[RepresentationModule]: # noqa:D401
"""The entity representations.
This property provides forward compatibility with the new-style :class:`pykeen.models.ERModel`.
"""
return [self.entity_embeddings]

def _reset_parameters_(self): # noqa: D102
self.entity_embeddings.reset_parameters()

Expand All @@ -731,9 +739,9 @@ class EntityRelationEmbeddingModel(_OldAbstractModel, ABC, autoreset=False):
"""A base module for KGE models that have different embeddings for entities and relations."""

#: Primary embeddings for entities
entity_embedding: Embedding
entity_embeddings: Embedding
#: Primary embeddings for relations
relation_embedding: Embedding
relation_embeddings: Embedding

def __init__(
self,
Expand Down Expand Up @@ -774,10 +782,26 @@ def embedding_dim(self) -> int: # noqa:D401
return self.entity_embeddings.embedding_dim

@property
def relation_dim(self): # noqa:D401
def relation_dim(self) -> int: # noqa:D401
"""The relation embedding dimension."""
return self.relation_embeddings.embedding_dim

@property
def entity_representations(self) -> Sequence[RepresentationModule]: # noqa:D401
"""The entity representations.
This property provides forward compatibility with the new-style :class:`pykeen.models.ERModel`.
"""
return [self.entity_embeddings]

@property
def relation_representations(self) -> Sequence[RepresentationModule]: # noqa:D401
"""The relation representations.
This property provides forward compatibility with the new-style :class:`pykeen.models.ERModel`.
"""
return [self.relation_embeddings]

def _reset_parameters_(self): # noqa: D102
self.entity_embeddings.reset_parameters()
self.relation_embeddings.reset_parameters()
Expand Down

0 comments on commit eb9919d

Please sign in to comment.