# Understanding MONAI Datasets, Caching, and Networks
---

Users often need to train the model with many (potentially thousands of) epochs over the data to achieve the desired model quality. A native PyTorch implementation may repeatedly load data and run the same preprocessing steps for every epoch during training, which can be time-consuming and unnecessary, especially when the medical image volumes are large.  By utilizing Dataset Caching, you can reduce the amount of time your system takes to load this data and preprocess it, reducing your overall training time.

Network functionality represents a significant design opportunity for MONAI. Pytorch is very much unopinionated in how networks are defined. It provides Module as a base class to create a network and a few methods that must be implemented. Still, there is no prescribed pattern nor much helper functionality for initializing networks.

The lack of helper functionality leaves a lot of room for defining some beneficial 'best practice' patterns for constructing new networks in MONAI. Although trivial, inflexible network implementations are easy enough, but we can give users a toolset that makes it much easier to build well-engineered, flexible networks and demonstrate their value by committing to use them in the networks that we build.

## MONAI Datasets, Caching, and Networks

To help you understand more about MONAI Datasets and Caching options, this guide will help you answer five key questions:

1. **What is a MONAI Dataset?**
2. **What is Dataset Caching and how do I use it?**
3. **What common datasets are provided by MONAI?**
4. **How do you use MONAI Layers?**
5. **How do you use these flexible layers to create a network?**


Let's get started by importing our dependencies.  

In [None]:
import time
import torch

import monai
from monai.data import Dataset, DataLoader, CacheDataset, PersistentDataset, SmartCacheDataset, ZipDataset
from monai.apps import DecathlonDataset
from monai.transforms import (
    MapTransform,
)

## **1. What is a MONAI Dataset?**


A MONAI Dataset is a generic dataset with a __len__ property, __getitem__ property, and an optional callable data transform when fetching a data sample.

We'll start by initializing some generic data, calling the Dataset class with the generic data, and specifying None for our transforms.

In [None]:
items = [{"data": 4}, 
         {"data": 9}, 
         {"data": 3}, 
         {"data": 7}, 
         {"data": 1},
         {"data": 2},
         {"data": 5}]
dataset = monai.data.Dataset(items, transform=None)

print(f"Length of dataset is {len(dataset)}")
for item in dataset:
    print(item)

#### Compatible with the PyTorch DataLoader

MONAI functionality should be compatible with the PyTorch DataLoader, although free to subclass from it if there is additional functionality that we consider key, which cannot be realized with the standard DataLoader class.

In [None]:
for item in torch.utils.data.DataLoader(dataset, batch_size=2):
    print(item)

### Load items with a customized transform

We'll create a custom transform called `SquareIt`, which will replace the corresponding value of the input's `keys` with a squared value. In our case, `SquareIt(keys='data')` will apply the square transform to the value of `x['data']`.

In [None]:
class SquareIt(MapTransform):
    def __init__(self, keys):
        MapTransform.__init__(self, keys)
        print(f"keys to square it: {self.keys}")
        
    def __call__(self, x):
        key = self.keys[0]
        data = x[key]
        output = {key: data ** 2}
        return output

square_dataset = Dataset(items, transform=SquareIt(keys='data'))
for item in square_dataset:
    print(item)

## **2. What is Dataset Caching and how do I use it?**

 MONAI provides multi-thread versions of `CacheDataset` and `LMDBDataset` to accelerate these transformation steps during training by storing the intermediate outcomes before the first randomized transform in the transform chain. Enabling this feature could potentially give 10x training speedups in the Datasets experiment.
 
<img src="cache_dataset.png" style="width: 700px;"/>
 
To demonstrate the benefit dataset caching, we're going to construct a dataset with a slow transform.  To do that, we're going to call the sleep function during each of the `__call__` functions.

