In [2]:
import xarray as xr
import numpy as np
from eofs.xarray import Eof


def calc_EOF_per_ensemble(self,max_modes: int):
    """
    Want to calculate the EOF for each ensmeble member per model per experiment (for projection just historical EOF NOT SF!!)
    Will do this by literally just using the psl_anomalies (in Pa) and then calling the seasonal mean method
    Will also check whether this data already exists or not and load if it does
    Bear in mind - the time coordinate is YEAR after seasonal means so need to change back to TIME
    EOF0 = EOF unitless, pc in hPa (equiv to pcscaling=0)
    EOF1 = the normalised EOF where EOF in hPa, PCs are unitless and (pc/sqrt(eigenvals)) and (EOF*sqrt(eigenvals))  (equiv to pcscaling=1)
    Finds the EOFs for the number of modes specified
    """
    output_file_EOF = self.all_experiments.output_dir / f"EOF/{self.experiment.name}/{self.all_experiments.season_str}/normalised_aligned/{self.model.name}/psl_mon_{self.experiment.name}_{self.model.name}_{self.member_id}_{self.all_experiments.season_str}_EOF_indiv_{self.model.time_bounds[3]}.nc"

    #if the EOF already exists then just load that in
    if output_file_EOF.exists():
        print(f"Loading the EOF for ensemble member {self.member_id} in model {self.model.name} and {self.experiment.name}")
        return xr.open_dataset(output_file_EOF)

    #calulate the EOF, PC and regression map from seasonal anomalies and convert to hPa
    print('calculating the EOF')
    anomaly = self.calc_seas_anomaly().rename({'year': 'time'}) / 100

    #transpose the anomalies
    anomaly_trans = anomaly.transpose('time', 'lat', 'lon')

    #calculating the weights sqrt(cos(lat))
    coslat = np.cos(np.deg2rad(anomaly_trans.coords['lat'].values)).clip(0., 1.)
    wgts = np.sqrt(coslat)[...,np.newaxis]

    #making the solver from the Eof package
    solver = Eof(anomaly_trans, weights=wgts)

    #calculating EOF0 and EOF1 and the eigenvalues
    EOF0 = solver.eofs(neofs=max_modes)
    eigs = solver.eigenvalues(neigs=max_modes)
    EOF1 = EOF0 * np.sqrt(eigs.values)[:, np.newaxis, np.newaxis]

    #calculating the PCs and the variance explained by the PC
    PCs = solver.pcs(npcs=max_modes, pcscaling=1).transpose('mode', 'time')
    var_frac = solver.varianceFraction(neigs=max_modes)

    #could also find the regression map (first demean time)
    #Basically converting the PCs back into a map with units of hPa/unit PC
    anomaly_trans_demeaned = anomaly_trans - anomaly_trans.mean(dim='time')
    PCs_demeaned = PCs - PCs.mean(dim='time')

    reg_maps = []

    #for each leading mode calc the reg maps
    for n in range(max_modes):
        cov = (anomaly_trans_demeaned * PCs_demeaned.isel(mode=n)).mean(dim='time') * wgts
        var = (PCs_demeaned.isel(mode=n)**2).mean(dim='time')
        reg_map = cov / var
        reg_map = reg_map.expand_dims(mode=[n])
        reg_maps.append(reg_map)

    regressions = xr.concat(reg_maps, dim='mode')
    regressions.name = 'regressions'
    #regression_map = (anomaly_trans * PC_demeaned).mean(dim='time')*wgts

    #Putting it all in one dataset - I only will use EOF1 from now on? hopefully?...
    ds = xr.Dataset({
        'eofs' : EOF1,
        'PCs' : PCs,
        'regressions' : regressions,
        #'var_frac' : (['mode'], var_frac)
    })

    #some metadata
    ds.attrs.update({
        'model': self.model.name,
        'experiment': self.experiment.name,
        'member_id' : self.member_id,
        'description': ('EOF, PC, and regression map per ensemble member.\n'
                       'EOFs are scaled by sqrt(eigenvalue) so that PCs are unitless and EOFs are in hPa')
    })

    ds.to_netcdf(output_file_EOF)
    print(f'The fraction of variance explained in model {self.model.name} and member {self.member_id}: {var_frac}')

    print(ds)
    return ds