In [None]:
%reload_ext autoreload
%autoreload 2

import numpy as np
from scipy.interpolate import griddata
import geopandas as gpd
import rasterio
from rasterio.transform import from_origin
from rasterio.plot import show
from rasterio.plot import show_hist
from rasterio.mask import mask
import json
import pandas as pd
import hvplot.pandas  # noqa
import holoviews as hv
hv.extension('bokeh')
import altair as alt
alt.data_transformers.disable_max_rows()
from matplotlib import pyplot as plt
from pathlib import Path

try:  # if on phy-server local modules will not be found if their directory is not added to PATH
    import sys
    sys.path.append("/silod7/lenz/MPSchleiSediments/analysis/")
    import os
    os.chdir("/silod7/lenz/MPSchleiSediments/analysis/")
except Exception:
    pass

from settings import Config

In [None]:
# # What happened so far: DB extract and blank procedure. Now import resulting MP data from csv
# mp_pdd = prepare_data.get_pdd()

# # Also import sediment data (sediment frequencies per size bin from master sizer export)
# grainsize_iow, grainsize_cau, sed_lower_boundaries = prepare_data.get_grainsizes()

# # ...some data wrangling to prepare particle domain data and sample domain data for MP and combine with certain sediment aggregates.
# mp_sdd = prepare_data.aggregate_SDD(mp_pdd)
# sdd_iow = prepare_data.additional_sdd_merging(mp_sdd)
# sdd_cau = pd.read_csv('../data/Metadata_CAU_sampling_log.csv', index_col=0)

In [None]:
# create geodataframe from geojson file
poly = gpd.read_file('../data/SchleiCoastline_from_OSM.geojson')
poly_as_str = [json.loads(poly.to_json())['features'][0]['geometry']]
# poly.plot()

In [None]:
savestamp = '20230403_233901'
f = [c for c in Path('../data/exports/models/predictions').glob(f'{savestamp}*.csv')][0]
target = f.name.split('_')[-2]
station_data = pd.read_csv(f)
station_data = gpd.GeoDataFrame(station_data, geometry=gpd.points_from_xy(station_data.LON, station_data.LAT), crs='EPSG:4326')
station_data
## old mehod
# station_data = gpd.GeoDataFrame(sdd_iow, geometry=gpd.points_from_xy(sdd_iow['LON'], sdd_iow['LAT'], crs='EPSG:4326')).to_crs("EPSG:3857")

In [None]:
station_data.to_crs(Config.baw_epsg, inplace=True)
poly.to_crs(Config.baw_epsg, inplace=True)

xres = yres = Config.interpolation_resolution
xmin, ymin, xmax, ymax = poly.total_bounds
xgrid, ygrid = np.meshgrid(np.arange(xmin, xmax + xres, xres), 
                           np.arange(ymin, ymax + yres, yres),
                          )

points = np.vstack((station_data.geometry.x, station_data.geometry.y)).T

values = griddata(
    points, station_data[target],
    (xgrid, ygrid),
    method=Config.interpolation_method,  # 'linear' and 'cubic' will result in nan outside of the convex hull of data points
)

nan_mask = np.isnan(values)  # if there are any nan points re-interpolate them using method 'nearest'

if np.any(nan_mask):
    values2 = griddata(
        points, station_data[target],
        (xgrid, ygrid), method='nearest',
    )
    # values[nan_mask] = values2[nan_mask]

grid_gdf = gpd.GeoDataFrame({f'{target}': values.ravel()}, 
                            geometry=gpd.points_from_xy(xgrid.ravel(), ygrid.ravel()),
                            crs=Config.baw_epsg,
                            )
clipped = grid_gdf.clip(poly)
## old method:
# clipped = gpd.overlay(grid_gdf, poly, how='intersection')  # takes about 15 min
# clipped = clipped.loc[grid_gdf.intersects(poly.geometry[0])]  # takes about 11 min

In [None]:
# clipped.plot(column=target, cmap='OrRd', edgecolor="none", antialiased=False)
# alt.Chart(clipped.assign(X = clipped.geometry.x, Y = clipped.geometry.y)).mark_square(size=100).encode(
#     x='X',
#     y='Y',
#     color=target
# ).interactive()

In [None]:
cell_areas = xres * yres
total = (clipped[target] * cell_areas).sum()

In [None]:
total

In [None]:
total / poly.area / 5

In [None]:
hv.Image(values, bounds=(xmin, ymin, xmax, ymax)).opts(width=800)

In [None]:
plt.imshow(values, interpolation='nearest')
plt.show()

In [None]:
f = f'../data/exports/models/predictions/{savestamp}_raster.tif'

transform = from_origin(xmin, ymax, Config.interpolation_resolution, Config.interpolation_resolution)

new_dataset = rasterio.open(f, 'w', driver='GTiff',
                            height = values.shape[0], width = values.shape[1],
                            count=1, dtype=str(values.dtype),
                            crs=Config.baw_epsg,
                            transform=transform)
new_dataset.write(values, 1)
new_dataset.close()
rasta = rasterio.open(f)

In [None]:
out_img, out_transform = mask(rasta, poly_as_str, crop=True)

In [None]:
show((out_img, 1), cmap='terrain')

In [None]:
xmin, ymin, xmax, ymax