In [None]:
import copy
from datetime import date, datetime
import os
from pathlib import Path
from typing import Dict, Any, List

import cv2
import registration
from dask.distributed import Client
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
import geopandas as gpd
import httpx
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import planetary_computer as pc
from pystac import Item, ItemCollection
import pystac_client
from rioxarray.merge import merge_arrays
from scipy import ndimage
from shapely.geometry import mapping, shape, LineString, MultiLineString
from skimage import filters, measure, morphology
import stackstac
import xarray as xr

from coastline_change_functions import (
    coregistration,
    create_transects,
    db_scale,
    filter_tide,
    intersection_percent,
    lee_filter,
    rescale,
    segmentation,
    smooth_linestring,
    subpixel_contours,
    tide_interpolation,
    tide_prediction,
    transect_analysis,
)

In [None]:
os.chdir(Path.home() / "Development" / "coastline")
Path.cwd()

In [None]:
CLIENT_URL = 'https://planetarycomputer.microsoft.com/api/stac/v1'
COLLECTION = "sentinel-1-rtc"
START_DATE = "2015-01-01"
STOP_DATE = "2022-12-31"
REGION_ID = 715
TIDE_TYPE = "mean"
MIN_AREA_PERCENT = 98 # percent
COREGISTRATION = False
THRESHOLD = None

In [None]:
dask_client = Client(n_workers=4, threads_per_worker=2, memory_limit='4GB')
dask_client

In [None]:
region_path = Path("./region/coastal_grids.geojson")
point_path = Path("./region/coastal_points.geojson")

region_gdf = gpd.read_file(region_path)
point_gdf = gpd.read_file(point_path)

In [None]:
# filter_region = region_gdf.query("province == 'BALI'")
# centroid = filter_region.unary_union.centroid

# m = region_gdf.explore(
#     location=[centroid.y, centroid.x],
#     zoom_start=9,
#     style_kwds={"fillOpacity": 0, "color": "red", "linewidth": 1}
# )

# for i, row in region_gdf.iterrows():
#     centroid = row.geometry.centroid
#     folium.Marker(
#         location=[centroid.y, centroid.x],
#         icon=folium.DivIcon(
#             html=f"<div style='font-size: 12px'>{i+1}</div>"
#         )
#     ).add_to(m)

# m

In [None]:
selected_region_gdf = region_gdf.loc[[REGION_ID-1]]
selected_point_gdf = point_gdf.loc[[REGION_ID-1]]

m = selected_region_gdf.explore(style_kwds={"fillOpacity": 0, "color": "red"})
m = selected_point_gdf.explore(m=m, marker_type="marker")
m

In [None]:
output_dir = Path("./output") / f"{REGION_ID:04d}"
output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
xmin, ymin, xmax, ymax = selected_region_gdf.total_bounds.tolist()
bbox = xmin, ymin, xmax, ymax
print(bbox)

In [None]:
catalog = pystac_client.Client.open(CLIENT_URL)
catalog

In [None]:
area_list = []
item_list = []

start_date = parse(START_DATE)
while start_date <= parse(STOP_DATE):
    stop_date = start_date + relativedelta(years=1)
    print(f"Datetime search: {start_date} - {stop_date}")

    query = catalog.search(
        collections=[COLLECTION],
        datetime=[start_date, stop_date],
        bbox=bbox,
        query={
            "sar:polarizations": {"eq": ['VV', 'VH']}
        }
    )
    items = query.get_items()
    for item in items:
        area = intersection_percent(item, mapping(selected_region_gdf.unary_union))
        if area >= MIN_AREA_PERCENT:
            area_list.append(area)
            item_list.append(item)
        
    start_date = stop_date

In [None]:
item_list = sorted(item_list, key=lambda x: x.datetime)
s1_items = ItemCollection(item_list)
print(f"Found: {len(s1_items)} datasets")

In [None]:
s1_item_gdf = gpd.GeoDataFrame.from_features(s1_items.to_dict(), crs="epsg:4326")
s1_item_gdf["area_percent"] = area_list

In [None]:
time_list = sorted(pd.to_datetime(s1_item_gdf["datetime"]))
year_list = sorted(set(map(lambda x: x.year, time_list)))
print(time_list[0], time_list[-1])
print(year_list)

In [None]:
m = s1_item_gdf[["platform", "geometry", "datetime", "sat:absolute_orbit", "sat:orbit_state", "area_percent"]].explore(
    column="platform", cmap="viridis", style_kwds={"fillOpacity": 0}
)

