In [None]:
from pathlib import Path

from datacube import Datacube
from dea_tools.spatial import subpixel_contours
import folium
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyTMD
from scipy.ndimage import binary_fill_holes, gaussian_filter1d, uniform_filter, variance
from skimage import img_as_ubyte
from skimage.filters.thresholding import threshold_local, threshold_otsu
from skimage.morphology import remove_small_objects
from shapely.geometry import LineString, MultiLineString
import xarray as xr

In [None]:
dc = Datacube(app=__name__)

In [None]:
region_path = Path("../region/coastal_grids.geojson")
point_path = Path("../region/coastal_points.geojson")

region_gdf = gpd.read_file(region_path)
point_gdf = gpd.read_file(point_path)

filter_region_gdf = region_gdf.query("province == 'BALI'")
centroid = filter_region_gdf.unary_union.centroid

m = region_gdf.explore(
    location=[centroid.y, centroid.x],
    zoom_start=9,
    style_kwds={"fillOpacity": 0, "color": "red", "linewidth": 1}
)

folium.GeoJson(point_gdf.to_json()).add_to(m)

for i, row in region_gdf.iterrows():
    centroid = row.geometry.centroid
    folium.Marker(
        location=[centroid.y, centroid.x],
        icon=folium.DivIcon(
            html=f"<div style='font-size: 12px'>{i+1}</div>"
        )
    ).add_to(m)

m

In [None]:
region_id = 715

selected_region_gdf = region_gdf.loc[[region_id-1]]
selected_point_gdf = point_gdf.loc[[region_id-1]]

m = selected_region_gdf.explore(style_kwds={"fillOpacity": 0, "color": "red"})
folium.GeoJson(selected_point_gdf.to_json()).add_to(m)
m

In [None]:
xmin, ymin, xmax, ymax = selected_region_gdf.total_bounds

In [None]:
search_query = {
    "time": ("2015-01-01", "2022-05-31"),
    "longitude": (xmin, xmax),
    "latitude": (ymin, ymax),
    # "dask_chunks": {"time": 1, "x": 1024, "y": 1024},
}

ds = dc.load(product="s1_iw_gamma0_rtc_vh", **search_query)

ds

In [None]:
time_list = ds.time.values
time_list

In [None]:
def get_constants(
    lon: np.array, lat: np.array, model: pyTMD.model
) -> tuple:

    print("Extracting constants...")

    # get amplitude and phase
    amp, ph = pyTMD.extract_FES_constants(
        np.atleast_1d(lon),
        np.atleast_1d(lat),
        model.model_file,
        TYPE=model.type,
        VERSION=model.version,
        METHOD="spline",
        EXTRAPOLATE=True,
        SCALE=model.scale,
        GZIP=model.compressed,
    )

    return amp, ph


def model_tide_prediction(lon: np.array, lat: np.array, date_list: np.array, model_dir: Path) -> np.array:
    
    print("Tide prediction")

    # convert list of datetime
    tide_time = pyTMD.time.convert_datetime(date_list)

    # define model directory and initialize model based on model format
    model = pyTMD.model(model_dir, format="FES", compressed=False).elevation("FES2014")

    # get tide constants (amplitude and phase) and it will take a while
    amp, ph = get_constants(lon, lat, model)

    # extract model constituent
    c = model.constituents

    # calculate delta time
    delta_file = pyTMD.utilities.get_data_path(["data", "merged_deltat.data"])
    DELTAT = pyTMD.calc_delta_time(delta_file, tide_time)

    # calculate complex phase in radians for Euler's
    cph = -1j * ph * np.pi / 180.0

    # calculate constituent oscillation
    hc = amp * np.exp(cph)

    # predict tidal time series
    TIDE = pyTMD.predict_tidal_ts(
        tide_time, hc, c, DELTAT=DELTAT, CORRECTIONS=model.format
    )

    # infer minor corrections
    MINOR = pyTMD.infer_minor_corrections(
        tide_time, hc, c, DELTAT=DELTAT, CORRECTIONS=model.format
    )

    # calculate tide with minor correction
    TIDE.data[:] += MINOR.data[:]
    
    print("Done")

    return TIDE

