# Modules

> Neural net modules

In [None]:
#| default_exp modules

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

## Modules

In [None]:
#| export

# python
from typing import List, Optional, Dict, Any

# torch
import torch.nn as nn
import torch

# lightning
from pytorch_lightning import LightningDataModule


In [None]:
#| export
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

### Usage

In [None]:
enc = Encoder()
batch = torch.rand((10, 28*28))
encoded = enc(batch)
print(encoded.shape)

torch.Size([10, 3])


In [None]:
dec = Decoder()
decoded = dec(encoded)
print(decoded.shape)

torch.Size([10, 784])


## DataModule skeleton

In [None]:
class DataModuleSkeleton(LightningDataModule):
    def __init__(
        self,
        data_dir: str = "~/Data/", # path to source data dir
        train_val_test_split:List[float] = [0.8, 0.1, 0.1], # train val test %
        batch_size: int = 64, # size of compute batch
        num_workers: int = 0, # num_workers equal 0 means that it’s the main process that will do the data loading when needed, num_workers equal 1 is the same as any n, but you’ll only have a single worker, so it might be slow
        pin_memory: bool = False, # If you load your samples in the Dataset on CPU and would like to push it during training to the GPU, you can speed up the host to device transfer by enabling pin_memory. This lets your DataLoader allocate the samples in page-locked memory, which speeds-up the transfer
        persistent_workers: bool = False
    ) -> None:
        super().__init__()
        self.save_hyperparameters(logger=False) # can access inputs with self.hparams
        self.transforms = transforms.Compose([transforms.ToTensor()])
        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        if sum(train_val_test_split) != 1.0:
            raise Exception('split percentages should sum up to 1.0')
    
    def prepare_data(self) -> None:
        pass

    def setup(self, stage: Optional[str]=None)->None:
        pass

    def train_dataloader(self) -> torch.utils.data.DataLoader:
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
            persistent_workers=self.hparams.persistent_workers
        )

    def val_dataloader(self) -> torch.utils.data.DataLoader:
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            persistent_workers=self.hparams.persistent_workers
        )

    def test_dataloader(self) -> torch.utils.data.DataLoader:
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            persistent_workers=self.hparams.persistent_workers
        )

    def teardown(self, stage: Optional[str] = None) -> None:
        """Clean up after fit or test."""
        pass

    def state_dict(self):
        """Extra things to save to checkpoint."""
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]):
        """Things to do when loading checkpoint."""
        pass

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()