# core

> Fill in a module description here

In [None]:
#| default_exp core

In [11]:
#| hide
from nbdev.showdoc import *

import torch
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, random_split, Dataset


from typing import List, Tuple, Union, Optional
import os



In [6]:
#| export
class ImageDataset(Dataset):
    " Base class for image datasets providing visualization of (image, label) samples"

    def __init__(self):
        logger.info("ImageDataset: init")
        super().__init__()

    def show_idx(self,
            index:int # Index of the (image,label) sample to visualize
        ):
        "display image from data point index of a image dataset"
        X, y = self.__getitem__(index)
        plt.figure(figsize = (1, 1))
        plt.imshow(X.numpy().reshape(28,28),cmap='gray')
        plt.title(f"Label: {int(y)}")
        plt.show()

    @staticmethod
    def show_grid(
            imgs: List[torch.Tensor], # python list of images dim (C,H,W)
            save_path=None, # path where image can be saved
            dims:Tuple[int,int] = (28,28)
        ):
        "display list of mnist-like images (C,H,W)"
        if not isinstance(imgs, list):
            imgs = [imgs]
        fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
        for i, img in enumerate(imgs):
            img = img.detach()
            axs[0, i].imshow(img.numpy().reshape(dims[0],dims[1]))
            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        if save_path:
            plt.savefig(save_path)

    def show_random(
            self,
            n:int=3, # number of images to display
            dims:Tuple[int,int] = (28,28)
        ):
        "display grid of random images"
        indices = torch.randint(0,len(self), (n,))
        images = []
        for index in indices:
            X, y = self.__getitem__(index)
            X = X.reshape(dims[0],dims[1])
            images.append(X)
        self.show_grid(images)
        

In [7]:
#| export 

class MNISTDataset(ImageDataset):
    "MNIST digit dataset"

    def __init__(
        self,
        data_dir:str='../data/image', # path where data is saved
        train = True, # train or test dataset
        transform:torchvision.transforms.transforms=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
        # TODO: add noramlization?
        # torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.1307,), (0.3081,))])

    ):
        os.makedirs(data_dir, exist_ok=True)
        super().__init__()
        logger.info("MNISTDataset: init")

        self.ds = MNIST(
            data_dir,
            train = train,
            transform=transform, 
            download=True
        )

    def __len__(self) -> int: # length of dataset
        return len(self.ds)
    
    def __getitem__(self, idx # index into the dataset
                    ) -> tuple[torch.FloatTensor, int]: # Y image data, x digit number
        x = self.ds[idx][0]
        y = self.ds[idx][1]
        return x, y
    
    def train_dev_split(
            self,
            ratio:float, # percentage of train/dev split,
        ) -> tuple[torchvision.datasets.MNIST, torchvision.datasets.MNIST]: # train and set mnnist datasets

        train_set_size = int(len(self.ds) * ratio)
        valid_set_size = len(self.ds) - train_set_size

        # split the train set into two
        train_set, valid_set = data.random_split(self.ds, [train_set_size, valid_set_size])
        # TODO: cast to ImageDataset to allow for drawing
        # train_set, valid_set = Dataset(train_set),j Dataset(valid_set)
        return train_set, valid_set



NameError: name 'transforms' is not defined

In [None]:
#| export
def foo(): pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()