# Notebook to explore netCDF files and change resolution, plus Python plotting
These files are downloaded from [Copernicus Climate Data Store](https://cds.climate.copernicus.eu/datasets/reanalysis-era5-land-monthly-means?tab=download), using `cdsapi`. Get the data running the Python script `inout.py`:
```
python onehealth_db/inout.py
```

The downloaded files are stored in `data/in`. The `area` option uses values `90`, `90`, `-90`, `-90` for `North`, `East`, `South`, `West`, respectively.

Question: What is the coordinate reference system for the era5 dataset? NUTS3 either on EPSG 3035, 4326, 3857.

-> According to [ERA5-Land's documentation](https://confluence.ecmwf.int/display/CKB/ERA5-Land%3A+data+documentation):
> The data is referenced in the horizontal with respect to the WGS84 ellipse (which defines the major/minor axes) and in the vertical it is referenced to the EGM96 geoid over land but over ocean it is referenced to mean sea level, with the approximation that this is assumed to be coincident with the geoid. 

Then according to [this page](https://spatialreference.org/ref/epsg/9707/), it seems like the coordinate reference system for ERA5-Land is EPSG:9707

> ERA5-Land produces a total of 50 variables describing the
water and energy cycles over land, globally, hourly, and at a
spatial resolution of 9 km, matching the ECMWF triangular–
cubic–octahedral (TCo1279) operational grid (Malardel
et al., 2016).

In [None]:
from pathlib import Path
import xarray as xr
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
import geopandas as gpd

The following cells aim to explore the data structure

In [None]:
data_folder = Path("../../../data/")

### ERA5-Land from CDS

In [None]:
f_area_before_celsius = (
    data_folder / "in" / "era5_data_2016_2017_all_2t_tp_monthly_raw.nc"
)
f_area_after_celsius = (
    data_folder
    / "processed"
    / "era5_data_2016-2017_allm_2t_tp_monthly_unicoords_adjlon_celsius_mm_05deg_trim_ts20250923-065745_hssc-laptop01.nc"
)

#### Dask Array

In [None]:
dask_ds = xr.open_dataset(f_area_after_celsius, chunks={})
dask_ds = dask_ds.chunk({"time": 1, "latitude": 900, "longitude": 1800})
dask_ds

In [None]:
t2m_data = (
    dask_ds["t2m"].dropna(dim="latitude", how="all").load()
)  # load data into memory
t2m_data

In [None]:
stacked = t2m_data.stack(points=("time", "latitude", "longitude"))
stacked

In [None]:
stacked = stacked.dropna("points")
stacked

In [None]:
stacked["time"].values.astype("datetime64[ns]")

In [None]:
stacked["latitude"].values

In [None]:
stacked["longitude"].values

In [None]:
stacked.values

#### Xarray

In [None]:
# load netCDF files
ds_area_before_celsius = xr.open_dataset(f_area_before_celsius)
ds_area_after_celsius = xr.open_dataset(f_area_after_celsius)

In [None]:
ds_area_before_celsius

In [None]:
ds_area_before_celsius.sel(
    latitude=20.0, longitude=10.0, method="nearest"
).to_dataframe().head(5)

In [None]:
ds_area_before_celsius["tp"].attrs

In [None]:
ds_area_after_celsius

In [None]:
ds_area_after_celsius.latitude.values[5]

In [None]:
ds_area_after_celsius["tp"].attrs

In [None]:
ds_area_after_celsius.sel(
    latitude=20.0, longitude=10.0, method="nearest"
).to_dataframe().head(5)

In [None]:
ds_area_after_celsius.latitude.values[5]

In [None]:
lat = 20.0
lon = 10.0
ds_area_after_celsius["t2m"].sel(latitude=lat, longitude=lon, method="nearest").plot(
    color="blue", marker="o"
)
plt.title("2m temperature in 2024 at lat-{}, lon-{}".format(lat, lon))
plt.show()

In [None]:
# plot the data for the first month
ds_area_after_celsius.t2m[0].plot.pcolormesh(figsize=(9, 5), robust=True)

In [None]:
ds_area_after_celsius.tp[0].plot.pcolormesh(figsize=(9, 5), robust=True)

In [None]:
# convert to dataframe
df = ds_area_after_celsius.to_dataframe().reset_index()
df

### Population data from ISIMIP

In [None]:
f_popu_data = (
    data_folder
    / "processed"
    / "population_histsoc_30arcmin_annual_1901_2021_unicoords_2016-2017_ts20250923-065749_hssc-laptop01.nc"
)
ds_popu_data = xr.open_dataset(f_popu_data, chunks={})

In [None]:
ds_popu_data

In [None]:
# only keep data from 2016 and 2017
end_time = np.datetime64("2017-12-01", "ns")
limit_time = np.datetime64("2016-01-01", "ns")
ds_popu_data = ds_popu_data.sel(time=slice(limit_time, end_time))
ds_popu_data.time.values

In [None]:
if ds_popu_data.time.values[0] > ds_popu_data.time.values[-1]:
    # sort the time dimension in ascending order
    ds_popu_data = ds_popu_data.sortby("time")
start_of_year = pd.Timestamp(
    year=pd.to_datetime(ds_popu_data.time.values[0]).year,
    month=1,
    day=1,
    hour=12,
    minute=0,  # 0 hours for era5 data
)
end_of_year = pd.Timestamp(
    year=pd.to_datetime(ds_popu_data.time.values[-1]).year,
    month=12,
    day=1,
    hour=12,
    minute=0,
)
monthly_time = pd.date_range(start=start_of_year, end=end_of_year, freq="MS")
monthly_time

In [None]:
# reindex the time dimension to match the monthly time
ds_popu_data = ds_popu_data.reindex(time=monthly_time, method="ffill")
ds_popu_data

In [None]:
ds_popu_data["total-population"].sel(
    latitude=8.67, longitude=49.39, method="nearest"
).to_dataframe().head(14)

In [None]:
ds_popu_data["total-population"].attrs

In [None]:
# resolution of population data
res = ds_popu_data.latitude[1] - ds_popu_data.latitude[0]
res

In [None]:
test_popu_data = ds_popu_data.sel(
    latitude=8.67, longitude=49.39, method="nearest"
).to_dataframe()
test_popu_data.head(5)

In [None]:
test_popu_data["total-population"].plot()

In [None]:
ds_popu_data["total-population"][-1].plot.pcolormesh(figsize=(9, 5), robust=True)

## Check all the data visually
### Cartesian grids

In [None]:
cartesian_grid_file = (
    data_folder
    / "processed"
    / "era5_data_2016-2017_allm_2t_tp_monthly_unicoords_adjlon_celsius_mm_05deg_trim_ts20250923-065745_hssc-laptop01.nc"
)
pop_cartesian_grid_file = (
    data_folder
    / "processed"
    / "population_histsoc_30arcmin_annual_1901_2021_unicoords_2016-2017_ts20250923-065749_hssc-laptop01.nc"
)

In [None]:
# read the netcdf data into xarray
cartesian_grid = xr.open_dataset(cartesian_grid_file)
pop_cartesian_grid = xr.open_dataset(pop_cartesian_grid_file)

In [None]:
cartesian_grid.info

In [None]:
# plot the cartesian grid data of t2m and tp for 2016-2017, all months
cartesian_grid.t2m.plot.pcolormesh(
    col="time", col_wrap=4, cmap="coolwarm", robust=True, figsize=(15, 10)
)
plt.savefig("era5_2016_2017_plots_t2m.png", dpi=300)
plt.show()

In [None]:
cartesian_grid.tp.plot.pcolormesh(
    col="time", col_wrap=4, cmap="coolwarm", robust=True, figsize=(15, 10)
)
plt.savefig("era5_2016_2017_plots_tp.png", dpi=300)
plt.show()

In [None]:
# plot the cartesian population data
pop_cartesian_grid["total-population"].plot.pcolormesh(
    row="time", cmap="coolwarm", robust=True, figsize=(8, 8)
)
plt.savefig("population_2016_2017_plots.png", dpi=300)
plt.show()

### NUTS averaged data

In [None]:
nuts_aggregated_file = (
    data_folder / "processed" / "NUTS_RG_20M_2024_4326_agg_era5_popu_2016-01-2017-12.nc"
)
nuts_shapefile = data_folder / "in" / "NUTS_RG_20M_2024_4326.shp"

In [None]:
# read the netcdf data into xarray
nuts_grid = xr.open_dataset(nuts_aggregated_file)

In [None]:
# convert the xarray DataArray to pandas DataFrame
# to be able to merge with the GeoDataFrame
nuts_grid = nuts_grid.to_dataframe().reset_index()
nuts_grid.head(5)

In [None]:
# read the NUTS shapefile
NUTS_shapes = gpd.read_file(nuts_shapefile)
# merge the shapes data with the grid values
era5_nuts = NUTS_shapes.merge(nuts_grid, on="NUTS_ID")

In [None]:
# Create plots for each monthly timestamp
unique_times = era5_nuts["time"].unique()
n_times = len(unique_times)
# Calculate subplot layout
n_cols = 4
n_rows = int(np.ceil(n_times / n_cols))

In [None]:
fig, ax = plt.subplots(n_rows, n_cols, figsize=(20, 5 * n_rows))
ax = ax.flatten() if n_times > 1 else [ax]

for i, timestamp in enumerate(unique_times):
    axes = ax[i]

    # Filter data for current timestamp
    current_data = era5_nuts[era5_nuts["time"] == timestamp]

    # Create the plot
    current_data.plot(
        ax=axes,
        column="t2m",
        legend=True,
        cmap="coolwarm",
        markersize=0.5,
        legend_kwds={"shrink": 0.8},
    )

    axes.set_title(f"Temperature (t2m) - {pd.to_datetime(timestamp).strftime('%Y-%m')}")
    axes.set_xlabel("Longitude")
    axes.set_ylabel("Latitude")

plt.tight_layout()
fig.savefig("nuts3_t2m.png", dpi=300)

In [None]:
fig, ax = plt.subplots(n_rows, n_cols, figsize=(20, 5 * n_rows))
ax = ax.flatten() if n_times > 1 else [ax]

for i, timestamp in enumerate(unique_times):
    axes = ax[i]

    # Filter data for current timestamp
    current_data = era5_nuts[era5_nuts["time"] == timestamp]

    # Create the plot
    current_data.plot(
        ax=axes,
        column="tp",
        legend=True,
        cmap="coolwarm",
        markersize=0.5,
        legend_kwds={"shrink": 0.8},
    )

    axes.set_title(
        f"Precipitation (tp) - {pd.to_datetime(timestamp).strftime('%Y-%m')}"
    )
    axes.set_xlabel("Longitude")
    axes.set_ylabel("Latitude")

plt.tight_layout()
fig.savefig("nuts3_tp.png", dpi=300)

In [None]:
# create the plots for the annual timestamps
january_times = [time for time in unique_times if pd.to_datetime(time).month == 1]
n_times = len(january_times)
# Calculate subplot layout
n_cols = 4
n_rows = int(np.ceil(n_times / n_cols))

In [None]:
fig, ax = plt.subplots(n_rows, n_cols, figsize=(20, 5 * n_rows))
ax = ax.flatten() if n_times > 1 else [ax]

for i, timestamp in enumerate(january_times):
    axes = ax[i]

    # Filter data for current timestamp
    current_data = era5_nuts[era5_nuts["time"] == timestamp]

    # Create the plot
    current_data.plot(
        ax=axes,
        column="total-population",
        legend=True,
        cmap="coolwarm",
        markersize=0.5,
        legend_kwds={"shrink": 0.8},
    )

    axes.set_title(f"Total Population - {pd.to_datetime(timestamp).strftime('%Y-%m')}")
    axes.set_xlabel("Longitude")
    axes.set_ylabel("Latitude")


# Remove empty subplots if any
for j in range(i + 1, len(ax)):
    fig.delaxes(ax[j])

plt.tight_layout()
fig.savefig("nuts3_total_population.png", dpi=300)

## Read data from the model backend, plotting over NUTS regions

In [None]:
jmodel_output = data_folder / "in" / "output_JModel_global.nc"

In [None]:
ds_jmodel = xr.open_dataset(jmodel_output)

In [None]:
ds_jmodel

In [None]:
# plot the cartesian grid data of t2m and tp for 2016-2017, all months
ds_jmodel.R0.plot.pcolormesh(
    col="time", col_wrap=4, cmap="coolwarm", robust=True, figsize=(15, 10)
)
plt.savefig("era5_2016_2017_plots_R0.png", dpi=300)
plt.show()

In [None]:
# plotting over NUTS regions
# currently in the other notebook, we have to decide how to split the material