Skip to content

Commit

Permalink
Implement the ComplEx knowledge graph algorithm (#848)
Browse files Browse the repository at this point in the history
This implements the ComplEx knowledge graph algorithm from: Complex Embeddings
for Simple Link Prediction; Théo Trouillon, Johannes Welbl, Sebastian Riedel,
Éric Gaussier and Guillaume Bouchard, ICML 2016. 
http://jmlr.org/proceedings/papers/v48/trouillon16.pdf

This comes in three parts:

1. a `KGTripleGenerator` where the `flow` method takes a list of `(source, edge
   type, target)` triples (aka `(subject, relation, object)`) and yields them as
   batches of sequential integers to use to look-up in embedding layers
2. the `ComplEx` model which creates the appropriate embedding layers and
   applies the ComplEx scoring mechanism to them (which gives logit
   "probabilities" for a link)
3. some minor changes to `StellarGraph` and its supporting code, so that passing
   in an unknown node gets an appropriate error

This doesn't add a notebook that reproduces the paper, because the datasets used
in it are new to the library, and the MRR metrics used are awkward to
compute. This can/should be future work (#862). However it does add tests,
include one that does simple and manual computation of the scores for links, to
validate that it matches the paper's description.

More future work: There's a whole class of knowledge graph algorithms that have
a similar approach (e.g. DistMult #755), just changing the details of the
embeddings (basically real or complex, matrix or vector) and the scoring
mechanism. The `KGTripleGenerator` class can be easily reused for them, and
there's likely some commonalities that we can factor out of the `ComplEx` model
to share code. In the interests of YAGNI and waiting to choose the right
generalisation, this PR does not try to generalise the `ComplEx` model yet. The
ComplEx paper has a good summary of these similar algorithms in Table 1 on page
3; it includes algorithms like TransE, DistMult and RESCAL.

See: #756
  • Loading branch information
huonw committed Feb 19, 2020
1 parent 9454a01 commit c13a988
Show file tree
Hide file tree
Showing 12 changed files with 737 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Some new algorithms and features are still under active development, and are available as an experimental preview. However, they may not be easy to use: their documentation or testing may be incomplete, and they may change dramatically from release to release. The experimental status is noted in the documentation and at runtime via prominent warnings.

- Watch Your Step: computes node embeddings by simulating the effect of random walks, rather than doing them. [\#750](https://github.com/stellargraph/stellargraph/pull/750). The implementation is not fully tested.
- ComplEx: computes embeddings for nodes and edge types in knowledge graphs, and use these to perform link prediction [\#756](https://github.com/stellargraph/stellargraph/issues/756). The implementation hasn't been validated to match the paper.

### Bug fixes and other changes

Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ The StellarGraph library currently includes the following algorithms for graph m
| Node2Vec [2] | The Node2Vec and Deepwalk algorithms perform unsupervised representation learning for homogeneous networks, taking into account network structure while ignoring node attributes. The node2vec algorithm is implemented by combining StellarGraph's random walk generator with the word2vec algorithm from [Gensim](https://radimrehurek.com/gensim/). Learned node representations can be used in downstream machine learning models implemented using [Scikit-learn](https://scikit-learn.org/stable/), [Keras](https://keras.io/), [Tensorflow](https://www.tensorflow.org/) or any other Python machine learning library. |
| Metapath2Vec [3] | The metapath2vec algorithm performs unsupervised, metapath-guided representation learning for heterogeneous networks, taking into account network structure while ignoring node attributes. The implementation combines StellarGraph's metapath-guided random walk generator and [Gensim](https://radimrehurek.com/gensim/) word2vec algorithm. As with node2vec, the learned node representations (node embeddings) can be used in downstream machine learning models to solve tasks such as node classification, link prediction, etc, for heterogeneous networks. |
| Relational Graph Convolutional Network [11] | The RGCN algorithm performs semi-supervised learning for node representation and node classification on knowledge graphs. RGCN extends GCN to directed graphs with multiple edge types and works with both sparse and dense adjacency matrices.|
| ComplEx[12] | The ComplEx algorithm computes embeddings for nodes (entities) and edge types (relations) in knowledge graphs, and can use these for link prediction |


## Getting Help
Expand Down Expand Up @@ -258,3 +259,5 @@ International Conference on Machine Learning (ICML), 2019. ([link](https://arxiv


11. Modeling relational data with graph convolutional networks. M. Schlichtkrull, T. N. Kipf, P. Bloem, R. Van Den Berg, I. Titov, and M. Welling, European Semantic Web Conference (2018), arXiv:1609.02907 ([link](https://arxiv.org/abs/1703.06103)).

12. Complex Embeddings for Simple Link Prediction. T. Trouillon, J. Welbl, S. Riedel, É. Gaussier and G. Bouchard, ICML 2016. ([link](http://jmlr.org/proceedings/papers/v48/trouillon16.pdf))
9 changes: 9 additions & 0 deletions docs/api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ Watch Your Step model
:members: WatchYourStep


Knowledge Graph models
----------------------

.. automodule:: stellargraph.mapper.knowledge_graph
:members: KGTripleGenerator

.. automodule:: stellargraph.layer.knowledge_graph
:members: ComplEx, ComplExScore

Link prediction layers
------------------------

Expand Down
17 changes: 16 additions & 1 deletion stellargraph/core/element_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,35 @@ def is_valid(self, ilocs: np.ndarray) -> np.ndarray:
"""
return (0 <= ilocs) & (ilocs < len(self))

def to_iloc(self, ids, smaller_type=True) -> np.ndarray:
def require_valid(self, query_ids, ilocs: np.ndarray) -> np.ndarray:
valid = self.is_valid(ilocs)

if not valid.all():
missing_values = np.asarray(query_ids)[~valid]

if len(missing_values) == 1:
raise KeyError(missing_values[0])

raise KeyError(missing_values)

def to_iloc(self, ids, smaller_type=True, strict=False) -> np.ndarray:
"""
Convert external IDs ``ids`` to integer locations.
Args:
ids: a collection of external IDs
smaller_type: if True, convert the ilocs to the smallest type that can hold them, to reduce storage
strict: if True, check that all IDs are known and throw a KeyError if not
Returns:
A numpy array of the integer locations for each id that exists, with missing IDs
represented by either the largest value of the dtype (if smaller_type is True) or -1 (if
smaller_type is False)
"""
internal_ids = self._index.get_indexer(ids)
if strict:
self.require_valid(ids, internal_ids)

# reduce the storage required (especially useful if this is going to be stored rather than
# just transient)
if smaller_type:
Expand Down
25 changes: 6 additions & 19 deletions stellargraph/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,15 +554,6 @@ def nodes_of_type(self, node_type=None):
ilocs = self._nodes.type_range(node_type)
return list(self._nodes.ids.from_iloc(ilocs))

def _key_error_for_missing(self, query_ids, node_ilocs):
valid = self._nodes.ids.is_valid(node_ilocs)
missing_values = np.asarray(query_ids)[~valid]

if len(missing_values) == 1:
return KeyError(missing_values[0])

return KeyError(missing_values)

def node_type(self, node):
"""
Get the type of the node
Expand All @@ -574,11 +565,8 @@ def node_type(self, node):
Node type
"""
nodes = [node]
node_ilocs = self._nodes.ids.to_iloc(nodes)
try:
type_sequence = self._nodes.type_of_iloc(node_ilocs)
except IndexError:
raise self._key_error_for_missing(nodes, node_ilocs)
node_ilocs = self._nodes.ids.to_iloc(nodes, strict=True)
type_sequence = self._nodes.type_of_iloc(node_ilocs)

assert len(type_sequence) == 1
return type_sequence[0]
Expand Down Expand Up @@ -672,10 +660,9 @@ def node_features(self, nodes, node_type=None):
# FIXME: None as a sentinel forces nodes to have dtype=object even with integer IDs, could
# instead use an impossible integer (e.g. 2**64 - 1)

nones = nodes == None
if not (nones | valid).all():
# every ID should be either valid or None, otherwise it was a completely unknown ID
raise self._key_error_for_missing(nodes[~nones], node_ilocs[~nones])
# everything that's not the sentinel should be valid
non_nones = nodes != None
self._nodes.ids.require_valid(nodes[non_nones], node_ilocs[non_nones])

sampled = self._nodes.features(node_type, valid_ilocs)
features = np.zeros((len(nodes), sampled.shape[1]))
Expand Down Expand Up @@ -978,7 +965,7 @@ def _get_index_for_nodes(self, nodes, node_type=None):
Returns:
Numpy array containing the indices for the requested nodes.
"""
return self._nodes._id_index.to_iloc(nodes)
return self._nodes._id_index.to_iloc(nodes, strict=True)

def _adjacency_types(self, graph_schema: GraphSchema):
"""
Expand Down
1 change: 1 addition & 0 deletions stellargraph/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from .preprocessing_layer import GraphPreProcessingLayer
from .rgcn import *
from .watch_your_step import *
from .knowledge_graph import *
193 changes: 193 additions & 0 deletions stellargraph/layer/knowledge_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
#
# Copyright 2020 Data61, CSIRO
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import activations, initializers, constraints, regularizers
from tensorflow.keras.layers import Input, Layer, Lambda, Dropout, Reshape, Embedding

from ..mapper.knowledge_graph import KGTripleGenerator
from ..core.experimental import experimental


class ComplExScore(Layer):
"""
ComplEx scoring Keras layer.
Original Paper: Complex Embeddings for Simple Link Prediction, Théo Trouillon, Johannes Welbl,
Sebastian Riedel, Éric Gaussier and Guillaume Bouchard, ICML
2016. http://jmlr.org/proceedings/papers/v48/trouillon16.pdf
This combines subject, relation and object embeddings into a score of the likelihood of the
link.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def build(self, input_shape):
self.built = True

def call(self, inputs):
"""
Applies the layer.
Args:
inputs: a list of 6 tensors (each batch size x embedding dimension k), where the three
consecutive pairs represent real and imaginary parts of the subject, relation and
object embeddings, respectively, that is, ``inputs == [Re(subject), Im(subject),
Re(relation), ...]``
"""
s_re, s_im, r_re, r_im, o_re, o_im = inputs

def inner(r, s, o):
return tf.reduce_sum(r * s * o, axis=2)

# expansion of Re(<w_r, e_s, conjugate(e_o)>)
score = (
inner(r_re, s_re, o_re)
+ inner(r_re, s_im, o_im)
+ inner(r_im, s_re, o_im)
- inner(r_im, s_im, o_re)
)

return score


@experimental(
reason="results from the reference paper have not been reproduced yet", issues=[862]
)
class ComplEx:
"""
Embedding layers and a ComplEx scoring layers that implement the ComplEx knowledge graph
embedding algorithm as in http://jmlr.org/proceedings/papers/v48/trouillon16.pdf
Args:
generator (KGTripleGenerator): A generator of triples to feed into the model.
k (int): the dimension of the embedding (that is, a vector in C^k is learnt for each node
and each link type)
embeddings_initializer (str or func, optional): The initialiser to use for the embeddings
(the default of random normal values matches the paper's reference implementation).
embeddings_regularizer (str or func, optional): The regularizer to use for the embeddings.
"""

def __init__(
self,
generator,
k,
embeddings_initializer="normal",
embeddings_regularizer=None,
):
if not isinstance(generator, KGTripleGenerator):
raise TypeError(
f"generator: expected KGTripleGenerator, found {type(generator).__name__}"
)

graph = generator.G
self.num_nodes = graph.number_of_nodes()
self.num_edge_types = len(graph._edges.types)
self.k = k
self.embeddings_initializer = initializers.get(embeddings_initializer)
self.embeddings_regularizer = regularizers.get(embeddings_regularizer)

# layer names
_NODE_REAL = "COMPLEX_NODE_REAL"
_NODE_IMAG = "COMPLEX_NODE_IMAG"

_REL_REAL = "COMPLEX_EDGE_TYPE_REAL"
_REL_IMAG = "COMPLEX_EDGE_TYPE_IMAG"

@staticmethod
def embeddings(model):
"""
Retrieve the embeddings for nodes/entities and edge types/relations in the given model.
Args:
model (tensorflow.keras.Model): a Keras model created using a ``ComplEx`` instance.
Returns:
A tuple of numpy complex arrays: the first element is the embeddings for nodes/entities
(``shape = number of nodes × k``), the second element is the embeddings for edge
types/relations (``shape = number of edge types x k``).
"""
node = 1j * model.get_layer(ComplEx._NODE_IMAG).embeddings.numpy()
node += model.get_layer(ComplEx._NODE_REAL).embeddings.numpy()

rel = 1j * model.get_layer(ComplEx._REL_IMAG).embeddings.numpy()
rel += model.get_layer(ComplEx._REL_REAL).embeddings.numpy()

return node, rel

def _embed(self, count, name):
return Embedding(
count,
self.k,
name=name,
embeddings_initializer=self.embeddings_initializer,
embeddings_regularizer=self.embeddings_regularizer,
)

def __call__(self, x):
"""
Apply embedding layers to the source, relation and object input "ilocs" (sequential integer
labels for the nodes and edge types).
Args:
x (list): list of 3 tensors (each batch size x 1) storing the ilocs of the subject,
relation and object elements for each edge in the batch.
"""
s_iloc, r_iloc, o_iloc = x

# ComplEx generates embeddings in C, which we model as separate real and imaginary
# embeddings
node_embeddings_real = self._embed(self.num_nodes, self._NODE_REAL)
node_embeddings_imag = self._embed(self.num_nodes, self._NODE_IMAG)
edge_type_embeddings_real = self._embed(self.num_edge_types, self._REL_REAL)
edge_type_embeddings_imag = self._embed(self.num_edge_types, self._REL_IMAG)

s_re = node_embeddings_real(s_iloc)
s_im = node_embeddings_imag(s_iloc)

r_re = edge_type_embeddings_real(r_iloc)
r_im = edge_type_embeddings_imag(r_iloc)

o_re = node_embeddings_real(o_iloc)
o_im = node_embeddings_imag(o_iloc)

scoring = ComplExScore()

return scoring([s_re, s_im, r_re, r_im, o_re, o_im])

def build(self):
"""
Builds a ComplEx model.
Returns:
A tuple of (list of input tensors, tensor for ComplEx model score outputs)
"""
s_iloc = Input(shape=1)
r_iloc = Input(shape=1)
o_iloc = Input(shape=1)

x_inp = [s_iloc, r_iloc, o_iloc]
x_out = self(x_inp)

return x_inp, x_out
1 change: 1 addition & 0 deletions stellargraph/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .full_batch_generators import *
from .mini_batch_node_generators import *
from .adjacency_generators import *
from .knowledge_graph import *

0 comments on commit c13a988

Please sign in to comment.