# Personal Notebook for exploring nc files

## Note on aggregation method

* `mean`
    * with `geopandas`: simply calculating mean on all grid cells in a NUTS region. There is no weighted area.
    * with `exactaxtract`: [reference](https://github.com/isciences/exactextract/blob/master/python/doc/operations.rst)
        * `mean`: Mean value of cells that intersect the polygon, weighted by the percent of each cell that is covered. Usually used for average temperature.
        * `weighted_mean`: Mean value of cells that intersect the polygon, weighted by the product over the coverage fraction and the weighting raster. Usually used for population-weighted average temperature
    * only calculate mean on non-nan values. **If all values are `nan`, then return `nan`**
* `sum`
    * with `geopandas`: simply calculating sum on all grid cells in a NUTS region. There is no weighted area.
    * with `exactextract`: Sum of values of raster cells that intersect the polygon, with each raster value weighted by its coverage fraction. Usually used for total population
    * only calculate mean on non-nan values. **If all values are `nan`, then return `0.0`**. This is mathematically makes sense for population value. But how about other total data variables, such as total precipitation,  radiation, soil moisture, etc.?

## Preparing data

In [None]:
import xarray as xr
import geopandas as gpd
import os
import xagg as xa
import exactextract as ee
import numpy as np

In [None]:
data_file = "../docs/source/notebooks/processed/era5_data_2016-2017_allm_2t_tp_monthly_unicoords_adjlon_celsius_mm_tutorial_B.nc"  # 0.1 deg
nuts_file = "../data/in/NUTS_RG_20M_2024_4326.shp.zip"

In [None]:
CRS = "EPSG:4326"

In [None]:
with xr.open_dataset(data_file, chunks={"time": "auto"}) as ds:
    df = ds.to_dataframe().reset_index()
df.head()

In [None]:
len(df[["latitude", "longitude"]].drop_duplicates())

In [None]:
# check if there is row where t2m or tp is NaN but not both
df_nan = df[
    (df["t2m"].isna() & ~df["tp"].isna()) | (~df["t2m"].isna() & df["tp"].isna())
]
len(df_nan)

In [None]:
nuts = gpd.read_file(nuts_file)
len(nuts.NUTS_ID.unique())

In [None]:
nuts.head()

## Using purely geopandas

**Note: Can't run for now with 0.1 deg as the data is too large (causing crashed kernel or VSCode)**

When using purely `geopandas.sjoin`, there are 448 NUTS_IDs that do not have points intersecting with their areas, and 422 NUTS_IDs that have `NaN` for `t2m` or `tp` due to the original dataset. However, after groupping by NUTS_ID and time, 422 `NaN` cases are reduced to 123 cases.

In [None]:
# convert xarray dataset to pandas dataframe
gpd_ds = gpd.GeoDataFrame(
    df,
    geometry=gpd.points_from_xy(df["longitude"], df["latitude"]),
    crs=CRS,
)
gpd_ds.head()

In [None]:
# merge with nuts geodataframe
merged = gpd.sjoin(gpd_ds, nuts, how="inner", predicate="within")
merged.head()

In [None]:
gpd_groupped = merged.groupby(["NUTS_ID", "time"], as_index=False).agg(
    {"t2m": "mean", "tp": "mean"}
)
gpd_groupped.head()

In [None]:
len(gpd_groupped)  # 32400

In [None]:
# count NUTS_IDs that have nan t2m or tp
gdp_nan_t2m = merged[merged["t2m"].isna()]["NUTS_ID"].unique()
gdp_nan_tp = merged[merged["tp"].isna()]["NUTS_ID"].unique()

(
    len(gdp_nan_t2m),
    len(gdp_nan_tp),
    set(gdp_nan_t2m) - set(gdp_nan_tp),
    set(gdp_nan_tp) - set(gdp_nan_t2m),
)  # 422, 422, 0, 0

In [None]:
len(gpd_groupped[gpd_groupped["t2m"].isna()]["NUTS_ID"].unique())  # 123

In [None]:
# test mean of nan of geopandas
# create a small test geodataframe
test_gdf = gpd.GeoDataFrame(
    {
        "NUTS_ID": ["A", "A", "A", "B", "B", "B"],
        "t2m": [1.0, 2.0, None, None, None, None],
    }
)
test_gdf_grouped_mean = test_gdf.groupby("NUTS_ID", as_index=False).agg({"t2m": "mean"})
test_gdf_grouped_mean

In [None]:
test_gdf_grouped_sum = test_gdf.groupby("NUTS_ID", as_index=False).agg({"t2m": "sum"})
test_gdf_grouped_sum

In [None]:
len(merged.NUTS_ID), len(merged.NUTS_ID.unique()), len(merged.geometry.unique())

In [None]:
# get specific values of merged for checking
merged_sample = merged[
    np.isclose(merged["latitude"], 89.75, atol=1e-6)
    & (merged["time"] == "2016-01-15 12:00:00")
]
merged_sample

In [None]:
# check if there are any geometries that map to multiple NUTS_IDs
# and among these NUTS_IDs, there is no hierarchy relationship
from collections import defaultdict

geom_to_nuts = defaultdict(list)
for geom, nuts_id in zip(merged.geometry, merged.NUTS_ID):
    geom_to_nuts[geom].append(nuts_id)

In [None]:
repeated_geoms = {
    geom: nuts_ids for geom, nuts_ids in geom_to_nuts.items() if len(nuts_ids) > 1
}
len(repeated_geoms)

In [None]:
shared_geoms = defaultdict(list)
for geom, nuts_ids in repeated_geoms.items():
    common_prefix = os.path.commonprefix(nuts_ids)
    if not common_prefix:
        shared_geoms[geom] = nuts_ids
len(shared_geoms)

In [None]:
# check if there are any NUTS_IDS that do not map to any geometry
all_nuts_ids = set(nuts.NUTS_ID.unique())
mapped_nuts_ids = set(merged.NUTS_ID.unique())
unmapped_nuts_ids = all_nuts_ids - mapped_nuts_ids
len(unmapped_nuts_ids)  # 448

In [None]:
list(unmapped_nuts_ids)[:10]

In [None]:
# double check if NUTS_ID is indeed not in merged
merged[merged["NUTS_ID"] == "NL226"]

In [None]:
# inspect geometries that do not map to any data point
unmapped_nuts = nuts[nuts["NUTS_ID"] == "NL226"]
unmapped_nuts.plot()

In [None]:
# check range of lat lon in this unmapped nuts region
unmapped_nuts.total_bounds  # minx, miny, maxx, maxy

In [None]:
# filter closed ranges of lat lon from gpd_ds
minx, miny, maxx, maxy = unmapped_nuts.total_bounds
gpd_ds_filtered = gpd_ds[
    (gpd_ds["longitude"] >= minx)
    & (gpd_ds["longitude"] <= maxx)
    & (gpd_ds["latitude"] >= miny)
    & (gpd_ds["latitude"] <= maxy)
]
len(gpd_ds_filtered)

In [None]:
filtered_points = gpd_ds_filtered["geometry"].unique()
filtered_points

In [None]:
# check if filtered points is within the unmapped_nuts geometry
within_flags = [unmapped_nuts.contains(point).any() for point in filtered_points]
within_flags

In [None]:
# check if there are any geometries that do not map to any NUTS_ID
all_geoms = set(gpd_ds.geometry.unique())
mapped_geoms = set(merged.geometry.unique())
unmapped_geoms = all_geoms - mapped_geoms
unmapped_geoms = list(unmapped_geoms)
len(unmapped_geoms)

In [None]:
unmapped_geoms[:10]

In [None]:
# check if there are NUTS3 inside another NUTS3
nuts3 = nuts[nuts.LEVL_CODE == 3]
nuts3_sjoined = gpd.sjoin(nuts3, nuts3, how="inner", predicate="within")
nuts3_sjoined_diff = nuts3_sjoined[
    nuts3_sjoined.NUTS_ID_left != nuts3_sjoined.NUTS_ID_right
]
(
    len(nuts3),
    len(nuts3_sjoined),
    len(nuts3_sjoined_diff),
    len(set(nuts3_sjoined.NUTS_ID_left)),
)

In [None]:
# check if NUTS3 touch other NUTS3
nuts3_sjoined_other = gpd.sjoin(nuts3, nuts3, how="inner", predicate="touches")
nuts3_sjoined_other_diff = nuts3_sjoined_other[
    nuts3_sjoined_other.NUTS_ID_left != nuts3_sjoined_other.NUTS_ID_right
]
(
    len(nuts3_sjoined_other),
    len(nuts3_sjoined_other_diff),
    len(set(nuts3_sjoined_other_diff.NUTS_ID_left)),
)

In [None]:
# check if there is NUTS inside another NUTS
nuts_sjoined = gpd.sjoin(nuts, nuts, how="inner", predicate="within")
# get all rows where NUTS_IDs are different and don't share common prefix
nuts_sjoined_diff = nuts_sjoined[
    nuts_sjoined.NUTS_ID_left != nuts_sjoined.NUTS_ID_right
]
shared_nuts = []
for _, row in nuts_sjoined_diff.iterrows():
    common_prefix = os.path.commonprefix([row.NUTS_ID_left, row.NUTS_ID_right])
    if not common_prefix:
        shared_nuts.append(row)
len(shared_nuts), len(nuts_sjoined), len(nuts_sjoined_diff)

In [None]:
nuts[nuts.CNTR_CODE == "BA"]

In [None]:
nuts[nuts.NUTS_ID == "DE502"]

In [None]:
nuts[(nuts.CNTR_CODE == "DE") & (nuts.LEVL_CODE == 3)]

In [None]:
# get all nuts related to Bremen
bremen_nuts_ids = nuts[nuts.NUTS_NAME == "Bremen"][["NUTS_ID"]].NUTS_ID.tolist()
bremen_root = sorted(bremen_nuts_ids, key=len)[0]
bremen_root

In [None]:
## get all nuts under DE5
bremen_nuts = nuts[nuts.NUTS_ID.str.startswith("DE5")]
bremen_nuts

In [None]:
# check if DE50 actually within DE5
de5_geom = nuts[nuts.NUTS_ID == "DE5"].geometry
de50_geom = nuts[nuts.NUTS_ID == "DE50"].geometry
de50_geom.within(de5_geom.iloc[0])

In [None]:
de501_geom = nuts[nuts.NUTS_ID == "DE501"].geometry
de502_geom = nuts[nuts.NUTS_ID == "DE502"].geometry

In [None]:
de501_geom.within(de5_geom.iloc[0]), de502_geom.within(de5_geom.iloc[0])

In [None]:
de501_geom.within(de50_geom.iloc[0]), de502_geom.within(de50_geom.iloc[0])

In [None]:
# check if there is any NUTSi that is not within its parent NUTS(i-1)
not_within_cases = []
ctrn_codes = nuts.CNTR_CODE.unique()
for ctrn_code in ctrn_codes:
    nuts_subset = nuts[nuts.NUTS_ID.str.startswith(ctrn_code)][
        ["NUTS_ID"]
    ].NUTS_ID.tolist()
    nuts_subset.sort(key=len)  # parent NUTS will appear before child NUTS
    for nuts_id in nuts_subset:
        parent_id = nuts_id[:-1] if len(nuts_id) > len(ctrn_code) else None
        if parent_id is None:
            continue
        check_within = gpd.sjoin(
            nuts[nuts.NUTS_ID == nuts_id],
            nuts[nuts.NUTS_ID == parent_id],
            how="inner",
            predicate="within",
        )
        if len(check_within) == 0:
            not_within_cases.append((nuts_id, parent_id))
len(not_within_cases), not_within_cases

## Aggregate data by NUTS using xagg


**Note: Can't run for now with 0.1 deg as the data is too large (took more than 20 minutes for aggregating t2m only)**

When using `xagg` for aggregation, there are in total 114 NUTS_IDs with `NaN` `t2m` or `tp`:
* 57 cases are due to the original values of `t2m` and `tp` in the original dataset
* 57 cases because the areas are too small

However, it seems like we can only calculate average (mean) with `xagg`. There is no other options for aggregation like `sum`, `min`, or `max`.

In [None]:
%pip install cartopy matplotlib cmocean

In [None]:
# check if there are nans in t2m or tp before aggregation
nan_t2m_ds = df[df["t2m"].isna()]
nan_tp_ds = df[df["tp"].isna()]

In [None]:
len(nan_t2m_ds[["latitude", "longitude"]].drop_duplicates()), len(nan_tp_ds)

In [None]:
# get overlap between pixels and polygons
weightmap = xa.pixel_overlaps(ds, nuts)
weightmap

In [None]:
# get row 50 of the nuts
nuts.iloc[50]

In [None]:
weightmap.diag_fig({"NUTS_ID": "BA01"}, ds)

In [None]:
# aggregate dat in ds onto polygons in nuts
agg_ds = xa.aggregate(ds, weightmap)
agg_ds

In [None]:
out_ds = agg_ds.to_dataset()
out_ds

In [None]:
out_df = out_ds.to_dataframe().reset_index()
out_df

In [None]:
# check how many got mapped
len(out_df), len(out_df.NUTS_ID.unique())

In [None]:
# check if there is any NUTS ID that does not have t2m or tp mapped
nan_t2m = out_df[out_df["t2m"].isna()]["NUTS_ID"].unique()
nan_tp = out_df[out_df["tp"].isna()]["NUTS_ID"].unique()
(
    len(nan_t2m),
    len(nan_tp),
    set(nan_t2m) - set(nan_tp),
    set(nan_tp) - set(nan_t2m),
)  # 114, 114, 0, 0

In [None]:
nan_tp

In [None]:
nan_t2m

In [None]:
# find common IDs with unmapped_nuts_ids from geopandas sjoin
common_ids = set(nan_t2m).intersection(set(unmapped_nuts_ids))
len(common_ids)  # 57

In [None]:
# check if NUTS_ID in nan_t2m is because t2m is nan in the original dataset
nan_t2m_from_org = []
nan_t2m_not_from_org = []
nan_t2m_points = gpd_ds[gpd_ds["t2m"].isna()]["geometry"].unique()
for nuts_id in nan_t2m:
    nuts_geom = nuts[nuts["NUTS_ID"] == nuts_id]
    check_contains = gpd.sjoin(
        gpd.GeoDataFrame(geometry=nan_t2m_points, crs=CRS),
        nuts_geom,
        how="inner",
        predicate="within",
    )
    if len(check_contains) > 0:
        nan_t2m_from_org.append(nuts_id)  # because of nan in original data
    else:
        nan_t2m_not_from_org.append(nuts_id)  # not because of nan in original data

len(nan_t2m_from_org), len(nan_t2m_not_from_org)  # 57, 57

In [None]:
nan_t2m_not_from_org[:10]

In [None]:
# check if these two lists have common ids with the ones from geopandas sjoin
common_ids_org = set(nan_t2m_from_org).intersection(set(unmapped_nuts_ids))
common_ids_not_org = set(nan_t2m_not_from_org).intersection(set(unmapped_nuts_ids))
len(common_ids_org), len(common_ids_not_org)

In [None]:
# common ids with gdp_nan_t2m of geopandas sjoin
common_ids_gdp = set(nan_t2m).intersection(set(gdp_nan_t2m))
common_ids_gdp_org = set(nan_t2m_from_org).intersection(set(gdp_nan_t2m))
common_ids_gdp_not_org = set(nan_t2m_not_from_org).intersection(set(gdp_nan_t2m))
len(common_ids_gdp), len(common_ids_gdp_org), len(common_ids_gdp_not_org)

In [None]:
# check if the unmapped NUTS_ID with geopandas is also unmapped here
out_df[out_df["NUTS_ID"] == "BE233"][["NUTS_ID", "t2m", "tp"]]

In [None]:
# plot one of the unmapped NUTS_ID region
unmapped_nuts_id = "BE233"
unmapped_nuts_region = nuts[nuts["NUTS_ID"] == unmapped_nuts_id]
unmapped_nuts_region.plot()

In [None]:
# check lat lon range of this unmapped NUTS region
unmapped_nuts_region.total_bounds  # minx, miny, maxx, maxy

In [None]:
# find if there is any grid point within this unmapped NUTS region
minx, miny, maxx, maxy = unmapped_nuts_region.total_bounds
gpd_ds_filtered = gpd_ds[
    (gpd_ds["longitude"] >= minx)
    & (gpd_ds["longitude"] <= maxx)
    & (gpd_ds["latitude"] >= miny)
    & (gpd_ds["latitude"] <= maxy)
]
len(gpd_ds_filtered)

In [None]:
filtered_points = gpd_ds_filtered["geometry"].unique()
filtered_points

In [None]:
# check if filtered points is within the unmapped_nuts geometry
within_flags = [unmapped_nuts.contains(point).any() for point in filtered_points]
within_flags

## Aggregate data by NUTS using exactextract

Using `exactextract` yields the same results as with `xagg`, when calculating `mean` for all data variables.

We can specify aggregation method for each data variable with `exactextract`.

In [None]:
# it seems like rioxarray is installed but used by exactextract
import rioxarray as rxr  # noqa: F401

In [None]:
# this cell seems unnecessary
# # ensure CRS is defined
# ds = ds.rio.write_crs(CRS, inplace=True)

# # tell rioxarray which dimensions are x and y
# ds = ds.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)

In [None]:
ds.sizes

In [None]:
# separate t2m and tp for exactextract
ds_t2m = ds[["t2m"]]
ds_tp = ds[["tp"]]

In [None]:
# aggregate for each time step separately
import pandas as pd

results = []
for t in ds.time.values:
    t2m_t = ds_t2m.sel(time=t)
    tp_t = ds_tp.sel(time=t)

    t2m_stats = ee.exact_extract(
        t2m_t, nuts, "t2m_mean=mean", include_cols=["NUTS_ID"], output="pandas"
    )
    t2m_stats["time"] = t
    results.append(t2m_stats)

    tp_stats = ee.exact_extract(
        tp_t,
        nuts,
        "tp_mean=mean",  # note that if sum is used, sum of all NaN will be 0
        include_cols=["NUTS_ID"],
        output="pandas",
    )
    tp_stats["time"] = t
    results.append(tp_stats)
    print(
        "Done for time", t
    )  # in total, ~ 5 minutes for 0.5 deg, ~ 8 minutes for 0.1 deg

In [None]:
len(results)

In [None]:
merged_dfs = [
    pd.merge(
        results[i],
        results[i + 1],
        on=["NUTS_ID", "time"],
        how="outer",
        validate="1:1",
    )
    for i in range(0, len(results), 2)
]
agg_df = pd.concat(merged_dfs, ignore_index=True)
agg_df.head()

In [None]:
len(results[0]), len(results[1]), len(agg_df)

In [None]:
set(results[0]["NUTS_ID"].unique()) - set(agg_df["NUTS_ID"].unique())

In [None]:
len(agg_df["NUTS_ID"].unique())

In [None]:
# NUTS_IDs with NaN t2m or tp
nan_t2m_ids = agg_df[agg_df["t2m_mean"].isna()]["NUTS_ID"].unique()
nan_tp_ids = agg_df[agg_df["tp_mean"].isna()]["NUTS_ID"].unique()
(
    len(nan_t2m_ids),
    len(nan_tp_ids),
    len(set(nan_t2m_ids) - set(nan_tp_ids)),
    len(set(nan_tp_ids) - set(nan_t2m_ids)),
)  # 114, 114, 0, 0 for 0.5 deg; 1, 1, 0, 0 for 0.1 deg

In [None]:
# get tp_mean from NUTS_IDs with NaN t2m
nan_t2m_not_nan_tp = agg_df[agg_df["NUTS_ID"].isin(nan_t2m_ids)]["tp_mean"].unique()
nan_t2m_not_nan_tp

In [None]:
nan_t2m_ids

In [None]:
# check if NUTS_ID in nan_t2m _ids is because t2m is nan in the original dataset
ee_nan_t2m_from_org = []
ee_nan_t2m_not_from_org = []
nan_t2m_points = df[df["t2m"].isna()][["latitude", "longitude"]].drop_duplicates()
for nuts_id in nan_t2m_ids:
    nuts_geom = nuts[nuts["NUTS_ID"] == nuts_id]
    check_contains = gpd.sjoin(
        gpd.GeoDataFrame(
            geometry=gpd.points_from_xy(
                x=nan_t2m_points["longitude"], y=nan_t2m_points["latitude"]
            ),
            crs=CRS,
        ),
        nuts_geom,
        how="inner",
        predicate="within",
    )
    if len(check_contains) > 0:
        ee_nan_t2m_from_org.append(nuts_id)  # because of nan in original data
    else:
        ee_nan_t2m_not_from_org.append(nuts_id)  # not because of nan in original data

(
    len(ee_nan_t2m_from_org),
    len(ee_nan_t2m_not_from_org),
)  # 57, 57 for 0.5 deg; 0, 1 for 0.1 deg

In [None]:
# plot one of the unmapped NUTS_ID region
unmapped_nuts_id = "MT002"
unmapped_nuts_region = nuts[nuts["NUTS_ID"] == unmapped_nuts_id]
unmapped_nuts_region.plot()

In [None]:
# check lat lon range of this unmapped NUTS region
unmapped_nuts_region.total_bounds  # minx, miny, maxx, maxy

In [None]:
# find if there is any grid point within this unmapped NUTS region
minx, miny, maxx, maxy = unmapped_nuts_region.total_bounds
df_filtered = df[
    (df["longitude"] >= minx)
    & (df["longitude"] <= maxx)
    & (df["latitude"] >= miny)
    & (df["latitude"] <= maxy)
]
len(df_filtered)

## Compare between 3 methods

In [None]:
# compare between geopandas sjoin, xagg, and exactextract results
nuts_id = "DE"
time = "2016-01-01"
gpd_result = gpd_groupped[
    (gpd_groupped["NUTS_ID"] == nuts_id) & (gpd_groupped["time"] == time)
][["NUTS_ID", "time", "t2m", "tp"]]
xagg_result = out_df[(out_df["NUTS_ID"] == nuts_id) & (out_df["time"] == time)][
    ["NUTS_ID", "time", "t2m", "tp"]
]
exactextract_result = agg_df[(agg_df["NUTS_ID"] == nuts_id) & (agg_df["time"] == time)][
    ["NUTS_ID", "time", "t2m_mean", "tp_mean"]
]
gpd_result, xagg_result, exactextract_result

In [None]:
# check if these two lists have common ids with the ones from geopandas sjoin
ee_common_ids_org = set(ee_nan_t2m_from_org).intersection(set(unmapped_nuts_ids))
ee_common_ids_not_org = set(ee_nan_t2m_not_from_org).intersection(
    set(unmapped_nuts_ids)
)
len(ee_common_ids_org), len(ee_common_ids_not_org)

In [None]:
# common ids with gdp_nan_t2m of geopandas sjoin
ee_common_ids_gdp = set(nan_t2m_ids).intersection(set(gdp_nan_t2m))
ee_common_ids_gdp_org = set(ee_nan_t2m_from_org).intersection(set(gdp_nan_t2m))
ee_common_ids_gdp_not_org = set(ee_nan_t2m_not_from_org).intersection(set(gdp_nan_t2m))
len(ee_common_ids_gdp), len(ee_common_ids_gdp_org), len(ee_common_ids_gdp_not_org)