# Note on license:
This notebook uses examples from the official Lightning repo, which is licensed under Apache 2.0. In compliance with the Apache license, any reused code is relicensed under the license in this project (as of September 2023, the MIT license), but I list modifications to the original code here:
- Refactor PyTorch Lightning example code so it can be used with an adapter class from this project.

In [1]:
# Don't use strict note for type checking yet
%nb_mypy mypy-options --pretty

In [2]:
from abc import abstractmethod, abstractproperty, ABC

import pandas as pd
import numpy as np
from torch.utils.data import Dataset as TorchDataset

from utils.why_not_generic_data_containers import DataSetInterface

# Data Set
# ========

class ImageDataSetImplementation(DataSetInterface):
    """
    In this simple example, we store the data as a TorchDataset, so we don't 
    have to do any transformation upon instantiation or retrieval.
    """
    def __init__(self, data: TorchDataset):
        self.data = data
    
    @classmethod
    def from_torch(cls, data):
        return cls(data=data)

    def to_torch(self) -> TorchDataset:
        return self.data


# Container
# =========

class DataContainer():
    """
    This class is using the simplest possible way to define a container: It 
    simply has a property for each of the three subsets. Note that we are not 
    implementing generics yet to distinguish containers for different kinds of 
    data sets.
    (Note also that it does not have a method implemented to retrieve the 
    complete set, because this is not necessary for our current testing 
    purposes.) 
    """
    def __init__(self, train: DataSetInterface, val: DataSetInterface, test: DataSetInterface):
        self._train = train
        self._val = val
        self._test = test

    @property
    def train(self):
        return self._train
    
    @property
    def val(self):
        return self._val
    
    @property
    def test(self):
        return self._test

This is where we run into problems with our simple design of making the container simply an aggregation of the different subsets: Since this makes the data sets the place where we store logic to return the data in a specified format, this does not work out-of-the-box with the concept of **data loaders** in libraries such as Pytorch: *Data loaders are classes similar to our data containers* in the sense that they allow retrieving different subsets; however, they differ in the fact that they are specific for particular ML frameworks, and thus expect the data sets to be returned in the model specific framework when we call a specifically named method. The problem is that it is not enough if our data container can *return* the data in the required format, but it would also need to follow the specific *interface* of the data loader. For example, it would need to return the training data when the train_dataloader() method is called, whereas we would achieve this with our data container by calling something like get_validation_set().to_torch().

Of course, this problem is not insurmountable, because we can simply leverage the *adapter* design pattern to translate one interface into another. This could look as follows:

In [3]:
from torch.utils.data import DataLoader as TorchDataLoader
from lightning.pytorch import LightningDataModule, LightningModule

class LightningDataloaderAdapter(LightningDataModule):
    """
    This adapter class takes an instance of our data container and converts it 
    into a LightningDataModule.
    """
    def __init__(self, data_container, batch_size) -> None:
        super().__init__()
        self.data_container = data_container
        self.batch_size = batch_size

    def train_dataloader(self):
        return TorchDataLoader(
            dataset=self.data_container.train.to_torch(),
            batch_size=self.batch_size
        )

    def val_dataloader(self):
        return TorchDataLoader(
            dataset=self.data_container.val.to_torch(), 
            batch_size=self.batch_size
        )

    def test_dataloader(self):
        return TorchDataLoader(
            dataset=self.data_container.test.to_torch(), 
            batch_size=self.batch_size
        )

    def predict_dataloader(self, data):
        pass

This data loader adapter can then be used by our OO_ML estimators, as shown in below cell: Our estimator will get the data is a data container, and can then use this adapter in order to convert the data container into a data loader. The key line is this in the \_\_init\_\_():
```
self.data_loader = LightningDataloaderAdapter(data_container=data_container, ...)        
```

In [4]:
from pydantic import BaseModel
import lightning as L

from utils.why_not_generic_data_containers import ModelConfig, EstimatorInterface

class PytorchLightningAdapter(EstimatorInterface):
    """
    This adaptor class takes a PyTorch Lightning classifier and converts its 
    interface to our OO_ML estimator interface.
    """
    def __init__(
        self, 
        Classifier: type[LightningModule],
        data_container: DataContainer,
        config: ModelConfig,
    ) -> None:
        self.config = config
        self.classifier = Classifier()
        
        # Convert data container to lightning data loader
        # ***********************************************
        self.data_loader =  LightningDataloaderAdapter(
            data_container=data_container,
            batch_size=config.batch_size,    
        )
        # ***********************************************
    
    def fit(self):
        trainer = L.Trainer(fast_dev_run=config.fast_dev_run)
        trainer.fit(
            model=self.classifier,
            datamodule=self.data_loader,
        )

However, while it is good to have the option of leveraging the adapter pattern, it is better to avoid this is possible. It turns out that this is relatively easy to avoid by simply making the data container (rather than the data set) the place that returns data in specific formats. This allows us to simply give the data container a to_dataloader() method, making the logic more straightforward.

In [5]:
from torchvision import transforms

# First we need to define the data loader class that we want to return. It will 
# contain a constructor that will allow instantiating it from a data container.
class LightningDataLoader(LightningDataModule):
    def __init__(self, train_data, val_data, test_data, transforms, batch_size) -> None:
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.batch_size = batch_size
        self.transforms = transforms        

    # This is the constructor we will use
    @classmethod
    def from_data_container(cls, data_container, batch_size: int, transforms):
        return cls(
            train_data=data_container.train.to_torch(),
            val_data=data_container.val.to_torch(),
            test_data=data_container.test.to_torch(),
            batch_size=batch_size,
            transforms=transforms,
        )

    def setup(self):
        # Todo: Split data here
        pass 
    
    def train_dataloader(self):
        return TorchDataLoader(
            dataset=self.train_data,
            batch_size=self.batch_size
        )

    def val_dataloader(self):
        return TorchDataLoader(
            dataset=self.val_data, 
            batch_size=self.batch_size
        )

    def test_dataloader(self):
        return TorchDataLoader(
            dataset=self.test_data, 
            batch_size=self.batch_size
        )

    def predict_dataloader(self, data):
        pass


