Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to upgrade PyKEEN<1.8.0 code that uses EmbeddingSpecification? #858

Closed
thtang opened this issue Apr 4, 2022 · 19 comments · Fixed by #861
Closed

How to upgrade PyKEEN<1.8.0 code that uses EmbeddingSpecification? #858

thtang opened this issue Apr 4, 2022 · 19 comments · Fixed by #861
Labels
question Further information is requested

Comments

@thtang
Copy link

thtang commented Apr 4, 2022

Describe the bug

It seems that EmbeddingSpecification is no longer under pykeen.nn.representation.

How to reproduce

from pykeen.nn.representation import EmbeddingSpecification

Environment

PyKEEN | 1.8.0

Additional information

No response

@thtang thtang added the bug Something isn't working label Apr 4, 2022
@cthoyt
Copy link
Member

cthoyt commented Apr 4, 2022

The EmbeddingSpecification was removed in PyKEEN 1.8.0. You can now just use a dictionary directly in places where you might have used EmbeddingSpecification. If you can post a full code example, we could try to help you adapt it. Otherwise, I will close this issue.

@cthoyt cthoyt removed the bug Something isn't working label Apr 4, 2022
@cthoyt cthoyt changed the title Where to import EmbeddingSpecification How to upgrade PyKEEN<1.8.0 code that uses EmbeddingSpecification? Apr 4, 2022
@cthoyt cthoyt added the question Further information is requested label Apr 4, 2022
@thtang
Copy link
Author

thtang commented Apr 4, 2022

Hi here is my full code example

from pykeen.pipeline import pipeline
from pykeen.datasets import get_dataset
from pykeen.nn.representation import EmbeddingSpecification, LabelBasedTransformerRepresentation
from pykeen.models import ERModel


training = TriplesFactory.from_path('./train_global_tree.txt')
testing = TriplesFactory.from_path(
    './test_global_tree.txt',
    entity_to_id=training.entity_to_id,
    relation_to_id=training.relation_to_id,
)

entity_representations = LabelBasedTransformerRepresentation.from_triples_factory(
    triples_factory=training,
)
result = pipeline(
    training=training,
    testing=testing,
    model=ERModel,
    model_kwargs=dict(
        interaction="ermlpe",
        interaction_kwargs=dict(
            embedding_dim=entity_representations.embedding_dim,
        ),
        entity_representations=entity_representations,
        relation_representations=EmbeddingSpecification(
            shape=entity_representations.shape,
        ),
    ),
    training_kwargs=dict(
        num_epochs=1,
    ),
)
model = result.model

@thtang
Copy link
Author

thtang commented Apr 4, 2022

Moreover, I will get AttributeError: 'dict' object has no attribute 'max_id' if I replace EmbeddingSpecification to dict

Many thanks for your help!

@cthoyt
Copy link
Member

cthoyt commented Apr 4, 2022

So inside the model_kwargs you want to use the relation_representations_kwargs - by default you don't have to pass the relation_representations since it defaults to the normal pykeen.nn.Embedding. Otherwise you can explicitly pass None or even directly "embedding":

        ...
        relation_representations=None,
        relation_representations_kwargs=dict(
            shape=entity_representations.shape,
        ),
        ...

@thtang
Copy link
Author

thtang commented Apr 4, 2022

Got it! Then I got the error UnboundLocalError: local variable 'batch' referenced before assignment by running below code:

from pykeen.pipeline import pipeline
from pykeen.datasets import get_dataset
from pykeen.nn.representation import LabelBasedTransformerRepresentation
from pykeen.models import ERModel


training = TriplesFactory.from_path('./train_global_tree.txt')
testing = TriplesFactory.from_path(
    './test_global_tree.txt',
    entity_to_id=training.entity_to_id,
    relation_to_id=training.relation_to_id,
)

