Skip to content

Commit

Permalink
Fix TGNMemory forward when device is CUDA (#8933)
Browse files Browse the repository at this point in the history
This PR address the issue
#8926
  • Loading branch information
Kh4L committed Feb 17, 2024
1 parent 7577035 commit 6517997
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torch_geometric/nn/models/tgn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType,
msg_module: Callable):
data = [msg_store[i] for i in n_id.tolist()]
src, dst, t, raw_msg = list(zip(*data))
src = torch.cat(src, dim=0)
dst = torch.cat(dst, dim=0)
t = torch.cat(t, dim=0)
src = torch.cat(src, dim=0).to(self.device)
dst = torch.cat(dst, dim=0).to(self.device)
t = torch.cat(t, dim=0).to(self.device)
# Filter out empty tensors to avoid `invalid configuration argument`.
# TODO Investigate why this is needed.
raw_msg = [m for i, m in enumerate(raw_msg) if m.numel() > 0 or i == 0]
raw_msg = torch.cat(raw_msg, dim=0)
raw_msg = torch.cat(raw_msg, dim=0).to(self.device)
t_rel = t - self.last_update[src]
t_enc = self.time_enc(t_rel.to(raw_msg.dtype))

Expand Down

0 comments on commit 6517997

Please sign in to comment.