# Note on license:
This notebook uses examples from the official Lightning repo, which is licensed under Apache 2.0. In compliance with the Apache license, any reused code is relicensed under the license in this project (as of September 2023, the MIT license), but I list modifications to the original code here:
- Refactor PyTorch Lightning example code so it can be used with an adapter class from this project.

In [1]:
# Don't use strict note for type checking yet
%nb_mypy mypy-options --pretty

In [2]:
from abc import abstractmethod, abstractproperty, ABC

import pandas as pd
import numpy as np


# Interface
# =========

class DataSetInterface(ABC):
    """
    This serves as the *abstract* type under which all the concrete dataset 
    interfaces fall. We can use when we want to depend only on the data set 
    abstraction, but not the concrete type of data set. 
    
    At the moment, this interface does not yet defined any shared behavior,
    so it would also be possible to use virtual subclasses (e.g., registering) 
    instead. However, we want to keep the option open for the future to define 
    shared behavior that all the concrete dataset interfaces must implement.
    """
    pass


# Implementation
# ==============
from torch.utils.data import Dataset as TorchDataset

class ImageDataSetImplementation(DataSetInterface):
    def __init__(self, data: TorchDataset):
        self.data = data
    
    @classmethod
    def from_torch(cls, data):
        return cls(data=data)

    def to_torch(self) -> TorchDataset:
        return self.data

In [3]:
# Container
# =========

class DataContainerInterface(ABC):
    @abstractproperty
    def train(self) -> DataSetInterface:
        pass

    @abstractproperty
    def val(self) -> DataSetInterface:
        pass

    @abstractproperty
    def test(self) -> DataSetInterface:
        pass


class DataContainer():
    """
    This class is using the simplest possible way to define a container: It 
    simply has a property for each of the three subsets. Note that we are not 
    implementing generics yet to distinguish containers for different kinds of 
    data sets.
    (Note also that it does not have a method implemented to retrieve the 
    complete set, because this is not necessary for our current testing 
    purposes.) 
    """
    def __init__(self, train: DataSetInterface, val: DataSetInterface, test: DataSetInterface):
        self._train = train
        self._val = val
        self._test = test

    @property
    def train(self):
        return self._train
    
    @property
    def val(self):
        return self._val
    
    @property
    def test(self):
        return self._test


In [4]:
from torch.utils.data import DataLoader
from lightning.pytorch import LightningDataModule, LightningModule

class TorchDataloaderAdapter(LightningDataModule):
    """
    This adapter class takes an instance of our data container and converts it 
    into a LightningDataModule.
    """
    def __init__(self, data_container, batch_size) -> None:
        super().__init__()
        self.data_container = data_container
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_container.train.to_torch(),
            batch_size=self.batch_size
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_container.val.to_torch(), 
            batch_size=self.batch_size
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_container.test.to_torch(), 
            batch_size=self.batch_size
        )

    def predict_dataloader(self, data):
        pass

In [5]:
# Create example data
# ===================
 
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

DATA_DIR = './data'

mnist_train_and_val = MNIST(
    root=DATA_DIR, 
    train=True, 
    download=True,
        transform=transforms.ToTensor()
)
mnist_train, mnist_val = random_split(
    dataset=mnist_train_and_val, 
    lengths=[.9, .1]
)
mnist_test = MNIST(
    root=DATA_DIR, 
    train=False, 
    download=True, 
    transform=transforms.ToTensor()
)

mnist_train = ImageDataSetImplementation.from_torch(mnist_train)
mnist_val = ImageDataSetImplementation.from_torch(mnist_val)
mnist_test = ImageDataSetImplementation.from_torch(mnist_test)

mnist_container = DataContainer(
    train=mnist_train,
    val=mnist_val,
    test=mnist_test
)

In [8]:
# EstimatorInterface
# ==================

class EstimatorInterface(ABC):
    @abstractmethod
    def fit(self):
        pass

    # @abstractmethod
    # def optimize_hyperparameters(self):
    #     pass

    # @abstractmethod
    # def predict(self) -> DataSetInterface:
    #     pass

    # @abstractmethod
    # def evaluate(self):
    #     pass

    # @abstractmethod
    # def main(self):
    #     pass


# EstimatorImplementation
# =======================

from datetime import datetime 
from pydantic import BaseModel
import lightning as L

from utils.lightning_adapter import LitClassifier


class ModelConfig(BaseModel):
    batch_size: int
    fast_dev_run: bool = False
    

class PytorchLightningAdapter(EstimatorInterface):
    """
    This adaptor class takes a PyTorch Lightning classifier and converts its 
    interface to our OO_ML estimator interface.
    """
    def __init__(
        self, 
        Classifier: type[LightningModule],
        data_container: DataContainer,
        config: ModelConfig,
    ) -> None:
        self.config = config
        self.classifier = Classifier()
        # Convert data container to lightning data loader
        self.data_loader =  TorchDataloaderAdapter(
            data_container=data_container,
            batch_size=config.batch_size,    
        )
        
    
    def fit(self):
        # Move trainer instantiation to __init__() if it needs to be accessed outside this method.
        trainer = L.Trainer(fast_dev_run=config.fast_dev_run)
        trainer.fit(
            model=self.classifier,
            datamodule=self.data_loader,
        )

    # def optimize_hyperparameters(self):
    # def predict(self) -> DataSetInterface:
    # def evaluate(self):
    # def main(self):
        # self.optimize_hyperparameters()
        # self.evaluate()
        # self.predict(data: DataInputInterface) -> DataOutputInterface:


if __name__ == "__main__":
    config = ModelConfig(
        batch_size=64,
        fast_dev_run=True,
    )

    model = PytorchLightningAdapter(
        Classifier=LitClassifier,
        data_container=mnist_container,
        config=config,
    )
    model.fit()

    print(datetime.now())

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name     | Type     | Params
--------------------------------------
0 | backbone | Backbone | 101 K 
--------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


2023-09-13 09:07:43.412314
