In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from loguru import logger

import lightning as L

In [None]:
sys.path.append("../")  # include parent directory
from src.pond_data import PondDataset, PondDataModule
from src.seg_model import SegmentationModel
from src.config_utils import build_kwargs_from_config

## Batch Size Finder

This notebook attempts to find the best batch size using the [batch size finder](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.BatchSizeFinder.html)


## Input
- Satellite Images (geoTIFF files)
- Raster Masks (geoTIFF files)


## Output
- Recommended batch size

## Set up parameters from config yaml

Feel free to edit the parameters here as well

In [None]:
DATA_PATH = Path("../data")
CONFIG_PATH = Path("../config")

CONFIG_FPATH = CONFIG_PATH / "pond_config.yaml"

In [None]:
kwargs_dict = build_kwargs_from_config(DATA_PATH, CONFIG_FPATH)

In [None]:
RANDOM_SEED = kwargs_dict["misc_kwargs"]["random_seed"]
RANDOM_SEED

In [None]:
DATASET_KWARGS = kwargs_dict["dataset_kwargs"]
DATASET_KWARGS

In [None]:
DATAMODULE_KWARGS = kwargs_dict["datamodule_kwargs"]
DATAMODULE_KWARGS

In [None]:
LIGHTNINGMODULE_KWARGS = kwargs_dict["lightningmodule_kwargs"]
LIGHTNINGMODULE_KWARGS

In [None]:
TRAINER_KWARGS = {
    "accelerator": "auto",
    "devices": 1,
    "enable_checkpointing": False,
    "auto_scale_batch_size": "power",
}

## Set the random seed for reproducibility

In [None]:
L.seed_everything(seed=RANDOM_SEED, workers=True)

## Set up the Pytorch Dataset and DataModule

In [None]:
%%time
pond_dataset = PondDataset(**DATASET_KWARGS)
pond_dataset

In [None]:
pond_datamodule = PondDataModule(**DATAMODULE_KWARGS)

## Set up the LightningModule 

In [None]:
LIGHTNINGMODULE_KWARGS["in_channels"] = pond_dataset.NUM_IN_CHANNELS
LIGHTNINGMODULE_KWARGS["num_classes"] = pond_dataset.NUM_CLASSES

In [None]:
model = SegmentationModel(**LIGHTNINGMODULE_KWARGS)

## Set up the Lightning Trainer

In [None]:
trainer = L.Trainer(**TRAINER_KWARGS)

## Find the recommended batch size

In [None]:
%%time
trainer.tune(model, datamodule=pond_datamodule)