In [None]:
%load_ext autoreload

%autoreload 2
import gcsfs
import numpy as np
from skimage.morphology import (
    closing,
    label,
    remove_small_holes,
    remove_small_objects,
    square,
)
fs = gcsfs.GCSFileSystem()


In [None]:
from oxeo.core.models.tile import load_tile, TilePath, tile_from_id
from oxeo.water.models.segmentation import Segmentation2DPredictor

tile_path = TilePath(tile_from_id("19_J_10000_73_697"), "sentinel-2")
revisit_slice = slice(0,5)
res = 10

seg_predictor = Segmentation2DPredictor(
    batch_size=16,
    ckpt_path="gs://oxeo-models/semseg/epoch_012.ckpt",
    input_channels=6,
    num_classes=3,
    chip_size=250,
    fs=fs,
)


In [None]:
revisit_slice = slice(0,5)
res = seg_predictor.predict(
    tile_path,
    revisit=revisit_slice,
)

In [None]:
tile = load_tile(
            fs_mapper=fs.get_mapper,
            tile_path=tile_path,
            masks=("pekel","cnn"),
            revisit=revisit_slice,
    bands=["nir", "red", "green", "blue", "swir1", "swir2"],
        )

In [None]:
def plot_imgs_in_row(imgs, labels=["img", "pekel", "cnn"], figsize=(8,5)):
    rows = 1
    cols = len(imgs)
    fig, ax = plt.subplots(rows, cols, figsize=figsize)
    # axes are in a two-dimensional array, indexed by [row, col]
   # fig.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0)

    for i in range(cols):
        ax[i].set_title(labels[i])
        img = imgs[i]
        ax[i].imshow(img, vmin=0.0, vmax=1.0, interpolation=None)  
        ax[i].axis('off')
    fig

In [None]:
import matplotlib.pyplot as plt
from skimage.exposure import rescale_intensity
from oxeo.water.metrics import segmentation_area
import os
os.environ["LOGURU_LEVEL"] = "INFO"
areas_diff = []

for i in range(0,5):
    cnn = tile["cnn"].numpy().squeeze()[i]

    cnn[cnn!=1] = 0
    cnn = cnn.astype(bool)
    cnn = closing(cnn, square(3))
    cnn = remove_small_holes(cnn, area_threshold=50, connectivity=2)
    cnn = remove_small_objects(cnn, min_size=50, connectivity=2)
    cnn = label(cnn, background=0, connectivity=2)
    
    
    pekel = tile["pekel"].numpy().squeeze()[i]
    pekel = pekel.astype(bool)
    pekel = closing(pekel, square(3))
    pekel = remove_small_holes(pekel, area_threshold=50, connectivity=2)
    pekel = remove_small_objects(pekel, min_size=50, connectivity=2)
    pekel = label(pekel, background=0, connectivity=2)
    
    img = tile["image"].numpy()[i][[1, 2, 3]].transpose(1, 2, 0)
    area_cnn = segmentation_area(cnn,"meter", res)
    area_pekel = segmentation_area(pekel,"meter", res)
    areas_diff.append(area_pekel/area_cnn)
    print("Seg area: ", area_cnn/1e8, area_pekel/1e8, area_pekel/area_cnn)



    vmin, vmax = np.percentile(img, q=(2, 98))
    img = rescale_intensity(img, in_range=(vmin, vmax), out_range=(0, 1))
    plot_imgs_in_row([img, pekel, cnn])
    plt.show()