m = selected_region_gdf.explore(m=m, style_kwds={"fillOpacity": 0.5, "color": "red"})
m = selected_point_gdf.explore(m=m, marker_type="marker")
m

In [None]:
signed_s1_items = [pc.sign(item).to_dict() for item in s1_items]

s1_data = (
    stackstac.stack(
        signed_s1_items,
        bounds_latlon=bbox,
        epsg=3857,
        resolution=10,
    )
    .where(lambda x: x > 0, other=np.nan)
    .sel(band="vh")
)
s1_data

In [None]:
dem_query = catalog.search(
    collections=["cop-dem-glo-30"],
    bbox=bbox
)

dem_items = dem_query.get_all_items()
print(f"Found: {len(dem_items):d} datasets")

In [None]:
dem_item_gdf = gpd.GeoDataFrame.from_features(dem_items.to_dict(), crs="epsg:4326")

In [None]:
m = dem_item_gdf.explore(
    style_kwds={"fillOpacity": 0}
)

m = selected_region_gdf.explore(m=m, style_kwds={"fillOpacity": 0.5, "color": "red"})
m = selected_point_gdf.explore(m=m, marker_type="marker")
m

In [None]:
signed_dem_items = [pc.sign(item).to_dict() for item in dem_items]

dem_data = (
    stackstac.stack(
        signed_dem_items,
        bounds_latlon=bbox,
        epsg=3857
    )
    # .where(lambda x: x > 0, other=np.nan)
    .sel(band="data")
    .rio.write_nodata(0)
)
dem_data

In [None]:
merged_dem_data = merge_arrays([dem for dem in dem_data.load()])
merged_dem_data

In [None]:
times = s1_data.time.values
print(times[0], times[-1])

x = selected_point_gdf.unary_union.centroid.x
y = selected_point_gdf.unary_union.centroid.y
print(x, y)

In [None]:
tide_path = output_dir / f"{REGION_ID:04d}_tide.csv"

if not tide_path.exists():
    start_date = pd.to_datetime(times)[0].date()
    stop_date = pd.to_datetime(times)[-1].date()
    tide_df = tide_prediction(x, y, start_date, stop_date)
    interp_tide_df = tide_interpolation(tide_df, pd.to_datetime(times).tolist())
    interp_tide_df.to_csv(tide_path, index=False)
else:
    interp_tide_df = pd.read_csv(tide_path)

In [None]:
tide_list = interp_tide_df['level'].tolist()

In [None]:
lt = np.min(tide_list)
ht = np.max(tide_list)
mean = np.mean(tide_list)

print(f"Low tide: {lt}")
print(f"High tide: {ht}")
print(f"Mean tide: {mean}")

In [None]:
ax = interp_tide_df.plot(x="datetime", y="level", figsize=(10, 5))
ax.axhline(y=lt, color="blue", linestyle="dashed", label="lt")
ax.axhline(y=mean, color="green", linestyle="dashed", label="mean")
ax.axhline(y=ht, color="red", linestyle="dashed", label="ht")
ax.legend()

In [None]:
tide_data = xr.DataArray(tide_list, coords=[s1_data.time], dims=["time"])
tide_data

In [None]:
s1_data["tide"] = tide_data

In [None]:
group_s1_data = s1_data.groupby("time.year")
ht_s1_data = filter_tide(group_s1_data, ht)
lt_s1_data = filter_tide(group_s1_data, lt)
mean_s1_data = filter_tide(group_s1_data, mean)

In [None]:
tide_s1_data_dict = {
    "ht": ht_s1_data,
    "lt": lt_s1_data,
    "mean": mean_s1_data
}

In [None]:
vh_data = tide_s1_data_dict[TIDE_TYPE].load()
vh_data

In [None]:
datatake_ids = vh_data["s1:datatake_id"].values
datatake_ids

In [None]:
s1_filter_gdf = s1_item_gdf[s1_item_gdf["s1:datatake_id"].isin(datatake_ids)]
s1_filter_gdf

In [None]:
m = s1_filter_gdf[["platform", "geometry", "datetime", "sat:absolute_orbit", "sat:orbit_state", "area_percent"]].explore(
    column="datetime", cmap="rainbow", style_kwds={"fillOpacity": 0}
)

m = selected_region_gdf.explore(m=m, style_kwds={"fillOpacity": 0.5, "color": "red"})
m = selected_point_gdf.explore(m=m, marker_type="marker")
m

