In [1]:
import pytorch_lightning as pl

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [2]:
class MNISTDataModule(LightningDataModule):
    DATASET_DIR = "datasets"
    
    def __init__(self, transform=None, batch_size=100):
        super(MNISTDataModule, self).__init__()
        if transform is None:
            # Default transform
            transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])
        self.transform = transform
        self.batch_size = batch_size

    
    def prepare_data(self):
        """
        All the steps needed to download, tokenize, prepare the raw data should be done under
        prepare data. We will download the MNIST dataset here.
        """
        # Download the train data
        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, download = True)
               
        # Download the test data
        datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, download = True)
    
    def setup(self, stage=None):
        """
        The steps to setup the dataset are usually done under setup method. 
        """
        train_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = True, 
                                            download = False, transform=self.transform)
        # We will split the train dataset into train and validation sets.
        # All experiments are run using the train and val datasets
        self.train_dataset, self.val_dataset = random_split(train_dataset, [55000, 5000])
        self.test_dataset = datasets.MNIST(root = MNISTDataModule.DATASET_DIR, train = False, 
                                            download = False, transform=self.transform)
    
    
    def train_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        train dataloader
        """
        return DataLoader(self.train_dataset, batch_size=self.batch_size, 
                          shuffle=True, num_workers=0) 
    
    def val_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        val dataloader
        """
        return DataLoader(self.val_dataset, batch_size=self.batch_size, 
                          shuffle=False, num_workers=0) 
    
    def test_dataloader(self):
        """
        As evident by the name, this method is responsible for creating and returning the 
        val dataloader
        """
        return DataLoader(self.test_dataset, batch_size=self.batch_size, 
                                          shuffle=False, num_workers=0)
    
    @property
    def num_classes(self):
        return 10