# Tensor Models with PyTorch-Lightning

In this notebook we show how sensAI's TensorModel wrappers can be used together with pytorch-lightning models
and trainers for even faster development and experimentation.

### Before running the notebook

Install the package and its dependencies, if you haven't done so already. E.g. for an editable install call
```
pip install -e .
```
from the root directory. You can also execute this command directly in the notebook but will need to reload the
kernel afterwards


In [None]:
# Note - this cell should be executed only once per session

%load_ext autoreload
%autoreload 2

import sys, os

# in order to get the top level modules; they are not part of the package
os.chdir("..")
sys.path.append(os.path.abspath("."))

In [None]:
from IPython.display import display
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import pandas as pd
from sensai.data import InputOutputArrays, DataSplitterFractional

from sensai.pytorch_lightning import PLTensorToScalarClassificationModel
from sensai.tensor_model import extractArray

import logging
logging.basicConfig(level=logging.INFO)

from config import get_config

c  = get_config()

## Loading the Data

Unlike in the mnist-based torch-lightning tutorial, here we will load the data in a more "realistic" way,
namely with pandas from disc.

In [None]:
X = pd.read_csv(c.datafile_path("mnist_train.csv.zip"))
labels = pd.DataFrame(X.pop("label"))
X = X.values.reshape(len(X), 28, 28) / 2 ** 8
X = pd.DataFrame({"mnist_image": list(X)}, index=labels.index)

display(X.head())
display(labels.head())

display("Plotting some image from the data set")
some_image = X.iloc[13, 0]
plt.imshow(some_image)
plt.show()


## Using Data Loaders in pure PyTorch Lightning

First, let us see how training would proceed in pure pytorch-lightning.

We will use sensaAI only for obtaining torch data loaders (which otherwise would require a few more lines of code)
by transforming the data frames to arrays, splitting them and converting them to loaders.

In [None]:
VALIDATION_FRACTION = 0.1

full_ds = InputOutputArrays(extractArray(X), extractArray(labels))
splitter = DataSplitterFractional(1-VALIDATION_FRACTION)

train_ds, val_ds = splitter.split(full_ds)
train_dataloader = train_ds.toTorchDataLoader()
val_dataloader = val_ds.toTorchDataLoader()

Now that we have the data loaders, let us forget about sensAI for the moment. We create the model declaration and
trainer with pytorch-lightning and fit on the MNIST data

In [None]:
class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x: torch.Tensor):
        x = x.float()
        x = torch.relu(self.l1(x.view(x.size(0), -1)))
        return F.softmax(x, dim=1)

    def training_step(self, batch, *args):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def validation_step(self, batch, *args):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

In [None]:
mnist_model = MNISTModel()

trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20)
trainer.fit(mnist_model, train_dataloader, val_dataloader)

Let us pick some images from the validation set and look at the results

In [None]:
mini_test_set = val_dataloader.dataset[10:20]
test_images, test_labels = mini_test_set

display(mnist_model(test_images).argmax(axis=1))
display(test_labels)

## Wrapping the Model with sensAI

Now let us wrap the model with sensAI interfaces. Since sensAI offers dedicated wrappers
for pytorch-lightning models, this requires only one additional line of code.

This model maps a tensor to a single label, so the correct class to wrap it with is `PLTensorToScalarClassificationModel`,
where the `PL` prefix stands for pytorch-lightning.

In [None]:
mnist_model = MNISTModel()
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20)
sensaiMnistModel = PLTensorToScalarClassificationModel(mnist_model, trainer, validationFraction=VALIDATION_FRACTION)

NB: Even without dedicated wrappers, it would require only a few more lines of code to get a custom implementation of
a suitable sensAI base class that wraps one's model.

With the wrapped model, we can fit directly on the data frames. We don't lose any of the niceties that pytorch-lightning
brings to the game (both the original model and the trainer are available in `sensaiMnistModel`). By wrapping the
model and trainer we gain all the safety, transparency, flexibility in feature engineering as well
as extensive support for model evaluation that sensAI is all about.

In [None]:
sensaiMnistModel.fit(X, labels)

The wrapped model performs predictions on data frames. Let us take some points from the training set,
perform a prediction on them and have a look at the true labels

In [None]:
display("Predicted data frame")
display(sensaiMnistModel.predict(X.iloc[:10]))
display("True labels data frame")
display(labels.iloc[:10])

## Evaluating Tensor Models

TODO - the evaluation part is unfinished yet (although we could already the above classifier with the standard
vector model evaluators).
We should also include TensorToTensor models here and show how to evaluate them
