Skip to content

embeddingbad error #2215

@ailzhang

Description

@ailzhang

reported by @taylanbil

import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm


device = xm.xla_device()
d = nn.EmbeddingBag(10, 10, mode="sum", sparse=False).to(device)
inp = torch.LongTensor([1, 5, 9]).to(device)
x = d(inp, offsets=torch.LongTensor([0]).to(device))
loss = x.sum()
loss.backward()

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions