New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
💃 🥤 Extract interaction function from models #107
Conversation
TODO: What to do with interaction models where we have more than one vector for an entity/relation, such as e.g. TransH? |
@cthoyt I used the interaction function abstraction to extract a common base class for DistMult, Complex and ER-MLP, since they all share that
The class name is preliminary |
@mberr |
To quote from this paper 😉
The idea is to separate storing the entity/relation representations from the interaction. This support use cases, where we want to use an interaction model as a component, rather than a stand-alone model, e.g. since we have features for entities, or another component which generates representations (e.g. based on GNNs, but we could also think about e.g. NLP models on the entities' labels, etc.). In the long term this might also help to generalize e.g. the regularization choices (right now, each Model defines which parameters are regularized, and which not), or constraints, such as unit length, orthogonality, etc. It does however also help for R-GCN 🙂 |
* Add base module for representations * Add reset_parameters * Add wrapper for nn.Embedding * Implement forward * Add properties * Use Embedding for EntityRelationEmbedding * Replace old factory method * Reorganize * Revert for now * Make initialization functions work * Update utils.py * Reorganize again * Fix docs * Update base.py * Fix initialization * Update RepresentationModel.forward API * Add deprecation warning * Add post_parameter_update * Add constrainer * Forward kwargs from EntityRelationEmbeddingModel * Use Embedding.init_with_device for EntityEmbeddingModel * Use new-style Embeddings for DistMult * Fix delegation of post_parameter_update * Update code style * Add unittest to check whether the model can handle custom representations * Pass flake8 * New implemntations of init functions * Fix attribute names * Add first round of upgrades for _reset_parameters_ * Switch more over * Update utils.py * Add TODOs [skip ci] * Make more general * Enable more kwargs * Fix up KG2E * Fix unittest * Simplify get_embedding_in_canonical_shape * Directly use Embedding.init_with_device * Replace deprecated usage for ConvE * Use new-style embeddings for TransD * Add in-place norm clamping * Fix unittest for TransD; remove manual score comparison * Fix docs * Forward kwargs * More generic types * More generic types, now as TypeVar * rename parameters * fix renamed parameters * Add comments * Remove deprecated .weight property * Fix ConvE's weight usage * Fix .weight usage in test_save_load_model_state * reduce code duplication in test * do not require in-place initializer * Remove manual unittest for TransR * Fix .weight usage in test_models.py * Remove post_parameter_update tests after embeddings have been exchanged * Add TODO to complex * use indices * Remove no_grad annotation, since it was causing problems with general gradient tracking * cleanup conve * cleanup convkb * cleanup distmult * Use keyword parameter, fix .weight * Use keyword parameter, fix .weight * Add constrainer, fix .weight * fix .weight * fix .weight * fix .weight, use initializer and constrainer * Directly use Embedding.init_with_device; fix type annotation * use new-style embeddings * use *_kwargs instead of functools.partial * Use new-style embeddings in TransE * Use new-style embeddings in TransR * Use new-style embeddings in Tucker * Use new-style embeddings in UM * Add chain_ and normalize_ utilities * Fix chain_ * Fix complex special case for hrt * fix custom representation unittest * Add missing calls to reset_parameters * Fix tests for get_embedding_in_canonical_shape * Use magicmock to check whether reset_parameters was called * hotfix RGCN nice fixes need #107 #110 * Fix block decomposition * Fix TransR * Add trailing comma * Fix line too long * Add trailing comma * Fix line-too-long * Fix wrong comparison * Remove trailing comma * fix line too long * fix line too long * fix line too long * Cosmetic improvements and remove code graveyard * Update docs [skip ci] * Extend documentation * Remove in-place variants * remove xavier_uniform_normed_ * Remove todos * fix line-too-long * Fix docstring * Remove unused imports * Revert removal of manual tests for score_hrt * Avoid overriding test_score_hrt with manual tests * fix manual test_hrt for TransR * fix manual test_hrt for TransD * Fix issue in TransR test Co-authored-by: Charles Tapley Hoyt <cthoyt@gmail.com>
@cthoyt I extracted a few functional form interaction functions, cf. https://github.com/pykeen/pykeen/compare/a735593..f0c20d6 I wait for feedback before I continue with other interaction functions. It is nice, since we e.g. could extract some common parts between distmult and complex in |
|
||
cls = pykeen.nn.modules.SimplEInteraction | ||
|
||
def _exp_score(self, h, r, t, h_inv, r_inv, t_inv, clamp) -> torch.FloatTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right? I thought SimplE's interaction function was
0.5 * (distmult_interaction(h, r, t) + dismult_interaction(t_inv, r_inf, h_inv)
The h and t get switched for the inverse one
@cthoyt did we finish the migration? If yes, we may close this one. |
@mberr well, we have the major code for RGCN (might be fine now), literal models (already in the new PR), and all of the examples of the new-style models for all of the pre-existing uni-modal models |
oops my bad, hand slipped |
@mberr so all of the big parts have been migrated. If we ever want to start converting old-style models, we can refer to this, but it's finally time to let this one go. |
This is a replacement for #88 , where the merge target is
master
.Overview
score_hrt
/score_h
/score_r
/score_t
in the base class, done inInteractionFunction
._score
for all models sharing the same set of embeddings (e.g. TransE/DistMult/ERMLP -> one vector for each entity/relation, TransH -> additional vector for each entity, etc.)pykeen.nn.modules
.pykeen.nn.functional
.Tasks:
Interaction
[ ] Update pipeline model compositionbumped to 🪢 🤗 Expose interactions and representations via pipeline #163Dependencies: