# Note on license:
This notebook uses examples from the official Lightning repo, which is licensed under Apache 2.0. In compliance with the Apache license, any reused code is relicensed under the license in this project (as of September 2023, the MIT license), but I list modifications to the original code here:
- Refactor PyTorch Lightning example code so it can be used with an adapter class from this project.

# Why this example is abandoned
This turned out to not be a good example code, because the Lightning CLI conflicts with the Jupyter notebook API. I verified that the original pure-Python version of the code also does not run within a Jupyter notebook. While I did find several suggestions on StackOverflow for how to try to remediate this, I did not find a quick way to apply this here, so I decided to instead take a different example.

# Why this notebook is not deleted yet
The way to create an adapter class should be similar for other examples, but I have not had time to port this over yet.

In [7]:
# Don't use strict note for type checking yet
%nb_mypy mypy-options --pretty

In [6]:
from abc import abstractmethod, abstractproperty, ABC
from typing import Any, TypeAlias, NewType, TypeVar, Type, NoReturn

import pandas as pd
import numpy as np


# Interface
# =========

class DataSetInterface(ABC):
    """
    This serves as the *abstract* type under which all the concrete dataset 
    interfaces fall. We can use when we want to depend only on the data set 
    abstraction, but not the concrete type of data set. 
    
    At the moment, this interface does not yet defined any shared behavior,
    so it would also be possible to use virtual subclasses (e.g., registering) 
    instead. However, we want to keep the option open for the future to define 
    shared behavior that all the concrete dataset interfaces must implement.
    """
    pass

class StructuredDataSetInterface(DataSetInterface):
    
    @abstractproperty
    def X(self):
        pass

    @abstractproperty
    def y(self):
        pass


    # Format conversions
    # ------------------
    @classmethod
    @abstractmethod
    def from_pandas(cls, input_data: pd.DataFrame):
        pass

    @abstractmethod
    def to_pandas(self) -> pd.DataFrame:
        pass

    @classmethod
    @abstractmethod
    def from_numpy(cls, input_data: np.ndarray):
        pass

    @abstractmethod
    def to_numpy(self) -> np.ndarray:
        pass

    @abstractmethod
    def get_column_names(self) -> list[str]:
        pass


# Implementation
# ==============

class StructuredDataSetImplementation(DataSetInterface):
    def __init__(self, X: pd.DataFrame, y: pd.DataFrame):
        self._X = X
        self._y = y

    @property
    def X(self):
        return self._X
    
    @property
    def y(self):
        return self._y
    
    @classmethod
    def from_pandas(cls, input_data: pd.DataFrame, target_name: str):
        return cls(pd_data_frame=input_data, target_name=target_name)

    def to_pandas(self) -> pd.DataFrame:
        return self.data    
            
    @classmethod
    def from_numpy(cls, input_data: np.ndarray):
        pd_data_frame=pd.DataFrame(input_data)
        return cls(pd_data_frame=pd_data_frame)

    def to_numpy(self) -> np.ndarray:
        return self.data.to_numpy()
    
    def get_column_names(self) -> list[str]:
        return self.columns.tolist()

