In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path
from loguru import logger
from tqdm import tqdm
import shutil

import lightning as L
from torch.utils.data import Subset, DataLoader

In [None]:
sys.path.append("../")  # include parent directory
from src.pond_data import PondDataset, PondDataModule
from src.seg_model import SegmentationModel
from src.config_utils import build_kwargs_from_config
from src.inference_utils import get_checkpoint_fpath

## Overfitting a batch

This notebook attempts to overfit a batch. Overfitting a batch is a good way of diagnosing errors and bugs with the model training. See more [here](https://fullstackdeeplearning.com/spring2021/lecture-7/#overfit-a-single-batch)

Check the following to verify if you've overfit a batch
1. Model Metrics
2. Model Predictions


## Input
- Satellite Images (geoTIFF files)
- Raster Masks (geoTIFF files)


## Output
- Overfit model
- Model predictions on overfit batch

## Set up parameters from config yaml

Feel free to edit the parameters here as well

In [None]:
DATA_PATH = Path("../data")
CONFIG_PATH = Path("../config")
MODELS_PATH = Path("../models")
MODELS_PATH.mkdir(exist_ok=True)

CONFIG_FPATH = CONFIG_PATH / "pond_config.yaml"

In [None]:
kwargs_dict = build_kwargs_from_config(DATA_PATH, CONFIG_FPATH, MODELS_PATH)

In [None]:
RANDOM_SEED = kwargs_dict["misc_kwargs"]["random_seed"]
RANDOM_SEED

In [None]:
DATASET_KWARGS = kwargs_dict["dataset_kwargs"]
DATASET_KWARGS

In [None]:
DATAMODULE_KWARGS = kwargs_dict["datamodule_kwargs"]
DATAMODULE_KWARGS

In [None]:
LIGHTNINGMODULE_KWARGS = kwargs_dict["lightningmodule_kwargs"]
LIGHTNINGMODULE_KWARGS

In [None]:
parent_dir_name = kwargs_dict["misc_kwargs"]["parent_dir"]
parent_dir_name = f"{parent_dir_name}-overfit_batch"

MODEL_ARTIFACTS_DIR = MODELS_PATH / parent_dir_name
MODEL_ARTIFACTS_DIR

In [None]:
TRAINER_KWARGS = {
    "accelerator": "auto",
    "devices": 1,
    "max_epochs": 100,
    "overfit_batches": 1,
    "logger": kwargs_dict["trainer_kwargs"]["logger"],
    "default_root_dir": MODELS_PATH,
    "callbacks": kwargs_dict["trainer_kwargs"]["callbacks"],
}

## Make folder for prediction masks

In [None]:
PREDICT_MASK_FOLDER = DATAMODULE_KWARGS["predict_masks_root"]
PREDICT_MASK_FOLDER = Path(f"{PREDICT_MASK_FOLDER}-overfit_batch")

if os.path.exists(PREDICT_MASK_FOLDER):
    logger.info(f"Deleting files in existing folder {PREDICT_MASK_FOLDER}")
    shutil.rmtree(PREDICT_MASK_FOLDER)

PREDICT_MASK_FOLDER.mkdir(exist_ok=True)

## Set the random seed for reproducibility

In [None]:
L.seed_everything(seed=RANDOM_SEED, workers=True)

## Set up the Pytorch Dataset and DataModule

Also force the datasets in the datamodule be 1 batch size long

In [None]:
%%time
pond_dataset = PondDataset(**DATASET_KWARGS)
pond_dataset

In [None]:
pond_datamodule = PondDataModule(**DATAMODULE_KWARGS)

In [None]:
# replace the training set with overfit indices only
pond_datamodule.setup(stage="fit")
n_overfit_samples = TRAINER_KWARGS["overfit_batches"] * DATAMODULE_KWARGS["batch_size"]
assert isinstance(n_overfit_samples, int)
overfit_indices = range(n_overfit_samples)
pond_datamodule.train_dataset = Subset(
    pond_datamodule.train_dataset.dataset, overfit_indices
)

In [None]:
# set up training dataloader

dataloader = DataLoader(
    pond_datamodule.train_dataset,
    batch_size=DATAMODULE_KWARGS["batch_size"],
    num_workers=DATAMODULE_KWARGS["num_workers"],
    shuffle=False,
)

## Set up the LightningModule 

In [None]:
LIGHTNINGMODULE_KWARGS["in_channels"] = pond_dataset.NUM_IN_CHANNELS
LIGHTNINGMODULE_KWARGS["num_classes"] = pond_dataset.NUM_CLASSES

In [None]:
if LIGHTNINGMODULE_KWARGS["lr_scheduler"] == "OneCycleLR":
    steps_per_epoch = TRAINER_KWARGS["overfit_batches"]
    epochs = TRAINER_KWARGS.get("max_epochs", None)
    train_time = TRAINER_KWARGS.get("train_time", None)
    assert epochs is not None
    assert train_time is None
    lr_scheduler_config = {"steps_per_epoch": steps_per_epoch, "epochs": epochs}
    LIGHTNINGMODULE_KWARGS["lr_scheduler_config"] = lr_scheduler_config

In [None]:
model = SegmentationModel(**LIGHTNINGMODULE_KWARGS)

## Set up the Lightning Trainer

In [None]:
trainer = L.Trainer(**TRAINER_KWARGS)

## Fit the Model

In [None]:
%%time
trainer.fit(model=model, train_dataloaders=dataloader, val_dataloaders=dataloader)

## Load the best model checkpoint

In [None]:
# will load dataset based on this checkpoint
MODEL_CHECKPOINT_FPATH = get_checkpoint_fpath(MODEL_ARTIFACTS_DIR)
MODEL_CHECKPOINT_FPATH

In [None]:
predict_model = model.load_from_checkpoint(MODEL_CHECKPOINT_FPATH)

## Model Prediction

In [None]:
%%time
predictions = trainer.predict(model=predict_model, dataloaders=dataloader)

In [None]:
%%time
for predict_batch in tqdm(predictions):
    idx_batch, pred_mask_batch = predict_batch
    idx_batch = idx_batch.numpy()
    pred_mask_batch = pred_mask_batch.numpy()

    batch_predict_dict = dict(zip(idx_batch, pred_mask_batch))

    # save predictions to disk
    for idx, pred_mask in batch_predict_dict.items():
        pond_datamodule.train_dataset.dataset.save_predict_mask(
            idx, pred_mask, output_dir=PREDICT_MASK_FOLDER
        )

In [None]:
batch_predict_dict

## Construct a Pytorch dataset to read the prediction mask

In [None]:
pred_dataset_kwargs = DATASET_KWARGS.copy()
pred_dataset_kwargs["masks_root"] = PREDICT_MASK_FOLDER
pred_dataset_kwargs

In [None]:
%%time
pred_dataset = PondDataset(**pred_dataset_kwargs)
pred_dataset

## Visualize model predictions

Visualize actual mask vs predicted mask

In [None]:
i = 2
pond_datamodule.train_dataset.dataset.plot_img(i)
pred_dataset.plot_img(i)