In [13]:
from pathlib import Path
import xarray as xr
import numpy as np
import pandas as pd
import re
import cftime
from datetime import datetime
from dask.diagnostics import ProgressBar
from scipy.stats import t
from scipy.stats import linregress
from eofs.xarray import Eof
from eofs.examples import example_data_path
import calendar
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import gridspec
import matplotlib.path as mpath
from matplotlib.ticker import MaxNLocator
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.patches as mpatches
from types import SimpleNamespace


In [34]:
class EnsembleMemberCalculations:
    def __init__(self, data_path: Path, model=None):
        #use self. to access attributes
        self.data_path = Path(data_path)
        
        #referencing the parent model and setting up shortcuts
        self.model = model
        self.experiment = model.experiment if model else None
        self.all_experiments = self.experiment.all_experiments if self.experiment else None


        #extracting the member_id (ensemble number) from the filename as named in CMIP6 convention
        match = re.search(r"(r\d+i\d+p\d+f\d+)", self.data_path.stem)
        if not match:
            raise ValueError(f"Could not extract the member_id from filenmae: {self.data_path.name}")

        #selecting the ensemble number and assigning it to the member ID
        self.member_id = match.group(1)


        self.data = None

    def load_data(self):
        #loads the data of an ensemble member
        self.data = xr.open_dataset(self.data_path)
        return self.data

    def get_calendar_type(self):
        if self.model is None:
            raise RuntimeError("Model is not set")
        return self.model.calendar_type

    def get_time_bounds(self):
        """
        Access the model's time_bounds - each model has a different calendar thus type of time to index/slice by
        """
        if self.model is None or self.model.time_bounds is None:
            raise RuntimeError("Model or time_bounds is not set")
        return self.model.time_bounds

    def select_data(self, varname: str):
        """
        Just selecting some data from the time bounds in model
        """
        ds = self.load_data()
        start, end, delta = self.model.time_bounds
        ds = ds[varname].sel(time=slice(start, end))
        
        return ds

    def crop_to_domain_ensemble(self, da: xr.DataArray):
        """
        Crops the ensemble members data (anomaly, raw data etc.) to the lat and lon of the domain
        lat_bounds = (lat_min, lat_max)
        lon_bounds = (lon_min, lon_max)
        Assuming data of the format -180 to 180, -90 to 90???
        Will add another function to check and if not flag error
        Will be a bit for if model is era5 shift the coords at somepoint in the processing data stages.
        Need to already have selected the variable"""

        if self.all_experiments and self.all_experiments.lat_bounds and self.all_experiments.lon_bounds:
            da = da.sel(
                lat=slice(self.all_experiments.lat_bounds[0], self.all_experiments.lat_bounds[1]),
                lon = slice(self.all_experiments.lon_bounds[0], self.all_experiments.lon_bounds[1])
            )
        return da

    def calc_anomaly(self, output_file: Path, varname: str):
        """
        Will calculate the anomaly
        """
        #print(output_file)
        #print(self.model.name)
        if output_file.exists():
            #print(f"Loading the exisiting file for the anomaly from: {output_file}")
            anomaly = xr.open_dataset(output_file)[varname]
            anomaly_cropped = self.crop_to_domain_ensemble(anomaly)
            return anomaly_cropped

   # def cal_seasonal_mean(self)

        

