# Lightning CNNs With SPOT:

## 1. JSON


 `data.lightning_hyper_dict.json`

In [None]:
  "GoogleNet":
    {
    "act_fn": {
            "levels": ["Sigmoid",
                       "Tanh",
                       "ReLU",
                       "LeakyReLU",
                       "ELU",
                       "Swish"],
            "type": "factor",
            "default": "ReLU",
            "transform": "None",
            "class_name": "spotPython.torch.activation",
            "core_model_parameter_type": "instance",
            "lower": 0,
            "upper": 5},
    "optimizer_name": {
        "levels": [
                   "Adam"
                ],
        "type": "factor",
        "default": "Adam",
        "transform": "None",
        "class_name": "torch.optim",
        "core_model_parameter_type": "str",
        "lower": 0,
        "upper": 0}
    }

## 2. Class HyperLightning.py

 ```{python}
 def fun(self,
        X: np.ndarray,
        fun_control: dict = None) -> np.ndarray:
```

1. Args: `numpy.array` from `spotPython`: hyperparameters as numerical values
2. Generates a dictionary with hyperparameters, e.g.:


   ```{JSON}
   config: {
      'act_fn': <class 'spotPython.torch.activation.ReLU'>,
      'optimizer_name': 'Adam'}`
   ```

3. Passes dictionary to method `train_model()` 

## 3. train_model()

 `def train_model(config: dict, fun_control: dict):`

1. Prepares the data, e.g., `CIFAR10DataModule`
2. Sets up the Trainer
3. Sets up the model, e.g.,

```{python}
model = NetCNNBase(
            model_name=fun_control["core_model"].__name__,
            model_hparams=config,
            optimizer_name="Adam",
            optimizer_hparams={"lr": 1e-3, "weight_decay": 1e-4},
        )
```

:::{.callout-note}
### Note: train_model is based on train_model

* Based on:
`def train_model(model_name, save_name=None, **kwargs)` from [https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html](https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html)
:::

## 4. netcnnbase.py

```{python}
class NetCNNBase(L.LightningModule):
    def __init__(self,
                model_name,
                model_hparams,
                optimizer_name,
                optimizer_hparams):
```

1. Saves hyperparameters in `self.hparams`
2. Creates model
3. Creates loss module
4. Creates optimizer
5. Defines forward pass
6. Defines training step
7. Defines validation step
8. Defines test step


:::{.callout-note}
### Note: netCNNBase is based on CIFARModule

* Based on:
`class CIFARModule(L.LightningModule)` from [https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html](https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html)
:::

## 5. GoogleNet

```{python}
class GoogleNet(nn.Module):
    """GoogleNet architecture

    Args:
        num_classes (int):
            Number of classes for the classification task. Defaults to 10.
        act_fn_name (str):
            Name of the activation function. Defaults to "relu".
        **kwargs:
            Additional keyword arguments.

    Attributes:
        hparams (SimpleNamespace):
            Namespace containing the hyperparameters.
        input_net (nn.Sequential):
            Input network.
        inception_blocks (nn.Sequential):
            Inception blocks.
        output_net (nn.Sequential):
            Output network.

    Returns:
        (torch.Tensor):
            Output tensor of the GoogleNet architecture

    Examples:
        >>> from spotPython.light.cnn.googlenet import GoogleNet
            import torch
            import torch.nn as nn
            model = GoogleNet()
            x = torch.randn(1, 3, 32, 32)
            y = model(x)
            y.shape
            torch.Size([1, 10])
    """
```

:::{.callout-note}
### Note: GoogleNet is based on GoogleNet

* Based on:
`class GoogleNet(nn.Module):` from [https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html](https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html)
:::

## 6. InceptionBlock

```{python}
class InceptionBlock(nn.Module):
    def __init__(self, c_in, c_red: dict, c_out: dict, act_fn):
        """
        Inception block as used in GoogLeNet.

        Args:
            c_in:
                Number of input feature maps from the previous layers
            c_red:
                Dictionary with keys "3x3" and "5x5" specifying
                the output of the dimensionality reducing 1x1 convolutions
            c_out:
                Dictionary with keys "1x1", "3x3", "5x5", and "max"
            act_fn:
                Activation class constructor (e.g. nn.ReLU)

        Returns:
            torch.Tensor:
                Output tensor of the inception block

        Examples:
            >>> from spotPython.light.cnn.googlenet import InceptionBlock
                import torch
                import torch.nn as nn
                block = InceptionBlock(3,
                            {"3x3": 32, "5x5": 16},
                            {"1x1": 16, "3x3": 32, "5x5": 8, "max": 8},
                            nn.ReLU)
                x = torch.randn(1, 3, 32, 32)
                y = block(x)
                y.shape
                torch.Size([1, 64, 32, 32])

        """
```

:::{.callout-note}
### Note: Inception Block is based on InceptionBlock

* Based on:
`class InceptionBlock(nn.Module)` from [https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html](https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/04-inception-resnet-densenet.html)
:::

## TBD: ResNet, DenseNet...

## Tensorboard {#sec-tensorboard-31}

The textual output shown in the console (or code cell) can be visualized with Tensorboard.

```{raw}
tensorboard --logdir="runs/"

# Sample Run

In [2]:
from spotPython.utils.init import fun_control_init
from spotPython.utils.file import get_experiment_name, get_spot_tensorboard_path
from spotPython.utils.device import getDevice
from spotPython.light.cnn.googlenet import GoogleNet
from spotPython.data.lightning_hyper_dict import LightningHyperDict
from spotPython.hyperparameters.values import add_core_model_to_fun_control
from spotPython.fun.hyperlightning import HyperLightning
from spotPython.hyperparameters.values import get_default_hyperparameters_as_array

MAX_TIME = 1
INIT_SIZE = 3
WORKERS = 8
PREFIX="TEST"
experiment_name = get_experiment_name(prefix=PREFIX)
fun_control = fun_control_init(
    spot_tensorboard_path=get_spot_tensorboard_path(experiment_name),
    num_workers=WORKERS,
    device=getDevice(),
    _L_in=3,
    _L_out=10,
    TENSORBOARD_CLEAN=True)

add_core_model_to_fun_control(core_model=GoogleNet,
                            fun_control=fun_control,
                            hyper_dict= LightningHyperDict)

X_start = get_default_hyperparameters_as_array(fun_control)
X_start

Global seed set to 42


array([[2, 0]])

In [None]:

hyper_light = HyperLightning(seed=126, log_level=50)
hyper_light.fun(X=X_start, fun_control=fun_control)

In [None]:
fun_control['weights']