<cell>27: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>27: [34mnote:[m Use [m[1m"-> None"[m if function does not return a value[m
<cell>31: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>31: [34mnote:[m Use [m[1m"-> None"[m if function does not return a value[m
<cell>39: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>48: [1m[31merror:[m Missing type parameters for generic type [m[1m"ndarray"[m  [m[33m[type-arg][m
<cell>48: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>52: [1m[31merror:[m Missing type parameters for generic type [m[1m"ndarray"[m  [m[33m[type-arg][m
<cell>64: [34mnote:[m [m[1m"StructuredDataSetImplementation"[m defined here[m
<cell>69: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<

In [2]:
# Container
# =========

class DataContainerInterface(ABC):
    @abstractproperty
    def train(self) -> DataSetInterface:
        pass

    @abstractproperty
    def val(self) -> DataSetInterface:
        pass

    @abstractproperty
    def test(self) -> DataSetInterface:
        pass


class DataContainer():
    def __init__(self, train: DataSetInterface, val: DataSetInterface, test: DataSetInterface):
        self._train = train
        self._val = val
        self._test = test

    @property
    def train(self):
        return self._train
    
    @property
    def val(self):
        return self._val
    
    @property
    def test(self):
        return self._test


<cell>25: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>29: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>33: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m


In [3]:
from torch.utils.data import Dataset as TorchDataset

class ImageDataSetImplementation(DataSetInterface):
    def __init__(self, data: TorchDataset):
        self.data = data
    
    @classmethod
    def from_torch(cls, data):
        return cls(data=data)

    def to_torch(self) -> TorchDataset:
        return self.data

<cell>4: [1m[31merror:[m Missing type parameters for generic type [m[1m"Dataset"[m  [m[33m[type-arg][m
<cell>8: [1m[31merror:[m Function is missing a type annotation  [m[33m[no-untyped-def][m
<cell>11: [1m[31merror:[m Missing type parameters for generic type [m[1m"Dataset"[m  [m[33m[type-arg][m


In [4]:
# Create example data
# ===================
 
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms

DATA_DIR = './data'

mnist_train_and_val = MNIST(root=DATA_DIR, train=True, download=True)
mnist_train, mnist_val = random_split(
    dataset=mnist_train_and_val, 
    lengths=[.9, .1]
)
mnist_test = MNIST(root=DATA_DIR, train=False, download=True)

mnist_train = ImageDataSetImplementation.from_torch(mnist_train)
mnist_val = ImageDataSetImplementation.from_torch(mnist_val)
mnist_test = ImageDataSetImplementation.from_torch(mnist_test)

mnist_container = DataContainer(
    train=mnist_train,
    val=mnist_val,
    test=mnist_test
)

<cell>17: [1m[31merror:[m Call to untyped function [m[1m"from_torch"[m of [m[1m"ImageDataSetImplementation"[m in typed context  [m[33m[no-untyped-call][m
<cell>18: [1m[31merror:[m Call to untyped function [m[1m"from_torch"[m of [m[1m"ImageDataSetImplementation"[m in typed context  [m[33m[no-untyped-call][m
<cell>19: [1m[31merror:[m Call to untyped function [m[1m"from_torch"[m of [m[1m"ImageDataSetImplementation"[m in typed context  [m[33m[no-untyped-call][m


In [5]:
from typing import Optional
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from lightning.pytorch import cli_lightning_logo, LightningDataModule, LightningModule
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.mnist_datamodule import MNIST
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms


In [6]:
def create_torch_dataloader(data_container, batch_size):  # -> type[TorchDataloader]:
    class TorchDataloader(LightningDataModule):
        def __init__(self, data_container) -> None:
            super().__init__()
            self.data_container = data_container

        def train_dataloader(self):
            return DataLoader(
                self.data_container.train.to_torch(),
                batch_size=batch_size
            )

        def val_dataloader(self):
            return DataLoader(
                self.data_container.val.to_torch(), 
                batch_size=self.batch_size
            )

        def test_dataloader(self):
            return DataLoader(
                self.data_container.test.to_torch(), 
                batch_size=self.batch_size
            )

        def predict_dataloader(self, data):
            pass

    return TorchDataloader


<cell>1: [1m[31merror:[m Function is missing a type annotation  [m[33m[no-untyped-def][m
<cell>3: [1m[31merror:[m Function is missing a type annotation for one or more arguments  [m[33m[no-untyped-def][m
<cell>7: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>13: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>16: [1m[31merror:[m [m[1m"TorchDataloader"[m has no attribute [m[1m"batch_size"[m  [m[33m[attr-defined][m
<cell>19: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>22: [1m[31merror:[m [m[1m"TorchDataloader"[m has no attribute [m[1m"batch_size"[m  [m[33m[attr-defined][m
<cell>25: [1m[31merror:[m Function is missing a type annotation  [m[33m[no-untyped-def][m


In [7]:
from pydantic import BaseModel
from lightning_adapter_extras import LitClassifier

# Inner config objects
class SaveConfigKwargs(BaseModel): 
    overwrite: bool
    

# Main config
class Config(BaseModel):
    batch_size: int
    # learning_rate: float = 0.001
    # hidden_dim: int = 128
    # data_dir: str = "./data"
    run: bool
    save_config_kwargs: SaveConfigKwargs
    seed_everything_default: int = 1
        

class PytorchLightningAdapter():
    def __init__(
        self, 
        classifier: LightningModule,
        data_container: DataContainer,
        config: Config,
    ) -> None:
        self.data_container = data_container
        self.classifier = classifier
        self.config = config

    # def optimize_hyperparameters(self):
    # def fit(self) -> None:
    # def predict(self) -> DataSetInterface:


    def cli_main(self):
        
        cli = LightningCLI(
            model_class=self.classifier,
            datamodule_class= create_torch_dataloader(
                self.data_container, 
                batch_size=config.batch_size
            ),
            seed_everything_default=self.config.seed_everything_default,
            save_config_kwargs=self.config.save_config_kwargs.dict(),
            run=self.config.run,
        )
        cli.trainer.fit(cli.model, datamodule=cli.datamodule)
        cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
        predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
        if predictions is not None:
            print(predictions[0])


if __name__ == "__main__":
    config = Config(
        batch_size=64,
        save_config_kwargs=SaveConfigKwargs(overwrite=True),
        run=False,
    )
    
    ptl = PytorchLightningAdapter(
        classifier=LitClassifier,
        data_container=mnist_container,
        config=config,
    )
    ptl.cli_main()

<cell>36: [1m[31merror:[m Function is missing a return type annotation  [m[33m[no-untyped-def][m
<cell>36: [34mnote:[m Use [m[1m"-> None"[m if function does not return a value[m
<cell>40: [1m[31merror:[m Call to untyped function [m[1m"create_torch_dataloader"[m in typed context  [m[33m[no-untyped-call][m
<cell>63: [1m[31merror:[m Argument [m[1m"classifier"[m to [m[1m"PytorchLightningAdapter"[m has incompatible type [m[1m"type[LitClassifier]"[m; expected [m[1m"LightningModule"[m  [m[33m[arg-type][m
<cell>67: [1m[31merror:[m Call to untyped function [m[1m"cli_main"[m in typed context  [m[33m[no-untyped-call][m
usage: ipykernel_launcher.py [-h] [-c CONFIG] [--print_config[=flags]]
                             [--seed_everything SEED_EVERYTHING]
                             [--trainer CONFIG]
                             [--trainer.accelerator.help CLASS_PATH_OR_NAME]
                             [--trainer.accelerator ACCELERATOR]
          

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