In [None]:
x = selected_point_gdf.unary_union.centroid.x
y = selected_point_gdf.unary_union.centroid.y

lons = np.repeat(x, len(time_list))
lats = np.repeat(y, len(time_list))

In [None]:
model_dir = Path("../datasets/tide/model")
tide_list = model_tide_prediction(lons, lats, time_list, model_dir)

In [None]:
lt = np.min(tide_list)
ht = np.max(tide_list)
mean = np.mean(tide_list)

print(f"Low tide: {lt}")
print(f"High tide: {ht}")
print(f"Mean tide: {mean}")

In [None]:
tide_ds = xr.DataArray(tide_list, coords=[ds.time], dims=["time"])
tide_ds

In [None]:
ds["tide"] = tide_ds

In [None]:
def filter_tide(group_ds, tide):
    ds_list = []
    for _, group in group_ds:
        ds = group.isel(time=np.argsort(np.abs(group.tide.values - tide))[0])
        ds_list.append(ds)
    all_ds = xr.concat(ds_list, dim="time")
    return all_ds

In [None]:
group_ds = ds.groupby("time.year")
ht_ds = filter_tide(group_ds, ht)
lt_ds = filter_tide(group_ds, lt)
mean_ds = filter_tide(group_ds, mean)

In [None]:
mean_ds

In [None]:
def db_scale(img):
    db_output = 10 * np.log10(img)
    return db_output

In [None]:
ds["vh_db"] = ds.vh.groupby("time").apply(db_scale)
ds

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(15, 10))

ds.isel(time=-1).vh.plot(cmap="gray", robust=True, ax=axes.flatten()[0])
ds.isel(time=-1).vh_db.plot(cmap="gray", robust=True, ax=axes.flatten()[1])

In [None]:
def lee_filter(img, size):
    img_mean = uniform_filter(img, size)
    img_sqr_mean = uniform_filter(img**2, size)
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = variance(img)

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img - img_mean)
    return img_output

In [None]:
ds["vh_db_filter"] = (
    ds.vh_db
    .groupby("time")
    .apply(lambda img: xr.apply_ufunc(
            lee_filter,
            img,
            kwargs={"size": 5},
            # dask="parallelized",
            # dask_gufunc_kwargs={"allow_rechunk": True}
        )
    )
)
ds

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(15, 10))

ds.isel(time=-1).vh_db.plot(cmap="gray", robust=True, ax=axes.flatten()[0])
ds.isel(time=-1).vh_db_filter.plot(cmap="gray", robust=True, ax=axes.flatten()[1])

In [None]:



def local_binary(img: np.ndarray, *args, **kwargs) -> np.ndarray:
    img = img_as_ubyte(img)
    threshold = threshold_local(img, *args, **kwargs)
    binary = img >= threshold
    binary = binary_fill_holes(binary)
    binary = remove_small_objects(binary)
    return binary.astype(np.uint8)


def otsu_binary(img: np.ndarray) -> np.ndarray:
    img = img_as_ubyte(img)
    threshold = threshold_otsu(img)
    binary = img >= threshold
    binary = binary_fill_holes(binary)
    binary = remove_small_objects(binary)
    return binary.astype(np.uint8)


def smooth_linestring(linestring, smooth_sigma):
    """
    Uses a gauss filter to smooth out the LineString coordinates.
    """
    smooth_x = np.array(gaussian_filter1d(linestring.xy[0], smooth_sigma))
    smooth_y = np.array(gaussian_filter1d(linestring.xy[1], smooth_sigma))
    smoothed_coords = np.hstack((smooth_x, smooth_y))
    smoothed_coords = zip(smooth_x, smooth_y)
    linestring_smoothed = LineString(smoothed_coords)
    return linestring_smoothed


