Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate random generation from transforms #115

Closed
wants to merge 1 commit into from

Conversation

fmassa
Copy link
Member

@fmassa fmassa commented Mar 21, 2017

This PR factors out the random number generation from the transforms. This way, the same random transform can be applied to different inputs (from eventually different domains).

In the dataset, if the user wants to support the same random transforms applied to different input, a set of generators should be passed in the constructor of the dataset.

class MySegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, transforms=None, target_transforms=None, generators=None):
        self.transforms = transforms
        self.target_transforms = target_transforms
        self.generators = generators

    def __getitem__(self, idx):
      # load img and target
      ...
      # generate random numbers, if generator is provided
      if self.generators is not None:
            for g in self.generators:
                g.generate()

        if self.transforms is not None:
            img = self.transforms(img)

        if self.target_transforms is not None:
            target = self.transforms(target)

        return img, target

An example of how it should be used is presented as follows:

# create random generators
crop_generator = torchvision.transforms.RandomCropGenerator()
flip_generator = torchvision.transforms.RandomFlipGenerator()

generators = (crop_generator, flip_generator)

# random transforms that consumes the generators
crop = RandomCrop(img_size, generator=crop_generator)
flip = RandomHorizontalFlip(generator=flip_generator)

# composed transforms that reuses crop and flip
train_input_transform = torchvision.transforms.Compose([
    crop,
    flip,
    torchvision.transforms.ToTensor(),
    normalize,
])
train_target_transform = torchvision.transforms.Compose([
    crop,
    flip,
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambd=lambda x: x.long())
])

# create dataset, but need to pass the generators as well
mydataset = MySegmentationDataset(train_input_transform,
                    train_target_transform,
                    generators)

img, target = mydataset[0]

A few points worth noting:

  • I'm not sure how the random state generator behaves on a multi-threaded setup (@colesbury ?)
  • For more complex transforms such as RandomSizedCrop, the size of the image is required for the generator. We can add an extra *args, **kwargs in the call to each generate method. I'll add that if you agree with that.
  • We need to pass an extra set of generators to the constructor of the dataset, and call generate at each __getitem__, which might not be ideal.

cc @bodokaiser @ellisbrown @desimone @felixgwu

@bodokaiser
Copy link
Contributor

👍

I believe this will not work with num_workers > 1 in DataLoader. Here we see that the loader passes the same dataset instance to all workers so different workers will access same dataset.generators. I guess the best would be to create copies in DataLoader so that every worker has its own dataset instance however this may be a problem with datasets which do a lot of caching for example. So maybe it would be better to create generators on every call to __getitem__ or pass some seed directly to the transform call?

@alykhantejani
Copy link
Contributor

@fmassa can this be closed in favor of #240?

@fmassa fmassa closed this Sep 19, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants