-
-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✨🤖 Update RotatE to ERModel #877
Conversation
trigger ci
@@ -611,7 +611,7 @@ def rotate_interaction( | |||
The scores. | |||
""" | |||
# r expresses a rotation in complex plane. | |||
h, r, t = [view_complex(x) for x in (h, r, t)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we ready to replace this function with the torch builtin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
except in testing and dead code, we do not seem to reference this function anywhere else
trigger ci
src/pykeen/nn/representation.py
Outdated
@@ -347,7 +347,7 @@ def __init__( | |||
# work-around until full complex support (torch==1.10 still does not work) | |||
# TODO: verify that this is our understanding of complex! | |||
if dtype.is_complex: | |||
shape = tuple(shape[:-1]) + (2 * shape[-1],) | |||
shape = tuple(shape[:-1]) + (shape[-1], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the shape torch.view_as_complex expects, i.e., (..., 2)
@@ -661,11 +661,6 @@ def negative_norm( | |||
assert not isinstance(p, str) | |||
return -(x.abs() ** p).sum(dim=-1) | |||
|
|||
if torch.is_complex(x): | |||
assert not isinstance(p, str) | |||
# workaround for complex numbers: manually compute norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not necessary anymore in current torch
trigger ci
trigger ci
trigger ci
This PR updates
RotatE
to new-styleERModel
.It also update
Embedding
to better encapsulate complex embeddings, and updates the interaction functions of ComplEx and RotatE to operate on PyTorch native complex tensors.Benchmark
TL;DR: slightly slower evaluation (new
27.04s
vs26.74s
); more than twice as fast in training (new39.77s
vs.91.72s
)Quadro RTX 8000
pykeen experiments run sun2019_rotate_fb15k237.json
with the configuration being equal to src/pykeen/experiments/rotate/sun2019_rotate_fb15k237.json @ f47ee602, except that the number of training epochs has been reduced to
2
.master
)update-rotate-to-ermodel
)*
: withbatch_size=256
**
: withbatch_size=128