class DataContainer_v2():
    def __init__(self, train: DataSetInterface, val: DataSetInterface, test: DataSetInterface):
        self._train = train
        self._val = val
        self._test = test

    @property
    def train(self):
        return self._train
    
    @property
    def val(self):
        return self._val
    
    @property
    def test(self):
        return self._test
    
    def to_lightning_dataloader(self, batch_size) -> LightningDataLoader:
        return LightningDataLoader.from_data_container(
            data_container=self,
            batch_size=batch_size,
            transforms=None
        )
    

In [6]:
from pydantic import BaseModel
import lightning as L

from utils.why_not_generic_data_containers import ModelConfig, EstimatorInterface

class PytorchLightningAdapter_v2(EstimatorInterface):
    """
    This adaptor class takes a PyTorch Lightning classifier and converts its 
    interface to our OO_ML estimator interface.
    """
    def __init__(
        self, 
        Classifier: type[LightningModule],
        data_container: DataContainer_v2,  # Note we are using the new container
        config: ModelConfig,
    ) -> None:
        self.config = config
        self.classifier = Classifier()

        # Convert data container to lightning data loader
        # ************************************************
        self.data_loader = data_container.to_lightning_dataloader(
            batch_size=config.batch_size,    
        )
        # ************************************************
        
    
    def fit(self):
        trainer = L.Trainer(fast_dev_run=config.fast_dev_run)
        trainer.fit(
            model=self.classifier,
            datamodule=self.data_loader,
        )

# Conclusion
Either way works, but I think that adding that to_lightning_dataloadermethod to the data container is slightly preferable because:
- It is less complex because it requires using only one rather than two adapters. (However, since the data loader adapter would be used in the lightning adapter, most of the added complexity is at least hidden from the library's user.)
- It may lead to better code organization, because the code for constructing the lightning adapter is located near the data container code.

However, I'm still somewhat on the fence because of the following disadvantages of this solution:
- The data container now needs to know about the data set's internals. We need to be careful this doesn't lead to any problems, such as circular imports.
- Except for PyTorch lightning, we have separate classes for training, validation, and test sets, so the conversion to the data format required by a specific library (e.g., to_numpy_array) is most naturally handled at the data *set* level, not the data *container* level. This raises the question whether we want to move this data format conversion from the data set to the container level for all ML libraries, or whether we are okay with this slight inconsistency of only handling it at the container level for Lightning. At this point I'm not ready to make a decision yet, but will experiment with the implications of both choices.

In any case, it is not ideal that we are letting a single library (Pytorch Lightning) determine our overall design. However, in the end I still think this is the right decision because:
- Lightning is one of the best libraries for many deep learning use cases;
- [Even when using plain PyTorch, you can still use Lightning's data module](https://pytorch-lightning.readthedocs.io/en/0.9.0/datamodules.html#datamodules-without-lightning) for better encapsulation;
- it is possible that we will encounter other libraries using this same design for the data loader;
- it is possible that libraries that do not currently raise this problem may add the same ability for a higher-level data loader as Lightning.

# Appendix: Model Run
The code below is not relevant to the main point of this notebook, but it is useful for verifying that everything is working. Note, however, that the **way of interacting with the model stays the same, independent of which container implementation we use.**

In [7]:
# Create example data
# ===================

from datetime import datetime 

from torch.utils.data import random_split
from torchvision.datasets import MNIST

from utils.why_not_generic_data_containers import LitClassifier

DATA_DIR = './data'

mnist_train_and_val = MNIST(
    root=DATA_DIR, 
    train=True, 
    download=True,
        transform=transforms.ToTensor()
)
mnist_train, mnist_val = random_split(
    dataset=mnist_train_and_val, 
    lengths=[.9, .1]
)
mnist_test = MNIST(
    root=DATA_DIR, 
    train=False, 
    download=True, 
    transform=transforms.ToTensor()
)

mnist_train = ImageDataSetImplementation.from_torch(mnist_train)
mnist_val = ImageDataSetImplementation.from_torch(mnist_val)
mnist_test = ImageDataSetImplementation.from_torch(mnist_test)

# v1 of container (requires data loader adapter)
mnist_container = DataContainer(
    train=mnist_train,
    val=mnist_val,
    test=mnist_test
)

# v2 of container with to_lightning_dataloader() method
mnist_container_v2 = DataContainer_v2(
    train=mnist_train,
    val=mnist_val,
    test=mnist_test
)


config = ModelConfig(
    batch_size=64,
    fast_dev_run=True,
)

### Version 1: Using adapter for data loader

In [8]:
if __name__ == "__main__":
    model = PytorchLightningAdapter(
        Classifier=LitClassifier,
        data_container=mnist_container,
        config=config,
    )
    model.fit()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name     | Type     | Params
--------------------------------------
0 | backbone | Backbone | 101 K 
--------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


### Version 2: Giving container a to_data_loader method

In [9]:
if __name__ == "__main__":
    model_v2 = PytorchLightningAdapter_v2(
        Classifier=LitClassifier,
        data_container=mnist_container_v2,
        config=config,
    )
    model.fit()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name     | Type     | Params
--------------------------------------
0 | backbone | Backbone | 101 K 
--------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [10]:
print(datetime.now())

2023-09-14 09:23:04.533447