class ModelCalculations:
    def __init__(self, folder: Path, experiment):
        self.folder = Path(folder)
        self.experiment = experiment
        self.all_experiments = self.experiment.all_experiments if self.experiment else None

        #can now set the model name using the folder name
        #the .name takes the last bit of the filepath (which in this case is the model)
        self.name = self.folder.name
        
        #creating my dictionary of ensemble members
        self.members = {}
        for i, file in enumerate(self.folder.glob("*.nc")):
            if i >= 3:
                break
            member = EnsembleMemberCalculations(file, model=self)
            self.members[member.member_id] = member
        
        #basically I pick the first member (load in just one)
        #doing this to then use to find the calendar type
        #and then find the time bounds as this will be the same time (correct datetime) per model
        member1 = next(iter(self.members.values()))
        member1.load_data()

        #find the calendar type
        time_var = member1.data["time"]
        self.calendar_type = type(time_var.time.values[0])
        self.time_bounds = None

    def set_time_bounds(self, start_year: int, end_year: int):
        #selecting the correct datetime format for the chosen time period
        if issubclass(self.calendar_type, cftime.DatetimeNoLeap):
            start = cftime.DatetimeNoLeap(start_year,1,16)
            end = cftime.DatetimeNoLeap(end_year,12,16)
        elif issubclass(self.calendar_type, cftime.Datetime360Day):
            start = cftime.Datetime360Day(start_year,1,16)
            end = cftime.Datetime360Day(end_year,12,16)
        else:
            start = datetime(start_year,1,16)
            end = datetime(end_year,12,16)

        #assumes that the last year is filled with data
        delta = end.year - start.year + 1

        self.time_bounds = (start, end, delta)
        return self.time_bounds

    def crop_to_domain_model(self, da: xr.DataArray):
        """
        Crops the ensemble members data (anomaly, raw data etc.) to the lat and lon of the domain
        lat_bounds = (lat_min, lat_max)
        lon_bounds = (lon_min, lon_max)
        Assuming data of the format -180 to 180, -90 to 90???
        Will add another function to check and if not flag error
        Will be a bit for if model is era5 shift the coords at somepoint in the processing data stages.
        Need to already have selected the variable"""

        if self.all_experiments and self.all_experiments.lat_bounds and self.all_experiments.lon_bounds:
            da = da.sel(
                lat=slice(self.all_experiments.lat_bounds[0], self.all_experiments.lat_bounds[1]),
                lon = slice(self.all_experiments.lon_bounds[0], self.all_experiments.lon_bounds[1])
            )
        return da


    def calc_ensemble_mean(self, varname: str):
        """
        Calculate the SPATIAL ensemble mean for a variable
        Restrict it to the time bounds
        Could in future restrict to an area
        If the file exists load this instead
        """
        output_file = self.all_experiments.output_dir / f"ens_mean_spat/{model_name}/{varname}_mon_{self.name}_{model_name}_spatial_DJF_EM_1850-2015.nc"
        
        if output_file.exists():
            print(f"Loading the exisiting file for the spatial ensemble mean from: {output_file}")
            ens_mean = xr.open_dataset(output_file)[varname]
            return ens_mean
        
        #Get all the filepaths from the ensemble members
        #this line basically is getting the filepaths of the ensemble member class objects - i think?
        file_paths = [member.data_path for member in self.members.values()]

        #then opening them all together
        ds = xr.open_mfdataset(
            file_paths,
            combine="nested",
            concat_dim="member",
            parallel=True,
            chunks={"member": 1}
        )[varname]

        #selecting the time period
        start,end, delta = self.time_bounds
        ds = ds.sel(time=slice(start,end))

        #calculating the ensemble mean across the members
        ens_mean = ds.mean(dim="member")
    
        # Compute the result with a progress bar
        print(f"Computing the ensemble mean for {self.name}...")
        with ProgressBar():
            ens_mean.compute()#.to_netcdf(output_file)

        ens_mean.to_netcdf(output_file)
        print(f'Ensemble mean NOT saved to {output_file}')

        return ens_mean

    def calc_seasonal_mean_per_model(self, varname:str):
        """
        Calculating the seasonal ensemble mean (could actually just be the seasonal mean for ay single file!!!)
        Only accpeting one file so per model
        Will save the seasonal mean as well as outputting it
        Wil crop to the already saved time periods.
        Will at some point make more versatille to choose the months of the year etc. as a list of int.
        """

        #select the time bounds and seasons
        start,end = self.time_bounds
        print(start,end)
        months = self.all_experiments.season
        season = "".join([calendar.month_abbr[m][0] for m in months])
        print(months, season)
        
        seas_EM_output_file = self.all_experiments.output_dir / f"ens_mean_spat/{self.experiment.name}/{self.name}/psl_mon_{self.experiment.name}_{self.name}_spatial_{season}_EM_1850-2015.nc"
        
        if seas_EM_output_file.exists():
            print(f"Loading the seasonal mean for model: {self.name}")
            seasonal_mean = xr.open_dataset(seas_EM_output_file)[varname]
            return seasonal_mean

        #loading the ensemble mean
        ens_mean = self.calc_ensemble_mean(varname)

        #checking it is a datetime object - already done???
        #ens_mean['time'] = xr.decode(ens_mean).time

        #selecting the ensemble mean for the specified time (might load full ensemble mean as this is already calculated and then slice here)
        ens_mean = ens_mean.sel(time=slice(start,end))

        #creating mask for the season and grouping the months
        mask = ens_mean['time'].dt.month.isin(months)
        months_in_seas = ens_mean.sel(time=mask)

        #assigning a 'season year'
        season_year = months_in_seas['time'].dt.year

        #general code for if season WRAPS then fix the year to avg over
        if months[0] > months[-1]:
            season_year = xr.where(months_in_seas['time'].dt.month >= months[0],
                                  season_year +1,
                                  season_year)

        #assign as "year" instead of "season_year" for future bits
        months_in_seas = months_in_seas.assign_coords(year=season_year)
        
        #now groupby and average over the years
        ens_mean_seas = months_in_seas.groupby('year').mean(dim='time')
        #ens_mean_seas.to_netcdf(seas_EM_output_file)

        print('calc seas_mean')
        
        return ens_mean_seas

    def calc_linear_trend_per_model(self, varname:str):
        """
        Calcualte the linear trend from the seasonal ensemble mean
        Full stats file
        Converts from per unit time to just the change over the entire period (delta)"""
        #getting the timebounds
        start,end, delta = self.time_bounds
             
        #create and check if output_file exists
        output_file = self.all_experiments.output_dir / f"trend_calc_LESFMIP/linear_regression/NAO/{self.experiment.name}/{self.name}/psl_mon_{self.experiment.name}_{self.name}_DJF_linear_trend_1850-2015_stats.nc"

        if output_file.exists():
            print(f"loading the linear trend for model: {self.name}")
            return xr.open_dataset(output_file)
            
        #call the seasonal ensemble mean method
        seas_ens_mean = self.calc_seasonal_mean_per_model(varname)
        print(seas_ens_mean)
    
        time = seas_ens_mean['year'].values
        lat = seas_ens_mean['lat'].values
        lon = seas_ens_mean['lon'].values
        time_numeric = np.arange(len(time))
    
        slope = np.full((len(lat), len(lon)), np.nan)
        intercept = np.full((len(lat), len(lon)), np.nan)
        p_value = np.full((len(lat), len(lon)), np.nan)
        stderr = np.full((len(lat), len(lon)), np.nan)

        #now solving for the slope and other stats (multiplying through by delta to convert from per index
        #(year) to change over the entire period)
        for i in range(len(lat)):
            for j in range(len(lon)):
                ts = seas_ens_mean[:, i, j].values
                if np.all(np.isfinite(ts)):
                    reg = linregress(time_numeric, ts)
                    slope[i, j] = reg.slope
                    intercept[i, j] = reg.intercept
                    p_value[i, j] = reg.pvalue
                    stderr[i, j] = reg.stderr
    
        n = len(time_numeric)
        df = n - 2
        alpha = 0.05
        t_crit = t.ppf(1 - alpha/2, df)
    
        ci_lower = slope - t_crit * stderr
        ci_upper = slope + t_crit * stderr
    
        slope_da = xr.DataArray(slope, coords=[lat, lon], dims=["lat", "lon"], name="slope")
        intercept_da = xr.DataArray(intercept, coords=[lat, lon], dims=["lat", "lon"], name="intercept")
        p_value_da = xr.DataArray(p_value, coords=[lat, lon], dims=["lat", "lon"], name="p_value")
        ci_lower_da = xr.DataArray(ci_lower, coords=[lat, lon], dims=["lat", "lon"], name="slope_CI_lower")
        ci_upper_da = xr.DataArray(ci_upper, coords=[lat, lon], dims=["lat", "lon"], name="slope_CI_upper")
    
        # Save to one combined netCDF file
        trend_stats = xr.Dataset({
            "slope": slope_da,
            "intercept": intercept_da,
            "p_value": p_value_da,
            "slope_CI_lower": ci_lower_da,
            "slope_CI_upper": ci_upper_da
        })
        
        #combined_ds.to_netcdf(output_file)
        return trend_stats

    def calc_anomalies_all_members(self, varname: str,):
        """
        This will calculate the anomaly across the ensemble members
        Could maybe do this for all the other individual ensemble calcs needed to calc anomaly?
        Will return a dictionary {member_id: anomaly_dataarray}
        Only calculates for the hisotircal??? - no need it for all!"""

        self.all_experiments.output_dir.mkdir(parents=True, exist_ok=True)

        results = {}
        for member_id, member in self.members.items():
            output_file = self.all_experiments.output_dir / f"psl_anomalies/{self.experiment.name}/{self.name}/psl_mon_{self.experiment.name}_{self.name}_{member_id}_DJF_anomaly.nc"
            anomaly = member.calc_anomaly(output_file, varname)
            results[member_id] = anomaly
        return results

    def calc_EOF_concat(self, varname: str, max_modes: int):
        """
        Calculates the EOF from a concatenated list of anomalies for one model
        outputs the number of modes specified
        Calls on the anomalies calculated for each ensemble member
        This right now does all models not just historical - need an option for that somewhere???
        """
        output_file = self.all_experiments.output_dir / f"EOF/{self.experiment.name}/DJF/{self.name}/psl_mon_{self.experiment.name}_{self.name}_DJF_EOF_concat_1850-2015.nc"

        if output_file.exists():
            #print(f"Loading the exisiting file for the EOF for: {output_file}")
            EOF = xr.open_dataset(output_file)['eofs']#['__xarray_dataarray_variable__']
            #EOF.sel(mode=0).plot()
            #plt.show()
            return EOF


        else:
            print('concatenating')
            #preprocess anomalies method called
            #converting from Pa to hPa - bear in mind I will need to have this as an option to convert.
            anomalies = {k: v / 100 for k, v in self.calc_anomalies_all_members(varname).items()}
            #anomalies = self.calc_anomalies_all_members(varname)
            
            anomaly_list = list(anomalies.values())
            anomaly_concat = xr.concat(anomaly_list, dim="ensemble")
            anomaly_2d = anomaly_concat.stack(time=('ensemble', 'year')).reset_index('time', drop=True)
            anomaly_trans = anomaly_2d.transpose('time', 'lat', 'lon')
            
            coslat = np.cos(np.deg2rad(anomaly_trans.coords['lat'].values)).clip(0., 1.)
            wgts = np.sqrt(coslat)[...,np.newaxis]

            solver = Eof(anomaly_trans, weights=wgts)
            EOF = solver.eofs(neofs=max_modes).sel(mode=([0,max_modes-1]))

            #checking orthogonality
            #1. Get the first two PCs (time series of each mode)
            pcs = solver.pcs(npcs=max_modes, pcscaling=0)  # shape (time, mode)
    
            # 2. Compute their correlation / covariance
            pc_corr = np.corrcoef(pcs[:, 0], pcs[:, 1])[0, 1]
            print("Correlation between PC1 and PC2:", pc_corr)

            # 3. (Optional) Gram matrix of PCs
            G = np.dot(pcs.T, pcs)  # shape (2,2)
            print("Gram matrix of PCs:\n", G)

            #Check orthogonality
            EOF1 = EOF.sel(mode=0)
            EOF2 = EOF.sel(mode=1)
    
            inner = (EOF1 * EOF2).sum(dim=("lat", "lon"))
            norm1 = np.sqrt((EOF1**2).sum(dim=("lat", "lon")))
            norm2 = np.sqrt((EOF2**2).sum(dim=("lat", "lon")))
            cos_sim = inner / (norm1 * norm2)
            is_ortho = bool(np.isclose(inner, 0, atol=1e-10))
            
            print(f"Inner product: {inner:.3e}, Cosine similarity: {cos_sim:.3e}")
            if not is_ortho:
                print("⚠️ Warning: EOF1 and EOF2 are not orthogonal within tolerance!")

            # --- Regression maps for all modes ---
            pcs_all = solver.pcs(npcs=max_modes, pcscaling=0)  # shape (time, max_modes)
            regression_maps = []
        
            for mode_idx in range(max_modes):
                pc = pcs_all[:, mode_idx]
                pc_da = xr.DataArray(pc, dims="time", coords={"time": anomaly_trans.coords["time"]})
                reg_map = (anomaly_trans * pc_da).mean(dim="time") / pc_da.var(dim="time")
                regression_maps.append(reg_map)
        
            # Combine into a single DataArray with a 'mode' dimension
            regression_map_all = xr.concat(regression_maps, dim=pd.Index(range(max_modes), name='mode'))
            #regression_map_all.to_netcdf(output_file)
            EOF.to_netcdf(output_file)
            
        return EOF

    def projection(self, varname: str, max_modes):
        #need to somehow specify that its historical EOF ONLY - but maybe still calculate for the other ones - need all anomalies???

        output_file = self.all_experiments.output_dir / f"regression_patterns/NAO/psl_mon_historical_HadGEM3-GC31-LL_DJF_NAO_EOF_pattern_1850-2015.nc"

        if output_file.exists():
            print("loading EOF file for model:", self.name)
            return xr.open_dataset(output_file)

        trend = self.calc_linear_trend_per_model['slope']
        EOF = self.calc_EOF_concat['eofs']

        #weighting the trend ONLY (EOF already weighted)
        w = np.sqrt(np.cos(np.radians(trend['lat'])))
        w2d, _ = xr.broadcast(w, trend)

        #stacking the trend and EOF 
        trend_w = (trend * w2d).stack(spatial=('lat','lon')).values
        EOF = (EOF).stack(spatial=('lat','lon')).values #lat,lon,mode

        #transpose for the lstsq
        E_matrix = EOF.T

        #solve the weighted least squares
        c = np.linalg.lstsq(E_matrix, trend_w, rcond=None)[0]

        #reconstruct the coefficients
        #got to here need to figure out how to extract???
        #for i in range(0, max_modes) vibes then extract and just save
        #somehow return??? need to think more about the format

        return ('this')

            
            
        