In [None]:
class SlowSquare(MapTransform):
    def __init__(self, keys):
        MapTransform.__init__(self, keys)
        print(f"keys to square it: {self.keys}")

    def __call__(self, x):
        time.sleep(1.0)
        output = {key: x[key] ** 2 for key in self.keys}
        return output

square_dataset = Dataset(items, transform=SlowSquare(keys='data'))

As expected, it's going to take about 7 seconds to go through all the items.

In [None]:
%time for item in square_dataset: print(item)

Every time we run this loop we're going to get roughly 7 seconds to go through all of the items.  If you were do this for 100 epochs, you're adding almost 12 extra minutes of load time to your total training loop.  Let's look at ways that we can improve this time by utilizing caching.

### Cache Dataset

When using [CacheDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#cachedataset) the caching is done when the object is initialized for the first time, so the initialization is slower than a regular dataset.

By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline. If the requested data is not in the cache, all transforms will run normally.

In [None]:
square_cached = CacheDataset(items, transform=SlowSquare(keys='data'))

However, repeatedly fetching the items from an initialized CacheDataset is fast.

In [None]:
%timeit list(item for item in square_cached)

### Persistent Caching

[PersistantDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#persistentdataset) allows for persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data.

The non-random transform components are computed when first used and stored in the cache_dir for rapid retrieval on subsequent uses.

In [None]:
square_persist = monai.data.PersistentDataset(items, transform=SlowSquare(keys='data'), cache_dir="my_cache")

In [None]:
%time for item in square_persist: print(item)

During the initialization of the PersistentDataset we passed in the parameter "my_cache" for the location to store the intermediate data. We'll look at that directory below.

In [None]:
!ls my_cache

When calling out to the dataset on the following epochs, it will not call the slow transform but used the cached data.

In [None]:
%timeit [item for item in square_persist]

Fresh dataset instances can make use of the caching data:

In [None]:
square_persist_1 = monai.data.PersistentDataset(items, transform=SlowSquare(keys='data'), cache_dir="my_cache")
%timeit [item for item in square_persist_1]

#### Caching in action
- There's also a [SmartCacheDataset](https://docs.monai.io/en/latest/data.html#monai.data.SmartCacheDataset) to hide the transforms latency with less memory consumption.
- The dataset tutorial notebook has a working example and a comparison of different caching mechanism in MONAI: https://github.com/Project-MONAI/tutorials/blob/master/acceleration/dataset_type_performance.ipynb

<img src="datasets_speed.png" style="width: 700px;"/>

## **3. What common datasets are provided by MONAI?**

To quickly get started with popular training data in the medical domain, MONAI provides several data-specific Datasets(like: MedNISTDataset, DecathlonDataset, etc.), which include downloading from our AWS storage, extracting data files and support generation of training/evaluation items with transforms.

The [DecathlonDataset](https://docs.monai.io/en/latest/data.html?highlight=dataset#decathlon-datalist) function leverages the features described throughout this notebook.  These datasets are an extension of CacheDataset covered above.

In [None]:
dataset = monai.apps.DecathlonDataset(root_dir="./", task="Task09_Spleen", section="training", download=True)

In [None]:
print(dataset.get_properties("numTraining"))
print(dataset.get_properties("description"))

In [None]:
print(dataset[0]['image'].shape)
print(dataset[0]['label'].shape)

## **4. How do you use MONAI Layers?**

In [None]:
from monai.networks.layers import Conv, Act, split_args, Pool

### Convolution as an example

The [Conv](https://docs.monai.io/en/latest/networks.html#convolution) class has two options for the first argument. The second argument must be the number of spatial dimensions, `Conv[name, dimension]`, for example:

In [None]:
print(Conv[Conv.CONV, 1])
print(Conv[Conv.CONV, 2])
print(Conv[Conv.CONV, 3])
print(Conv[Conv.CONVTRANS, 1])
print(Conv[Conv.CONVTRANS, 2])
print(Conv[Conv.CONVTRANS, 3])

The configured classes are the "vanilla" PyTorch layers. We could create instances of them by specifying the layer arguments:

In [None]:
Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
Conv3d(1, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1))

The [Act](https://docs.monai.io/en/latest/networks.html#module-monai.networks.layers.Act) classes don't require the spatial dimension information, but supports additional arguments.

In [None]:
print(Act[Act.PRELU])
Act[Act.PRELU](num_parameters=1, init=0.1)

These could be fully specified with a tuple of `(type_name, arg_dict)`, such as `("prelu", {"num_parameters": 1, "init": 0.1})`:

In [None]:
act_name, act_args = split_args(("prelu", {"num_parameters": 1, "init": 0.1}))
Act[act_name](**act_args)

## **5. How do you use these flexible layers to create a network?**

These APIs allow for flexible definitions of networks.  Below we'll create a class called `MyNetwork` that utilizes `Conv`, `Act`, and `Pool`.  Each Network requires an `__init__` and a `forward` function.

In [None]:
class MyNetwork(torch.nn.Module):
    
  def __init__(self, dims=3, in_channels=1, out_channels=8, kernel_size=3, pool_kernel=2, act="relu"):
    super(MyNetwork, self).__init__()
    # convolution
    self.conv = Conv[Conv.CONV, dims](in_channels, out_channels, kernel_size=kernel_size)
    # activation
    act_type, act_args = split_args(act)
    self.act = Act[act_type](**act_args)
    # pooling
    self.pool = Pool[Pool.MAX, dims](pool_kernel)
  
  def forward(self, x: torch.Tensor):
    x = self.conv(x)
    x = self.act(x)
    x = self.pool(x)
    return x

This network definition can be instantiated to support either 2D or 3D inputs, with flexible kernel sizes.  It becomes handy when adapting the same architecture design for different tasks, switching among 2D, 2.5D, 3D easily.

Almost all the MONAI layers, blocks and networks are extensions of `torch.nn.modules` and follow this pattern. This makes the implementations compatible with any PyTorch pipelines and flexible with the network design. The current collections of those differentiable modules are listed in https://docs.monai.io/en/latest/networks.html.

In [None]:
# default network instance
default_net = MyNetwork()
print(default_net)
print(default_net(torch.ones(3, 1, 20, 20, 30)).shape)

# 2D network instance
elu_net = MyNetwork(dims=2, in_channels=3, act=("elu", {"inplace": True}))
print(elu_net)
print(elu_net(torch.ones(3, 3, 24, 24)).shape)

# 3D network instance with anisotropic kernels
sigmoid_net = MyNetwork(3, in_channels=4, kernel_size=(3, 3, 1), act="sigmoid")
print(sigmoid_net)
print(sigmoid_net(torch.ones(3, 4, 30, 30, 5)).shape)

MONAI includes over 20 Networks, you can find them listed at https://docs.monai.io/en/latest/networks.html#nets.

## **Summary**

We've covered MONAI Datasets, Caching and Networks.  Here are some key highlights:

- A MONAI Dataset is a generic dataset with a len property, getitem property, and an optional callable data transform when fetching a data sample.
- You can use dataset caching to store dataset transforms to speed up training.  Some included Caching options are CachingDataset, PersistentCaching, and SmartCaching
- MONAI provides access to some commonly used medical imaging datasets including the DecathlonDataset
- Understanding the basic MONAI Layers
- Use MONAI layers to implement a flexible network


## **Next Steps**

In this next notebook, we cover MONAI Sliding Inference and Post-Processing Transforms.

You can find more information about everything covered here on our [MONAI Documentation Page](https://docs.monai.io/).  

If you're looking for more examples and tutorials, we have a repo dedicated just to that!  You can find it on our [GitHub Organization Page](https://github.com/Project-MONAI/tutorials).  We also have all of our videos from our first ever MONAI Bootcamp available on our [Youtube Channel](https://www.youtube.com/c/ProjectMONAI)