def create_transects(line, space, length, crs):
    # Profile spacing. The distance at which to space the perpendicular profiles
    # In the same units as the original shapefile (e.g. metres)
    space = space

    # Length of cross-sections to calculate either side of central line
    # i.e. the total length will be twice the value entered here.
    # In the same co-ordinates as the original shapefile
    length = length

    # Define a schema for the output features. Add a new field called 'Dist'
    # to uniquely identify each profile

    transect_list = []

    # Calculate the number of profiles to generate
    n_prof = int(line.length / space)

    # Start iterating along the line
    for prof in range(1, n_prof + 1):
        # Get the start, mid and end points for this segment
        seg_st = line.interpolate((prof - 1) * space)
        seg_mid = line.interpolate((prof - 0.5) * space)
        seg_end = line.interpolate(prof * space)

        # Get a displacement vector for this segment
        vec = np.array(
            [
                [
                    seg_end.x - seg_st.x,
                ],
                [
                    seg_end.y - seg_st.y,
                ],
            ]
        )

        # Rotate the vector 90 deg clockwise and 90 deg counter clockwise
        rot_anti = np.array([[0, -1], [1, 0]])
        rot_clock = np.array([[0, 1], [-1, 0]])
        vec_anti = np.dot(rot_anti, vec)
        vec_clock = np.dot(rot_clock, vec)

        # Normalise the perpendicular vectors
        len_anti = ((vec_anti**2).sum()) ** 0.5
        vec_anti = vec_anti / len_anti
        len_clock = ((vec_clock**2).sum()) ** 0.5
        vec_clock = vec_clock / len_clock

        # Scale them up to the profile length
        vec_anti = vec_anti * length
        vec_clock = vec_clock * length

        # Calculate displacements from midpoint
        prof_st = (seg_mid.x + float(vec_clock[0]), seg_mid.y + float(vec_clock[1]))
        prof_end = (seg_mid.x + float(vec_anti[0]), seg_mid.y + float(vec_anti[1]))

        distance = (prof - 0.5) * space
        transect = LineString([prof_end, prof_st])

        gdf = gpd.GeoDataFrame({"distance": [distance]}, geometry=[transect])

        transect_list.append(gdf)

    transect_gdf = pd.concat(transect_list, ignore_index=True)
    transect_gdf.crs = crs

    return transect_gdf


def transect_analysis(line_gdf, transect_gdf, time_column, reverse=False):
    line_gdf[time_column] = pd.to_datetime(line_gdf[time_column])
    line_gdf["time_idx"], _ = pd.factorize(line_gdf[time_column])

    line_gdf.sort_values(by=time_column, inplace=True, ignore_index=True)
    transect_gdf.reset_index(drop=True, inplace=True)

    analysis_list = []

    for i, transect in transect_gdf.iterrows():
        start, end = transect.geometry.boundary.geoms
        if reverse:
            start = end
        if any(line_gdf.geometry.intersects(transect.geometry)):
            intersect_gdf = line_gdf.copy()
            intersect_gdf.geometry = intersect_gdf.geometry.intersection(
                transect.geometry
            )
            geom_types = [geom.geom_type for geom in intersect_gdf.geometry]
            if geom_types.count("Point") == len(intersect_gdf):
                oldest_date = intersect_gdf.iloc[0][time_column]
                oldest_geom = intersect_gdf.iloc[0]["geometry"]
                oldest_distance = oldest_geom.distance(start)

                analysis_data = {"name": [i]}

                for j in range(len(intersect_gdf)):
                    intersect = intersect_gdf.iloc[j]
                    test_date = intersect[time_column]
                    time_str = test_date.strftime("%Y%m%d")
                    time_idx = intersect["time_idx"]

                    if j > 0:
                        distance = intersect.geometry.distance(start)
                        change = distance - oldest_distance
                        rate = change / (test_date - oldest_date).days / 365
                    else:
                        distance = oldest_distance
                        change = 0
                        rate = 0

                    analysis_data[f"distance_{time_str}"] = [distance]
                    analysis_data[f"change_{time_str}"] = [change]
                    analysis_data[f"rate_{time_str}"] = [rate]

                analysis_geom = LineString(intersect_gdf.geometry)

                analysis_gdf = gpd.GeoDataFrame(analysis_data, geometry=[analysis_geom])

                distance_columns = analysis_gdf.columns[
                    analysis_gdf.columns.str.contains("distance")
                ]
                analysis_gdf["mean_distance"] = analysis_gdf[distance_columns].mean(
                    axis=1
                )

                change_columns = analysis_gdf.columns[
                    analysis_gdf.columns.str.contains("change")
                ]
                analysis_gdf["mean_change"] = analysis_gdf[change_columns].mean(axis=1)

                rate_columns = analysis_gdf.columns[
                    analysis_gdf.columns.str.contains("rate")
                ]
                analysis_gdf["mean_rate"] = analysis_gdf[rate_columns].mean(axis=1)

                analysis_list.append(analysis_gdf)
    
    transect_analysis_gdf = pd.concat(analysis_list, ignore_index=True)
    transect_analysis_gdf.crs = line_gdf.crs

    return transect_analysis_gdf


