In [7]:
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader

In [8]:
from torch.utils.data.dataset import Subset

valid_indices = torch.arange(0, 1000)
train_indices = torch.arange(1000, 60000)


train_and_valid = datasets.MNIST(root='data', 
                                 train=True, 
                                 transform=transforms.ToTensor(),
                                 download=True)

train_dataset = Subset(train_and_valid, train_indices)
valid_dataset = Subset(train_and_valid, valid_indices)

print(f'Total number of training examples: {len(train_dataset)}')

Total number of training examples: 59000


## SubsetRandomSampler Method
Compared to the Subset method, the SubsetRandomSampler is a more convenient solution if we want to assign different transformation methods to training and test subsets. Similar to the Subset example, we will use the first 1000 examples for the validation set and the remaining 59000 examples for training.

In [12]:
from torch.utils.data import SubsetRandomSampler


train_indices = torch.arange(1000, 60000)
valid_indices = torch.arange(0, 1000)


train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)


training_transform = transforms.Compose([transforms.Resize((32, 32)),
                                         transforms.RandomCrop((28, 28)),
                                         transforms.ToTensor()])

valid_transform = transforms.Compose([transforms.Resize((32, 32)),
                                         transforms.CenterCrop((28, 28)),
                                         transforms.ToTensor()])



train_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=training_transform,
                               download=True)

# note that this is the same dataset as "train_dataset" above
# however, we can now choose a different transform method
valid_dataset = datasets.MNIST(root='data', 
                               train=True, 
                               transform=valid_transform,
                               download=False)

test_dataset = datasets.MNIST(root='data', 
                              train=False, 
                              transform=valid_transform,
                              download=False)

train_loader = DataLoader(train_dataset,
                          batch_size=1,
                          num_workers=4,
                          sampler=train_sampler)

valid_loader = DataLoader(valid_dataset,
                          batch_size=4,
                          num_workers=4,
                          sampler=valid_sampler)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=4,
                         num_workers=4,
                         shuffle=False)

print(f'Total number of training examples: {len(train_loader)}')

Total number of training examples: 59000
