In [1]:
from pathlib import Path
import xarray as xr
import numpy as np
import re
import cftime
from datetime import datetime
from dask.diagnostics import ProgressBar

In [10]:
class EnsembleMemberCalculations:
    def __init__(self, data_path: Path, model=None):
        #use self. to access attributes
        self.data_path = Path(data_path)

        #extracting the member_id (ensemble number) from the filename as named in CMIP6 convention
        match = re.search(r"(r\d+i\d+p\d+f\d+)", self.data_path.stem)
        if not match:
            raise ValueError(f"Could not extract the member_id from filenmae: {self.data_path.name}")

        #selecting the ensemble number and assigning it to the member ID
        self.member_id = match.group(1)

        #referencing the parent model
        self.model = model
        self.data = None

    def load_data(self):
        #loads the data of an ensemble member
        self.data = xr.open_dataset(self.data_path)
        return self.data

    def get_calendar_type(self):
        if self.model is None:
            raise RuntimeError("Model is not set")
        return self.model.time_bounds

    def get_time_bounds(self):
        """
        Access the model's time_bounds - each model has a different calendar thus type of time to index/slice by
        """
        if self.model is None or self.model.time_bounds is None:
            raise RuntimeError("Model or time_bounds is not set")
        return self.model.calendar_type

    def select_data(self, varname: str):
        """
        Just selecting some data from the time bounds in model
        """
        ds = self.load_data()
        start, end = self.model.time_bounds
        ds = ds[varname].sel(time=slice(start, end))
        
        return ds

    def cal_seasonal_mean(self)

        

class ModelCalculations:
    def __init__(self, folder: Path):
        self.folder = Path(folder)

        #can now set the model name using the folder name
        self.name = self.folder.name
        print(f"Model name set to {self.name}")

        #creating my dictionary of ensemble members
        self.members = {}
        for i, file in enumerate(self.folder.glob("*.nc")):
            if i >= 3:
                break
            member = EnsembleMemberCalculations(file, model=self)
            self.members[member.member_id] = member
        
        #basically I pick the first member (load in just one)
        #doing this to then use to find the calendar type
        #and then find the time bounds as this will be the same time (correct datetime) per model
        member1 = next(iter(self.members.values()))
        member1.load_data()

        #find the calendar type
        time_var = member1.data["time"]
        self.calendar_type = time_var.encoding.get("calendar", "standard")
        self.time_bounds = None

    def set_time_bounds(self, start_year: int, end_year: int):
        #selecting the correct datetime format for the chosen time period
        if self.calendar_type == "noleap":
            start = cftime.DatetimeNoLeap(start_year,1,16)
            end = cftime.DatetimeNoLeap(end_year,12,16)
        elif self.calendar_type == "360_day":
            start = cftime.Datetime360Day(start_year,1,16)
            end = cftime.Datetime360Day(end_year,12,16)
        else:
            start = datetime(start_year,1,16)
            end = datetime(end_year,12,16)

        self.time_bounds = (start, end)
        return self.time_bounds

    def calc_ensemble_mean_spatial(self, output_file: Path, varname: str):
        """
        Calculate the SPATIAL ensemble mean for a variable
        Restrict it to the time bounds
        Could in future restrict to an area
        If the file exists load this instead
        """

        if output_file.exists():
            print(f"Loading the exisiting file for the spatial ensemble mean from: {output_file}")
            ens_mean = xr.open_dataset(output_file)[varname]
            return ens_mean
        
        #Get all the filepaths from the ensemble members
        #this line basically is getting the filepaths of the ensemble member class objects - i think?
        file_paths = [member.data_path for member in self.members.values()]

        #then opening them all together
        ds = xr.open_mfdataset(
            file_paths,
            combine="nested",
            concat_dim="member",
            parallel=True,
            chunks={"member": 1}
        )[varname]

        #selecting the time period
        start,end = self.time_bounds
        ds = ds.sel(time=slice(start,end))

        #calculating the ensemble mean across the members
        ens_mean = ds.mean(dim="member")
    
        # Compute the result with a progress bar
        print("Computing the ensemble mean...")
        with ProgressBar():
            ens_mean.compute()#.to_netcdf(output_file)

        #ens_mean.to_netcdf(output_file)
        print(f'Ensemble mean NOT saved to {output_file}')

        return ens_mean


class ExperimentCalculations:
    def __init__(self, folder: Path):
        """
        Initialise an experiment object
        Each folder inside this experiment is a model
        """

        self.folder = Path(folder)

        #this is the dictionary of ModelCalculations objects
        self.models = {}

        #going through all the sub folders within this experiment folder
        #The creating a ModelCalculations object for models folders
        for model_folder in self.folder.iterdir():
            if model_folder.is_dir():
                model = ModelCalculations(model_folder)
                self.models[model.name] = model

        print(f"loaded in {len(self.models)} models for experiment '{self.folder.name}'")

    def set_time_bounds(self, start_year: int, end_year:int):
        """
        Set the time bounds for all the models in this experiment
        """

        for model in self.models.values():
            model.set_time_bounds(start_year, end_year)

    def calc_ensemble_mean_spatial_all_models(self, varname: str, output_dir: Path):
        """
        Calculate the ensemble meban for each model within this experiment
        Save results to an output dir
        """

        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)

        results = {}
        for model_name, model in self.models.items():
            #to be changed - maybe create a method for naming the files?????
            output_file = output_dir / f"{model_name}_{varname}_ensemble_mean.nc"

            #
            ens_mean = model.calc_ensemble_mean_spatial(output_file, varname)
            results[model_name] = ens_mean

        return results
        

In [14]:
historical = ExperimentCalculations(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/"))
historical.set_time_bounds(1850, 2014)
historical_ens_mean = historical.calc_ensemble_mean_spatial_all_models('psl', Path("/gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/"))


Model name set to ACCESS-ESM1-5
Model name set to CMCC-CM2-SR5
Model name set to CanESM5
Model name set to FGOALS-g3
Model name set to GISS-E2-1-G
Model name set to HadGEM3-GC31-LL
Model name set to IPSL-CM6A-LR
Model name set to MIROC6
Model name set to MPI-ESM1-2-LR
Model name set to NorESM2-LM
loaded in 10 models for experiment 'historical'
Computing the ensemble mean...
[#############                           ] | 32% Completed | 42.59 ss


KeyboardInterrupt: 

In [6]:
model = ModelCalculations(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/"))
model.set_time_bounds(1850, 2014)
ens_mean = model.calc_ensemble_mean_spatial(Path("/gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_spatial_DJF_EM_1850-2015.nc"), 'psl')
ens_mean

Model name set to HadGEM3-GC31-LL
Loading the exisiting file for the spatial ensemble mean from: /gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_spatial_DJF_EM_1850-2015.nc


In [68]:
model = ModelCalculations(Path("/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/"))

model.set_time_bounds(1850, 2014)

member1 = next(iter(model.members.values()))

#so model.members is my dictionary of ensemble members, keys are member_ids values are EnsembleMemberCalculations objects
#"r1i1p1f1": <EnsembleMemberCalculations object>, ...
#the .values() returns all the values
#iter() lets you iterate throught the objects one by one
#next() grabs the firt item.

#EQUIVALENT
#first_key = list(model.members.keys())[0]
#member1 = model.members[first_key]

print(model.time_bounds)
print(member1.get_calendar_type())  # e.g., 'noleap'
print(member1.get_time_bounds())    # (start, end) as cftime objects

print(member1.select_data('psl'))

ens_mean = model.calc_ensemble_mean_spatial(Path("/gws/nopw/j04/extant/users/slbennie/ens_mean_spat/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_spatial_DJF_EM_1850-2015.nc"), 'psl')


(cftime.Datetime360Day(1850, 1, 16, 0, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2014, 12, 16, 0, 0, 0, 0, has_year_zero=True))
(cftime.Datetime360Day(1850, 1, 16, 0, 0, 0, 0, has_year_zero=True), cftime.Datetime360Day(2014, 12, 16, 0, 0, 0, 0, has_year_zero=True))
360_day
<xarray.DataArray 'psl' (time: 1980, lat: 71, lon: 144)> Size: 162MB
[20243520 values with dtype=float64]
Coordinates:
  * time     (time) object 16kB 1850-01-16 00:00:00 ... 2014-12-16 00:00:00
  * lat      (lat) float64 568B -87.5 -85.0 -82.5 -80.0 ... 80.0 82.5 85.0 87.5
  * lon      (lon) float64 1kB -180.0 -177.5 -175.0 -172.5 ... 172.5 175.0 177.5
Attributes:
    standard_name:  air_pressure_at_mean_sea_level
    long_name:      Sea Level Pressure
    comment:        Sea Level Pressure
    units:          Pa
    original_name:  mo: (stash: m01s16i222, lbproc: 128)
    cell_methods:   area: time: mean
    cell_measures:  area: areacella
Loading the exisiting file for the spatial ensemble mean from: /gw

In [24]:
#idea is to create objects for all the ensembles within the folder
#within the class there will be a method to extract the member_id from the filename


folder = Path('/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/')

#collecting all the ensemble members in this folder - slightly tricky in that depends on this folder not really being altered?
ensemble_members = [
    EnsembleMemberCalculations(file)
    for file in folder.glob("*.nc")
]

print([member.member_id for member in ensemble_members])

#loading in the data
ds = ensemble_members[0].load_data()
ensemble_members[0].get_time_bounds('1850', '2014')

['r11i1p1f3', 'r12i1p1f3', 'r13i1p1f3', 'r14i1p1f3', 'r15i1p1f3', 'r16i1p1f3', 'r17i1p1f3', 'r18i1p1f3', 'r19i1p1f3', 'r1i1p1f3', 'r20i1p1f3', 'r21i1p1f3', 'r22i1p1f3', 'r23i1p1f3', 'r24i1p1f3', 'r25i1p1f3', 'r26i1p1f3', 'r27i1p1f3', 'r28i1p1f3', 'r29i1p1f3', 'r2i1p1f3', 'r30i1p1f3', 'r31i1p1f3', 'r32i1p1f3', 'r33i1p1f3', 'r34i1p1f3', 'r35i1p1f3', 'r36i1p1f3', 'r37i1p1f3', 'r38i1p1f3', 'r39i1p1f3', 'r3i1p1f3', 'r40i1p1f3', 'r41i1p1f3', 'r42i1p1f3', 'r43i1p1f3', 'r44i1p1f3', 'r45i1p1f3', 'r46i1p1f3', 'r47i1p1f3', 'r48i1p1f3', 'r49i1p1f3', 'r4i1p1f3', 'r50i1p1f3', 'r51i1p1f3', 'r52i1p1f3', 'r53i1p1f3', 'r54i1p1f3', 'r55i1p1f3', 'r56i1p1f3', 'r57i1p1f3', 'r58i1p1f3', 'r59i1p1f3', 'r5i1p1f3', 'r60i1p1f3']
1850


In [8]:
member1 = EnsembleMemberCalculations(
    member_id='1',
    data_path = Path('/gws/nopw/j04/leader_epesc/CMIP6_SinglForcHistSimul/InterpolatedFlds/psl/historical/HadGEM3-GC31-LL/psl_mon_historical_HadGEM3-GC31-LL_r1i1p1f3_interp.nc')
)

ds = member1.load_data()

print(member1.member_id)
print(ds)


1
<xarray.Dataset> Size: 169MB
Dimensions:    (time: 1980, bnds: 2, lat: 71, lon: 144)
Coordinates:
  * time       (time) object 16kB 1850-01-16 00:00:00 ... 2014-12-16 00:00:00
  * lat        (lat) float64 568B -87.5 -85.0 -82.5 -80.0 ... 82.5 85.0 87.5
  * lon        (lon) float64 1kB -180.0 -177.5 -175.0 ... 172.5 175.0 177.5
Dimensions without coordinates: bnds
Data variables:
    time_bnds  (time, bnds) object 32kB ...
    lat_bnds   (time, lat, bnds) float64 2MB ...
    lon_bnds   (time, lon, bnds) float64 5MB ...
    psl        (time, lat, lon) float64 162MB ...
Attributes: (12/46)
    Conventions:            CF-1.7 CMIP-6.2
    activity_id:            CMIP
    branch_method:          standard
    branch_time_in_child:   0.0
    branch_time_in_parent:  0.0
    creation_date:          2019-06-19T12:06:35Z
    ...                     ...
    title:                  HadGEM3-GC31-LL output prepared for CMIP6
    variable_id:            psl
    variant_label:          r1i1p1f3
    li

In [None]:
class EnsembleMemberProcessing:
    def __init__(self, member_id: str, data_path:Path):
        #use self. to access the attributes
        self.member_id = member_id
        self.datat_path = Path(data_path)
        self.data = None
        self.eofs = None

    def load_data(self):
        #loading the LESFMIP data in for each ensemble member
        #each object has a path attribute
        self.data = xr.open_dataset(self.data_path)
        return self.data

    def calculate_EOF(self, domain='North Atlantic'):
        #will add in code to calculate the EOFs
        #for now just a test
        #domain would be passed through.
        print(f"calculating the EOF for the domain: {domain} for {self.member_id")
        #self.eofs = xr.Dataset()
        #EOF calculation
        #return self.eofs

    def save_eofs(self, path: Path):
        if self.eofs is not None:
            path.parent.mkdir(parents=True, exist_ok=True)
            self.eofs.to_netcdf(path)
            print(f"saved EOFs for {self.member_id} to {path}")

class EnsembleMemberProcessed:
    def __init__(self, member_id: str, eofs: xr.Dataset):
        self.member_id = member_id
        self.eofs = eofs

    @classmethod
    def from_file(cls, member_id: str, eofs_path: Path):
        eofs = xr.open_dataset(eofs_path)
        print(f"Loaded EOFs for {member_id} from {eofs_path}")
        return cls(member_id=member_id, eofs=eofs)


class ModelCalculations:
    def __init__(self, model_name: str, ensemble_members_processing: dict = None, ensemble_members_processed: dict = None, eof_mode='NAO'):
        self.model_name = model_name
        self.eof_mode = eof_mode

        #Making sue the code doesn't calculate new and load in old.
        if ensemble_members_processing is not None and ensemble_members_processed is not None:
            raise ValueError("Provide either ensemble_members_processing OR ensemble_members_processed - NOT both!")

        if ensemble_members_processing is not None:
            self.mode = 'raw'
            self.ensemble_members_processing = ensemble_members_processing
            #so now it calculates the things I need, well hopefully....
            #mid is just ensemble member id
            #computing the EOFs
            #results in a dictionary mapping the member ID to the processed EOFs
            self.ensemble_members_processed = {}

        elif ensemble_members_processed is not None:
            #basically just stores the data given
            self.mode = 'processed'
            self.ensemble_members_processed = ensemble_members_processed
            self.ensemble_members_processing = None
        else:
            raise ValueError('Need at least either ensemble_members_processing OR ensemble_members_processed')


    # ------------------
    # For raw members: calculate EOFs and optionally save
    # ------------------
    def calculate_raw_eofs(self, save_dir: Path = None):
        if self.mode != "raw":
            raise RuntimeError("This method is only valid for raw ensemble members.")
        for mid, member in self.ensemble_members_processing.items():
            #member.load_data()
            #eofs = member.calculate_eofs(self.eof_mode)
            #self.ensemble_members_processing[mid] = EnsembleMemberProcessed(mid, eofs)
            #if save_dir is not None:
            #    save_path = save_dir / f"{mid}_{self.eof_mode}_eofs.nc"
            #    member.save_eofs(save_path)
        print("All raw EOFs calculated and processed members created.")

    # ------------------
    # For processed members: load EOFs from .nc
    # ------------------
    @classmethod
    def from_processed_files(cls, model_name: str, file_paths: dict, eof_mode="NAO"):
        processed_members = {}
        for mid, path in file_paths.items():
            processed_members[mid] = EnsembleMemberProcessed.from_file(mid, Path(path))
        return cls(model_name, processed_ensemble_members=processed_members, eof_mode=eof_mode)
        