Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training loop does not update relation representations when continuing training #1357

Open
3 tasks done
lukas-schwab opened this issue Jan 11, 2024 · 2 comments
Open
3 tasks done
Labels
bug Something isn't working

Comments

@lukas-schwab
Copy link

lukas-schwab commented Jan 11, 2024

Describe the bug

When looking at the representations of TransE before and after training multiple times using the training loop with continue_training=True, only the entity representation changes. In contrast the relation representation does not change. It would be expected, that the relation representation changes just like the entity representation.

I believe this problem might exist for other models as well (e. g. DistMult)

How to reproduce

from pykeen.models import TransE, DistMult
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop
from torch.optim import Adam
from tqdm import tqdm
from typing import List
import pykeen.nn

reps = []

training_triples_factory = Nations().training

model = TransE(
    triples_factory=training_triples_factory,
    embedding_dim=2,
    #random_seed=1235,
)

optimizer = Adam(params=model.get_grad_params())
training_loop = SLCWATrainingLoop(
    model=model,
    triples_factory=training_triples_factory,
    optimizer=optimizer,
)

_ = training_loop.train(
    triples_factory=training_triples_factory,
    batch_size=32,
    num_epochs=1,
    use_tqdm=False,
    use_tqdm_batch=False
)

n = 10

for i in tqdm(range(1, n)):
    # Continue training only seems to work for entity embeddings. Relation embeddings don't change when using continue training. This might be a bug in PyKEEN

    loss = training_loop.train(
        triples_factory=training_triples_factory,
        num_epochs=i,
        batch_size=32,
        continue_training=True,
        use_tqdm=False,
        use_tqdm_batch=False
    )

    # saving the representation after training an epoch
    reps.append((model.entity_representations[0](indices=None).detach().numpy(), model.relation_representations[0](indices=None).detach().numpy()))
   

If we now look at the saved relation representations in reps we can see that they don't change over time. They are always equal to the first representation in the list. You can try this by comparing the relation representations like this:

print(reps[0][1] == [reps[i][1] for i in range(1, len(reps))])

All of the items will be True. We'd expect to see differences and therefore False for most of these values. You can see a perfect example of what it should look like when doing the same for the entity representation:

print(reps[0][0] == [reps[i][0] for i in range(1, len(reps))])

Environment

Key Value
OS posix
Platform Linux
Release 5.15.0-91-generic
Time Thu Jan 11 10:36:10 2024
Python 3.11.4
PyKEEN 1.10.1
PyKEEN Hash UNHASHED
PyKEEN Branch
PyTorch 2.1.2+cu121
CUDA Available? true
CUDA Version 12.1
cuDNN Version 8902

Additional information

No response

Issue Template Checks

  • This is not a feature request (use a different issue template if it is)
  • This is not a question (use the discussions forum instead)
  • I've read the text explaining why including environment information is important and understand if I omit this information that my issue will be dismissed
@lukas-schwab lukas-schwab added the bug Something isn't working label Jan 11, 2024
@mberr
Copy link
Member

mberr commented Jan 12, 2024

Hi @lukas-schwab ,

thanks for reporting the issue. I can reproduce it locally, although I have not yet have the time to dive deeper into why this happens.

Interestingly, it only seems to happen with two separate .train calls; if I record weights over multiple epochs of a single training run, they change (as expected):

from collections import defaultdict
from typing import Any

import torch

from pykeen.pipeline import pipeline
from pykeen.training.callbacks import TrainingCallback


class WeightRecorderCallback(TrainingCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weights = defaultdict(list)

    def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None:
        for name, tensor in self.model.named_parameters():
            self.weights[name].append(tensor.detach().clone().cpu())


callback = WeightRecorderCallback()
result = pipeline(dataset="nations", model="Transe", training_kwargs=dict(callbacks=[callback]))
print(
    {
        key: [torch.allclose(weights[0], weights[i]) for i in range(len(weights))]
        for key, weights in callback.weights.items()
    }
)
# {
#   'entity_representations.0._embeddings.weight': [True, False, False, False, False], 
#   'relation_representations.0._embeddings.weight': [True, False, False, False, False]
# }

@lukas-schwab
Copy link
Author

Interestingly, it only seems to happen with two separate .train calls

Yes, it's quite strange. Took me a while to get convinced that the library was at fault here and not me.

Fortunately the code you provided is the proper solution to what I was actually trying to achieve. So thank you for posting that and keep up the good work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants