In [None]:
%load_ext autoreload
%autoreload 2

import os
import yaml
import shutil
import sys
from pathlib import Path
from loguru import logger
from tqdm import tqdm

import lightning as L

In [None]:
sys.path.append("../")  # include parent directory
from src.pond_data import PondDataset, PondDataModule
from src.seg_model import SegmentationModel
from src.config_utils import build_kwargs_from_config
from src.inference_utils import (
    get_checkpoint_fpath,
    get_checkpoint_hparams,
)
from src.rollout_utils import polygonize_raster_mask

## Model Inference
This notebook predicts on images given a trained model.

## Input
- Satellite Images (geoTIFF files)
- Trained Model (Pytorch Lightning checkpoint file)

## Output
- Raster Masks (geoTIFF files)
- Predicted Pond Polygons (geopackage files)

## Set Input Parameters

In [None]:
DATA_PATH = Path("../data")
CONFIG_PATH = Path("../config")
MODELS_PATH = Path("../models")

CONFIG_FPATH = CONFIG_PATH / "pond_config.yaml"

# if version is unspecified, will get latest checkpoint
VERSION_NUM = None

In [None]:
kwargs_dict = build_kwargs_from_config(DATA_PATH, CONFIG_FPATH)

PARENT_DIR = kwargs_dict["misc_kwargs"]["parent_dir"]

In [None]:
RANDOM_SEED = kwargs_dict["misc_kwargs"]["random_seed"]
RANDOM_SEED

In [None]:
MODEL_ARTIFACTS_DIR = MODELS_PATH / PARENT_DIR
MODEL_ARTIFACTS_DIR

In [None]:
PRED_POLYGONS_FPATH = DATA_PATH / PARENT_DIR / "pred_polygons.gpkg"

In [None]:
# will load dataset based on this checkpoint
MODEL_CHECKPOINT_FPATH = get_checkpoint_fpath(
    MODEL_ARTIFACTS_DIR, version_num=VERSION_NUM
)
MODEL_CHECKPOINT_FPATH

In [None]:
DATASET_KWARGS = kwargs_dict["dataset_kwargs"]
DATASET_KWARGS

In [None]:
DATAMODULE_KWARGS = kwargs_dict["datamodule_kwargs"]
DATAMODULE_KWARGS

In [None]:
# get the model arguments from the checkpoint folder
model_kwargs_path = get_checkpoint_hparams(MODEL_ARTIFACTS_DIR, version_num=VERSION_NUM)
with open(model_kwargs_path, "r") as stream:
    LIGHTNINGMODULE_KWARGS = yaml.safe_load(stream)
LIGHTNINGMODULE_KWARGS

In [None]:
TRAINER_KWARGS = {
    "accelerator": "auto",
    "logger": False,
}

## Make folder for prediction masks

In [None]:
PREDICT_MASK_FOLDER = DATAMODULE_KWARGS["predict_masks_root"]

if os.path.exists(PREDICT_MASK_FOLDER):
    logger.info(f"Deleting files in existing folder {PREDICT_MASK_FOLDER}")
    shutil.rmtree(PREDICT_MASK_FOLDER)

PREDICT_MASK_FOLDER.mkdir(exist_ok=True)

## Set the random seed for reproducibility

In [None]:
L.seed_everything(seed=RANDOM_SEED, workers=True)

## Set up the Pytorch Dataset and DataModule

In [None]:
pond_dataset = PondDataset(**DATASET_KWARGS)
pond_dataset

In [None]:
pond_datamodule = PondDataModule(**DATAMODULE_KWARGS)

## Set up the LightningModule 

In [None]:
model = SegmentationModel(**LIGHTNINGMODULE_KWARGS)

## Load Trained Model and Trainer

In [None]:
model = model.load_from_checkpoint(MODEL_CHECKPOINT_FPATH)

In [None]:
trainer = L.Trainer(**TRAINER_KWARGS)

## Model Prediction

Predict on test data

In [None]:
%%time
predictions = trainer.predict(model=model, datamodule=pond_datamodule)

In [None]:
%%time
for predict_batch in tqdm(predictions):
    idx_batch, pred_mask_batch = predict_batch
    idx_batch = idx_batch.numpy()
    pred_mask_batch = pred_mask_batch.numpy()

    batch_predict_dict = dict(zip(idx_batch, pred_mask_batch))

    # save predictions to disk
    for idx, pred_mask in batch_predict_dict.items():
        pond_datamodule.predict_dataset.save_predict_mask(
            idx, pred_mask, output_dir=PREDICT_MASK_FOLDER
        )

## Visualize model predictions

Visualize actual mask vs predicted mask

In [None]:
# setup the training dataset
pond_datamodule.setup(stage="fit")

### Predict on the Training Set

In [None]:
i = 12
train_indices = pond_datamodule.train_dataset.indices
pond_datamodule.train_dataset.dataset.plot_img(train_indices[i])
pond_datamodule.predict_dataset.plot_img(train_indices[i])

### Predict on the Validation Set

In [None]:
i = 12
val_indices = pond_datamodule.val_dataset.indices
pond_datamodule.val_dataset.dataset.plot_img(val_indices[i])
pond_datamodule.predict_dataset.plot_img(val_indices[i])

## Polygonize Predictions

Convert TIFF files into polygons

In [None]:
# Get all tif files within the pond_masks folder
pred_mask_fpaths = sorted(list((PREDICT_MASK_FOLDER).glob("**/*tif")))
len(pred_mask_fpaths), pred_mask_fpaths[:3]

In [None]:
%%time
# Delete polygons file if it exists
# It will be generated in this loop
PRED_POLYGONS_FPATH.unlink(missing_ok=True)
pred_mask_crs = None
empty_prediction_fpaths = []

for pred_mask_fpath in tqdm(pred_mask_fpaths):
    skip_labels = [pond_dataset.BACKGROUND_PIXEL_VAL]
    pred_polygons = polygonize_raster_mask(
        pred_mask_fpath,
        skip_labels=skip_labels,
        simplify_tolerance_m=None,
    )

    # Skip next steps if no polygons were generated
    if pred_polygons.empty:
        empty_prediction_fpaths.append(pred_mask_fpath)
        continue

    # check if crs is consistent
    if pred_mask_crs is None:
        pred_mask_crs = pred_polygons.crs
    if pred_mask_crs != pred_polygons.crs:
        error_msg = f"Incompatible crs of {pred_mask_path}. Expected {pred_mask_crs} but it has {pred_polygons.crs}"
        raise ValueError(error_msg)

    pred_polygons["label"] = pred_polygons["label"].map(pond_dataset.label_mapping)
    pred_polygons = pred_polygons.to_crs("epsg:4326")

    # Write to file
    write_mode = "a" if PRED_POLYGONS_FPATH.exists() else "w"
    pred_polygons.to_file(
        PRED_POLYGONS_FPATH, driver="GPKG", mode=write_mode, index=False
    )

logger.info(
    f"Finished polygonizing. There were {len(empty_prediction_fpaths):,} TIFF files that were purely background"
)