-
Notifications
You must be signed in to change notification settings - Fork 559
Closed
Description
reported by @taylanbil
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
device = xm.xla_device()
d = nn.EmbeddingBag(10, 10, mode="sum", sparse=False).to(device)
inp = torch.LongTensor([1, 5, 9]).to(device)
x = d(inp, offsets=torch.LongTensor([0]).to(device))
loss = x.sum()
loss.backward()
Metadata
Metadata
Assignees
Labels
No labels