entity_representations = LabelBasedTransformerRepresentation.from_triples_factory(
    triples_factory=training,
)
result = pipeline(
    training=training,
    testing=testing,
    model=ERModel,
    model_kwargs=dict(
        interaction="ermlpe",
        interaction_kwargs=dict(
            embedding_dim=entity_representations.embedding_dim,
        ),
        entity_representations=entity_representations,
        relation_representations=None,
        relation_representations_kwargs=dict(
            shape=entity_representations.shape,
        )
    ),
    training_kwargs=dict(
        num_epochs=1,
    ),
)
model = result.model

Stack trace:

UnboundLocalError                         Traceback (most recent call last)
/tmp/ipykernel_124068/2522147645.py in <module>
     15     triples_factory=training,
     16 )
---> 17 result = pipeline(
     18     training=training,
     19     testing=testing,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/pipeline/api.py in pipeline(dataset, dataset_kwargs, training, testing, validation, evaluation_entity_whitelist, evaluation_relation_whitelist, model, model_kwargs, interaction, interaction_kwargs, dimensions, loss, loss_kwargs, regularizer, regularizer_kwargs, optimizer, optimizer_kwargs, clear_optimizer, lr_scheduler, lr_scheduler_kwargs, training_loop, training_loop_kwargs, negative_sampler, negative_sampler_kwargs, epochs, training_kwargs, stopper, stopper_kwargs, evaluator, evaluator_kwargs, evaluation_kwargs, result_tracker, result_tracker_kwargs, metadata, device, random_seed, use_testing_data, evaluation_fallback, filter_validation_when_testing, use_tqdm)
   1235     # Train like Cristiano Ronaldo
   1236     training_start_time = time.time()
-> 1237     losses = training_loop_instance.train(
   1238         triples_factory=training,
   1239         stopper=stopper_instance,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, clear_optimizer, checkpoint_directory, checkpoint_name, checkpoint_frequency, checkpoint_on_failure, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    340             # send model to device before going into the internal training loop
    341             self.model = self.model.to(get_preferred_device(self.model, allow_ambiguity=True))
--> 342             result = self._train(
    343                 num_epochs=num_epochs,
    344                 batch_size=batch_size,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, save_checkpoints, checkpoint_path, checkpoint_frequency, checkpoint_on_failure_file_path, best_epoch_model_file_path, last_best_epoch, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    485         ):
    486             # return the relevant parameters slice_size and batch_size
--> 487             sub_batch_size, slice_size = self.sub_batch_and_slice(
    488                 batch_size=batch_size, sampler=sampler, triples_factory=triples_factory
    489             )

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in sub_batch_and_slice(self, batch_size, sampler, triples_factory)
    899     ) -> Tuple[int, Optional[int]]:
    900         """Check if sub-batching and/or slicing is necessary to train the model on the hardware at hand."""
--> 901         sub_batch_size, finished_search, supports_sub_batching = self._sub_batch_size_search(
    902             batch_size=batch_size,
    903             sampler=sampler,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _sub_batch_size_search(self, batch_size, sampler, triples_factory)
    974             self._free_graph_and_cache()
    975             logger.debug(f"Trying batch_size {batch_size} for training now.")
--> 976             self._train(
    977                 triples_factory=triples_factory,
    978                 num_epochs=1,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, save_checkpoints, checkpoint_path, checkpoint_frequency, checkpoint_on_failure_file_path, best_epoch_model_file_path, last_best_epoch, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    633                     evaluated_once = True
    634 
--> 635                 del batch
    636                 del batches
    637                 gc.collect()

UnboundLocalError: local variable 'batch' referenced before assignment

@thtang
Copy link
Author

thtang commented Apr 4, 2022

On the other hand, if I try to reproduce the sample code in the document with above modifications:

from pykeen.pipeline import pipeline
from pykeen.datasets import get_dataset
from pykeen.nn.representation import LabelBasedTransformerRepresentation
from pykeen.models import ERModel

dataset = get_dataset(dataset="nations")
entity_representations = LabelBasedTransformerRepresentation.from_triples_factory(
    triples_factory=dataset.training,
)
result = pipeline(
    dataset=dataset,
    model=ERModel,
    model_kwargs=dict(
        interaction="ermlpe",
        interaction_kwargs=dict(
            embedding_dim=entity_representations.embedding_dim,
        ),
        entity_representations=entity_representations,
        relation_representations=None,
        relation_representations_kwargs=dict(
            shape=entity_representations.shape,
        )
        ,
    ),
    training_kwargs=dict(
        num_epochs=1,
    ),
)
model = result.model

It will raise RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_124068/201939963.py in <module>
      8     triples_factory=dataset.training,
      9 )
---> 10 result = pipeline(
     11     dataset=dataset,
     12     model=ERModel,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/pipeline/api.py in pipeline(dataset, dataset_kwargs, training, testing, validation, evaluation_entity_whitelist, evaluation_relation_whitelist, model, model_kwargs, interaction, interaction_kwargs, dimensions, loss, loss_kwargs, regularizer, regularizer_kwargs, optimizer, optimizer_kwargs, clear_optimizer, lr_scheduler, lr_scheduler_kwargs, training_loop, training_loop_kwargs, negative_sampler, negative_sampler_kwargs, epochs, training_kwargs, stopper, stopper_kwargs, evaluator, evaluator_kwargs, evaluation_kwargs, result_tracker, result_tracker_kwargs, metadata, device, random_seed, use_testing_data, evaluation_fallback, filter_validation_when_testing, use_tqdm)
   1235     # Train like Cristiano Ronaldo
   1236     training_start_time = time.time()
-> 1237     losses = training_loop_instance.train(
   1238         triples_factory=training,
   1239         stopper=stopper_instance,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, clear_optimizer, checkpoint_directory, checkpoint_name, checkpoint_frequency, checkpoint_on_failure, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    340             # send model to device before going into the internal training loop
    341             self.model = self.model.to(get_preferred_device(self.model, allow_ambiguity=True))
--> 342             result = self._train(
    343                 num_epochs=num_epochs,
    344                 batch_size=batch_size,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, save_checkpoints, checkpoint_path, checkpoint_frequency, checkpoint_on_failure_file_path, best_epoch_model_file_path, last_best_epoch, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    485         ):
    486             # return the relevant parameters slice_size and batch_size
--> 487             sub_batch_size, slice_size = self.sub_batch_and_slice(
    488                 batch_size=batch_size, sampler=sampler, triples_factory=triples_factory
    489             )

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in sub_batch_and_slice(self, batch_size, sampler, triples_factory)
    899     ) -> Tuple[int, Optional[int]]:
    900         """Check if sub-batching and/or slicing is necessary to train the model on the hardware at hand."""
--> 901         sub_batch_size, finished_search, supports_sub_batching = self._sub_batch_size_search(
    902             batch_size=batch_size,
    903             sampler=sampler,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _sub_batch_size_search(self, batch_size, sampler, triples_factory)
    985             self._free_graph_and_cache()
    986             if not is_cudnn_error(runtime_error) and not is_cuda_oom_error(runtime_error):
--> 987                 raise runtime_error
    988             logger.debug(f"The batch_size {batch_size} was too big, sub_batching is required.")
    989             sub_batch_size //= 2

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _sub_batch_size_search(self, batch_size, sampler, triples_factory)
    974             self._free_graph_and_cache()
    975             logger.debug(f"Trying batch_size {batch_size} for training now.")
--> 976             self._train(
    977                 triples_factory=triples_factory,
    978                 num_epochs=1,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, save_checkpoints, checkpoint_path, checkpoint_frequency, checkpoint_on_failure_file_path, best_epoch_model_file_path, last_best_epoch, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    670                 # During automatic memory optimization only the error message is of interest
    671                 if only_size_probing:
--> 672                     raise e
    673 
    674                 logger.warning(f"The training loop just failed during epoch {epoch} due to error {str(e)}.")

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _train(self, triples_factory, num_epochs, batch_size, slice_size, label_smoothing, sampler, continue_training, only_size_probing, use_tqdm, use_tqdm_batch, tqdm_kwargs, stopper, sub_batch_size, num_workers, save_checkpoints, checkpoint_path, checkpoint_frequency, checkpoint_on_failure_file_path, best_epoch_model_file_path, last_best_epoch, drop_last, callbacks, callback_kwargs, gradient_clipping_max_norm, gradient_clipping_norm_type, gradient_clipping_max_abs_value, pin_memory)
    602 
    603                         # forward pass call
--> 604                         batch_loss = self._forward_pass(
    605                             batch,
    606                             start,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/training_loop.py in _forward_pass(self, batch, start, stop, current_batch_size, label_smoothing, slice_size)
    778     ) -> float:
    779         # forward pass
--> 780         loss = self._process_batch(
    781             batch=batch,
    782             start=start,

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/training/slcwa.py in _process_batch(self, batch, start, stop, label_smoothing, slice_size)
    105 
    106         # Compute negative and positive scores
--> 107         positive_scores = self.model.score_hrt(positive_batch, mode=self.mode)
    108         negative_scores = self.model.score_hrt(negative_batch, mode=self.mode).view(*negative_score_shape)
    109 

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/models/nbase.py in score_hrt(self, hrt_batch, mode)
    404         # Note: we do not delegate to the general method for performance reasons
    405         # Note: repetition is not necessary here
--> 406         h, r, t = self._get_representations(h=hrt_batch[:, 0], r=hrt_batch[:, 1], t=hrt_batch[:, 2], mode=mode)
    407         return self.interaction.score_hrt(h=h, r=r, t=t)
    408 

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/models/nbase.py in _get_representations(self, h, r, t, mode)
    502         """Get representations for head, relation and tails."""
    503         entity_representations = self._entity_representation_from_mode(mode=mode)
--> 504         hr, rr, tr = [
    505             [representation.forward_unique(indices=indices) for representation in representations]
    506             for indices, representations in (

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/models/nbase.py in <listcomp>(.0)
    503         entity_representations = self._entity_representation_from_mode(mode=mode)
    504         hr, rr, tr = [
--> 505             [representation.forward_unique(indices=indices) for representation in representations]
    506             for indices, representations in (
    507                 (h, entity_representations),

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/models/nbase.py in <listcomp>(.0)
    503         entity_representations = self._entity_representation_from_mode(mode=mode)
    504         hr, rr, tr = [
--> 505             [representation.forward_unique(indices=indices) for representation in representations]
    506             for indices, representations in (
    507                 (h, entity_representations),

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/nn/representation.py in forward_unique(self, indices)
    166             return self(None)
    167         unique, inverse = indices.unique(return_inverse=True)
--> 168         x_unique = self._plain_forward(indices=unique)
    169         # normalize *before* repeating
    170         if self.normalizer is not None:

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/nn/representation.py in _plain_forward(self, indices)
    967             indices = torch.arange(self.max_id, device=self.device)
    968         uniq, inverse = indices.to(device=self.device).unique(return_inverse=True)
--> 969         x = self.encoder(
    970             labels=[self.labels[i] for i in uniq.tolist()],
    971         )

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/data/jason.tangth/py39/lib/python3.9/site-packages/pykeen/nn/utils.py in forward(self, labels)
     53         if isinstance(labels, str):
     54             labels = [labels]
---> 55         return self.model(
     56             **self.tokenizer(
     57                 labels,

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/data/jason.tangth/py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    987         head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
    988 
--> 989         embedding_output = self.embeddings(
    990             input_ids=input_ids,
    991             position_ids=position_ids,

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/data/jason.tangth/py39/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds, past_key_values_length)
    212 
    213         if inputs_embeds is None:
--> 214             inputs_embeds = self.word_embeddings(input_ids)
    215         token_type_embeddings = self.token_type_embeddings(token_type_ids)
    216 

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    156 
    157     def forward(self, input: Tensor) -> Tensor:
--> 158         return F.embedding(
    159             input, self.weight, self.padding_idx, self.max_norm,
    160             self.norm_type, self.scale_grad_by_freq, self.sparse)

/data/jason.tangth/py39/lib/python3.9/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2042         # remove once script supports set_grad_enabled
   2043         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2044     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2045 
   2046 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

@mberr
Copy link
Member

mberr commented Apr 4, 2022

Got it! Then I got the error UnboundLocalError: local variable 'batch' referenced before assignment by running below code:
[...]

This is #828

@pykeen pykeen deleted a comment from thtang Apr 4, 2022
@thtang
Copy link
Author

thtang commented Apr 4, 2022

My training is TriplesFactory(num_entities=309, num_relations=3, num_triples=200, inverse_triples=False, path="/data/jason.tangth/codebase/label_synonyms_pair/train_global_tree.txt")
I think it's not empty.
Besides, if I set

training_kwargs=dict(
        num_epochs=1,
        batch_size=32
    ),

then it seems to be fixed with another error same as using nations dataset (RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

  • batch_size>200 (num_triples) will get UnboundLocalError: local variable 'batch' referenced before assignment again.

@mberr
Copy link
Member

mberr commented Apr 4, 2022

@thtang could you try setting

training_kwargs=dict(
    ...,
    drop_last=False,
)

A full documentation of all training parameters can be found at https://pykeen.readthedocs.io/en/stable/api/pykeen.training.TrainingLoop.html#pykeen.training.TrainingLoop.train.

In particular:

drop_last (Optional[bool]) – Whether to drop the last batch in each epoch to prevent smaller batches. Defaults to False, except if the model contains batch normalization layers. Can be provided explicitly to override.

The problem occurs if batch_size>200 since you only have 200 training triples.

btw, drop_last gets passed to torch.utils.data.DataLoader, cf. https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader . It might be worth to link this in the doc, too 🙂

@thtang
Copy link
Author

thtang commented Apr 4, 2022

Got it. btw, do you have idea about the error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

Is there an option to pass ERModel to cuda? The entity_representations.device is device(type='cuda', index=0)

Or it's not necessary.

@mberr
Copy link
Member

mberr commented Apr 4, 2022

It looks like the entity indices are not on the correct device, but rather stayed on CPU - I would need to closer investigate the underlying issue.

Update: Likely it is an error in

pykeen/nn/utils.py in forward(self, labels)

where the result of the tokenizer is not moved to the representation module's device.

@thtang
Copy link
Author

thtang commented Apr 4, 2022

Cool. Is there any way I can fix it in my local pykeen src?

@mberr
Copy link
Member

mberr commented Apr 4, 2022

you can install PyKEEN from source in editable mode

git clone https://github.com/pykeen/pykeen.git
cd pykeen
pip install -e .

and then modify the code. If you find a solution, we are happy to accept a PR (or your comments on how you fixed it).

@thtang
Copy link
Author

thtang commented Apr 5, 2022

Do all the models that can pass pre-trained embedding encounter the same issue?

@mberr
Copy link
Member

mberr commented Apr 5, 2022

Do all the models that can pass pre-trained embedding encounter the same issue?

No, I do not think so.

@thtang
Copy link
Author

thtang commented Apr 5, 2022

Could I use the latest code via

git clone https://github.com/pykeen/pykeen.git
cd pykeen
pip install -e .

?

@mberr
Copy link
Member

mberr commented Apr 5, 2022

cf. #861 (comment).

It would be nice to only comment in one place though.

@mberr
Copy link
Member

mberr commented Apr 5, 2022

@thtang , should be fixed now on master

@thtang
Copy link
Author

thtang commented Apr 5, 2022

It works. Many thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants