# Loading tools

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
Project = 'Africa_SON'
Model = 'UK'
Model2 = 'UK'
Month = 'SON'
month = '06'
Ens_Memb =7

prrr = f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/test-1_Predict.pt'
train = f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/train-1_Predict.pt'

trrr = f'/data/ERA5/temp/ERA_{Month}_Season.nc'
mrrr = f'/data/{Model}-ALL/out/{Model2}_{Month}_Season.nc'

out_trrr = f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/test_Target.pt'
out_mrrr = f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/test_Input.pt'
ensmemb =  f'/data/Paper_code_Out/{Project}/outs/{Model}/Predict_test_Ensemble1.nc'
Tar = f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_test_Ensemble1.nc'
if Project == 'Africa_SON':
    latitude_Range = range(60, 124)
    longitude_Range = range(64)
elif Project == 'US_JJA':
    latitude_Range = range(10, 74)
    longitude_Range = range(230, 294)


# Functions

In [None]:
import matplotlib.pyplot as plt
import xskillscore as xs
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.colors import ListedColormap

def plot_scores(DS1, DS2, DS3, RESULT1, RESULT2, RESULT3, world, vmax, vmin, cmap=None, metric=None):
    # Create the figure with gridspec for flexible layout
    fig = plt.figure(figsize=(16, 8), dpi=900)
    gs = fig.add_gridspec(2, 3, height_ratios=[3, 1], hspace=0.3, wspace=0.14)
    
    # Add subplots for the maps
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])
    
    # Add a subplot for the heatmap spanning the entire bottom row
    ax_heatmap = fig.add_subplot(gs[1, 0:3])
    
    # Plot 1: SeasonNet
    DS1 = DS1.assign_coords(longitude=((DS1.longitude + 180) % 360) - 180)
    plot1 = DS1.T.plot(ax=ax1, vmin=vmin, vmax=vmax, cmap=cmap, add_colorbar=False)
    ax1.set_title("")
    ax1.set_xlabel("")
    ax1.set_ylabel("")
    ax1.tick_params(labelsize=16)
    world.boundary.plot(ax=ax1, edgecolor='black')
    
    # Plot 2: LS
    DS2 = DS2.assign_coords(longitude=((DS2.longitude + 180) % 360) - 180)
    plot2 = DS2.plot(ax=ax2, vmin=vmin, vmax=vmax, cmap=cmap, add_colorbar=False)
    ax2.set_title("")
    ax2.set_xlabel("")
    ax2.set_ylabel("")
    ax2.tick_params(labelsize=16)
    world.boundary.plot(ax=ax2, edgecolor='black')
    
    # Plot 3: QM
    DS3 = DS3.assign_coords(longitude=((DS3.longitude + 180) % 360) - 180)
    plot3 = DS3.plot(ax=ax3, vmin=vmin, vmax=vmax, cmap=cmap, add_colorbar=False)
    ax3.set_title("")
    ax3.set_xlabel("")
    ax3.set_ylabel("")
    ax3.tick_params(labelsize=16)
    world.boundary.plot(ax=ax3, edgecolor='black')
    
    # Add shared colorbar for the maps
    cbar_ax1 = fig.add_axes([0.92, 0.45, 0.02, 0.35])
    cbar1 = fig.colorbar(plot3, cax=cbar_ax1, orientation='vertical')
    cbar1.set_label(metric, fontsize=14)
    cbar1.ax.tick_params(labelsize=14)
    
    # Merge the datasets
    df_combined = pd.merge(RESULT1, RESULT2, on='Acronym')
    df = pd.merge(df_combined, RESULT3, on='Acronym')
    df.set_index('Acronym', inplace=True)
    df = df.round(2)
    if metric in ['CRPSS', 'Kendall Corr']:
        for index, row in df.iterrows():
            if sum(row.iloc[0] == row.iloc[1:]) >= 1:  # Check if Col1 matches at least two other columns
                df.at[index, df.columns[0]] += 0.01  # Add 0.01 to Col1
        for index, row in df.iterrows():
            if sum(row.iloc[0] == row.iloc[1:]) >= 1:  # Check if Col1 matches at least two other columns
                df.at[index, df.columns[0]] += 0.01  # Add 0.01 to Col1        
    else:
        for index, row in df.iterrows():
            if sum(row.iloc[0] == row.iloc[1:]) >= 1:  # Check if Col1 matches at least two other columns
                df.at[index, df.columns[0]] += -0.01  # Add 0.01 to Col1 
        for index, row in df.iterrows():
            if sum(row.iloc[0] == row.iloc[1:]) >= 1:  # Check if Col1 matches at least two other columns
                df.at[index, df.columns[0]] += -0.01  # Add 0.01 to Col1 
    
    # Determine ranking order based on metric type
    ascending = False if metric in ['CRPSS', 'Kendall Corr'] else True
    rank_df = df.rank(axis=1, method='dense', ascending=ascending).T
    
    # Determine colormap based on unique rank values
    unique_ranks = rank_df.values.flatten()
    unique_ranks = np.unique(unique_ranks)
    
    if metric in ['CRPSS', 'Kendall Corr']:
        colors = ["#ece7f2", "#a6bddb", "#2b8cbe"]
    elif metric in ['BS', 'Bias (RMSE)']:
        colors = ["#fee8c8", "#fdbb84", "#e34a33"]
    else:
        colors = ["#fee8c8", "#fdbb84", "#e34a33"]
    
    if 3 not in unique_ranks:
        colors = colors[:2]  # Use only first two colors if there is no rank 3
    
    custom_cmap = ListedColormap(colors)
    # Create the heatmap with rankings
    sns.heatmap(
        rank_df,
        annot=df.T,  # Display real values
        fmt=".2f",  # Format real values
        cmap=custom_cmap,
        linewidths=0.5,
        annot_kws={"size": 12},
        cbar_kws={'label': metric + ' Rank', 'shrink': 0.8},
        ax=ax_heatmap
    )
    
    # Adjust heatmap labels
    ax_heatmap.set_title(metric, fontsize=18)
    ax_heatmap.set_xlabel("")
    ax_heatmap.set_ylabel("")
    ax_heatmap.set_xticklabels(ax_heatmap.get_xticklabels(), rotation=45, ha="right", fontsize=14)
    ax_heatmap.set_yticklabels(ax_heatmap.get_yticklabels(), fontsize=14)
    
    # Tighten layout and show
    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show()

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt

# Load the shapefile of global land borders
# world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
world = gpd.read_file('/data/IPCC/IPCC.shp')

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def plot_horizontal_heatmap(result1, result2, result3, Metric="Brier Score"):
    """
    Creates a horizontal heatmap with larger font sizes for labels, annotations, and colorbar.

    Parameters:
    - result1: DataFrame for 'Predict'
    - result2: DataFrame for 'LS'
    - result3: DataFrame for 'QM'
    - title: Title of the heatmap (default: "Heatmap of Predict, LS, and QM")
    """
    # Merge the results
    df_combined = pd.merge(result1, result2, on='Acronym')
    df = pd.merge(df_combined, result3, on='Acronym')
    df.set_index('Acronym', inplace=True)

    # Transpose the DataFrame to make regions (countries) the columns
    df_transposed = df.transpose()

    # Create the heatmap
    plt.figure(figsize=(12, 3))  # Adjusted for horizontal orientation
    ax = sns.heatmap(
        df_transposed,
        annot=True,
        fmt=".4f",
        cmap="plasma_r",
        linewidths=0.5,  # Adds slight gridlines for a cleaner look
        annot_kws={"size": 12},  # Adjust the font size of numbers
        cbar_kws={'label': 'Values', 'shrink': 0.8}  # Adjust colorbar size
    )

    # Increase colorbar label and tick font sizes
    cbar = ax.collections[0].colorbar
    cbar.set_label(Metric, fontsize=14)  # Increase colorbar label font size
    cbar.ax.tick_params(labelsize=12)  # Increase colorbar tick font size

    # Increase other font sizes
    plt.xlabel("IPCC Regions", fontsize=14)  # Larger x-axis label font size
#     plt.ylabel("Model", fontsize=14)  # Larger y-axis label font size
    plt.xticks(rotation=45, ha="right", fontsize=12)  # Rotate region labels and increase font size
    plt.yticks(fontsize=12)  # Increase font size for y-axis ticks

    # Tighten the layout
    plt.tight_layout()
    plt.show()


In [None]:
import geopandas as gpd
import pandas as pd
import xarray as xr
from shapely.geometry import Point
import numpy as np

def compute_country_averages(nc_file, VAR, shapefile=None):
    """
    Computes the average of a given variable from a NetCDF file within country boundaries,
    considering only values between the 10th and 90th percentile in each country.

    Parameters:
    - nc_file (str): Path to the NetCDF file.
    - VAR (str): Name of the variable to analyze.
    - shapefile (str, optional): Path to a shapefile containing country boundaries.

    Returns:
    - pd.DataFrame: DataFrame with country-wise averages, using 10-90 percentile filtering.
    """

    # Load the shapefile (default: global land borders)
    if shapefile:
        world = gpd.read_file(shapefile)
    else:
        # world = gpd.read_file('/data/home/acw720/IPCC/IPCC.shp')  # Adjust path as needed
        ipcc_regions = gpd.read_file('/data/IPCC/IPCC.shp')  # Replace with actual file path

        # List of regions to remove
        regions_to_remove = ["WCA", "SAS", "GIC", "CAR"]  # Update with actual region names
        
        # Filter out these regions
        world = ipcc_regions[~ipcc_regions["Acronym"].isin(regions_to_remove)]

    # Load NetCDF dataset and extract the required variable
    DS = nc_file

    # Extract latitude, longitude, and variable values
    lat_vals = DS.latitude.values
    lon_vals = DS.longitude.values
    values = DS.values  # Extract variable data

    # If data has a time dimension, compute the mean over time
    if values.ndim == 3:  
        values = np.nanmean(values, axis=0)  # Time-averaged values

    # Create coordinate grids
    lon_grid, lat_grid = np.meshgrid(lon_vals, lat_vals)

    # Flatten all arrays
    lon_flat, lat_flat, values_flat = lon_grid.ravel(), lat_grid.ravel(), values.ravel()

    # Remove NaN values
    valid_mask = ~np.isnan(values_flat)
    lon_flat, lat_flat, values_flat = lon_flat[valid_mask], lat_flat[valid_mask], values_flat[valid_mask]

    # Create a GeoDataFrame
    df = pd.DataFrame({"lon": lon_flat, "lat": lat_flat, "value": values_flat})
    geometry = [Point(xy) for xy in zip(df["lon"], df["lat"])]
    gdf = gpd.GeoDataFrame(df, geometry=geometry, crs="EPSG:4326")  # Set correct CRS

    # Ensure CRS Matches (transform if necessary)
    if gdf.crs != world.crs:
        gdf = gdf.to_crs(world.crs)

    # Perform Spatial Join (Assigning points to countries)
    joined = gpd.sjoin(gdf, world, how="inner", predicate="intersects")  # Use 'intersects' for better accuracy

    # **Apply 10-90 Percentile Filtering for Each Country**
    def filter_percentile(group):
        """Filter values within the 10-90 percentile range for each country."""
        low = np.percentile(group["value"], 1)
        high = np.percentile(group["value"], 99)
        return group[(group["value"] >= low) & (group["value"] <= high)]

    filtered = joined.groupby("Acronym", group_keys=False).apply(filter_percentile)

    # Compute country-wise mean after filtering
    country_avg = filtered.groupby("Acronym")["value"].mean(skipna=True).reset_index()
    country_avg.rename(columns={'value': VAR}, inplace=True)

    return country_avg  # Return only matched countries


# Quantile Mapping & Linear Scaling

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from cmethods import adjust
ds_full = xr.open_dataset('/data/Lands.nc')
# ds = ds_full.assign_coords(latitude=((ds_full.latitude*-1 + 89.5)))
# ds = ds.reindex(latitude=list(reversed(ds.latitude)))
# ds = ds_full.rename({'latitude': 'lat','longitude': 'lon'})
ds = ds_full.isel(latitude = latitude_Range, longitude = longitude_Range)
# ds = ds.where(ds.latitude>-60,np.nan)
ds.rename()
# Load datasets
model = xr.open_dataset(mrrr)
model = model.isel(latitude = latitude_Range, longitude = longitude_Range)

reference = xr.open_dataset(trrr)
reference = reference.isel(latitude = latitude_Range, longitude = longitude_Range)
# Initialize lists to store the modified data
List1 = []
List2 = []

# Iterate over the 24 years
for i in range(24):
    s1 = model.isel(season=i,time=range(90))
    s2 = reference.isel(season=i,time=range(90))
    
    # Define the start date for each season (assuming March 1st for the first year)
    start_date = pd.Timestamp('1993-03-01') + pd.DateOffset(years=i)
    # Generate a date range for MAM season (March, April, May)
    date_range = pd.date_range(start=start_date, periods=90, freq='D')
    
    # Assign the new date range to the 'time' coordinate
    s1 = s1.assign_coords(time=date_range)
    s2 = s2.assign_coords(time=date_range)
    
    # Append to lists
    List1.append(s1)
    List2.append(s2)

# Concatenate along time dimension
Mod_Cat = xr.concat(List1, dim='time')
Ref_Cat = xr.concat(List2, dim='time')

# Ensure time is interpreted as a datetime object after concatenation
Mod_Cat['time'] = pd.to_datetime(Mod_Cat['time'].values)
Ref_Cat['time'] = pd.to_datetime(Ref_Cat['time'].values)



variable = "t2m" # temperatures
List_LS = []
for num in range(Ens_Memb):
    obs = Ref_Cat.isel(time=range(0*90,20*90))
    simh = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    simp = Mod_Cat.isel(time=range(20*90,24*90),number=num)
    linear_scaling = adjust(
        method="linear_scaling",#"quantile_mapping",#linear_scaling#
        obs=obs[variable],
        simh=simh[variable],
        simp=simp[variable],
        n_quantiles=10,
        kind="+",
    )
    List_LS.append(linear_scaling)

linear_scaling_all = xr.concat(List_LS,dim='number')
# ############################################################
linear_scaling_all = linear_scaling_all*ds.mask
linear_scaling_all = linear_scaling_all.rename({"number": "member"})
##############################################################

variable = "t2m" # temperatures
List_QM = []
for num in range(Ens_Memb):
    obs = Ref_Cat.isel(time=range(0*90,20*90))
    simh = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    simp = Mod_Cat.isel(time=range(20*90,24*90),number=num)
    qm_adjusted = adjust(
        method="quantile_mapping",
        obs=obs[variable],
        simh=simh[variable],
        simp=simp[variable],
        n_quantiles=10,
        kind="+",
    )
    List_QM.append(qm_adjusted)

qm_adjusted_all = xr.concat(List_QM,dim='number')
# ############################################################
qm_adjusted_all = qm_adjusted_all*ds.mask
qm_adjusted_all = qm_adjusted_all.rename({"number": "member"})
##############################################################
##############################################################

variable = "t2m" # temperatures
List_LS_Full = []
for num in range(Ens_Memb):
    obs = Ref_Cat.isel(time=range(0*90,20*90))
    simh = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    simp = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    LS_adjusted = adjust(
        method="linear_scaling",#"quantile_mapping",#linear_scaling#
        obs=obs[variable],
        simh=simh[variable],
        simp=simp[variable],
        n_quantiles=10,
        kind="+",
    )
    List_LS_Full.append(LS_adjusted)

LS_Full = xr.concat(List_LS_Full,dim='number')
# ############################################################
LS_Full = LS_Full*ds.mask
LS_Full = LS_Full.rename({"number": "member"})
##############################################################
##############################################################
variable = "t2m" # temperatures
List_QM_Full = []
for num in range(Ens_Memb):
    obs = Ref_Cat.isel(time=range(0*90,20*90))
    simh = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    simp = Mod_Cat.isel(time=range(0*90,20*90),number=num)
    QM_adjusted = adjust(
        method="quantile_mapping",#"quantile_mapping",#linear_scaling#
        obs=obs[variable],
        simh=simh[variable],
        simp=simp[variable],
        n_quantiles=10,
        kind="+",
    )
    List_QM_Full.append(QM_adjusted)

QM_Full = xr.concat(List_QM_Full,dim='number')
# ############################################################
QM_Full = QM_Full*ds.mask
QM_Full = QM_Full.rename({"number": "member"})



# Ensemble Reconstruction

In [None]:
import pandas as pd
import xarray as xr
import numpy as np
import os
import torch
torch.device('cpu')
import torchvision.transforms.functional as TF

import os
file_path = f'/data/Paper_code_Out/{Project}/outs/{Model}/Predict_test_Ensemble1.nc'

try:
    os.remove(file_path)
    print(f"File '{file_path}' has been removed successfully.")
except FileNotFoundError:
    print(f"File '{file_path}' does not exist.")
except PermissionError:
    print(f"Permission denied while trying to delete '{file_path}'.")
except Exception as e:
    print(f"An error occurred: {e}")

    import os

file_path = f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_test_Ensemble1.nc'

try:
    os.remove(file_path)
    print(f"File '{file_path}' has been removed successfully.")
except FileNotFoundError:
    print(f"File '{file_path}' does not exist.")
except PermissionError:
    print(f"Permission denied while trying to delete '{file_path}'.")
except Exception as e:
    print(f"An error occurred: {e}")

    
    import os
file_path = f'/data/Paper_code_Out/{Project}/outs/{Model}/Train_test_Ensemble1.nc'

try:
    os.remove(file_path)
    print(f"File '{file_path}' has been removed successfully.")
except FileNotFoundError:
    print(f"File '{file_path}' does not exist.")
except PermissionError:
    print(f"Permission denied while trying to delete '{file_path}'.")
except Exception as e:
    print(f"An error occurred: {e}")

    import os

file_path = f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_train_test_Ensemble1.nc'

try:
    os.remove(file_path)
    print(f"File '{file_path}' has been removed successfully.")
except FileNotFoundError:
    print(f"File '{file_path}' does not exist.")
except PermissionError:
    print(f"Permission denied while trying to delete '{file_path}'.")
except Exception as e:
    print(f"An error occurred: {e}")
    
    

df = list(range(Ens_Memb))
def reshape_and_assign_time(dataarray):
    # Stack the 'batch' and 'time' dimensions into a single 'time' dimension
    reshaped_data = dataarray.stack(new_time=('season', 'time'))
    reshaped_data = reshaped_data.drop_vars([ 'season', 'time'])
    reshaped_data = reshaped_data.rename({'new_time':'time'})
    # Assign a new continuous time coordinate
    reshaped_data = reshaped_data.assign_coords(time=qm_adjusted.time)

    return reshaped_data
def reshape_and_assign_time2(dataarray):
    # Stack the 'batch' and 'time' dimensions into a single 'time' dimension
    reshaped_data = dataarray.stack(new_time=('season', 'time'))
    reshaped_data = reshaped_data.drop_vars([ 'season', 'time'])
    reshaped_data = reshaped_data.rename({'new_time':'time'})
    # Assign a new continuous time coordinate
    reshaped_data = reshaped_data.assign_coords(time=QM_Full.time)

    return reshaped_data

List1 = []
List2 = []
List3 = []
List4 = []
for ii in range(Ens_Memb):
#     /data/EECS-Theory/Clim_risk_Lab_Zahir_Rendani/output_thesis/Moji/MAM/Tensor_Out/test-0_Predict.pt
    predict = torch.load(f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/test-{df[ii]}_Predict.pt',map_location=torch.device('cpu'))  
    train = torch.load(f'/data/Paper_code_Out/{Project}/outs/{Model}/Tensor_Out/train-{df[ii]}_Predict.pt',map_location=torch.device('cpu'))  

    ###########################################################
    ############################################################
    Predict = xr.DataArray(predict, dims=('batch', 'time','lon', 'lat'))
    train = xr.DataArray(train, dims=('batch', 'time','lon', 'lat'))
    ############################################################
    ds_full = xr.open_dataset('/data/Lands.nc')
    # ds = ds_full.assign_coords(latitude=((ds_full.latitude*-1 + 89.5)))
    # ds = ds.reindex(latitude=list(reversed(ds.latitude)))
    ds = ds_full.rename({'latitude': 'lat','longitude': 'lon'})
    ds = ds.isel(lat = latitude_Range, lon = longitude_Range)
    # ds = ds.where(ds.lat>-60,np.nan)
    # ############################################################
    Predict = Predict.assign_coords(lat=ds.lat, lon= ds.lon)
    Predict = Predict*ds.mask
    train = train.assign_coords(lat=ds.lat, lon= ds.lon)
    train = train*ds.mask
    
    dss = xr.open_dataset(trrr)
    Target = dss.t2m.isel(season=range(20,24),time=range(90),latitude = latitude_Range, longitude = longitude_Range)
    Target = Target.rename({'latitude':'lat','longitude':'lon'})
    Target = Target*ds.mask

    Target_train = dss.t2m.isel(season=range(0,16),time=range(90),latitude = latitude_Range, longitude = longitude_Range)
    Target_train = Target_train.rename({'latitude':'lat','longitude':'lon'})
    Target_train = Target_train*ds.mask

    
    
    Predict = Predict.rename({'lat':'latitude','lon':'longitude','batch':'season'})
    Target = Target.rename({'lat':'latitude','lon':'longitude'})
    train = train.rename({'lat':'latitude','lon':'longitude','batch':'season'})
    Target_train = Target_train.rename({'lat':'latitude','lon':'longitude'})
    
    Predict = Predict.isel(time=range(90)).assign_coords(season=range(4))
    Target = Target.isel(time=range(90)).assign_coords(season=range(4))
    train = train.isel(time=range(90)).assign_coords(season=range(16))
    Target_train = Target_train.isel(time=range(90)).assign_coords(season=range(16))
    
    
    Predict_Season = Predict.drop_vars(['number'])
    Target_Season = Target.drop_vars(['number'])
    train_Season = train.drop_vars(['number'])
    Target_train_Season = Target_train.drop_vars(['number'])

    Predict_Time = reshape_and_assign_time(Predict_Season.isel(time=range(90)))
    Target_Time = reshape_and_assign_time(Target_Season.isel(time=range(90)))
    train_Time = reshape_and_assign_time2(train_Season.isel(time=range(90)))
    Target_train_Time = reshape_and_assign_time2(Target_train_Season.isel(time=range(90)))
    
    List1.append(Predict_Time)
    List2.append(train_Time)
    List3.append(Predict_Season)
    List4.append(train_Season)
Predict_Conc = xr.concat(List1,dim='number')
Train_Conc = xr.concat(List2,dim='number')
Predict_Season = xr.concat(List3,dim='number')
train_Season = xr.concat(List4,dim='number')
# Target_Conc = xr.concat(List3,dim='number')

Predict_Conc.to_netcdf(f'/data//Paper_code_Out/{Project}/outs/{Model}/Predict_test_Ensemble1.nc')
Target_Time.to_netcdf(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_test_Ensemble1.nc')
Train_Conc.to_netcdf(f'/data/Paper_code_Out/{Project}/outs/{Model}/Train_test_Ensemble1.nc')
Target_train_Time.to_netcdf(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_train_test_Ensemble1.nc')


In [None]:
Predict = xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Predict_test_Ensemble1.nc')
Target =xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_test_Ensemble1.nc')
Train = xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Train_test_Ensemble1.nc')
Target_train =xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_train_test_Ensemble1.nc')
Predict.__xarray_dataarray_variable__.shape,Target.__xarray_dataarray_variable__.shape,Train.__xarray_dataarray_variable__.shape,Target_train.__xarray_dataarray_variable__.shape,

# Read Data

In [None]:
Predict = xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Predict_test_Ensemble1.nc')
Target =xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_test_Ensemble1.nc')
Train = xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Train_test_Ensemble1.nc')
Target_train =xr.open_dataset(f'/data/Paper_code_Out/{Project}/outs/{Model}/Target_train_test_Ensemble1.nc')


new_channel_names = np.arange(1, len(Predict.number)+1)
Predict['number'] = new_channel_names
Train['number'] = new_channel_names
###########################################################
# Get the first variable name
first_variable_name = list(Predict.data_vars.keys())[0]
# Define the new name
new_variable_name = 't2m'
# Rename the variable
Predict = Predict.rename({first_variable_name: new_variable_name})
Predict = Predict.t2m.transpose('longitude','latitude','time','number')
Predict = Predict.rename({"number": "member"})
first_variable_name = list(Target.data_vars.keys())[0]
# Define the new name
new_variable_name = 't2m'
# Rename the variable
Target = Target.rename({first_variable_name: new_variable_name})
Target = Target.t2m
###########################################################
# Get the first variable name
first_variable_name = list(Train.data_vars.keys())[0]
# Define the new name
new_variable_name = 't2m'
# Rename the variable
Train = Train.rename({first_variable_name: new_variable_name})
Train = Train.t2m.transpose('longitude','latitude','time','number')
Train = Train.rename({"number": "member"})
first_variable_name = list(Target_train.data_vars.keys())[0]
# Define the new name
new_variable_name = 't2m'
# Rename the variable
Target_train = Target_train.rename({first_variable_name: new_variable_name})
Target_train = Target_train.t2m

# Probabilistic Analysis

In [None]:
import matplotlib.pyplot as plt
import xskillscore as xs
import numpy as np
qnt = 0.66

DS1 = xs.brier_score(Target > Target_train.quantile(qnt, dim='time'),
                     (Predict > Target_train.quantile(qnt, dim='time')),
                     dim=['time'], fair=True)
DS1 = DS1.where(DS1 != 0, np.nan)
DS1 = DS1.assign_coords(longitude=((DS1.longitude + 180) % 360) - 180)


# Plot 2: LS
DS2 = xs.brier_score(Target > Target_train.quantile(qnt, dim='time'),
                     (linear_scaling_all.t2m > Target_train.quantile(qnt, dim='time')),
                     dim=['time'], fair=True)
DS2 = DS2.where(DS2 != 0, np.nan)
DS2 = DS2.assign_coords(longitude=((DS2.longitude + 180) % 360) - 180)


# Plot 3: QM
DS3 = xs.brier_score(Target > Target_train.quantile(qnt, dim='time'),
                     (qm_adjusted_all.t2m > Target_train.quantile(qnt, dim='time')),
                     dim=['time'], fair=True)
DS3 = DS3.where(DS3 != 0, np.nan)
DS3 = DS3.assign_coords(longitude=((DS3.longitude + 180) % 360) - 180)

DS6 = xs.brier_score(Target_train > Target_train.quantile(qnt, dim='time'),
                     (QM_Full['t2m'] > Target_train.quantile(qnt, dim='time')),
                     dim=['time'], fair=True)
DS6 = DS6.where(DS6 != 0, np.nan)
DS6 = DS6.assign_coords(longitude=((DS6.longitude + 180) % 360) - 180)

# Copute Brier Skill Score (BSS)
DS1 = 1-(DS1/DS6)
DS2 = 1-(DS2/DS6)
DS3 = 1-(DS3/DS6)
result1 = compute_country_averages(DS1.T,'SN')
result2 = compute_country_averages(DS2,'LS')
result3 = compute_country_averages(DS3,'QM')
plot_scores(DS1, DS2, DS3, result1, result2, result3, world,vmin=0,vmax=0.5, cmap='PRGn',metric = 'BS')

In [None]:
import matplotlib.pyplot as plt
import xskillscore as xs
import numpy as np
                 dim=['time'])


DS1 = xs.crps_ensemble(Target ,
                         Predict,
                         dim=['time'])
DS1 = DS1.where(DS1 != 0, np.nan)
DS1 = DS1.assign_coords(longitude=((DS1.longitude + 180) % 360) - 180)


# Plot 2: QM
DS2 = xs.crps_ensemble(Target ,
                         linear_scaling_all.t2m ,
                         dim=['time'])
DS2 = DS2.where(DS2 != 0, np.nan)
DS2 = DS2.assign_coords(longitude=((DS2.longitude + 180) % 360) - 180)


# Plot 3: QM
DS3 = xs.crps_ensemble(Target ,
                         qm_adjusted_all.t2m ,
                         dim=['time'])
DS3 = DS3.where(DS3 != 0, np.nan)
DS3 = DS3.assign_coords(longitude=((DS3.longitude + 180) % 360) - 180)

DS6 = xs.crps_ensemble(Target_train ,
                         QM_Full['t2m'] ,
                         dim=['time'])
DS6 = DS6.where(DS6 != 0, np.nan)
DS6 = DS6.assign_coords(longitude=((DS6.longitude + 180) % 360) - 180)

# Copute CRPSS 
DS1 = 1-(DS1/DS6)
DS2 = 1-(DS2/DS6)
DS3 = 1-(DS3/DS6)
result1 = compute_country_averages(DS1.T,'SN')
result2 = compute_country_averages(DS2,'LS')
result3 = compute_country_averages(DS3,'QM')
plot_scores(DS1, DS2, DS3, result1, result2, result3, world,vmin=-0.5,vmax=0.5, cmap='PRGn',metric = 'CRPSS')

# Unit Conversion

In [None]:
import geopandas as gpd
import rasterio
from rasterio import features
import xarray as xr
import pandas as pd
import numpy as np
import xclim as xc

predtemp = Predict-273.15
predtemp = predtemp
predtemp.attrs['units'] = "degC"

Train = Train-273.15
Train.attrs['units'] = "degC"

Targettemp2 =Target_train-273.15
Targettemp2.attrs['units'] = "degC"

Targettemp = Target-273.15
Targettemp.attrs['units'] = "degC"

lstemp = linear_scaling_all-273.15
lstemp = lstemp.t2m
lstemp.attrs['units'] = "degC"

qmtemp = qm_adjusted_all-273.15
qmtemp = qmtemp.t2m
qmtemp.attrs['units'] = "degC"

QM_Full = QM_Full-273.15
QM_Full = QM_Full.t2m
QM_Full.attrs['units'] = "degC"



# Deterministic Analysis

In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

# Ensure mean operations are done correctly
EnsAvg_pred = predtemp.mean('member').transpose('latitude', 'longitude', 'time')
EnsAvg_ls = lstemp.mean('member').transpose('latitude', 'longitude', 'time')
EnsAvg_qm = qmtemp.mean('member').transpose('latitude', 'longitude', 'time')
TarAvg = Targettemp.transpose('latitude', 'longitude', 'time')

# Function to calculate RMSE
def calculate_rmse(target, prediction):
    """
    Calculate RMSE along the time dimension for spatial data.
    Assumes the time dimension is the last one.
    """
    rmse = np.sqrt(((target - prediction) ** 2).mean(axis=-1))
    return rmse

# Calculate RMSE for each method
rmse_pred = calculate_rmse(TarAvg, EnsAvg_pred)
rmse_ls = calculate_rmse(TarAvg, EnsAvg_ls)
rmse_qm = calculate_rmse(TarAvg, EnsAvg_qm)

rmse_pred = xr.DataArray(rmse_pred, dims=['latitude', 'longitude'], coords={'latitude': TarAvg.latitude, 'longitude': TarAvg.longitude})
rmse_pred = rmse_pred.assign_coords(longitude=((rmse_pred.longitude + 180) % 360) - 180)
rmse_ls = xr.DataArray(rmse_ls, dims=['latitude', 'longitude'], coords={'latitude': TarAvg.latitude, 'longitude': TarAvg.longitude})
rmse_ls = rmse_ls.assign_coords(longitude=((rmse_ls.longitude + 180) % 360) - 180)
rmse_qm = xr.DataArray(rmse_qm, dims=['latitude', 'longitude'], coords={'latitude': TarAvg.latitude, 'longitude': TarAvg.longitude})
rmse_qm = rmse_qm.assign_coords(longitude=((rmse_qm.longitude + 180) % 360) - 180)


result1 = compute_country_averages(rmse_pred,'SN')
result2 = compute_country_averages(rmse_ls,'LS')
result3 = compute_country_averages(rmse_qm,'QM')
plot_scores(rmse_pred.T, rmse_ls, rmse_qm, result1, result2, result3, world,vmin=0,vmax=10, cmap='plasma_r',metric = 'Bias (RMSE)')

In [None]:
import numpy as np
import xarray as xr
from scipy.signal import detrend

# Function to detrend 1D slices with NaN handling
def detrend_with_nan(array):
    """
    Detrend a 1D array along the time dimension, handling NaN values.
    """
    mask = np.isfinite(array)  # Identify valid (non-NaN) values
    if not np.any(mask):
        return array  # If all values are NaN, return the input

    # Temporarily fill NaN values with the mean of valid data
    filled_array = np.where(mask, array, np.nanmean(array))
    
    # Apply detrending
    detrended = detrend(filled_array, axis=0)
    
    # Restore NaN values
    detrended[~mask] = np.nan
    return detrended

# Apply the detrending to the DataArray
def detrend_dataarray(dataarray):
    """
    Detrend an xarray.DataArray along the 'time' dimension.
    """
    return xr.apply_ufunc(
        detrend_with_nan,
        dataarray,
        input_core_dims=[["time"]],
        output_core_dims=[["time"]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[dataarray.dtype],
    )
import xarray as xr
from scipy.stats import kendalltau

def kendall_rank_corr(da1, da2):
    """
    Compute Kendall rank correlation and p-values between two xarray.DataArray objects.
    
    Parameters:
        da1 (xarray.DataArray): First data array.
        da2 (xarray.DataArray): Second data array.
    
    Returns:
        corr (xarray.DataArray): Kendall rank correlation coefficient.
        p_values (xarray.DataArray): P-values for significance.
    """
    # Ensure data arrays have the same dimensions
    assert da1.shape == da2.shape, "Data arrays must have the same shape"
    
    # Get dimensions
    dims = da1.dims
    lat_dim, lon_dim, time_dim = dims[1], dims[0], dims[2]  # Assuming this order
    
    # Print debug information
#     print("Latitude values:", da1[lat_dim].values)
#     print("Longitude values:", da1[lon_dim].values)
    
    # Initialize empty arrays to store correlation values and p-values
    corr_values = xr.full_like(da1.isel(time=0), fill_value=float('nan'))
    p_values = xr.full_like(da1.isel(time=0), fill_value=float('nan'))
    
    # Iterate over all latitude and longitude indices
    for lat_idx in da1[lat_dim].values:
        for lon_idx in da1[lon_dim].values:
            # Attempt to select data at the current latitude and longitude
            try:
                da1_slice = da1.sel({lat_dim: lat_idx, lon_dim: lon_idx}, method='nearest')
                da2_slice = da2.sel({lat_dim: lat_idx, lon_dim: lon_idx}, method='nearest')
                
                # Ensure the slices have the same time dimension length
                if da1_slice.time.size == da2_slice.time.size:
                    # Compute Kendall rank correlation and p-value
                    tau, p_val = kendalltau(da1_slice.values.flatten(), da2_slice.values.flatten())
                    
                    # Store correlation value and p-value
                    corr_values.loc[{lat_dim: lat_idx, lon_dim: lon_idx}] = tau
                    p_values.loc[{lat_dim: lat_idx, lon_dim: lon_idx}] = p_val
            except KeyError as e:
                print(1)
    
    # Assign appropriate coordinates and dimensions
    corr_values = corr_values.assign_coords({lat_dim: da1[lat_dim], lon_dim: da1[lon_dim]})
    p_values = p_values.assign_coords({lat_dim: da1[lat_dim], lon_dim: da1[lon_dim]})
    
    corr_values.name = 'kendall_rank_corr'
    p_values.name = 'p_value'
    
    return corr_values, p_values

EnsAvg_pred = predtemp.mean('member').transpose('latitude', 'longitude', 'time')
EnsAvg_ls = lstemp.mean('member').transpose('latitude', 'longitude', 'time')
EnsAvg_qm = qmtemp.mean('member').transpose('latitude', 'longitude', 'time')
TarAvg = Targettemp.transpose('latitude', 'longitude', 'time')

EnsAvg_pred_det = detrend_dataarray(EnsAvg_pred)
EnsAvg_ls_det = detrend_dataarray(EnsAvg_ls)
EnsAvg_qm_det = detrend_dataarray(EnsAvg_qm)
TarAvg_det = detrend_dataarray(TarAvg)


Pred_Kendal, Pred_pval = kendall_rank_corr(EnsAvg_pred_det,TarAvg_det)
Pred_Kendal = Pred_Kendal.assign_coords(longitude=((Pred_Kendal.longitude + 180) % 360) - 180)

LS_Kendal, LS_pval = kendall_rank_corr(EnsAvg_ls_det,TarAvg_det)
LS_Kendal = LS_Kendal.assign_coords(longitude=((LS_Kendal.longitude + 180) % 360) - 180)

QM_Kendal, QM_pval = kendall_rank_corr(EnsAvg_qm_det,TarAvg_det)
QM_Kendal = QM_Kendal.assign_coords(longitude=((QM_Kendal.longitude + 180) % 360) - 180)

result1 = compute_country_averages(Pred_Kendal,'SN')
result2 = compute_country_averages(LS_Kendal,'LS')
result3 = compute_country_averages(QM_Kendal,'QM')
plot_scores(Pred_Kendal.T, LS_Kendal, QM_Kendal, result1, result2, result3, world,vmin=-0.3,vmax=0.3, cmap='BrBG',metric = 'Kendall Corr')


# Impact Based Analysis Indices

In [None]:
from xclim.core.calendar import percentile_doy
from xclim.indices import warm_spell_duration_index
import matplotlib.pyplot as plt
qnt = 90
qntp = 0.9
Window = 3
tasmax2 =Targettemp2
years = [0,4,8,12]
years2 = list(range(0, 61, 4))
List1 = []
List2 = []
List3 = []
List4 = []
List5 = []
List6 = []
List7 = []
List8 = []
List9 = []
List10 = []
List11 = []
List12 = []
for i in range(Ens_Memb):
    #seasonnet
    tasmax = predtemp.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_seaosnnet = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_seaosnnet = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')
    
    #seasonnet train
    tasmax = Train.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_Train = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_Train = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')

    #linear scaling
    tasmax = lstemp.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_ls = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_ls = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')

    #quantile mapping
    tasmax = qmtemp.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_qm = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_qm = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')
    
    
    #linear scaling Full
    tasmax = LS_Full.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_LS = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_LS = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')

    #quantile mapping Full
    tasmax = QM_Full.T.isel(member=i)
    tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
    Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
    Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
    Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
    Qunt_max.attrs['units'] = "degC"
    WSDI_QM = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
    WDF_QM = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')


    List1.append(WSDI_seaosnnet.isel(time=years,quantile=0))
    List2.append(WDF_seaosnnet.isel(time=years))
    
    List3.append(WSDI_ls.isel(time=years,quantile=0))
    List4.append(WDF_ls.isel(time=years))
    
    List5.append(WSDI_qm.isel(time=years,quantile=0))
    List6.append(WDF_qm.isel(time=years))
    
    List7.append(WSDI_Train.isel(time=years2,quantile=0))
    List8.append(WDF_Train.isel(time=years2))

    
    List9.append(WSDI_LS.isel(time=years,quantile=0))
    List10.append(WDF_LS.isel(time=years))
    
    List11.append(WSDI_QM.isel(time=years2,quantile=0))
    List12.append(WDF_QM.isel(time=years2))
    
    
Heat1 = xr.concat(List1,dim='member')
Heat2 = xr.concat(List2,dim='member')
Heat3 = xr.concat(List3,dim='member')
Heat4 = xr.concat(List4,dim='member')
Heat5 = xr.concat(List5,dim='member')
Heat6 = xr.concat(List6,dim='member')
Heat7 = xr.concat(List7,dim='member')
Heat8 = xr.concat(List8,dim='member')
Heat9 = xr.concat(List9,dim='member')
Heat10 = xr.concat(List10,dim='member')
Heat11 = xr.concat(List11,dim='member')
Heat12 = xr.concat(List12,dim='member')

#Target Data
tasmax = Targettemp
tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
Qunt_max.attrs['units'] = "degC"
WSDI_tar = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
WDF_tar = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')
# WDF_tar.coords['longitude'] = np.mod(WDF_tar.coords['longitude'] + 180, 360) - 180
WSDI_tar = WSDI_tar.isel(time=years,quantile=0)
WDF_tar = WDF_tar.isel(time=years)

#Full Target Data
tasmax = Targettemp2
tasmin_q = xc.core.calendar.percentile_doy(tasmax,per=qnt).sel(percentiles=qnt)
Qunt_max = tasmax2.quantile(q=[qntp], dim='time')
Qunt_max = Qunt_max.expand_dims(dim={"dayofyear": 90})
Qunt_max = Qunt_max.assign_coords(dayofyear=tasmin_q.dayofyear[:-1])
Qunt_max.attrs['units'] = "degC"
WSDI_Ftar = xc.indices.warm_spell_duration_index(tasmax, Qunt_max, window=Window, freq='1MS')
WDF_Ftar = xc.indices.warm_day_frequency(tasmax, thresh='35 degC', freq='1MS')
# WDF_Ftar.coords['longitude'] = np.mod(WDF_tar.coords['longitude'] + 180, 360) - 180
WSDI_Ftar = WSDI_Ftar.isel(time=years2,quantile=0)
WDF_Ftar = WDF_Ftar.isel(time=years2)


In [None]:
import matplotlib.pyplot as plt
import xskillscore as xs
import numpy as np

# Create the figure and subplots for Warm days Frequency 

# Plot 1: Season-Net
DS1 = xs.rmse(WDF_tar.drop_vars("time") ,WDF_seaosnnet.isel(time=years).drop_vars("time"),dim=['time'])
DS1 = DS1.where(~np.isnan(Targettemp2.isel(time=0).transpose()), np.nan)
DS1 = DS1.assign_coords(longitude=((DS1.longitude + 180) % 360) - 180)


# Plot 2: LS
DS2 = xs.rmse(WDF_tar.drop_vars("time") ,WDF_ls.isel(time=years).drop_vars("time"),dim=['time'])
DS2 = DS2.where(~np.isnan(Targettemp2.isel(time=0)), np.nan)
DS2 = DS2.assign_coords(longitude=((DS2.longitude + 180) % 360) - 180)


# Plot 3: QM
DS3 = xs.rmse(WDF_tar.drop_vars("time") ,WDF_qm.isel(time=years).drop_vars("time"),dim=['time'])
DS3 = DS3.where(~np.isnan(Targettemp2.isel(time=0)), np.nan)
DS3 = DS3.assign_coords(longitude=((DS3.longitude + 180) % 360) - 180)

result1 = compute_country_averages(DS1,'SN')
result2 = compute_country_averages(DS2,'LS')
result3 = compute_country_averages(DS3,'QM')
plot_scores(DS1.T, DS2, DS3, result1, result2, result3, world,vmin=0,vmax=10, cmap='plasma_r',metric = 'RMSE0')



In [None]:
import matplotlib.pyplot as plt
import xskillscore as xs
import numpy as np

# Create the figure and subplots for Warm Spell Duration Index 

# Plot 1: Season-Net
DS1 = xs.rmse(WSDI_tar.drop_vars("time") ,WSDI_seaosnnet.isel(time=years,quantile=0).drop_vars("time"),dim=['time'])
DS1 = DS1.where(~np.isnan(Targettemp2.isel(time=0).transpose()), np.nan)
DS1 = DS1.assign_coords(longitude=((DS1.longitude + 180) % 360) - 180)


# Plot 2: LS
DS2 = xs.rmse(WSDI_tar.drop_vars("time") ,WSDI_ls.isel(time=years,quantile=0).drop_vars("time"),dim=['time'])
DS2 = DS2.where(~np.isnan(Targettemp2.isel(time=0)), np.nan)
DS2 = DS2.assign_coords(longitude=((DS2.longitude + 180) % 360) - 180)


# Plot 3: QM
DS3 = xs.rmse(WSDI_tar.drop_vars("time") ,WSDI_qm.isel(time=years,quantile=0).drop_vars("time"),dim=['time'])
DS3 = DS3.where(~np.isnan(Targettemp2.isel(time=0)), np.nan)
DS3 = DS3.assign_coords(longitude=((DS3.longitude + 180) % 360) - 180)

result1 = compute_country_averages(DS1,'SN')
result2 = compute_country_averages(DS2,'LS')
result3 = compute_country_averages(DS3,'QM')
plot_scores(DS1.T, DS2, DS3, result1, result2, result3, world,vmin=0,vmax=10, cmap='plasma_r',metric = 'RMSE0')

