# Note on license:
This notebook uses examples from the official Lightning repo, which is licensed under Apache 2.0. In compliance with the Apache license, any reused code is relicensed under the license in this project (as of September 2023, the MIT license), but I list modifications to the original code here:
- Refactor PyTorch Lightning example code so it can be used with an adapter class from this project.

In [1]:
# Don't use strict note for type checking yet
%nb_mypy mypy-options --pretty

In [2]:
FAST_DEV_RUN=True

In [3]:
from abc import abstractmethod, abstractproperty, ABC

import pandas as pd
import numpy as np


# Interface
# =========

class DataSetInterface(ABC):
    """
    This serves as the *abstract* type under which all the concrete dataset 
    interfaces fall. We can use when we want to depend only on the data set 
    abstraction, but not the concrete type of data set. 
    
    At the moment, this interface does not yet defined any shared behavior,
    so it would also be possible to use virtual subclasses (e.g., registering) 
    instead. However, we want to keep the option open for the future to define 
    shared behavior that all the concrete dataset interfaces must implement.
    """
    pass

class StructuredDataSetInterface(DataSetInterface):
    
    @abstractproperty
    def X(self):
        pass

    @abstractproperty
    def y(self):
        pass


    # Format conversions
    # ------------------
    @classmethod
    @abstractmethod
    def from_pandas(cls, input_data: pd.DataFrame):
        pass

    @abstractmethod
    def to_pandas(self) -> pd.DataFrame:
        pass

    @classmethod
    @abstractmethod
    def from_numpy(cls, input_data: np.ndarray):
        pass

    @abstractmethod
    def to_numpy(self) -> np.ndarray:
        pass

    @abstractmethod
    def get_column_names(self) -> list[str]:
        pass


# Implementation
# ==============

class StructuredDataSetImplementation(DataSetInterface):
    def __init__(self, X: pd.DataFrame, y: pd.DataFrame):
        self._X = X
        self._y = y

    @property
    def X(self):
        return self._X
    
    @property
    def y(self):
        return self._y
    
    @classmethod
    def from_pandas(cls, input_data: pd.DataFrame, target_name: str):
        return cls(pd_data_frame=input_data, target_name=target_name)

    def to_pandas(self) -> pd.DataFrame:
        return self.data    
            
    @classmethod
    def from_numpy(cls, input_data: np.ndarray):
        pd_data_frame=pd.DataFrame(input_data)
        return cls(pd_data_frame=pd_data_frame)

    def to_numpy(self) -> np.ndarray:
        return self.data.to_numpy()
    
    def get_column_names(self) -> list[str]:
        return self.columns.tolist()

<cell>63: [34mnote:[m [m[1m"StructuredDataSetImplementation"[m defined here[m
<cell>77: [1m[31merror:[m Unexpected keyword argument [m[1m"pd_data_frame"[m for
<cell>77: [1m[31merror:[m Unexpected keyword argument [m[1m"target_name"[m for
<cell>80: [1m[31merror:[m [m[1m"StructuredDataSetImplementation"[m has no attribute [m[1m"data"[m 
<cell>85: [1m[31merror:[m Unexpected keyword argument [m[1m"pd_data_frame"[m for
<cell>88: [1m[31merror:[m [m[1m"StructuredDataSetImplementation"[m has no attribute [m[1m"data"[m 
<cell>91: [1m[31merror:[m [m[1m"StructuredDataSetImplementation"[m has no attribute


In [4]:
# Container
# =========

class DataContainerInterface(ABC):
    @abstractproperty
    def train(self) -> DataSetInterface:
        pass

    @abstractproperty
    def val(self) -> DataSetInterface:
        pass

    @abstractproperty
    def test(self) -> DataSetInterface:
        pass


class DataContainer():
    def __init__(self, train: DataSetInterface, val: DataSetInterface, test: DataSetInterface):
        self._train = train
        self._val = val
        self._test = test

    @property
    def train(self):
        return self._train
    
    @property
    def val(self):
        return self._val
    
    @property
    def test(self):
        return self._test


In [5]:
from torch.utils.data import Dataset as TorchDataset

class ImageDataSetImplementation(DataSetInterface):
    def __init__(self, data: TorchDataset):
        self.data = data
    
    @classmethod
    def from_torch(cls, data):
        return cls(data=data)

    def to_torch(self) -> TorchDataset:
        return self.data

In [6]:
# Create example data
# ===================
 
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

from lightning.pytorch import LightningDataModule, LightningModule
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms


DATA_DIR = './data'

mnist_train_and_val = MNIST(
    root=DATA_DIR, 
    train=True, 
    download=True,
        transform=transforms.ToTensor()
)
mnist_train, mnist_val = random_split(
    dataset=mnist_train_and_val, 
    lengths=[.9, .1]
)
mnist_test = MNIST(
    root=DATA_DIR, 
    train=False, 
    download=True, 
    transform=transforms.ToTensor()
)

mnist_train = ImageDataSetImplementation.from_torch(mnist_train)
mnist_val = ImageDataSetImplementation.from_torch(mnist_val)
mnist_test = ImageDataSetImplementation.from_torch(mnist_test)

mnist_container = DataContainer(
    train=mnist_train,
    val=mnist_val,
    test=mnist_test
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 16784757.98it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 11823884.22it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 12688009.49it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 25709215.61it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [7]:
class TorchDataloaderAdapter(LightningDataModule):
    def __init__(self, data_container, batch_size) -> None:
        super().__init__()
        self.data_container = data_container
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            self.data_container.train.to_torch(),
            batch_size=self.batch_size
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_container.val.to_torch(), 
            batch_size=self.batch_size
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_container.test.to_torch(), 
            batch_size=self.batch_size
        )

    def predict_dataloader(self, data):
        pass

my_dataloader = TorchDataloaderAdapter(
    data_container=mnist_container,
    batch_size=64,    
)

In [8]:
from pydantic import BaseModel
import lightning as L

from ..lightning_adapter_extras import LitClassifier


# Inner config objects
class SaveConfigKwargs(BaseModel): 
    overwrite: bool
    

# Main config
class Config(BaseModel):
    batch_size: int
    # learning_rate: float = 0.001
    # hidden_dim: int = 128
    # data_dir: str = "./data"
    run: bool
    save_config_kwargs: SaveConfigKwargs
    seed_everything_default: int = 1
        

class PytorchLightningAdapter():
    def __init__(
        self, 
        classifier: LightningModule,
        data_container: DataContainer,
        config: Config,
    ) -> None:
        self.data_container = data_container
        self.classifier = classifier
        self.config = config

    # def optimize_hyperparameters(self):
    # def fit(self) -> None:
    # def predict(self) -> DataSetInterface:


    # def main(self):
        
    #     cli = LightningCLI(
    #         model_class=self.classifier,
    #         datamodule_class= create_torch_dataloader(
    #             self.data_container, 
    #             batch_size=config.batch_size
    #         ),
    #         seed_everything_default=self.config.seed_everything_default,
    #         save_config_kwargs=self.config.save_config_kwargs.dict(),
    #         run=self.config.run,
    #     )
    #     cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    #     cli.trainer.test(ckpt_path="best", datamodule=cli.datamodule)
    #     predictions = cli.trainer.predict(ckpt_path="best", datamodule=cli.datamodule)
    #     if predictions is not None:
    #         print(predictions[0])


if __name__ == "__main__":
    config = Config(
        batch_size=64,
        save_config_kwargs=SaveConfigKwargs(overwrite=True),
        run=False,
    )
    
    classifier=LitClassifier()
    trainer = L.Trainer(fast_dev_run=FAST_DEV_RUN)
    
    trainer.fit(
        model=classifier,
        # data.DataLoader(train), data.DataLoader(val))
        datamodule=my_dataloader,
    )
    # ptl = PytorchLightningAdapter(
    #     classifier=LitClassifier,
    #     data_container=mnist_container,
    #     config=config,
    # ptl.cli_main()

<cell>4: [1m[31merror:[m Relative import climbs too many namespaces  [m[33m[misc][m


ImportError: attempted relative import with no known parent package