Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.utils.data.random_split() returns dataset index as tensor #10165

Open
mfournarakis opened this issue Aug 2, 2018 · 4 comments
Open

torch.utils.data.random_split() returns dataset index as tensor #10165

mfournarakis opened this issue Aug 2, 2018 · 4 comments
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mfournarakis
Copy link

mfournarakis commented Aug 2, 2018

Issue description

torch.utils.data.random_split() returns the index of the datapoint (idx) as a tensor rather than a float which messes up the __getitem__() routine of the dataset

Code example

class AntsDataset(Dataset):
    def __init__(self, root_dir, csv_file, transform=None):
       
        self.rotations = pd.read_csv(csv_file,header=None)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.rotations)

    def __getitem__(self, idx):
        import ipdb; ipdb.set_trace()
        img_name = os.path.join(self.root_dir,
                                self.rotations.iloc[idx, 0])
        image = plt.imread(img_name,format='RGB')
        rotation = self.rotations.iloc[idx, 1].astype('float')

        if self.transform is not None:
            image=self.transform(image)

        return (image, rotation)

ants_dataset=AntsDataset(ants1_root_dir, ants1_rot_file,
        transform=transforms.Compose([transforms.ToTensor()]))

dataloader=torch.utils.data.DataLoader(ants_dataset,
        batch_size=10, shuffle=True)

train_length=int(0.7* len(ants_dataset))

test_length=len(ants_dataset)-train_length

train_dataset,test_dataset=torch.utils.data.random_split(ants_dataset,(train_length,test_length))

dataloader_train=torch.utils.data.DataLoader(train_dataset,
        batch_size=10, shuffle=True)

for batch_idx, (data,rotations) in enumerate(ants_dataset):
    print(rotations)

TypeError                                 Traceback (most recent call last)
<ipython-input-101-f629e71651de> in <module>()
      1 dataloader_train=torch.utils.data.DataLoader(train_dataset,
      2         batch_size=10, shuffle=True)
----> 3 for batch_idx, (data,rotations) in enumerate(dataloader_train):
      4     print(rotations)

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
    312         if self.num_workers == 0:  # same-process loading
    313             indices = next(self.sample_iter)  # may raise StopIteration
--> 314             batch = self.collate_fn([self.dataset[i] for i in indices])
    315             if self.pin_memory:
    316                 batch = pin_memory_batch(batch)

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py in <listcomp>(.0)
    312         if self.num_workers == 0:  # same-process loading
    313             indices = next(self.sample_iter)  # may raise StopIteration
--> 314             batch = self.collate_fn([self.dataset[i] for i in indices])
    315             if self.pin_memory:
    316                 batch = pin_memory_batch(batch)

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py in __getitem__(self, idx)
    101 
    102     def __getitem__(self, idx):
--> 103         return self.dataset[self.indices[idx]]
    104 
    105     def __len__(self):

<ipython-input-95-f56aceeb246b> in __getitem__(self, idx)
     19     def __getitem__(self, idx):
     20         img_name = os.path.join(self.root_dir,
---> 21                                 self.rotations.iloc[idx, 0])
     22         image = plt.imread(img_name,format='RGB')
     23         rotation = self.rotations.iloc[idx, 1].astype('float')

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in __getitem__(self, key)
   1470             except (KeyError, IndexError):
   1471                 pass
-> 1472             return self._getitem_tuple(key)
   1473         else:
   1474             # we by definition only have the 0th axis

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_tuple(self, tup)
   2011     def _getitem_tuple(self, tup):
   2012 
-> 2013         self._has_valid_tuple(tup)
   2014         try:
   2015             return self._getitem_lowerdim(tup)

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _has_valid_tuple(self, key)
    220                 raise IndexingError('Too many indexers')
    221             try:
--> 222                 self._validate_key(k, i)
    223             except ValueError:
    224                 raise ValueError("Location based indexing can only have "

~/anaconda3/envs/pytorch/lib/python3.6/site-packages/pandas/core/indexing.py in _validate_key(self, key, axis)
   1965             l = len(self.obj._get_axis(axis))
   1966 
-> 1967             if len(arr) and (arr.max() >= l or arr.min() < -l):
   1968                 raise IndexError("positional indexers are out-of-bounds")
   1969         else:

TypeError: len() of unsized object```

## System Info

PyTorch version: 0.4.1
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.6
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip] Could not collect
[conda] torch                     0.4.1                     <pip>
[conda] torchvision               0.2.1                     <pip>


cc @SsnL @VitalyFedyunin @ejguan @jlin27 @mruberry
@ailzhang
Copy link
Contributor

ailzhang commented Aug 3, 2018

Using tensor as index should work well when indexing inside tensors. But your self.rotations is not a tensor that's why index didn't work I think. In this case, do we require user to do the conversion or should we change the output of random_split()?
cc: @ssnl

@fmassa
Copy link
Member

fmassa commented Aug 4, 2018

This is a duplicate of #9211

I still think that converting to a list is the simplest thing to do, but well :-)

@ailzhang
Copy link
Contributor

ailzhang commented Aug 6, 2018

Based on #9211 , seems like our conclusion is that it's better to do it on user side since .tolist() it's a quite expensive operation. @mariosfourn does it work for you?

@mfournarakis
Copy link
Author

mfournarakis commented Aug 7, 2018 via email

@zou3519 zou3519 added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: docs Related to our documentation, both in docs/ and docblocks and removed triage review labels Feb 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants