In [2]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import os
import cftime
import pandas as pd
from datetime import datetime
import matplotlib.colors as mcolors
from scipy.stats import linregress
from eofs.xarray import Eof
from eofs.examples import example_data_path

In [5]:
#getting the correct calendar (model dependent.)
def get_time_bounds(calendar_type, start, end):
    #1850-2015 all of 2014 - none of 2015.
    if calendar_type == cftime.DatetimeNoLeap:
        return cftime.DatetimeNoLeap(start,1,16), cftime.DatetimeNoLeap(end,1,16)
    elif calendar_type == cftime.Datetime360Day:
        return cftime.Datetime360Day(start,1,16), cftime.Datetime360Day(end-1,12,16)
    else:
        return datetime(start,1,16), datetime(end,1,16)

#finding all the models that have ensembles for that experiment.
def get_models_for_experiment(experiment):
    if experiment == 'historical':
        model = ['ACCESS-ESM1-5','CanESM5','CMCC-CM2-SR5','FGOALS-g3','GISS-E2-1-G','HadGEM3-GC31-LL','IPSL-CM6A-LR','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
    elif experiment == 'hist-aer':
        model = ['ACCESS-ESM1-5','CanESM5','CMCC-CM2-SR5','CNRM-CM6-1','FGOALS-g3','GISS-E2-1-G','HadGEM3-GC31-LL','IPSL-CM6A-LR','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
    elif experiment == 'hist-GHG':
        model = ['ACCESS-ESM1-5','CanESM5','CMCC-CM2-SR5','CNRM-CM6-1','FGOALS-g3','GISS-E2-1-G','HadGEM3-GC31-LL','IPSL-CM6A-LR','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
    elif experiment == 'hist-sol':
        model = ['ACCESS-ESM1-5','CanESM5','GISS-E2-1-G','HadGEM3-GC31-LL','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
    elif experiment == 'hist-totalO3':
        model = ['CanESM5','GISS-E2-1-G','HadGEM3-GC31-LL','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
    elif experiment == 'hist-volc':
        model = ['ACCESS-ESM1-5','CanESM5','CMCC-CM2-SR5','GISS-E2-1-G','HadGEM3-GC31-LL','MIROC6','MPI-ESM1-2-LR','NorESM2-LM']
        
    return model

#Cropping CVDP data to the North Atlantic sector - requires some shifting of 0 of the lat lon coordinate system.
def CVDP_EM_crop_NA_sector(filename, pattern):
    #function which will crop the historical ensemble mean CVDP output to the NA sector
    ds = xr.open_dataset(filename)
    ds = ds[pattern]
    
    #finding the longitudes that are greater than 180
    new_lon = np.where(ds.longitude > 179, ds.longitude -360, ds.longitude)
    
    #creating a copy of the data array where the longitudes have been shifted
    ds_shifted = ds.copy()
    ds_shifted.coords['longitude'] = new_lon
    
    #Now need to make sure they are in the correct order and then re-index to make sure the lon get put to match the sorted lon
    sorted_lon = np.sort(ds_shifted.longitude)
    ds_shifted = ds_shifted.sel(longitude=sorted_lon)
    ds_shifted = ds_shifted.sortby('latitude')

    historical_NAO_EM_shifted = ds_shifted.sel(latitude=slice(20,80), longitude=slice(-90,40))

    return historical_NAO_EM_shifted

#Crops to the North Atlantic sector - for the LESFMIP data NOT processed by CVDP
def open_cropNA_unitshPA(filename):
    #function to crop an ensemble member to the north atlantic region
    data = xr.open_dataset(filename)
    data_NA = data.sel(lat=slice(20,80), lon=slice(-90,40))/100

    return data_NA

In [6]:
#calculating the seasonal and annual ensemble spatial means.

def calculate_spatial_ensemble_mean(file_paths, output_file, variable):
    #Will be passing through an experiment's model's ensembles.
    #opens all the files given by filepath (basically opens all the ensembles)
    ds = xr.open_mfdataset(file_paths, combine='nested', concat_dim='ensemble')

    #calculate the mean
    mean = ds[variable].mean(dim='ensemble')

    #save the ensemble mean to the a .nc file
    mean.to_netcdf(output_file)
    print('saved')

    ds.close()
    return mean

def calculate_seasonal_spatial_ensemble_mean_djf(file_path, var, seas, output_file, year_init, year_final):
    #opening dataset
    print('in function')
    ds = xr.open_dataset(file_path)
    
    #checking it is a datetime object
    ds['time'] = xr.decode_cf(ds).time

    calendar = type(ds.time.values[0])
    
    start,end = get_time_bounds(calendar, year_init, year_final)

    #selecting the psl variable within time bounds
    variable = ds[var].sel(time=slice(start, end))
    
    #Filter for the desired season (e.g., DJF)
    season_mask = variable.time.dt.season == seas
    ds_months_seas = variable.sel(time=season_mask)
    
    #assign and adjust year (DJF split over two years so increasing the year of december and then grouping and finding the mean)
    ds_months_seas = ds_months_seas.assign_coords(year=ds_months_seas['time'].dt.year)
    ds_months_seas['year'] = ds_months_seas['year'].where(ds_months_seas['time'].dt.month != 12, ds_months_seas['year'] + 1)
    #ds_months_seas = ds_months_seas.set_coords('year')
    
    # average over DJF months for each year
    ds_season = ds_months_seas.groupby('year').mean(dim='time')
    ds_season.to_netcdf(output_file)
    print('saved file')
    return ds_season

In [7]:
#functions defined for calculating the linear trend
def calculate_linear_trend_spat_pattern(file_path, variable, output_file):
    # Open dataset and extract variable
    ds = xr.open_dataset(file_path)
    da = ds[variable]

    time = ds['year'].values
    lat = ds['latitude'].values
    lon = ds['longitude'].values
    time_numeric = np.arange(len(time))

    slope = np.full((len(lat), len(lon)), np.nan)
    intercept = np.full((len(lat), len(lon)), np.nan)
    p_value = np.full((len(lat), len(lon)), np.nan)
    stderr = np.full((len(lat), len(lon)), np.nan)

    for i in range(len(lat)):
        for j in range(len(lon)):
            ts = da[:, i, j].values
            if np.all(np.isfinite(ts)):
                reg = linregress(time_numeric, ts)
                slope[i, j] = reg.slope
                intercept[i, j] = reg.intercept
                p_value[i, j] = reg.pvalue
                stderr[i, j] = reg.stderr

    from scipy.stats import t
    n = len(time_numeric)
    df = n - 2
    alpha = 0.05
    t_crit = t.ppf(1 - alpha/2, df)

    ci_lower = slope - t_crit * stderr
    ci_upper = slope + t_crit * stderr

    slope_da = xr.DataArray(slope, coords=[lat, lon], dims=["latitude", "longitude"], name="slope")
    intercept_da = xr.DataArray(intercept, coords=[lat, lon], dims=["latitude", "longitude"], name="intercept")
    p_value_da = xr.DataArray(p_value, coords=[lat, lon], dims=["latitude", "longitude"], name="p_value")
    ci_lower_da = xr.DataArray(ci_lower, coords=[lat, lon], dims=["latitude", "longitude"], name="slope_CI_lower")
    ci_upper_da = xr.DataArray(ci_upper, coords=[lat, lon], dims=["latitude", "longitude"], name="slope_CI_upper")

    # Save to one combined netCDF file
    combined_ds = xr.Dataset({
        "slope": slope_da,
        "intercept": intercept_da,
        "p_value": p_value_da,
        "slope_CI_lower": ci_lower_da,
        "slope_CI_upper": ci_upper_da
    })
    combined_ds.to_netcdf(output_file)





In [8]:
#function to calculate the regression map
def calculate_regression_map(anomalies, mode, m, period):
    #this will find the regression map and EOF pattern.
    #psl anomalies are linearly regressed onto the PC timeseries (the amount that the EOF's amplitude changes with time)

    # Support both single file and list of files
    if isinstance(anomalies, str):
        anomalies = [anomalies]
        
    #setting up output files paths for the projection and the residual
    output_regression_map = '/gws/nopw/j04/extant/users/slbennie/regression_patterns/'+mode+'/psl_mon_'+m+'_DJF_'+mode+'_regression_map_'+period+'.nc'
    output_EOF = '/gws/nopw/j04/extant/users/slbennie/regression_patterns/'+mode+'/psl_mon_'+m+'_DJF_'+mode+'_EOF_pattern_'+period+'.nc'


    #selecting the mode of which EOF to calculate
    if mode == 'NAO':
        mode_number = 0
    elif mode == 'EA':
        mode_number = 1

    #opening up all anomaly files and cropping to NA and converting into hPa (anomaly data is in Pa)
    print(anomalies)
    anomaly_list = [CVDP_EM_crop_NA_sector(f, 'msl') for f in anomalies]

    #selecting the psl data and concatenating list of data arrays.
    all_anomalies = xr.concat(anomaly_list, dim='ensemble')

    # If only one file and 'ensemble' not present, ensure it has the dimension
    if 'ensemble' not in all_anomalies.dims:
        all_anomalies = all_anomalies.expand_dims('ensemble')

    print(all_anomalies.dims)

    print("Shape:", all_anomalies.shape)
    print("Sizes:", all_anomalies.sizes)
    print("Coords:", list(all_anomalies.coords))

        
    #Flatten ensemble and year into one time dimension (needs to be called time for the pcs function to work later)
    # Stack the data into a new dimension
    all_anomalies_stacked = all_anomalies.stack(stacked_time=('ensemble', 'time'))

    # Make 'stacked_time' a coordinate
    all_anomalies_stacked = all_anomalies_stacked.reset_index('stacked_time')

    # Rename the coordinate to 'time' (safe now because it's not a dimension)
    all_anomalies_stacked = all_anomalies_stacked.rename({'stacked_time': 'time'})

    # Swap the new coordinate into the dimension
    all_anomalies_stacked = all_anomalies_stacked.swap_dims({'year': 'time'}) if 'year' in all_anomalies_stacked.dims else all_anomalies_stacked.swap_dims({'time': 'time'})

    # Optional: tag it as a time axis
    all_anomalies_stacked['time'].attrs['axis'] = 'T'

    # Rearranging dimensions
    all_anomalies_stacked = all_anomalies_stacked.transpose('time', 'latitude', 'longitude')




    #basically weighting so that each grid cell has influence actually proportional to its area
    coslat = np.cos(np.deg2rad(all_anomalies_stacked.coords['latitude'].values)).clip(0., 1.)
    wgts = np.sqrt(coslat)[..., np.newaxis]


    print("Shape of all_anomalies_stacked:", all_anomalies_stacked.shape)
    print("Number of NaNs:", np.isnan(all_anomalies_stacked).sum().item())
    print("Total number of elements:", all_anomalies_stacked.size)

    # Optional: show mean or min/max
    print("Min/Max:", all_anomalies_stacked.min().item(), all_anomalies_stacked.max().item())

    # Count grid points where all time values are NaN
    grid_mask = all_anomalies_stacked.isnull().all(dim='year')
    num_dead_gridpoints = grid_mask.sum().item()
    total_gridpoints = grid_mask.size

    print(f"Grid points with only NaNs over time: {num_dead_gridpoints} / {total_gridpoints}")

    
    #EOF solver
    solver = Eof(all_anomalies_stacked, weights=wgts)
    
    #finding the pattern of the EOF - unitless
    EOF_pattern = solver.eofs(neofs=mode_number+1).sel(mode=mode_number)
        
    #getting the EA Pattern's PC
    #using pcscaling=1 for a normalised PC. If not normalised need to divide by the variance of PC ((pc.std(dim='time'))**2) to find the regression map.
    pc = solver.pcs(npcs=mode_number+1, pcscaling=1).sel(mode=mode_number)
        
    #finding regression_map = pattern of psl anomalies regressed onto EA PC, kinda which bits of the trend link to this pattern, units of hPa/unit of PC
    #how psl anomalies change spatially for a one-unit change in the PC
    regression_map = (all_anomalies_stacked * pc).mean(dim='year')

    #making sure that the patterns match what they should for the NAO and EA patterns (basically fixing for sign conventions to make sure physical)
    if mode == 'NAO' and regression_map.sel(lat=50, lon=-30, method='nearest') < 0:
        regression_map *= -1
        pc *= -1

    if mode == 'NAO' and regression_map.sel(lat=50, lon=-25, method='nearest') > 0:
        regression_map *= -1
        pc *= -1
        
    #outputting .nc files for plotting

    
    regression_map.name = 'regression_'+mode+'_djf'
    regression_map.to_netcdf(output_regression_map)
    EOF_pattern.name = 'EOF_'+mode+'_djf'
    EOF_pattern.to_netcdf(output_EOF)
    

In [9]:
#functions defined for calculating the projections
def project_onto_regression(trend_raw, regression_map, trend_var, mode, m, period):
    #function which will project a trend (lat,lon) in hPa onto a spatial pattern (lat,lon) hPa to get a single NAO index value
    #will then calculate the residual (trend - mode_congruent part) and saves both the NAO congruent part and the residual
    #in the same folder, output_file is for the NAO/EA_congruent part filename
    #can then change the input spat pattern to calculate the projection onto other eofs, e.g. the EAP


    if isinstance(trend_raw, xr.DataArray):
        trend = trend_raw
    else:
        print('here')
        trend = trend_raw[trend_var]
        
    # Weight psl data by coslat to account for grid cell area decreasing with latitude
    weights = np.cos(np.radians(trend["latitude"].values))
    weights_2d = weights[:, np.newaxis]

    # weight psl (or another variable) anomalies by area of each gridcell
    weighted_trend = trend * weights_2d
    weighted_regression = regression_map * weights_2d

    # flatten both of the fields so that they are both 1D
    trend_flat = weighted_trend.stack(spatial=('latitude','longitude'))
    regression_flat = weighted_regression.stack(spatial=('latitude','longitude'))

    #replace any NaNs with zeros to stop any weird stuff happening
    trend_flat = trend_flat.fillna(0)
    regression_flat = regression_flat.fillna(0)

    #Now do the dot product which is the projection
    dot_product = (trend_flat * regression_flat).sum().item()

    #calculating the index - or I guess the PC?????
    index = dot_product / (regression_flat**2).sum().item()

    #Now multiplying the pattern by the index and returning that too
    projection = (index * regression_map)
    residual = trend - projection
    
    projection.name = 'projection_'+mode+'_djf'
    residual.name = 'residual_'+mode+'_djf'

    output_projection = '/gws/nopw/j04/extant/users/slbennie/projection_indicies/NAtlantic_forced_trends/'+m+'/psl_mon_'+m+'_DJF_'+mode+'_projection_'+period+'.nc'
    output_residual = '/gws/nopw/j04/extant/users/slbennie/projection_indicies/NAtlantic_forced_trends/'+m+'/psl_mon_'+m+'_DJF_'+mode+'_residual_'+period+'.nc'
    
    #outputting .nc files for plotting
    projection.to_netcdf(output_projection)
    residual.to_netcdf(output_residual)
    
    return projection, residual

In [24]:
era5 = xr.open_dataset('/gws/nopw/j04/leader_epesc/era5/era5.194001-202406.nc')
era5['msl']

In [31]:
#notes find a way to update period so its nicer.
#also could fix folder so that I use like a home dir etc.
#anomalies have already been calculated separetly for each model so preusming they are all correct this should run?...
home = '/gws/nopw/j04/extant/users/slbennie/'
era5 = '/gws/nopw/j04/leader_epesc/era5/era5.194001-202406.nc'
era5_djf_file = era5.replace('/gws/nopw/j04/leader_epesc', home)
era5_djf_file = era5_djf_file.replace('era5.', 'era5_djf.')

variable = 'msl'
period = '1850-2015'
modes = ['NAO', 'EA']
seas = 'DJF'
m = 'era5'

#calculate_regression_map(era5, 'NAO', 'ERA5', period)
#calculate_regression_map(era5, 'EA', 'ERA5', period)

#opening historical regression maps for use later - one per model. Use historical for all experiments.
#regression_NAO = xr.open_dataset(home+'regression_patterns/NAO/'+variable+'_mon_historical_'+m+'_'+seas+'_NAO_regression_map_'+period+'.nc')
#regression_EA = (xr.open_dataset(home+'regression_patterns/EA/'+variable+'_mon_historical_'+m+'_'+seas+'_EA_regression_map_'+period+'.nc'))

#No spatial ensemble mean - just one member for ERA5
print('Calculating the seasonal '+seas+' spatial ensemble mean')
calculate_seasonal_spatial_ensemble_mean_djf(era5, variable, seas, era5_djf_file, 1850, 2015)

#calculating the trend
era5_trend_file = home+'trend_calc_LESFMIP/linear_regression/NAO/'+m+'/'+variable+'_mon_'+m+'_'+seas+'_linear_trend_'+period+'_stats.nc'
calculate_linear_trend_spat_pattern(era5_djf_file, variable, era5_trend_file)

#cropping the trend
trend = open_cropNA_unitshPA(era5_trend_file)#, 1850,2014)
trend = trend * 165

#projecting
#proj_NAO, residual_NAO = project_onto_regression(trend, regression_NAO['regression_NAO_djf'], 'slope', 'NAO', m, period)
#proj_EA, residual_EA = project_onto_regression(residual_NAO, regression_EA['regression_EA_djf'], 'residual_NAO_djf', 'EA', m, period)

Calculating the seasonal DJF spatial ensemble mean
in function
saved file


NameError: name 'ens_files' is not defined

In [10]:
reg = xr.open_dataset('/gws/nopw/j04/extant/users/slbennie/trend_calc_LESFMIP/linear_regression/NAO/era5/msl_mon_era5_DJF_linear_trend_1850-2015_stats.nc')
#reg in Pa per year?
era5 = xr.open_dataset('/gws/nopw/j04/extant/users/slbennie/era5/era5_djf.194001-202406.nc')

print(reg['slope'],era5['msl'])

#reg['slope'].plot()
#plt.show()

diff = era5['msl'] - reg['slope']

#diff.isel(year=0).plot()

# Compute DJF anomaly for diff
diff_anom = diff - diff.mean(dim='year')

calculate_regression_map(diff_anom, 'NAO', 'ERA5', period)

<xarray.DataArray 'slope' (latitude: 721, longitude: 1440)> Size: 8MB
[1038240 values with dtype=float64]
Coordinates:
  * latitude   (latitude) float32 3kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * longitude  (longitude) float32 6kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8 <xarray.DataArray 'msl' (year: 76, latitude: 721, longitude: 1440)> Size: 631MB
[78906240 values with dtype=float64]
Coordinates:
  * longitude  (longitude) float32 6kB 0.0 0.25 0.5 0.75 ... 359.2 359.5 359.8
  * latitude   (latitude) float32 3kB 90.0 89.75 89.5 ... -89.5 -89.75 -90.0
  * year       (year) int64 608B 1940 1941 1942 1943 ... 2012 2013 2014 2015
Attributes:
    standard_name:  air_pressure_at_mean_sea_level
    long_name:      Mean sea level pressure
    units:          Pa


NameError: name 'period' is not defined