In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import dask
from scipy.stats import theilslopes
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import os

BASE_FOLDER = "/gws/nopw/j04/aopp/tildah/HW_metrics/"
MODELS = [ "CanESM5", "ACCESS-ESM1-5", "CMCC-CM2-SR5", 
    "HadGEM3-GC31-LL", "MIROC6", "MPI-ESM1-2-LR",
    "NorESM2-LM"
]

EXPERIMENTS = ["historical",
    "hist-GHG", "hist-aer", "hist-nat",
    "hist-sol", "hist-volc", "hist-totalO3"
]

ERA5 = xr.open_dataset("~/Internship/model_map_creation/data/ERA5_historical_trends.nc")
if ERA5['lat'][0] > ERA5['lat'][-1]:
    ERA5 = ERA5.sortby('lat')

ERA5_mask = xr.open_dataset('~/ERA5_lsm_NH.nc')
ERA5_mask = ERA5_mask.rename({
    "longitude": "lon",
    "latitude": "lat",
    "time": "valid_time"
})

ERA5_mask = ERA5_mask.sel(lon=slice(-25,45), lat=slice(72, 36))
land_mask_binary = xr.where(ERA5_mask > 0, 1, 0) 
lsm = land_mask_binary['lsm']

TIME_PERIODS = [
    ("1940-1979", slice(1940, 1979)),
    ("1980-2020", slice(1980, 2020)),
    ("1940-2020", slice(1940, 2020)),
]

lat1 = 36
lat2 = 72
latmid = 50

REGIONS = {
    "Europe": lambda trends: trends,                  # whole domain
    "North": lambda trends: trends.sel(lat=slice(latmid, lat2)),
    "South": lambda trends: trends.sel(lat=slice(lat1, latmid)),
} 

def theil_sen_slope(y, time):
    """Compute Theil-Sen slope."""
    mask = ~np.isnan(y)  # Mask NaN values
    if np.sum(mask) < 5:  # Require at least 5 valid points
        return np.nan

    slope, _, _, _ = theilslopes(y[mask], time[mask])
    return slope


def compute_trends_vectorized(ds):
    time = ds.coords['year'].values  # Extract time as NumPy array
    trends = {}

    for var in ds.data_vars:
        data = ds[var].load()  # Ensure data is loaded into memory if using Dask

        for period, (start, end) in zip(["1940-1979", "1980-2020", "1940-2020"], [(1940, 1979), (1980, 2020), (1940,2020)]):
            time_mask = (time >= start) & (time <= end)
            time_subset = time[time_mask]
            data_subset = data.sel(year=time_mask)

            # Compute Theil-Sen slope
            slopes = xr.apply_ufunc(
                theil_sen_slope,
                data_subset,
                xr.DataArray(time_subset, dims=["year"], coords={"year": time_subset}),
                input_core_dims=[["year"], ["year"]],
                vectorize=True,
                dask="parallelized",
                output_dtypes=[np.float64]
            )

            # Store trend and significance mask
            trends[f'{period}'] = slopes * 10 

    return xr.Dataset(trends, coords={"lat": ds.lat, "lon": ds.lon})

def compute_member_region_means(trends):
    """Return dict region -> xarray.DataArray averaged over lon/lat (weighted)."""
    # apply land-sea mask already done upstream; assume trends has lat/lon coords
    weights = np.cos(np.deg2rad(trends.lat))
    out = {}
    for rname, selector in REGIONS.items():
        sel = selector(trends)
        # weighted mean over latitude and longitude
        out[rname] = sel.weighted(weights).mean(dim=("lat", "lon"))
    return out