class ExperimentCalculations:
    def __init__(self, folder: Path, all_experiments):
        """
        Initialise an experiment object
        Each folder inside this experiment is a model
        """
        self.folder = Path(folder)
        
        #the experiment name (resturns the last part of the name!!!)
        self.name = self.folder.name
        self.all_experiments = all_experiments

        #this is the dictionary of ModelCalculations objects
        self.models = {}

        #going through all the sub folders within this experiment folder
        #The creating a ModelCalculations object for models folders
        for model_folder in self.folder.iterdir():
            if model_folder.is_dir():
                model_name = model_folder.name

                #skipping CNRM-CM6-1 for now
                if model_name == 'CNRM-CM6-1':
                    continue
                    
                model = ModelCalculations(model_folder, experiment=self)
                self.models[model.name] = model
        
    def calc_ensemble_mean_all_models(self, varname: str):
        """
        Calculate the ensemble mean for each model within this experiment
        Save results to an output dir
        """


        self.all_experiments.output_dir.mkdir(exist_ok=True, parents=True)

        results = {}
        for model_name, model in self.models.items():
            print(self.name)

            ens_mean = model.calc_ensemble_mean(varname)
            results[model_name] = ens_mean

        return results
        
    def calc_seasonal_EM_all_models(self, varname: str):
        """
        Calculate the ensemble mean for each model within this experiment
        Save results to an output dir
        """

        self.all_experiments.output_dir.mkdir(exist_ok=True, parents=True)

        results = {}
        for model_name, model in self.models.items():
            print(self.name)
            
            seas_ens_mean = model.calc_seasonal_mean_per_model(varname)
            results[model_name] = seas_ens_mean

        return results

    def calc_linear_trend_all_models(self, varname: str):
        """
        Calculate the linear trend from the seasonal ensemble mean
        Save results to an output dir
        """

        self.all_experiments.output_dir.mkdir(exist_ok=True, parents=True)

        results = {}
        for model_name, model in self.models.items():
            print(self.name)
            
            trend = model.calc_linear_trend_per_model(varname)
            results[model_name] = trend

        return results


    def calc_anomalies_all_models(self, varname: str):
        """
        Calculate the anomalies across all ensemble members and models
        Will probably add in the steps that happen before - either call here or within calc_anomaly
        returns a dict of dict: {model_name: {member_id: anomaly_dataset}}
        """

        self.all_experiments.output_dir.mkdir(parents=True, exist_ok=True)


        results = {}
        for model_name, model in self.models.items():
            model_results = model.calc_anomalies_all_members(varname)
            results[model_name] = model_results

        return results

    def calc_EOF_concat_all_models(self, varname: str, max_modes: int):
        """
        To calculate the EOFs across all models
        """

        self.all_experiments.output_dir.mkdir(parents=True, exist_ok=True)

        results = {}
        for model_name, model in self.models.items():
            model_results = model.calc_EOF_concat(varname, max_modes)
            results[model_name] = model_results

        return results

    def projection_all_models(self, varname: str, max_modes: int):
        """
        Project linear trend onto ALL modes for each model
        """
        self.all_experiments.output_dir.mkdir(parents=True, exist_ok=True)

        results = {}
        for model_name, model in self.models.items():
            proj = model.projection(varname, max_modes)
            results[model_name] = proj

        return results
        


