In [1]:
import os
from typing import Optional
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import DataLoader, random_split
import torchvision 
from torchvision import transforms 
from torchvision.datasets import CIFAR10

import pytorch_lightning as pl
from  pytorch_lightning import LightningDataModule,LightningModule 
os.environ["TORCH_HOME"] = "/root/data/torch_home"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]  = "0,1"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.determinstic = True
pl.seed_everything(42)


Global seed set to 42


42

In [11]:
DATASET_PATH = os.path.join(os.environ.get("TORCH_HOME",'./'),'cifar')
BATCH_SIZE = 128
NUM_WORKERS = 4

In [12]:
class CIFARDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        batch_size: int,
        num_workers: int 
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers 

        self.test_transforms = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
            ]
        )
        self.train_transforms = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.RandomResizedCrop((32,32),scale=(0.8,1.0), ratio=(0.9,1.1)),
                transforms.ToTensor(),
                transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
            ]
        )

    def prepare_data(self) -> None: # 
        CIFAR10(self.data_dir, train=True, download=False)
        CIFAR10(self.data_dir, train=False, download=False)

    def setup(self, stage: Optional[str] = None) -> None:
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.train_transforms)
            self.train_data, self.val_data = random_split(cifar_full, [45000,5000])
        if stage == 'test' or stage is None:
            self.test_data = CIFAR10(self.data_dir, train=False, transform = self.test_transforms)

    def data_example(self):
        #self.prepare_data()
        self.setup()
        return self.train_data

    def train_dataloader(self) :
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=self.num_workers)

    def val_dataloader(self) :
        return DataLoader(self.val_data, batch_size = self.batch_size, shuffle=True, drop_last=False, num_workers=self.num_workers) 

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, shuffle=True, drop_last=True, pin_memory=True, num_workers=self.num_workers)

In [13]:
dm = CIFARDataModule(DATASET_PATH, BATCH_SIZE, NUM_WORKERS)

In [14]:
# visual some example
train_set = dm.data_example()

  "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7."


In [15]:
dm.train_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7f383fe6b6a0>