In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

import lightning as L

In [None]:
sys.path.append("../")  # include parent directory
from src.pond_data import PondDataset, PondDataModule
from src.seg_model import SegmentationModel
from src.config_utils import build_kwargs_from_config, add_lr_scheduler_config

## Model Training

This notebook sets up the pytorch lightning Dataset, DataModule, and LightningModule for training a pond segmentation model using NICFI satellite imagery.


## Input
- Satellite Images (geoTIFF files)
- Raster Masks (geoTIFF files)


## Output
- Trained Model (Pytorch Lightning checkpoint file)

## Set up parameters from config yaml

Feel free to edit the parameters here as well

In [None]:
DATA_PATH = Path("../data")
CONFIG_PATH = Path("../config")
MODELS_PATH = Path("../models")
MODELS_PATH.mkdir(exist_ok=True)

# get data paths here
CONFIG_FPATH = CONFIG_PATH / "pond_config.yaml"

In [None]:
kwargs_dict = build_kwargs_from_config(DATA_PATH, CONFIG_FPATH, MODELS_PATH)

In [None]:
RANDOM_SEED = kwargs_dict["misc_kwargs"]["random_seed"]
RANDOM_SEED

In [None]:
DATASET_KWARGS = kwargs_dict["dataset_kwargs"]
DATASET_KWARGS

In [None]:
DATAMODULE_KWARGS = kwargs_dict["datamodule_kwargs"]
DATAMODULE_KWARGS

In [None]:
LIGHTNINGMODULE_KWARGS = kwargs_dict["lightningmodule_kwargs"]
LIGHTNINGMODULE_KWARGS

In [None]:
TRAINER_KWARGS = {
    "accelerator": "auto",
    "devices": 1,
    "max_epochs": kwargs_dict["trainer_kwargs"]["num_epochs"],
    "max_time": kwargs_dict["trainer_kwargs"]["train_time"],
    "logger": kwargs_dict["trainer_kwargs"]["logger"],
    "default_root_dir": MODELS_PATH,
    "callbacks": kwargs_dict["trainer_kwargs"]["callbacks"],
}

## Set the random seed for reproducibility

In [None]:
L.seed_everything(seed=RANDOM_SEED, workers=True)

## Set up the Pytorch Dataset and DataModule

In [None]:
pond_dataset = PondDataset(**DATASET_KWARGS)
pond_dataset

In [None]:
pond_datamodule = PondDataModule(**DATAMODULE_KWARGS)

## Set up the LightningModule 

In [None]:
LIGHTNINGMODULE_KWARGS["in_channels"] = pond_dataset.NUM_IN_CHANNELS
LIGHTNINGMODULE_KWARGS["num_classes"] = pond_dataset.NUM_CLASSES

In [None]:
LIGHTNINGMODULE_KWARGS = add_lr_scheduler_config(
    LIGHTNINGMODULE_KWARGS, TRAINER_KWARGS, pond_datamodule
)
LIGHTNINGMODULE_KWARGS

In [None]:
model = SegmentationModel(**LIGHTNINGMODULE_KWARGS)

## Set up the Lightning Trainer

In [None]:
trainer = L.Trainer(**TRAINER_KWARGS)

## Fit the Model

The Trainer will automatically save a checkpoint of the model at the end of every epoch.

The Trainer will save to `MODELS_PATH`. See this [reference](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html) for Pytorch Lightning checkpointing.

In [None]:
%%time
trainer.fit(model=model, datamodule=pond_datamodule)

## Validate the Model

In [None]:
# this also updates the confusion matrix in Wandb if the PlotWandbConfusionMatrix callback is being used
trainer.validate(model=model, datamodule=pond_datamodule, ckpt_path="best")