In [1]:
import rasterio as rio
import torch
import numpy as np
from tqdm.auto import tqdm
from pathlib import Path
import pickle

from fastai.vision.all import *
import gc

In [2]:
model_name = "regnety_006_v1.17_model.pkl"
raster_folder = Path("/media/nick/SNEAKERNET/inference_scenes_5")
raster_files = list(raster_folder.glob("[!.]*[!pred].tif"))
print(len(raster_files))

24


In [3]:
means = [
    0.09561849,
    0.09644007,
    0.09435602,
    0.09631168,
    0.09356618,
    0.09504625,
    0.09509373,
    0.09508776,
    0.0911776,
    0.091464,
    0.09334985,
    0.09400712,
]
stds = [
    0.02369863,
    0.03057647,
    0.0244495,
    0.03169953,
    0.02380443,
    0.03068336,
    0.02376207,
    0.03026029,
    0.02387124,
    0.03011121,
    0.02285621,
    0.02902071,
]
means = np.array(means)
stds = np.array(stds)

In [4]:
def normalise(band_stack):
    band_stack = band_stack.astype("float16") / 32767
    band_stack = band_stack - means[:, np.newaxis, np.newaxis]
    band_stack = band_stack / stds[:, np.newaxis, np.newaxis]
    return band_stack

In [5]:
def make_patches(band_stack, patch_size):
    patches = []
    top = 0
    left = 0
    for i in range(10):
        for j in range(10):
            patch = band_stack[:, top : top + patch_size, left : left + patch_size]
            patches.append(patch)
            left += patch_size
        left = 0
        top += patch_size
    return patches

In [6]:
def inference(patches, model):
    preds = []
    for patch in tqdm(patches, leave=False):
        pred = model(TensorImage(patch).unsqueeze(0).cuda().half())
        pred = pred.squeeze().cpu().detach().numpy()
        preds.append(pred)
    return np.array(preds)

In [7]:
def stitch_preds(preds_np_mean, patch_size):
    pred_array = np.zeros((10980, 10980))
    for i in range(10):
        for j in range(10):
            pred_array[
                i * patch_size : (i + 1) * patch_size,
                j * patch_size : (j + 1) * patch_size,
            ] = preds_np_mean[i * 10 + j]
    return pred_array

In [8]:
def export_pred(output_path, pred_array, src_raster):
    profile = src_raster.profile.copy()
    profile.update(dtype=rio.int8, count=1, compress="lzw", driver="GTiff")
    with rio.open(output_path, "w", **profile) as dst:
        dst.write(pred_array > 0, 1)

In [9]:
model = pickle.load(open(model_name, "rb")).half()
patch_size = 1098

In [10]:
for raster_path in tqdm(raster_files):
    output_path = Path(str(raster_path).replace(".tif", "_pred.tif"))
    if output_path.exists():
        continue
    src_raster = rio.open(raster_path)
    band_stack = src_raster.read()
    band_stack = band_stack.astype("float16")

    band_stack = normalise(band_stack)
    patches = make_patches(band_stack, patch_size)

    del band_stack

    preds_np_mean = inference(patches, model)

    del patches

    pred_array = stitch_preds(preds_np_mean, patch_size)

    del preds_np_mean

    export_pred(output_path, pred_array, src_raster)

    torch.cuda.empty_cache()
    gc.collect()

  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]