Skip to content
Permalink
Browse files

perm on shuffle

  • Loading branch information...
rusty1s committed May 13, 2019
1 parent c8f5c77 commit 7801043dea8a0dbf110613925f348246e937a80b
Showing with 12 additions and 3 deletions.
  1. +1 −0 test/datasets/test_enzymes.py
  2. +11 −3 torch_geometric/data/in_memory_dataset.py
@@ -23,6 +23,7 @@ def test_enzymes():

assert len(dataset[0]) == 3
assert len(dataset.shuffle()) == 600
assert len(dataset.shuffle(return_perm=True)) == 2
assert len(dataset[:100]) == 100
assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100
mask = torch.zeros(600, dtype=torch.uint8)
@@ -86,9 +86,17 @@ def __getitem__(self, idx):
'Only integers, slices (`:`) and long or byte tensors are valid '
'indices (got {}).'.format(type(idx).__name__))

def shuffle(self):
r"""Randomly shuffles the examples in the dataset."""
return self.__indexing__(torch.randperm(len(self)))
def shuffle(self, return_perm=False):
r"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will
additionally return the random permutation used to shuffle the
data. (default: :obj:`False`)
"""
perm = torch.randperm(len(self))
dataset = self.__indexing__(perm)
return (dataset, perm) if return_perm is True else dataset

def get(self, idx):
data = Data()

0 comments on commit 7801043

Please sign in to comment.
You can’t perform that action at this time.