In [None]:
from localtileserver import get_leaflet_tile_layer, TileClient
from ipyleaflet import Map, WidgetControl, DrawControl, LayersControl, Marker, Icon, MarkerCluster
import ipywidgets

import rioxarray as rxr
import xarray as xr
import numpy as np
from p_tqdm import p_map
# import matplotlib.pyplot as plt
from time import time
import os
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from rasterio.enums import Resampling

In [None]:
ms_raster = rxr.open_rasterio('data/multispectral.nc', masked=True)
ms_raster = ms_raster.squeeze(dim='band', drop=True)
bands = list(ms_raster.keys())
ms_raster = ms_raster.to_array()


In [None]:
#rescale raster
rescale_factor = 0.5
ms_raster = ms_raster.rio.reproject(ms_raster.rio.crs, shape=(round(ms_raster.shape[1]*rescale_factor), round(ms_raster.shape[2]*rescale_factor)), resampling=Resampling.bilinear)

In [None]:
#normalize
mins = ms_raster.min(dim=['x', 'y'])
maxs = ms_raster.max(dim=['x', 'y'])
ms_raster = (ms_raster - mins) / (maxs - mins)

In [None]:
#initialize raster for storing segmentation
seg_raster = ms_raster[0].copy()
seg_raster.values[:] = 0
seg_raster = seg_raster.astype(np.float32)

In [None]:
dataset = ms_raster.to_numpy()
shape = dataset.shape
nan_mask = ~(np.isnan(dataset).any(axis=0))
dataset = dataset[:, nan_mask]
dataset = dataset.reshape(shape[0], -1)
dataset = dataset.swapaxes(0, 1)

dataset = torch.from_numpy(dataset)

In [None]:
#create two layer model with 5 inputs and 1 output
class Segmentator():
    class EarlyStopper:
        def __init__(self, patience, min_delta):
            self.patience = patience
            self.min_delta = min_delta
            self.counter = 0
            self.best_loss = np.inf
        def early_stop(self, loss):
            if loss < self.best_loss - self.min_delta:
                self.best_loss = loss
                self.counter = 0
            else:
                self.counter += 1
                if self.counter >= self.patience:
                    return True
            return False
    
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.layer1 = torch.nn.Linear(5, 10)
            self.layer2 = torch.nn.Linear(10, 1)
            self.sigmoid = torch.nn.Sigmoid()
        def forward(self, x):
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.sigmoid(x)
            return x
        
    def __init__(self, init_lr, es_patience, es_min_delta, epochs):
        self.model = self.Model()
        self.epochs = epochs
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=init_lr)
        self.early_stopper = self.EarlyStopper(patience=es_patience, min_delta=es_min_delta)
        

    def train(self, x, y):
        self.model.train()
        losses = []
        pbar = tqdm(range(self.epochs))
        for epoch in pbar:
            self.optimizer.zero_grad()
            y_ = self.model(x).squeeze()
            loss = torch.nn.functional.binary_cross_entropy(y_, y)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            if self.early_stopper.early_stop(loss.item()):
                print("Early stop")
                break
        return losses

    def predict(self, x):
        self.model.eval()
        with torch.no_grad():
            y_ = self.model(x).squeeze()
        return y_

def make_segmentation():
    global seg_probs
    x_init = torch.cat([torch.from_numpy(np.array(pos_values)), torch.from_numpy(np.array(neg_values))])
    y_init = torch.cat([torch.ones(len(pos_values)), torch.zeros(len(neg_values))])
    #permute
    perm = torch.randperm(len(x_init))
    x_init = x_init[perm]
    y_init = y_init[perm]

    segmentator = Segmentator(init_lr=1, es_patience=10, es_min_delta=0.0001, epochs=1000)
    losses = segmentator.train(x_init, y_init)

    y_ = segmentator.predict(dataset)
    seg_probs = np.zeros((shape[1], shape[2]), dtype=np.float32)
    seg_probs[nan_mask] = y_.numpy()
    seg_probs = seg_probs * nan_mask

In [None]:
basemap_client = TileClient("data/rgb.tif")
m = Map(center=basemap_client.center(), zoom=basemap_client.default_zoom, scroll_wheel_zoom=True)
basemap_layer = get_leaflet_tile_layer(basemap_client, name='Raster')
m.add_layer(basemap_layer)
m.add_control(LayersControl(position='bottomright'))

#custom toolbar
posneg_box = ipywidgets.VBox()
action_box = ipywidgets.VBox()
pos_button = ipywidgets.ToggleButton(description='Positive')
pos_button.set_trait('value', True)
state = True
neg_button = ipywidgets.ToggleButton(description='Negative')
gen_button = ipywidgets.Button(description='Generate')
clear_button = ipywidgets.Button(description='Clear')
def pos_button_event(b):
    global state
    if b["new"]==True:
        #untoggle other buttons
        neg_button.set_trait('value', False)
        state = True
def neg_button_event(b):
    global state
    if b["new"]==True:
        #untoggle other buttons
        pos_button.set_trait('value', False)
        state = False
def clear_button_event(b):
    global pos_points, neg_points, pos_values, neg_values, mask_layer, mask_client, mask_path
    pos_points = []
    neg_points = []
    pos_values = []
    neg_values = []
    pos_markers.markers = []
    neg_markers.markers = []
    draw_control.clear()
    for layer in m.layers:
        if layer.name == 'Mask':
            m.remove_layer(layer)
            del mask_layer
            del mask_client
            #delete file
            os.remove(mask_path)
def gen_button_event(b):
    make_segmentation()
    visualize_mask()
#add events to buttons
pos_button.observe(pos_button_event, names=['value'])
neg_button.observe(neg_button_event, names=['value'])
gen_button.on_click(gen_button_event)
clear_button.on_click(clear_button_event)

#add buttons to widget
posneg_box.children = [pos_button, neg_button]
action_box.children = [gen_button, clear_button]

m.add_control(WidgetControl(widget=posneg_box, position='topright'))
m.add_control(WidgetControl(widget=action_box, position='topright'))

#define icons
pos_icon = Icon(icon_url='https://raw.githubusercontent.com/pointhi/leaflet-color-markers/master/img/marker-icon-2x-green.png', icon_size=[25, 41], icon_anchor=[12, 41])
neg_icon = Icon(icon_url='https://raw.githubusercontent.com/pointhi/leaflet-color-markers/master/img/marker-icon-2x-red.png', icon_size=[25, 41], icon_anchor=[12, 41])

draw_control = DrawControl(polyline={}, polygon={}, circle={}, rectangle={}, circlemarker={}, marker={"repeatMode": True}, edit=False, remove=False)
pos_points = []
neg_points = []
pos_values = []
neg_values = []
pos_markers = MarkerCluster(name='Positive')
neg_markers = MarkerCluster(name='Negative')
m.add_layer(pos_markers)
m.add_layer(neg_markers)
def draw_event(target, action, geo_json):
    global point
    if state == True:
        draw_control.clear()
        point = geo_json['geometry']['coordinates']
        values = probe_rasters(point)
        #if values is not None
        if values is not None:
            pos_points.append(point)
            pos_values.append(values)
            pos_markers.markers = [Marker(location=[point[1],point[0]], name='Positive', icon=pos_icon) for point in pos_points]
        # m.add_layer(Marker(location=[point[1],point[0]], name='Positive', icon=pos_icon))
    elif state == False:
        draw_control.clear()
        point = geo_json['geometry']['coordinates']
        values = probe_rasters(point)
        if values is not None:
            neg_points.append(point)
            neg_values.append(values)
            neg_markers.markers = [Marker(location=[point[1],point[0]], name='Negative', icon=neg_icon) for point in neg_points]
draw_control.on_draw(draw_event)

def probe_rasters(points):
    points = np.array(points)
    flag_1d = False
    #if points are not 2D, make them 2D
    if len(points.shape) == 1:
        flag_1d = True
        points = points.reshape(1,-1)
    #get pixel values
    values = ms_raster.sel(x=points[:,0], y=points[:,1], method='nearest').values
    #get diagonal values
    values = np.diagonal(values, axis1=1, axis2=2)
    if flag_1d:
        values = values.squeeze()
    #if any of values are nan, return None
    if np.isnan(values).any():
        return None
    return values

def visualize_mask():
    global seg_raster, mask_layer, mask_client, mask_path
    
    for layer in m.layers:
        if layer.name == 'Mask':
            m.remove_layer(layer)
            del mask_layer
            del mask_client
            #delete file
            os.remove(mask_path)
    #save mask to temporary file
    mask_path = f'data/mask_{int(time()*1000)}.tif'
    seg_raster.values = (seg_probs>0.5).astype(np.float32)
    seg_raster = seg_raster.where(seg_raster==1,-1)
    seg_raster = seg_raster.rio.write_nodata(-1)
    seg_raster.rio.to_raster(mask_path)
    mask_client = TileClient(mask_path)
    mask_layer = get_leaflet_tile_layer(mask_client, name='Mask', palette=["red", "red"])
    m.add_layer(mask_layer)

m.add_control(draw_control)
m