### Model prediction across KH-9 images
1) Applies the trained models to all KH-9 images across both study areas (raster files for each KH-9 image)
2) Post-processing of model predictions, removing boundary class pixels and identifying individual crater instances (raster files for each KH-9 image)
3) Extract individual crater polygons and centroids from post-processed raster files (geojson files for each KH-9 image)
4) Create final set of crater polygons and centroids for each study area by removing craters that were counted twice due to overlapping KH-9 images (only relevant for tri-border area as Quang Tri images were mosaicked before) 

#### Input:
* *study_areas -> [study_area] -> rasters*: KH-9 images for each study area (geotiff files)
* *model_path_sa*: Best models for each study area after fine-tuning
* *footprint_no_overlap_path*: KH-9 image-level footprints with no overlap between images (geojson files)

#### Parameters:
* *tilesize*: Size of image tiles 
* *pred_overlap*: Overlap of image tiles for the sliding window approach
* *pred_batch_size*: Batch size used during model prediction
* *pred_num_batches*: Number of batches to process together (each KH-9 image is processed in multiple batches due to their large size)
* *pred_nodata*: Nodata value for created output rasters
* *pred_dtype*: Data type of created output rasters
* *postprocess_height*: Tile height used during post-processing step
* *postprocess_width*: Tile width used during post-processing step
* *postprocess_overlap*: Overlap of image tiles for the sliding window approach during post-processing
* *min_crater_area*: Minimum area (in pixels) of a crater in the labelled image tiles (smaller craters are removed)
* *n_classes*: Number of classes in the labelled image tiles = len(crater_ids) + 2 (boundary and background classes)
* *boundary_id*: Integer that represents the boundary class in the labelled image tiles

#### Outputs
* *prediction_path*: model prediction for each KH-9 image (geotiff files)
* *prediction_processed_path*: post-processed model predictions for each KH-9 image (geotiff files)
* *prediction_polygons_path*: crater polygons for each KH-9 image (geojson files) 
* *prediction_polygons_path_sa*: crater polygons aggregated by study area, removing duplicate craters due to overlapping images (geojson files) 
* *prediction_centroids_path*: crater polygon centroids for each KH-9 image (geojson files)
* *prediction_centroids_path_sa*: crater polygon centroids aggregated by study area, removing duplicate craters due to overlapping images (geojson files) 

In [1]:
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
import torch

from torch.utils.data import DataLoader, Dataset
from rasterio import features
from rasterio.windows import intersection, Window
from shapely.geometry import shape
from skimage import morphology
from utils import apply_min_max_scaling, create_dir, load_config
from evaluation import filter_by_size

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
class PredDataset(Dataset):
    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]

        return image


def pred_data(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()

    pred_list = []
    with torch.no_grad():
        for images in data_loader:
            images = images.to(device)
            outputs = model(images)
            pred_list.append(outputs.cpu().numpy())

    # Concatenate the predictions and masks along the batch dimension to get the final result
    pred = np.concatenate(pred_list, axis=0)
    pred = pred.transpose((0, 2, 3, 1))

    return pred


def get_overlapping_windows(src, height, width, overlap, crop_at_edges=True):
    col_offsets = range(0, src.width, width - 2 * overlap)
    row_offsets = range(0, src.height, height - 2 * overlap)
    full_raster_window = Window(
        col_off=0, row_off=0, width=src.width, height=src.height
    )

    windows = list()
    for row_off in row_offsets:
        for col_off in col_offsets:
            window_read = Window(
                col_off=col_off - overlap,
                row_off=row_off - overlap,
                width=width,
                height=height,
            )

            window_write = Window(
                col_off=col_off,
                row_off=row_off,
                width=width - 2 * overlap,
                height=height - 2 * overlap,
            )

            if crop_at_edges:
                window_read = intersection(window_read, full_raster_window)
                window_write = intersection(window_write, full_raster_window)

            windows.append((window_read, window_write))
    return windows


def split_list(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


def pred_raster_in_batches(
    src_path,
    dst_path,
    model_path,
    height,
    width,
    overlap,
    num_batches=2,
    batch_size=32,
    dst_dtype="uint8",
    dst_nodata=255,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model
    model = torch.load(model_path, map_location=device)

    with rasterio.open(src_path) as src:
        profile = src.profile.copy()
        profile["dtype"] = dst_dtype
        profile["nodata"] = dst_nodata

        windows = get_overlapping_windows(src, height, width, overlap)
        window_batches = split_list(windows, num_batches * batch_size)

        create_dir(dst_path, is_file=True)
        n_batches = len(windows) // (num_batches * batch_size)
        print(f"Number of windows: {len(windows)}")
        print(f"Number of batches: {n_batches}")
        i = 0
        with rasterio.open(dst_path, "w", **profile) as dst:
            for window_batch in window_batches:
                print(f"Predicting batch {i+1}/{n_batches}")
                i += 1
                windows_write = list()
                tile_list = list()
                for window_read, window_write in window_batch:
                    array = src.read(1, window=window_read)
                    # skip if nodata values are in tile or if window not full tile
                    if (array == src.nodata).any():
                        continue
                    if array.shape != (height, width):
                        continue

                    windows_write.append(window_write)
                    tile_list.append(array)

                # continue in case all tiles in batch are nodata tiles
                if len(tile_list) == 0:
                    continue
                pred_tiles = np.expand_dims(np.array(tile_list), axis=3)
                pred_tiles = apply_min_max_scaling(pred_tiles)

                # pred
                pred_tensor = torch.FloatTensor(pred_tiles).permute(0, 3, 1, 2)
                dataset = PredDataset(pred_tensor)
                dataset_loader = DataLoader(
                    dataset, batch_size=batch_size, shuffle=False
                )
                preds = pred_data(model, dataset_loader)
                preds = np.argmax(preds, -1).astype("uint8")

                # remove boundaries of image based on selected overlap
                preds = preds[:, overlap:-overlap, overlap:-overlap]
                for pred, window in zip(preds, windows_write):
                    dst.write_band(1, pred, window=window)


def postprocess_pred_tile(pred, boundary_id, n_classes, min_crater_area=25):
    # remove boundary and filter by size filling small holes
    pred[pred == boundary_id] = 0
    pred = filter_by_size(pred, min_crater_area=min_crater_area, filled_id=boundary_id)

    # use pred > 0 here to allow for different crater types to be connected
    pred_crater, n_pred = morphology.label(
        pred > 0, background=0, connectivity=1, return_num=True
    )

    # compute class with largest number of pixels for each predicted and labelled crater
    h_pred = np.histogram2d(
        pred_crater.flatten(), pred.flatten(), bins=(n_pred + 1, range(n_classes + 2))
    )
    # get majority class after removing last column to avoid "neutral" class being the majority
    pred_majority = np.argmax(h_pred[0][:, :-1], axis=-1)

    # map crater ids back to their majority class
    pred_out = pred_majority[pred_crater].astype("uint8")

    return pred_out


def postprocess_raster(
    src_path,
    dst_path,
    boundary_id,
    n_classes,
    min_crater_area,
    height=5120,
    width=5120,
    overlap=512,
):
    src = rasterio.open(src_path)
    profile = src.profile.copy()

    windows = get_overlapping_windows(src, height, width, overlap)

    create_dir(dst_path, is_file=True)
    with rasterio.open(dst_path, "w", **profile, BIGTIFF="YES") as dst:
        for window_read, window_write in windows:
            array = src.read(1, window=window_read)

            # treat nodata as background but set back to nodata afterwards
            nodata_mask = array == src.nodata
            if nodata_mask.all():
                continue
            array[nodata_mask] = 0

            # remove boundaries of image based on selected overlap
            array = postprocess_pred_tile(
                array,
                boundary_id=boundary_id,
                n_classes=n_classes,
                min_crater_area=min_crater_area,
            )

            # reset nodata values and crop image remove overlap areas
            array[nodata_mask] = src.nodata
            array = array[overlap:-overlap, overlap:-overlap]

            # write tile to output raster
            dst.write_band(1, array, window=window_write)


def polygonize_prediction_raster(src_path, dst_polygons_path, dst_centroids_path):
    with rasterio.open(src_path) as src:
        crs = src.crs
        array = src.read(1)
        array[array == src.nodata] = 0

    res = features.shapes(array, array > 0, connectivity=4, transform=src.transform)
    value = []
    geometry = []
    for poly, val in res:
        value.append(val)
        geometry.append(shape(poly))

    df = gpd.GeoDataFrame({"value": value, "geometry": geometry}, crs=crs)

    create_dir(dst_polygons_path, is_file=True)
    df.to_file(dst_polygons_path, driver="GeoJSON")
    df["geometry"] = df.geometry.centroid

    create_dir(dst_centroids_path, is_file=True)
    df.to_file(dst_centroids_path, driver="GeoJSON")

In [3]:
config_path = "../config.yaml"
config = load_config(config_path)
study_areas = config.get("study_areas").keys()

In [4]:
%%time

for study_area in study_areas:
    rasters = config.get("study_areas").get(study_area).get("rasters")
    
    for raster_id, raster_path in rasters.items():
        print(raster_id)

        print("Model prediction ...")
        pred_raster_in_batches(
            src_path=raster_path,
            dst_path=config["prediction_path"].format(study_area=study_area, raster_id=raster_id),
            model_path=config.get("model_path_sa").format(study_area=study_area),
            height=config["tilesize"],
            width=config["tilesize"],
            num_batches=config["pred_num_batches"],
            batch_size=config["pred_batch_size"],
            overlap=config["pred_overlap"],
            dst_dtype=config["pred_dtype"],
            dst_nodata=config["pred_nodata"]
        )

        print("Post-processing ...")
        postprocess_raster(
            src_path=config["prediction_path"].format(study_area=study_area, raster_id=raster_id),
            dst_path=config["prediction_processed_path"].format(study_area=study_area, raster_id=raster_id),
            boundary_id=config["boundary_id"],
            n_classes=config["n_classes"],
            min_crater_area=config["min_crater_area"],
            height=config["postprocess_height"],
            width=config["postprocess_width"],
            overlap=config["postprocess_overlap"]
        )

        print("Polygonize rasters ...")
        polygonize_prediction_raster(
            src_path=config["prediction_processed_path"].format(study_area=study_area, raster_id=raster_id),
            dst_polygons_path=config["prediction_polygons_path"].format(study_area=study_area, raster_id=raster_id),
            dst_centroids_path=config["prediction_centroids_path"].format(study_area=study_area, raster_id=raster_id),    
        )
        


quang_tri_aft
Model prediction ...


  from .autonotebook import tqdm as notebook_tqdm


Directory already exists: ../data/5_model_prediction/quang_tri/pixels
Number of windows: 243486
Number of batches: 76
Predicting batch 1/76
Predicting batch 2/76
Predicting batch 3/76
Predicting batch 4/76
Predicting batch 5/76
Predicting batch 6/76
Predicting batch 7/76
Predicting batch 8/76
Predicting batch 9/76
Predicting batch 10/76
Predicting batch 11/76
Predicting batch 12/76
Predicting batch 13/76
Predicting batch 14/76
Predicting batch 15/76
Predicting batch 16/76
Predicting batch 17/76
Predicting batch 18/76
Predicting batch 19/76
Predicting batch 20/76
Predicting batch 21/76
Predicting batch 22/76
Predicting batch 23/76
Predicting batch 24/76
Predicting batch 25/76
Predicting batch 26/76
Predicting batch 27/76
Predicting batch 28/76
Predicting batch 29/76
Predicting batch 30/76
Predicting batch 31/76
Predicting batch 32/76
Predicting batch 33/76
Predicting batch 34/76
Predicting batch 35/76
Predicting batch 36/76
Predicting batch 37/76
Predicting batch 38/76
Predicting batch 

In [5]:
%%time
# create a final set of crater polygons per study area, avoiding duplicate craters by using 
# the no_overlap footprint generated in 0_extract_imagery_footprints

for study_area in study_areas:
    rasters = config.get("study_areas").get(study_area).get("rasters")

    crater_polygons_list = list() 
    for raster_id in rasters.keys():
        config = load_config(config_path, study_area=study_area, raster_id=raster_id)

        footprint = gpd.read_file(config["footprint_path"].format(study_area=study_area, raster_id=raster_id))
        pred_polygons = gpd.read_file(config["prediction_polygons_path"].format(study_area=study_area, raster_id=raster_id))
        
        overlap = gpd.sjoin(
            pred_polygons,
            footprint,
            how="inner",
            predicate="within").drop("index_right", axis=1)
        
        crater_polygons_list.append(overlap)
    
    # combine all polygons into one dataframe for the full study area
    crater_polygons = pd.concat(crater_polygons_list, ignore_index=True)

    # Convert the combined DataFrame back to a GeoPandas DataFrame
    crater_polygons = gpd.GeoDataFrame(crater_polygons, geometry='geometry')

    create_dir(config.get("prediction_polygons_path_sa"), is_file=True)
    crater_polygons.to_file(config.get("prediction_polygons_path_sa"), driver="GeoJSON")

    crater_polygons["geometry"] = crater_polygons.geometry.centroid
    
    create_dir(config.get("prediction_centroids_path_sa"), is_file=True)
    crater_polygons.to_file(config.get("prediction_centroids_path_sa"), driver="GeoJSON")
    

Directory created: ../outputs/predictions/quang_tri
Directory already exists: ../outputs/predictions/quang_tri
Directory created: ../outputs/predictions/tri_border_area
Directory already exists: ../outputs/predictions/tri_border_area
CPU times: total: 5min 51s
Wall time: 5min 59s