In [None]:
if COREGISTRATION:
    new_vh_data = (
        coregistration(vh_data)
        .groupby("time")
        .apply(
            lambda x: x
            .rio.write_nodata(0)
            .rio.interpolate_na()
        )
    )
else:
    new_vh_data = vh_data.copy()
new_vh_data

In [None]:
new_vh_data.plot(robust=True, cmap="gray", col="time", col_wrap=4, size=5)

In [None]:
vh_db_data = new_vh_data.groupby("time").apply(db_scale)
vh_db_data

In [None]:
vh_db_data.plot(robust=True, cmap="gray", col="time", col_wrap=4, size=5)

In [None]:
vh_filter = (
    vh_db_data
    .groupby("time")
    .apply(lambda img: xr.apply_ufunc(
            lee_filter,
            img,
            kwargs={"size": 5},
            # dask="parallelized",
            # dask_gufunc_kwargs={"allow_rechunk": True}
        )
    )
)
vh_filter

In [None]:
vh_binary = (
    vh_filter
    .groupby("time")
    .apply(
        lambda img: xr.apply_ufunc(
            segmentation,
            kwargs={
                "img": img,
                "threshold": THRESHOLD
            }
            # img.chunk({"x": -1, "y": -1}),
            # dask="parallelized",
        )
    )
)
vh_binary

In [None]:
dem_regrid = merged_dem_data.interp_like(vh_binary.isel(time=-1))
dem_regrid = dem_regrid > 30

In [None]:
vh_binary_filtered = (
    vh_binary.groupby("time")
    .apply(lambda x: x.where(~dem_regrid, other=1))
)
vh_binary_filtered

In [None]:
vh_binary_filtered.plot(cmap="gray", col="time", col_wrap=4, size=5)

In [None]:
coastline_gdf = subpixel_contours(
    vh_binary_filtered,
    min_vertices=100,
    crs=s1_data.crs,
    affine=s1_data.transform
)
coastline_gdf

In [None]:
new_lines = []
for i, row in coastline_gdf.iterrows():
    line = row.geometry
    if line.geom_type == "MultiLineString":
        new_line = MultiLineString([smooth_linestring(l, 5) for l in line.geoms])
    else:
        new_line = smooth_linestring(line, 5)
    new_lines.append(new_line)
    
coastline_gdf.geometry = new_lines

In [None]:
baseline = coastline_gdf.geometry.iloc[0]
transect_gdf = create_transects(baseline, 500, 100, crs=coastline_gdf.crs)
transect_gdf.head() 

In [None]:
transect_analysis_gdf = transect_analysis(coastline_gdf, transect_gdf, "time", reverse=True)
transect_analysis_gdf.head()

In [None]:
coastline_gdf["time"] = coastline_gdf["time"].astype(str)

In [None]:
m = coastline_gdf.explore(tiles="CartoDB dark_matter", column="time", cmap="Reds")
transect_analysis_gdf[["name", "mean_distance", "mean_change", "mean_rate", "geometry"]].explore(m=m, column="mean_rate", cmap="rainbow", tiles="CartoDB dark_matter")

In [None]:
suboutput_dir = output_dir / TIDE_TYPE
suboutput_dir.mkdir(parents=True, exist_ok=True)

In [None]:
coastline_path = suboutput_dir / f"{REGION_ID:04d}_s1_coastlines.geojson"
transect_path = suboutput_dir / f"{REGION_ID:04d}_s1_transects.geojson"
transect_analysis_path = suboutput_dir / f"{REGION_ID:04d}_s1_transect_analysis.geojson"

In [None]:
coastline_gdf.to_file(coastline_path, driver="GeoJSON")
transect_gdf.to_file(transect_path, driver="GeoJSON")
transect_analysis_gdf.to_file(transect_analysis_path, driver="GeoJSON")

In [None]:
vh_db_coreg_rescale = (
    vh_db_data.groupby("time")
    .apply(rescale, target_type_min=1, target_type_max=255, target_type=np.uint8)
)

In [None]:
if COREGISTRATION:
    suffix = "_coreg.tif"
else:
    suffix = ".tif"

for time, d in vh_db_rescale.rename("vh_db").groupby("time"):
    year = pd.to_datetime(time).year
    raster_path = suboutput_dir / f"{region_id:04d}_{year}_s1_vh_db{suffix}"
    d.rio.to_raster(raster_path, crs=s1_data.crs, compress="lzw")
    print(f"Saved to {raster_path}")