class AllDataComparisons:
    def __init__(self, folder: Path, output_dir: Path):
        """
        Initialise a comparisons object
        Will be used to comapre between all models and all experiments.
        Pass through the folder up to the experiments - will automatically sort for the experiment available
        """
        self.folder = Path(folder)
        self.output_dir = Path(output_dir)
        #the experiment name (resturns the last part of the name!!!)

        #this is the dictionary of ExperimentCalculations objects
        self.experiments = {}

        #going through all the sub folders within this experiment folder
        #The creating a ModelCalculations object for models folders
        for experiment_folder in self.folder.iterdir():
            if experiment_folder.is_dir():
                experiment = ExperimentCalculations(experiment_folder, all_experiments=self)
                self.experiments[experiment.name] = experiment
        #for the domain bounds
        self.lat_bounds = None
        self.lon_bounds = None

        self.season = None

    def summary(self, show_members: bool = False):
        """
        Print a summary of the experiments, models, and optionally ensemble members.
        """
        print(f"\n📊 Summary of AllDataComparisons: '{self.folder.name}'")
        print(f"  Experiments loaded: {len(self.experiments)}")

        for exp_name, exp in self.experiments.items():
            print(f"  └── Experiment: {exp_name} ({len(exp.models)} models)")

            for model_name, model in exp.models.items():
                print(f"      └── Model: {model_name} ({len(model.members)} members)")

                if show_members:
                    for member_id, member in model.members.items():
                        print(f"          └── Member: {member_id}")
                        
            
    def set_time_bounds_all(self, start_year: int, end_year:int):
        """
        Set the time bounds for all the models in this experiment
        """

        for experiment in self.experiments.values():
            for model in experiment.models.values():
                model.set_time_bounds(start_year, end_year)

    def set_domain(self, lat_bounds: tuple, lon_bounds: tuple):
        """
        Defining the coordingates for the domain for cropping
        """

        self.lat_bounds = lat_bounds
        self.lon_bounds = lon_bounds
        print(self.lat_bounds, self.lon_bounds)

    def set_season(self, season: list[int]):
        """
        Sets the season - list like 12,1,2 is DJF
        """

        self.season = season
        print(season)

    def crop_to_domain_all_exp(self, da: xr.DataArray):
        """
        Crops to the lat and lon of the domain
        lat_bounds = (lat_min, lat_max)
        lon_bounds = (lon_min, lon_max)
        Assuming data of the format -180 to 180, -90 to 90???
        Will add another function to check and if not flag error
        Will be a bit for if model is era5 shift the coords at somepoint in the processing data stages.
        Need to already have selected the variable"""

        if self and self.lat_bounds and self.lon_bounds:
            da = da.sel(
                lat=slice(self.lat_bounds[0], self.lat_bounds[1]),
                lon = slice(self.lon_bounds[0], self.lon_bounds[1])
            )
        return da

    def add_sum_experiment(self, varname:str):
        """
        Creating a 'sum' object-like thing to be added to the other experiments.
        New 'sum' in self.experiments"""

        sum_models = {}

        #getting all models names
        for model_name in self.get_all_model_names():
            model_trends = []

            for exp_name, exp in self.experiments.items():
                if exp_name == 'historical':
                    continue

                if model_name in exp.models:
                    print(model_name)
                    ds = exp.models[model_name].calc_linear_trend_per_model(varname)
                    model_trends.append(ds['slope'])
                    if 'slope' not in ds:
                        print(f"{model_name} in {exp_name} has no slope")
                        continue                    

            if not model_trends:
                print(f"No trends found for {model_name}")
                continue

            # Align grids and sum slopes across experiments
            aligned = xr.align(*model_trends, join="outer")
            slope_sum = sum(a.fillna(0) for a in aligned)
            ds_sum = xr.Dataset({"slope": slope_sum})
            print(f"Slope sum for {model_name} computed")

            # Fake model object with required method
            sum_models[model_name] = SimpleNamespace(
                name=model_name,
                calc_linear_trend_per_model=lambda v, ds=ds_sum: ds
            )

        # Add new synthetic experiment
        self.experiments["sum"] = SimpleNamespace(
            name="sum",
            models=sum_models
        )

        
    def project_trend_EOF(self, varname: str, max_modes: int):
        """
        Projection of the linear trend onto the number of EOFs specified
        Will weight the trend to match the EOFs subspace (EOFs already weighted)
        If it already exists then can just upload the files
        Outputs:
            - Weighted linear trend
            - Projections for number of EOF modes specified
            - Residual (weighted linear trend - sum(projections))

        Returns the nested dict:
            experiment -> model -> xr.Dataset
            with variables: projection(projection_mode1, projection_mode2, ..., projection_modeN), residual, weighted_trend
        """
        delta = 165
        
        #this is basically selecting all models' EOFs for the historical and raising an error if not.
        if 'historical' in self.experiments:
            exp = self.experiments['historical']
            EOFs = exp.calc_EOF_concat_all_models(varname, max_modes)
        else:
            raise ValueError(f"historical experiment has not been found")

        all_results = {}
        
        for experiment in self.experiments.values():
            all_results[experiment.name] = {}
            #Now looping through all the experiments
            trend_dict_hpa = {
                model_name: trend_ds * delta / 100
                for model_name, trend_ds in experiment.calc_linear_trend_all_models(varname).items()
            }
            trend_dict = experiment.calc_linear_trend_all_models(varname)
            for model_name, trend_ds in trend_dict.items():
                #select the correct models EOFs and Trends and then cropping the trend (calculated for entire globe)
                eofs = EOFs[model_name]
                trend_var = trend_ds['slope']
                trend = self.crop_to_domain_all_exp(trend_var)

                #apply weights ONLY to the trend
                w = np.sqrt(np.cos(np.radians(trend['lat'])))
                w2d, _ = xr.broadcast(w, trend)
                trend_w2d = trend * w2d
    
                #stack the spatial dimensions (to solve the least squares)
                T = trend_w2d.stack(spatial=('lat', 'lon'))
                E = eofs.stack(spatial=('lat', 'lon')).transpose('spatial', 'mode')
    
                #now solve the least squares to get the coefficients, T (n_spatial,) E (n_spat, n_modes)
                T_vals = T.values
                E_vals = E.values
                c_vals, _, _, _ = np.linalg.lstsq(E_vals, T_vals, rcond=None)
                c = xr.DataArray(c_vals, dims=['mode'], coords={'mode': E.mode})


                #setting up so that there is a dict for projections
                projections = (c*eofs).transpose('mode', 'lat', 'lon')
                projections.name = "projection"

                #the total reconstrcution
                reconstructed = projections.sum(dim='mode')
                
                #finding the residual and adding it and the weighted trend to the nested dict
                residual = trend_w2d - reconstructed
                residual.name = "residual"

                #create the dataset
                proj_ds = xr.Dataset({
                    "projections": projections,
                    "residual": residual,
                    "weighted_trend": trend_w2d
                })

                #storing all the projections
                all_results[experiment.name][model_name] = proj_ds
                
                
        return all_results

    def calc_R2(self, varname: str, max_modes: int):

        all_results = self.project_trend_EOF(varname, max_modes)
        R2_vals = {}

        #remember this is like looping through objects so need to extract name via .name
        for experiment in self.experiments.values():
            R2_vals[experiment.name] = {}
            
            for model in experiment.models.values():
                ds = all_results[experiment.name][model.name]
                
                trend_w = ds['weighted_trend']
                residual = ds['residual']
                projections = ds['projections']

                total_var = (trend_w**2).sum().item()
                
                #creating R2 datastructure
                R2_vals[experiment.name][model.name] = {
                    "total": total_var,
                    "residual": float(( (residual**2).sum() / total_var).item()),
                    "projections": {}
                }

                for mode in projections.mode.values:
                    proj = projections.sel(mode=mode)
                    R2 = ( (proj**2).sum() / total_var ).item()
                    R2_vals[experiment.name][model.name]["projections"][int(mode)] = float(R2)

        return R2_vals

    def R2_plot(self, varname: str, max_modes: int):
        """
        Creates the R2 plot across all the experiments and models
        """
        R2 = self.calc_R2(varname, max_modes)
        print(R2)
        
        experiments = list(R2.keys())
    
        # Collect all unique models
        models = set()
        for exp_name in experiments:
            models.update(R2[exp_name].keys())
        models = sorted(models)
        
        # Marker styles for models
        marker_styles = ['o', 's', '^', 'D', 'v', 'P', '*', 'X', 'h', '1']
        model_markers = {model: marker_styles[i % len(marker_styles)] for i, model in enumerate(models)}
        
        nrows = max_modes + 1  # weighted trend + projections + residual
        fig, axes = plt.subplots(nrows=nrows, ncols=1, figsize=(10, 4*nrows), sharex=True)
    
        # Row labels
        row_labels = [f'Mode {i}' for i in range(max_modes)] + ['Residual']
    
        for i, row_label in enumerate(row_labels):
            ax = axes[i]
            for model in models:
                y_vals = []
                x_vals = []
                for j, exp_name in enumerate(experiments):
                    if model not in R2[exp_name]:
                        continue  # skip missing experiment for this model
                    model_dict = R2[exp_name][model]
                    if i == nrows-1:
                        y_val = model_dict['residual']
                    else:
                        y_val = model_dict['projections'].get(i, np.nan)
                    y_vals.append(y_val)
                    x_vals.append(j)
                if y_vals:
                    ax.scatter(x_vals, y_vals, marker=model_markers[model], s=100, label=model if i==0 else "")
            ax.set_ylabel("R²", fontsize=14)
            ax.set_title(row_label, fontsize=16)
    
        axes[-1].set_xticks(range(len(experiments)))
        axes[-1].set_xticklabels(experiments, rotation=45, ha='right', fontsize=12)
        axes[0].legend(loc='upper right', fontsize=12)
        fig.suptitle("R² for all models and experiments", fontsize=18)
        plt.tight_layout()
        plt.savefig('R2_all_models_and_exp.png')
        plt.show()

    def get_experiments_per_model(self, model:str):
        """
        Returns the dictionary of all experiments for a given model
        Take the model name in and returns dictionary
        """
        return {exp_name: exp_obj.models[model]
            for exp_name, exp_obj in self.experiments.items()
                if model in exp_obj.models
               }

    def get_all_model_names(self) -> list[str]:
        """
        Returns a list of models (list[str)]) across all experiments
        """
        model_names = ({
            model_name
            for exp_obj in self.experiments.values()
            for model_name in exp_obj.models.keys()
        })
        return sorted(model_names)

    def get_all_exp_names(self) -> list[str]:
        """
        Returns a list of models (list[str)]) across all experiments
        """
        exp_names = ({
            exp_name
            for exp_name in self.experiments.keys()
        })
        return sorted(exp_names)
                
    def projection_steps_plot(self, varname: str, max_modes: int):
        """
        Create the projection plot basically plots the following for each model and experiment
        - weighted trend
        - bit explained by each mode
        - the residual
        currently manually setting the max and min - could be good to have a speerate method doing this?...
        """

        #getting the data from all models and experiments (the maps and the R2 values)
        all_results = self.project_trend_EOF(varname, max_modes)
        R2 = self.calc_R2(varname, max_modes)
        
        cmap='seismic'
        norm_all = mcolors.TwoSlopeNorm(vmin=-2.25, vcenter=0, vmax=2.25)
        levels = np.arange(-2.25,2.5,0.05)
        lon_min, lon_max = self.lon_bounds
        lat_min, lat_max = self.lat_bounds

        #okay so getting the list of all model names across all experiments.
        #then finding the names/keys for the experimetns in the model
        #then plotting if they exist of do the missing step if not.
        
        exp_names = self.get_all_exp_names()
        
        for model_name in self.get_all_model_names():
            fig, ax = plt.subplots(
                nrows=max_modes + 2,
                ncols=len(exp_names),
                figsize=(28,15),
                subplot_kw={"projection": ccrs.PlateCarree()}
            )
            
            for j, exp_name in enumerate(self.get_all_exp_names()):
                try:
                    ds = all_results[exp_name][model_name]
                    R2_vals = R2[exp_name][model_name]

                    #getting the min and max of the weighted trend.
                    wt_min = float(ds['weighted_trend'].min())
                    wt_max = float(ds['weighted_trend'].max())
                    print(f"{model_name} - {exp_name} - weighted_trend: min={wt_min:.2f}, max={wt_max:.2f}")

                    #plotting the weighted trend if this experiment exists for the model
                    contour = ax[0,j].contourf(ds['lon'], ds['lat'], ds['weighted_trend'],
                        cmap=cmap, norm=norm_all,
                        levels=levels, transform=ccrs.PlateCarree()
                    )
                    ax[0,j].set_title(f"{exp_name}\n", fontsize=30)

                    #projection rows1...rown
                    for i, mode in enumerate(range(max_modes), start=1):
                        ds["projections"].sel(mode=mode).plot.contourf(
                            ax=ax[i, j],
                            cmap=cmap, norm=norm_all,
                            levels=levels, transform=ccrs.PlateCarree(),
                            add_colorbar=False
                        )
                        ax[i,j].set_title("")
                        ax[i,j].text(0.5, -0.05, f"r² = {R2_vals['projections'][i-1]:.2f}", transform=ax[i, j].transAxes, ha='center', va='top', fontsize=28)
                    
                    #residual (last row)
                    ds["residual"].plot.contourf(
                        ax=ax[max_modes + 1, j], cmap=cmap, norm=norm_all,
                        levels=levels, transform=ccrs.PlateCarree(),
                        add_colorbar=False
                    )
                    ax[3,j].text(0.5, -0.05, f"r² = {R2_vals['residual']:.2f}", transform=ax[3, j].transAxes, ha='center', va='top', fontsize=28)
                    ax[3,j].set_title("")
                        
                except KeyError:
                #if the experiment is missing
                    for i in range(max_modes + 2):
                        ax[i,j].set_title(f"{exp_name}")
                        ax[i,j].text(
                            0.5,0.5, "Missing",
                            ha="center", va="center",
                            fontsize=28, color="gray", transform=ax[i,j].transAxes
                        )
                        ax[i,j].set_xticks([])
                        ax[i,j].set_yticks([])
                        ax[i,j].coastlines
                # Add coastlines and formatting
                for i in range(4):
                    ax[i, j].set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree())
                    ax[i, j].add_feature(cfeature.COASTLINE, linewidth=1)
                    #ax[i, j].set_xlabel('lon', fontsize=28)
                    #ax[i, j].set_ylabel('lat', fontsize=28)
                    ax[i, j].set_aspect('auto')
                    if i == 0:
                        ax[0,j].set_title(f"{exp_name}\n", fontsize=30)
                    else:
                        ax[i,j].set_title("")
                    
                # Add colorbar
            ticks = [-2,-1,0,1,2]

            cax = fig.add_axes([0.92, 0.08, 0.015, 0.8])
            cbar = fig.colorbar(contour, cax=cax, orientation='vertical', ticks=ticks)
            cbar.set_label('MSLP (hPa)', fontsize=32)
            cbar.ax.tick_params(labelsize=30)

            # add y-axis labels only on the leftmost column
            #setting up the row labels
            row_labels = ["Weighted trend"] + [f"Mode {m}" for m in range(max_modes)] + ["Residual"]
            # Add row labels along the left-hand side using fig.text
            for i, label in enumerate(row_labels):
                # y coordinate is relative to the whole figure
                # compute vertical position from row index
                y = (ax[i, 0].get_position().y0 + ax[i, 0].get_position().y1) / 2
                fig.text(
                    0.1, y, label,
                    va="center", ha="center", rotation=90, fontsize=30
                )
        
            fig.suptitle(f"Projection steps for {model_name}", x=0.45, fontsize=32)
            plt.tight_layout
            for file_type in ['png','svg']:
                plt.savefig(f"Figures/Projection_plots_per_model/Projection_{model_name}.{file_type}", bbox_inches='tight')

                

