-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Description
🐛 Bug
If you concatenate a dataset such as CIFAR10 with FakeData, you get error
- AttributeError: 'int' object has no attribute 'numel'
To Reproduce
Steps to reproduce the behavior:
cifar_dataset = torchvision.datasets.CIFAR10(...)
fake_dataset = torchvision.datasets.FakeData(...)
train_data = Concat([cifar_dataset, fake_dataset])
train_loader = DataLoader(train_data, ...)
for data in train_loader
then error
Additional context
The reason why it happens is the labels in CIFAR10 are int and labels in FakeData are tensors. When concatenating them to construct a batch, the batch labels look like [0,1,2,3,tensor(0),3,4,5,6,tensor(2)...].
I can solve this bug by letting target_transform=int
when I load fake_dataset. However, this is very hard to debug. I assume that the default target type in the FakeData source code should be set to int instead of long tensor.
Here:
https://pytorch.org/vision/0.8/_modules/torchvision/datasets/fakedata.html#FakeData
in function __getitem__
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
It's long tensor. It should be int.