def get_all_trends(model, experiments=EXPERIMENTS, base_folder=BASE_FOLDER):
    """
    Now returns rows per member (no averaging across members).
    Adds 'member' column to CSV and writes one row per (model, experiment, member, region, period).
    """
    rows = []

    for exp in experiments:
        print(f"Processing: model={model}, exp={exp}")
        folder = os.path.join(base_folder, model, exp)
        if not os.path.isdir(folder):
            print(f"  Missing folder: {folder}  -> skipping")
            continue

        files = [f for f in os.listdir(folder) if f.endswith(".nc")]
        if not files:
            print("  No .nc files found -> skipping")
            continue

        for file in files:
            path = os.path.join(folder, file)
            ds = xr.open_dataset(path)
            # assume member is in filename as before
            try:
                member = file.split('_')[2]
            except Exception:
                member = os.path.splitext(file)[0]

            # restrict years and compute trends
            ds = ds.sel(year=slice(1940, 2020))
            trends_ds = compute_trends_vectorized(ds).squeeze()
            trends_ds = trends_ds.squeeze().drop_vars([v for v in ['quantile', 'height', 'type'] if v in trends_ds])
            trends = trends_ds.interp(lat=ERA5['lat'], lon=ERA5['lon'], method='nearest')
            trends = trends.where(lsm == 1)
            print(trends)
            #trends['1980-2020'].plot()
            # drop any dims/vars not needed
            if 'valid_time' in trends.coords or 'valid_time' in trends:
                trends = trends.squeeze().drop_vars('valid_time')

            # compute region means per member (DataArray indexed by year)
            region_means = compute_member_region_means(trends)

            # for each region and each period, compute mean over the period and append row for this member
            for rname, da in region_means.items():
                for period_label, year_slice in TIME_PERIODS:
                    # convert DataArray / xarray scalar to python float if possible
                    value = float(da[period_label].values)
                    value = float(f"{value:.4g}")  
                    rows.append({
                        "model": model,
                        "experiment": exp,
                        "member": member,
                        "period": period_label,
                        "region": rname,
                        "value": value
                    })
            ds.close()

    df = pd.DataFrame(rows)
    return df

# Run and save
for model in MODELS:
    df_all = get_all_trends(model = model)
    df_all.to_csv(f"{model}_experiment_member_region_period_trends.csv", index=False)

In [4]:
# ERA5 
rows = []
ds = xr.open_dataset('~/Internship/ERA5/ERA5_HWF_with2025.nc')

# restrict years and compute trends
ds = ds.sel(year=slice(1940, 2020))
trends_ds = compute_trends_vectorized(ds).squeeze()
trends_ds = trends_ds.squeeze().drop_vars([v for v in ['quantile', 'height', 'type'] if v in trends_ds])
trends = trends_ds.interp(lat=ERA5['lat'], lon=ERA5['lon'], method='nearest')
trends = trends.where(lsm == 1)
# drop any dims/vars not needed


if 'valid_time' in trends.coords or 'valid_time' in trends:
    trends = trends.squeeze().drop_vars('valid_time')

# compute region means per member (DataArray indexed by year)
region_means = compute_member_region_means(trends)

# for each region and each period, compute mean over the period and append row for this member
for rname, da in region_means.items():
    for period_label, year_slice in TIME_PERIODS:
        # convert DataArray / xarray scalar to python float if possible
        value = float(da[period_label].values)
        value = float(f"{value:.4g}")  
        rows.append({
            "model": "ERA5",
            "experiment": "ERA5",
            "member": "ERA5",
            "period": period_label,
            "region": rname,
            "value": value
        })
ds.close()

ERA5_avg = pd.DataFrame(rows)

print(ERA5_avg)
ERA5_avg.to_csv("ERA5_region_period_trends.csv", index=False)

  model experiment member     period  region    value
0  ERA5       ERA5   ERA5  1940-1979  Europe -0.03117
1  ERA5       ERA5   ERA5  1980-2020  Europe  1.97000
2  ERA5       ERA5   ERA5  1940-2020  Europe  0.49160
3  ERA5       ERA5   ERA5  1940-1979   North  0.02410
4  ERA5       ERA5   ERA5  1980-2020   North  1.26800
5  ERA5       ERA5   ERA5  1940-2020   North  0.32500
6  ERA5       ERA5   ERA5  1940-1979   South -0.08450
7  ERA5       ERA5   ERA5  1980-2020   South  2.67000
8  ERA5       ERA5   ERA5  1940-2020   South  0.65500
