Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🦜 🏴‍☠️ Implement NodePiece representation and model #621

Merged
merged 93 commits into from
Nov 15, 2021

Conversation

mberr
Copy link
Member

@mberr mberr commented Nov 8, 2021

This is a first draft to add NodePiece representations to pykeen.

For now, it uses a simple variant of it, where each entity is represented by k randomly chosen incident relations.

One_Piece,Volume_61_Cover(Japanese)-1

@mberr mberr changed the title Node Piece Repreentation Node Piece Representation Nov 8, 2021
@mberr
Copy link
Member Author

mberr commented Nov 8, 2021

@migalkin would be great to have your feedback, as you may be familiar with it 😉

@cthoyt
Copy link
Member

cthoyt commented Nov 8, 2021

Can we have a demo on how you would use this representation with a model? Like can we easily implement a TransE with NodePiece?

@mberr
Copy link
Member Author

mberr commented Nov 8, 2021

Can we have a demo on how you would use this representation with a model? Like can we easily implement a TransE with NodePiece?

Sure.

from typing import Optional

from class_resolver.api import HintOrType

from pykeen.models.nbase import ERModel
from pykeen.nn.emb import EmbeddingSpecification, NodePieceRepresentation
from pykeen.nn.modules import Interaction, TransEInteraction
from pykeen.pipeline import pipeline
from pykeen.triples.triples_factory import CoreTriplesFactory


class NodePieceModel(ERModel):
    def __init__(
        self,
        *,
        triples_factory: CoreTriplesFactory,
        embedding_specification: Optional[EmbeddingSpecification] = None,
        interaction: HintOrType[Interaction] = TransEInteraction,
        **kwargs,
    ) -> None:
        if embedding_specification is None:
            embedding_specification = EmbeddingSpecification(
                shape=(64,),
            )
        entity_representations = NodePieceRepresentation(
            triples_factory=triples_factory,
            token_representation=embedding_specification,
        )
        super().__init__(
            triples_factory=triples_factory,
            interaction=interaction,
            entity_representations=entity_representations,
            relation_representations=embedding_specification,
            **kwargs,
        )


result = pipeline(
    dataset="nations",
    model=NodePieceModel,
    model_kwargs=dict(
        interaction_kwargs=dict(
            p=2,
        ),
    ),
)
print(result.get_metric("hits_at_10"))

EDIT: added in 6b29c4e

src/pykeen/nn/emb.py Outdated Show resolved Hide resolved
src/pykeen/nn/emb.py Outdated Show resolved Hide resolved
src/pykeen/nn/emb.py Outdated Show resolved Hide resolved
src/pykeen/nn/emb.py Outdated Show resolved Hide resolved
trigger ci
@mberr
Copy link
Member Author

mberr commented Nov 14, 2021

fe90b42 - @cthoyt this is not really part of this PR

@@ -931,7 +931,7 @@ def test_score_t(self) -> None:
try:
scores = self.instance.score_t(batch)
except NotImplementedError:
self.fail(msg="Score_o not yet implemented")
self.fail(msg="score_t not yet implemented")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this typo is not really part of the PR

@@ -950,7 +968,7 @@ def test_score_h(self) -> None:
try:
scores = self.instance.score_h(batch)
except NotImplementedError:
self.fail(msg="Score_s not yet implemented")
self.fail(msg="score_h not yet implemented")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@cthoyt
Copy link
Member

cthoyt commented Nov 15, 2021

Looks like the issues are now with ConvE's tests

@migalkin
Copy link
Member

Looks like the issues are now with ConvE's tests

I was running the debugger for the ConvE test and for some reason after initialization of TestConvE(cases.ModelTestCase) it goes on to initialize NodePiece although it is not related anyhow to the ConvE test 🤔

Copy link
Member

@migalkin migalkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation works 🎉
Exposing the ratio param from the MLP encoder sounds like a good idea (with the default value 2), otherwise everything looks ready!

Comment on lines 110 to 111
:func:`torch.max`, or even trainable aggregations e.g., ``MLP(mean(MLP(tokens)))``
(cf. DeepSets from [zaheer2017]_) if given value ``"mlp"``.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current _ConcatMLP is not DeepSets :)

The idea of DeepSets is to project each set member independently through some encoder, then aggregate (like with mean) and then pass through another FF net. It would look like this:

enc1 = nn.Sequential(
    nn.Linear(embedding_dim, embedding_dim),
    nn.ReLU(),
    nn.Linear(embedding_dim, embedding_dim)
)

enc2 = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.ReLU(), nn.Linear(embedding_dim, output_dim))

and in forward pass:

# x: shape (bs, num_elements, embedding_dim)
x = enc1(x)               # the same shape (bs, num_elements, embedding_dim)
x = torch.mean(-2)  # here we do the aggregation to (bs, embedding_dim)
x = enc2(x)               # final projection keeping (bs, output_dim)

It can be added as an option along with mlp though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct docstring somehow got lost during the refactoring 😅

here it was still correct: #621 (comment)

@mberr
Copy link
Member Author

mberr commented Nov 15, 2021

Looks like the issues are now with ConvE's tests

I think the issue is that we did not yet think about what should be scored in score_r, if we have inverse relations. In the baseline implementation in _OldAbstractModel, we use relation_ids to create the individual hrt triples, which does not contain the inverse relations, but only the "real" ones.

We could either:

  1. skip the score_r test for models with inverse triples
  2. decide on how to handle this generally

To keep this PR focused on one thing, I tend towards option 1.

@cthoyt
Copy link
Member

cthoyt commented Nov 15, 2021

Yes let’s bump this. So let’s override the test in conve to be skipped and leave a todo for later

it is not yet clear what would be the desired output shape
trigger ci
@mberr mberr marked this pull request as ready for review November 15, 2021 10:08
Copy link
Member

@cthoyt cthoyt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@mberr mberr merged commit 9837077 into master Nov 15, 2021
@mberr mberr deleted the node-pieces branch November 15, 2021 12:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants