In [16]:
"""
Imports
"""

import os
from dataclasses import dataclass

import lightning as L
import lightning.pytorch as pl
import pandas as pd
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging, LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split

torch.set_float32_matmul_precision('medium')

In [17]:
@dataclass
class TrainingConfig:
    image_size = 64  # the generated image resolution

    train_batch_size = 32
    val_batch_size = 32
    eval_batch_size = 4  # how many images to sample during evaluation

    max_epochs = 15
    check_val_every_n_epoch = 1
    accumulate_grad_batches = 2
    learning_rate = 1e-6

    output_dir = "lightning"

    seed = 10


config = TrainingConfig()

In [18]:
transform = T.Compose([
    # T.ToPILImage(),
    # T.Resize((config.image_size, config.image_size)),
    # T.ToTensor(),
    T.Normalize([0.5], [0.5]),
])

reverse_transform = T.Compose([
    # T.Resize((config.image_size, config.image_size)),
    # T.ToTensor(),
    T.Normalize([-0.5/0.5], [1/0.5]),
    T.ToPILImage(),
])

In [19]:
class MNISTDataset(torch.utils.data.Dataset):
    def __init__(self, labels: pd.DataFrame, images: pd.DataFrame, transform=None):
        super().__init__()
        self.labels = labels
        self.images = images
        assert len(self.labels) == len(self.images)
        
        self.transform = transform

    def __len__(self):
        length = len(self.labels)
        return length

    def __getitem__(self, index):
        image = self.samples.iloc[index]['image']
        # print(f'Reading : {image}')
        image = torchvision.io.read_image(image)

        if self.transform:
            image = self.transform(image)

        return image

In [20]:
class MNISTDataModule(L.LightningDataModule):
    def __init__(self,
                 train_transform,
                 test_transform):
        super().__init__()
        self.num_workers = os.cpu_count()  # <- use all available CPU cores

        self.train_transform = train_transform
        self.test_transform = test_transform

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage: str):
        if stage == "fit":
            train_data = pd.read_csv("data/train.csv")
            # samples = samples.sample(frac=0.4)
            train_data, val_data = train_test_split(
                train_data, 
                train_size=0.7, 
                shuffle=True)
            
            train_labels = train_data.label
            val_labels = val_data.label
            
            # Reshaping data
            train_images = train_data.iloc[:,1:].values.reshape(len(train_data), 28, 28)
            val_images = val_data.iloc[:,1:].values.reshape(len(val_data), 28, 28)
                              
            self.train_dataset = MNISTDataset(
                labels=train_labels,
                images=train_images,
                transform=self.train_transform
            )

            self.val_dataset = MNISTDataset(
                labels=val_labels,
                images=val_images,
                transform=self.test_transform
            )

            print(f"Total Dataset       : {len(self.train_dataset) + len(self.val_dataset)} samples")
            print(f"Train Dataset       : {len(self.train_dataset)} samples")
            print(f"Validation Dataset  : {len(self.val_dataset)} samples")
        
        if stage == 'predict':
            samples = pd.read_csv("data/test.csv")
            labels = samples.label
            images = samples.iloc[:,1:].values.reshape(len(samples), 28, 28)
            
            self.test_dataset = MNISTDataset(
                labels=labels,
                images=images,
                transform=self.test_transform
            )
            
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=config.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=config.val_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True,
        )

In [21]:
dm = MNISTDataModule(
    train_transform=transform,
    test_transform=transform
)

In [23]:
dm.setup(stage='fit')

yo
Total Dataset       : 42000 samples
Train Dataset       : 29399 samples
Validation Dataset  : 12601 samples


In [24]:
dl = dm.train_dataloader()

In [None]:
x = next(iter(dl))