In [None]:
from pathlib import Path
import pickle
import rasterio as rio
import numpy as np
from tqdm.auto import tqdm
import torch
from fastai.torch_core import TensorImage

In [None]:
model_name = "regnety_006_v1.09_model.pkl"
raster_folder = Path("/media/nick/SSD1/Coast test/Tas")
raster_files = list(raster_folder.glob("[!.]*[!pred].tif"))
print(len(raster_files))

In [None]:
def normalise(band_stack):
    mean = 0.06330473795822207
    std = 0.02668270641026416
    return ((band_stack / 32767) - mean) / std

In [None]:
def make_patches(band_stack, patch_size):
    patches = []
    top = 0
    left = 0
    for i in range(10):
        for j in range(10):
            patch = band_stack[:, top : top + patch_size, left : left + patch_size]
            patches.append(patch)
            left += patch_size
        left = 0
        top += patch_size
    return patches

In [None]:
def tta_inference(patches, model):
    preds = []
    for patch in tqdm(patches, leave=False):
        front = patch[:4]
        back = patch[4:]
        patch = np.concatenate((back, front), axis=0)
        pred = model(TensorImage(patch).unsqueeze(0).cuda())
        pred = pred.squeeze().cpu().detach().numpy()
        preds.append(pred)
    return np.array(preds)

In [None]:
def stitch_preds(preds_np_mean, patch_size):
    pred_array = np.zeros((10980, 10980))
    for i in range(10):
        for j in range(10):
            pred_array[
                i * patch_size : (i + 1) * patch_size,
                j * patch_size : (j + 1) * patch_size,
            ] = preds_np_mean[i * 10 + j]
    return pred_array

In [None]:
def export_pred(output_path, pred_array, src_raster):
    profile = src_raster.profile.copy()
    profile.update(dtype=rio.int8, count=1, compress="lzw", driver="GTiff")
    with rio.open(output_path, "w", **profile) as dst:
        dst.write(pred_array > 0, 1)

In [None]:
model = pickle.load(open(model_name, "rb"))
patch_size = 1098

In [None]:
for raster_path in tqdm(raster_files):
    output_path = Path(str(raster_path).replace(".tif", "_pred.tif"))
    if output_path.exists():
        continue
    src_raster = rio.open(raster_path)
    band_stack = src_raster.read()
    band_stack = normalise(band_stack)
    patches = make_patches(band_stack, patch_size)
    del band_stack

    preds_np_mean = tta_inference(patches, model)
    del patches

    pred_array = stitch_preds(preds_np_mean, patch_size)
    del preds_np_mean

    export_pred(output_path, pred_array, src_raster)

    torch.cuda.empty_cache()