In [None]:
ds = ds.where(ds != 0)

In [None]:
selected_ds = ds.isel(time=-1).load()
selected_ds

In [None]:
selected_ds.vh.plot(cmap="gray", robust=True, size=10)

In [None]:
filtered_ds = ds.dropna(dim="time")

In [None]:
filtered_ds

In [None]:
print(filtered_ds.time.values)

In [None]:
filtered_ds["filtered_vh"] = filtered_ds["vh"].groupby("time").apply(lee_filter, size=7)
filtered_ds["filtered_vh_db"] = filtered_ds["filtered_vh"].groupby("time").apply(convert_raster)
filtered_ds["filtered_vh_db_binary"] = filtered_ds["filtered_vh_db"].groupby("time").apply(otsu_binary)

In [None]:
output_dir = Path(f"./output/bali/{ID:04d}")
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
raster_path = output_dir.joinpath(f"{ID:04d}_s1_filtered_vh_db.tif")
filtered_ds.filtered_vh_db.rio.to_raster(raster_path, compress="lzw")

In [None]:
filtered_ds.filtered_vh_db_binary.isel(time=slice(0, 4)).load().plot(cmap="Greys_r", robust=True, size=5, col="time", col_wrap=2)

In [None]:
coastline_gdf = subpixel_contours(
    da=filtered_ds.filtered_vh_db_binary,
    affine=filtered_ds.filtered_vh_db_binary.rio.transform(),
    crs=filtered_ds.filtered_vh_db_binary.rio.crs,
    min_vertices=100
)

smooth_lines = []
for line in coastline_gdf.geometry:
    if line.geom_type == "LineString":
        smooth_line = smooth_linestring(line, 5)
    else:
        smooth_line = MultiLineString([smooth_linestring(subline, 5) for subline in line.geoms])
    smooth_lines.append(smooth_line)

coastline_gdf["geometry"] = smooth_lines
coastline_gdf

In [None]:
coastline_path = output_dir.joinpath(f"{ID:04d}_s1_coastlines.geojson")
coastline_gdf.to_file(coastline_path, driver="GeoJSON")

In [None]:
baseline = coastline_gdf.geometry.iloc[0]
transect_gdf = create_transects(baseline, 500, 100, crs=coastline_gdf.crs)
transect_gdf

In [None]:
transect_path = output_dir.joinpath(f"{ID:04d}_s1_transects.geojson")
transect_gdf.to_file(transect_path, driver="GeoJSON")

In [None]:
transect_analysis_gdf = transect_analysis(coastline_gdf, transect_gdf, "time", reverse=True)
transect_analysis_gdf

In [None]:
transect_analysis_path = output_dir.joinpath(f"{ID:04d}_s1_transect_analysis.geojson")
transect_analysis_gdf.to_file(transect_analysis_path, driver="GeoJSON")