In [35]:
all_data = AllDataComparisons(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/"),
                              Path("/gws/nopw/j04/extant/users/slbennie/")
)

#all_data.summary()
all_data.set_time_bounds_all(1850, 2014)
all_data.set_domain(lat_bounds=(20,80), lon_bounds=(-90,40))
all_data.set_season([12,1,2])

#all_results = all_data.calc_R2('psl', 2)
#R2 = all_results['historical']['HadGEM3-GC31-LL']
#all_data.project_trend_EOF("psl",2)
#all_data.calc_R2("psl",2)
#all_data.projection_steps_plot("psl",2)

#print(R2)
exp_name_CanESM5 = all_data.get_experiments_per_model('CanESM5')

print(exp_name_CanESM5)

all_data.projection_steps_plot('psl',2)
#all_data.add_sum_experiment('psl')

#print(all_data.experiments['sum'])

#all_data.projection_steps_plot('psl', 2)
#all_data.R2_plot('psl',2)
#historical = all_data.experiments['historical']
#EOF = historical.calc_EOF_concat_all_models(
#                    varname='psl',
#                    max_modes=2
#)

(20, 80) (-90, 40)
[12, 1, 2]
{'hist-GHG': <__main__.ModelCalculations object at 0x7f91c39b4dd0>, 'hist-aer': <__main__.ModelCalculations object at 0x7f91c2a45410>, 'hist-sol': <__main__.ModelCalculations object at 0x7f91cbe181d0>, 'hist-totalO3': <__main__.ModelCalculations object at 0x7f91c2b04c10>, 'hist-volc': <__main__.ModelCalculations object at 0x7f91c3904590>, 'historical': <__main__.ModelCalculations object at 0x7f91cc45b5d0>}
ACCESS-ESM1-5
loading the linear trend for model: ACCESS-ESM1-5
ACCESS-ESM1-5
loading the linear trend for model: ACCESS-ESM1-5
ACCESS-ESM1-5
loading the linear trend for model: ACCESS-ESM1-5
ACCESS-ESM1-5
loading the linear trend for model: ACCESS-ESM1-5
Slope sum for ACCESS-ESM1-5 computed
CMCC-CM2-SR5
loading the linear trend for model: CMCC-CM2-SR5
CMCC-CM2-SR5
loading the linear trend for model: CMCC-CM2-SR5
CMCC-CM2-SR5
loading the linear trend for model: CMCC-CM2-SR5
Slope sum for CMCC-CM2-SR5 computed
CanESM5
loading the linear trend for model: C

AttributeError: 'NoneType' object has no attribute 'items'

In [20]:
historical = ExperimentCalculations(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/"), Path("/gws/nopw/j04/extant/users/slbennie/"))
historical.set_time_bounds(1850, 2014)
historical.set_domain(lat_bounds=(20,80), lon_bounds=(-90,40))
historical.set_season([12,1,2])

linear_trend = historical.projection_all_models(
    varname='psl',
    max_modes=2
)



#historical_ens_mean = historical.calc_ensemble_mean_all_models('psl', Path("/gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/"))

# Compute anomalies for all ensemble members and models
#anomaly_results = historical.calc_anomalies_all_models(
#    varname="psl",
#    output_dir=Path("/gws/nopw/j04/extant/users/slbennie/psl_anomalies/historical/")
#)

#print(anomaly_results['CanESM5']['r10i1p1f1'].values)

#compute the EOF for all models

#HadGEM3 = historical.models['HadGEM3-GC31-LL']
#EOF = historical.calc_EOF_concat_all_models(
#    output_dir = Path('/gws/nopw/j04/extant/users/slbennie/EOF/'),
#                    varname='psl',
#                    max_modes=2
#)

AttributeError: 'ExperimentCalculations' object has no attribute 'set_time_bounds'

In [56]:
print(anomaly_results['CanESM5']['r10i1p1f1'])

<xarray.DataArray 'psl' (year: 165, lat: 25, lon: 53)> Size: 2MB
[218625 values with dtype=float64]
Coordinates:
  * lat      (lat) float64 200B 20.0 22.5 25.0 27.5 30.0 ... 72.5 75.0 77.5 80.0
  * lon      (lon) float64 424B -90.0 -87.5 -85.0 -82.5 ... 32.5 35.0 37.5 40.0
  * year     (year) int64 1kB 1850 1851 1852 1853 1854 ... 2011 2012 2013 2014
    season   <U3 12B ...


In [68]:
model = ModelCalculations(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/"))

model.set_time_bounds(1850, 2014)

member1 = next(iter(model.members.values()))

#so model.members is my dictionary of ensemble members, keys are member_ids values are EnsembleMemberCalculations objects
#"r1i1p1f1": <EnsembleMemberCalculations object>, ...
#the .values() returns all the values
#iter() lets you iterate throught the objects one by one
#next() grabs the firt item.

#EQUIVALENT
#first_key = list(model.members.keys())[0]
#member1 = model.members[first_key]

print(model.time_bounds)
print(member1.get_calendar_type())  # e.g., 'noleap'
print(member1.get_time_bounds())    # (start, end) as cftime objects

print(member1.select_data('psl'))

ens_mean = model.calc_ensemble_mean(Path("/gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_spatial_DJF_EM_1850-2015.nc"), 'psl')


(cftime.Datetime360Day(1850, 1, 16, 0, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2014, 12, 16, 0, 0, 0, 0, has_year_zero=True))
(cftime.Datetime360Day(1850, 1, 16, 0, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2014, 12, 16, 0, 0, 0, 0, has_year_zero=True))
360_day
<xarray.DataArray 'psl' (time: 1980, lat: 71, lon: 144)> Size: 162MB
[20243520 values with dtype=float64]
Coordinates:
  * time     (time) object 16kB 1850-01-16 00:00:00 ... 2014-12-16 00:00:00
  * lat      (lat) float64 568B -87.5 -85.0 -82.5 -80.0 ... 80.0 82.5 85.0 87.5
  * lon      (lon) float64 1kB -180.0 -177.5 -175.0 -172.5 ... 172.5 175.0 177.5
Attributes:
    standard_name:  air_pressure_at_mean_sea_level
    long_name:      Sea Level Pressure
    comment:        Sea Level Pressure
    units:          Pa
    original_name:  mo: (stash: m01s16i222, lbproc: 128)
    cell_methods:   area: time: mean
    cell_measures:  area: areacella
Loading the exisiting file for the spatial ensemble mean from: /gw

In [24]:
#idea is to create objects for all the ensembles within the folder
#within the class there will be a method to extract the member_id from the filename


folder = Path('/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/')

#collecting all the ensemble members in this folder - slightly tricky in that depends on this folder not really being altered?
ensemble_members = [
    EnsembleMemberCalculations(file)
    for file in folder.glob("*.nc")
]

print([member.member_id for member in ensemble_members])

#loading in the data
ds = ensemble_members[0].load_data()
ensemble_members[0].get_time_bounds('1850', '2014')

['r11i1p1f3', 'r12i1p1f3', 'r13i1p1f3', 'r14i1p1f3', 'r15i1p1f3', 'r16i1p1f3', 'r17i1p1f3', 'r18i1p1f3', 'r19i1p1f3', 'r1i1p1f3', 'r20i1p1f3', 'r21i1p1f3', 'r22i1p1f3', 'r23i1p1f3', 'r24i1p1f3', 'r25i1p1f3', 'r26i1p1f3', 'r27i1p1f3', 'r28i1p1f3', 'r29i1p1f3', 'r2i1p1f3', 'r30i1p1f3', 'r31i1p1f3', 'r32i1p1f3', 'r33i1p1f3', 'r34i1p1f3', 'r35i1p1f3', 'r36i1p1f3', 'r37i1p1f3', 'r38i1p1f3', 'r39i1p1f3', 'r3i1p1f3', 'r40i1p1f3', 'r41i1p1f3', 'r42i1p1f3', 'r43i1p1f3', 'r44i1p1f3', 'r45i1p1f3', 'r46i1p1f3', 'r47i1p1f3', 'r48i1p1f3', 'r49i1p1f3', 'r4i1p1f3', 'r50i1p1f3', 'r51i1p1f3', 'r52i1p1f3', 'r53i1p1f3', 'r54i1p1f3', 'r55i1p1f3', 'r56i1p1f3', 'r57i1p1f3', 'r58i1p1f3', 'r59i1p1f3', 'r5i1p1f3', 'r60i1p1f3']
1850


In [8]:
member1 = EnsembleMemberCalculations(
    member_id='1',
    data_path = Path('/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_r1i1p1f3_interp.nc')
)

ds = member1.load_data()

print(member1.member_id)
print(ds)


1
<xarray.Dataset> Size: 169MB
Dimensions:    (time: 1980, bnds: 2, lat: 71, lon: 144)
Coordinates:
  * time       (time) object 16kB 1850-01-16 00:00:00 ... 2014-12-16 00:00:00
  * lat        (lat) float64 568B -87.5 -85.0 -82.5 -80.0 ... 82.5 85.0 87.5
  * lon        (lon) float64 1kB -180.0 -177.5 -175.0 ... 172.5 175.0 177.5
Dimensions without coordinates: bnds
Data variables:
    time_bnds  (time, bnds) object 32kB ...
    lat_bnds   (time, lat, bnds) float64 2MB ...
    lon_bnds   (time, lon, bnds) float64 5MB ...
    psl        (time, lat, lon) float64 162MB ...
Attributes: (12/46)
    Conventions:            CF-1.7 CMIP-6.2
    activity_id:            CMIP
    branch_method:          standard
    branch_time_in_child:   0.0
    branch_time_in_parent:  0.0
    creation_date:          2019-06-19T12:06:35Z
    ...                     ...
    title:                  HadGEM3-GC31-LL output prepared for CMIP6
    variable_id:            psl
    variant_label:          r1i1p1f3
    li