In [None]:
import xarray as xr
import numpy as np
import warnings
import matplotlib.pyplot as plt
import pandas as pd

warnings.filterwarnings("ignore")
plt.style.use("default")
from pathlib import Path
import cftime
import os, sys
import seaborn as sns
import cartopy
import cartopy.feature as cpf
from global_land_mask import globe
import CMIP6_light_map
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from shapely.geometry import box, mapping
import geopandas as gpd
import rioxarray
from matplotlib import cm
import cartopy.feature as cpf
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import texttable
from tqdm.notebook import trange, tqdm

sys.path.append("../CMIP6_downscale/")
from CMIP6_ridgeplot import CMIP6_ridgeplot
from xclim import ensembles
from CMIP6_IO import CMIP6_IO
from CMIP6_config import Config_albedo

In [None]:
def convert_to_180(ds):
    ds = ds.assign_coords(lat=ds.y)
    return (ds.assign_coords(lon=(((ds.x + 180) % 360) - 180))).sortby("lon")


def convert_time(ds):
    if not ds.indexes["time"].dtype in ["datetime64[ns]"]:
        time_objects = ds.indexes["time"].to_datetimeindex()
        ds = ds.assign_coords({"time": time_objects})
        ds = xr.decode_cf(ds)

    return ds


def get_area_averaged_ds(
    fname,
    model,
    scenario,
    ensemble_id,
    var_name,
    LME,
    create_maps,
    frequency,
    models_dict,
    fname2=None,
):
    if os.path.exists(fname):
        if var_name not in ["velocity"]:
            with xr.open_dataset(fname) as ds:
                ds = convert_to_180(ds)

                ds = ds.sel(
                    time=slice(start_time, end_time)
                )  # .sel(lat=slice(min_lat,max_lat), lon=slice(min_lon,max_lon))
                ds = convert_time(ds)
                if var_name in ["tas"]:
                    ds = xr.where(ds > 100, ds - 273.15, ds)

                # Convert from kg/m-3 to mg/m-3
                if var_name in ["chl"]:
                    ds[var_name] = ds[var_name] / 1.0e6

                ds_lme = get_data_within_LME(ds, var_name, LME, create_maps)

                outfile = "Figures/{}_ensemble_{}_{}_ridgeplot.png".format(
                    var_name.capitalize(), scenario, LME
                )
                print(f"Creating ridgeplot: {outfile}")

                #  CMIP6_ridgeplot.ridgeplot(var_name,
                #                            None, outfile,
                #                            glorys=False, depth_threshold=None,
                #                            ds=ds_lme)

                ds = ds_lme.mean({"lat", "lon"})
                ds = ds.assign(TOZ_std=ds_lme[var_name].std({"lat", "lon"}))
                df = ds.to_dataframe().dropna()

                df = df.reset_index()

        else:
            with xr.open_mfdataset([fname, fname2]) as ds:
                ds = convert_to_180(ds)
                ds = ds.sel(
                    time=slice(start_time, end_time)
                )  # .sel(lat=slice(min_lat,max_lat), lon=slice(min_lon,max_lon))
                ds = convert_time(ds)
                ds_uas = get_data_within_LME(ds, "uas", LME, create_maps)
                ds_vas = get_data_within_LME(ds, "vas", LME, create_maps)
                ds_uas = ds_uas.mean({"lat", "lon"})
                ds_vas = ds_vas.mean({"lat", "lon"})
                df = ds_uas.to_dataframe().dropna()
                df2 = ds_vas.to_dataframe().dropna()

                df = df.reset_index()
                df2 = df2.reset_index()

                df["velocity"] = np.sqrt(
                    np.power(df["uas"], 2) + np.power(df2["vas"], 2)
                )
                df = df.reset_index()
                df.drop(columns=["uas", "vas"], inplace=True)
        df = df.set_index("time")

        df = df.resample(frequency).mean()
        df["model_name"] = model
        df["LME"] = LME
        df["roll_mean"] = df[var_name].rolling(5).mean().shift(-1)
        df["roll_std"] = df[{var_name}].rolling(5).std().shift(-1)
        df["roll_median"] = df[var_name].rolling(5).median().shift(-1)
        df["roll_max"] = df[var_name].rolling(5).max().shift(-1)
        df["roll_min"] = df[var_name].rolling(5).min().shift(-1)

        df["model_ensemble_id"] = ensemble_id
        df["model_scenario"] = scenario
        unique = "{}_{}_{}".format(model, scenario, ensemble_id)
        df["unique"] = unique

        model_info = {}
        model_info["model_name"] = model
        model_info["model_scenario"] = scenario
        model_info["model_ensemble_id"] = ensemble_id
        model_info["model_var"] = var_name
        model_info["LME"] = LME
        key = "{}_{}_{}_{}".format(model, ensemble_id, scenario, var_name)
        if var_name in ["TOZ"]:
            key = fname

        formatter = "{:.2f}"
        #   model_info["model_min"]=formatter.format(np.nanmin(df[var_name]))
        #   model_info["model_max"]=formatter.format(np.nanmax(df[var_name]))

        models_dict[key] = model_info

        return df, models_dict, None
    else:
        return None, models_dict, None

In [None]:
def get_LME_records():
    lme_file = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/Shapefiles/LME66/LMEs66.shp"
    return gpd.read_file(lme_file)


def get_LME_records_plot():
    lme_file = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/Shapefiles/LME66_180/LME66_180.shp"
    return gpd.read_file(lme_file)


def create_colors(N):
    color = iter(cm.tab20b(np.linspace(0, 1, N)))
    return [next(color) for c in range(N)]


def create_map(df, title, var_name, period, anomalies=False, details=False):
    if details is True:
        lonmin = -165
        lonmax = -143.5
        latmin = 53.5
        latmax = 65.0
        res = "10m"
    else:
        lonmin = -252
        lonmax = -100.5
        latmin = 20
        latmax = 80
        res = "50m"
    ax = plt.figure(figsize=(16, 10)).gca(
        projection=cartopy.crs.PlateCarree(central_longitude=-180)
    )

    ax.coastlines(resolution=res, linewidth=0.6, color="black", alpha=0.8, zorder=4)
    ax.add_feature(cpf.BORDERS, linestyle=":", alpha=0.4)
    ax.add_feature(cpf.LAND, color="lightgrey")
    ax.set_extent([lonmin, lonmax, latmin, latmax])

    xticks = np.linspace(lonmin, lonmax, 5)
    yticks = np.linspace(latmin, latmax, 5)

    ax.set_xticks(xticks, crs=cartopy.crs.PlateCarree())
    ax.set_yticks(yticks, crs=cartopy.crs.PlateCarree())
    lon_formatter = LongitudeFormatter(zero_direction_label=True)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)

    # if var_name in ["par"]:
    clb_label = "PAR ($W/m^{2}$)"
    cs = ax.contourf(
        df["lon"],
        df["lat"],
        df[var_name],  # np.where(df["H"] < 0, df["H"], np.nan), # df[var_name],
        cmap=sns.color_palette("Spectral_r", as_cmap=True),
        transform=ccrs.PlateCarree(),
    )

    if title not in ["Bathymetry"]:
        clb = plt.colorbar(cs, shrink=0.5, extend="both")

    plt.xlabel("Longitude")
    plt.ylabel("Latitude")

    # if details:
    #    plt.savefig("../../GOA-Laurel/Figures/Bottom_{}_july_sept_250m_zoomed_{}.png".format(var_name, period), dpi=300,
    #                facecolor='w',
    #                transparent=False,
    #                bbox_inches = 'tight',
    #                pad_inches = 0)
    # else:
    #    plt.savefig("../../GOA-Laurel/Figures/Bottom_{}_july_sept_250m_{}.png".format(var_name, period), dpi=300,
    #            facecolor='w',
    #                transparent=False,
    #                bbox_inches = 'tight',
    #                pad_inches = 0)

    plt.show()


def create_LME_figure(ax, LMES, projection, show, extent):
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.COASTLINE)

    ax.set_extent(extent)

    # Get the -180-180 projected shapefile containing LMEs to make it
    # easy to plot across the Pacific Ocean
    shdf = get_LME_records_plot()
    colors_rgb = create_colors(len(LMES))
    counter = 0
    for LME_NAME, LME_NUMBER in zip(shdf["LME_NAME"], shdf["LME_NUMBER"]):
        shdf_sel = shdf[shdf["LME_NAME"] == LME_NAME]

        if LME_NAME in LMES:
            # print("Adding geometry for LME {}".format(LME_NAME))
            # Add the geometry and fill it with color
            if len(LMES) == 1:
                color = "red"
            else:
                color = colors_rgb[counter]
            ax.add_geometries(
                shdf_sel["geometry"], projection, facecolor=color, edgecolor="k"
            )

            # Add the label LME_NUMBER of the selected LME at the center of the LME
            #  ax.annotate(s=LME_NUMBER,
            #              xy=(shdf_sel.centroid.x,shdf_sel.centroid.y),
            #              color="white",
            #              fontsize=13)
            counter += 1
        else:
            ax.add_geometries(
                shdf_sel["geometry"], projection, facecolor="LightGray", edgecolor="k"
            )

    if show:
        plotfile = "Figures/CMIP6_lightpaper_map_{}.png".format(LMES[0])
        print("Created figure {}".format(plotfile))
        plt.savefig(plotfile, dpi=200, bbox_inches="tight")
        plt.show()


def get_data_within_LME(ds, var_name, LME, create_maps):
    print("Working on LME: {}".format(LME))

    # Extract the polygon defining the boundaries of the LME
    shdf = get_LME_records()
    # for name in shdf['LME_NAME']:
    #     print(name)
    shdf_sel = shdf[shdf["LME_NAME"] == LME]

    # Create the map of the LME bopundaries and color it.
    # The active LME has color while the others are grey.
    if create_maps:
        # Setup the figure panels
        fig = plt.figure(figsize=(13, 8))
        if LME in ["Barents Sea", "Arctic Ocean"]:
            projection = (
                ccrs.NorthPolarStereo()
            )  # ccrs.PlateCarree(central_longitude=0)
            extent = [-20, 90, 60, 90]
        #       extent = [-180, 180, 60, 90]
        else:
            projection = ccrs.PlateCarree(central_longitude=-180)
            extent = [-252, -100, 10, 65]
            extent = [-200, -145, 40, 80]
        ax1 = fig.add_subplot(111, projection=projection)

        create_LME_figure(
            ax1, [LME], ccrs.PlateCarree(central_longitude=-180), True, extent
        )

    # Rioxarray requires x and y dimensions - we convert these back to lon and lat later.
    # We also add the projection (lat-lon) so that rioxarray can do the clipping of the data according to the
    # shapefile.

    tos = ds.rename({"lon": "x", "lat": "y"})
    tos = tos.rio.write_crs(4326)

    # Clip the data within the LME. We have to convert the polygon geometry to a geodataframe using
    # `shapely.geometry`. The clipping of data within the polygon is done using rioxarray.clip function

    clipped = tos.rio.clip(geometries=shdf_sel.geometry.apply(mapping), crs=tos.rio.crs)
    clipped = clipped.rename({"x": "lon", "y": "lat"})  # .to_dataset()

    p1 = "2000-01-01 to 2020-01-01"
    p2 = "2080-01-01 to 2020-01-01"

    create_maps = False
    if create_maps:
        clipped_p1 = clipped.sel(time=slice("2000-01-01", "2020-01-01")).mean({"time"})
        # clipped_p2=clipped.sel(time=slice("2080-01-01","2099-12-16")).mean({"time"})

        create_map(
            clipped_p1,
            "{} 2000-01-01 to 2020-01-01".format(var_name),
            var_name,
            period=p1,
            anomalies=False,
            details=False,
        )
        # create_map(clipped_p2, "{} 2080-01-01 to 2020-01-01".format(var_name), var_name, period=p2, anomalies=False, details=False)

        plt.show()
    return clipped

In [None]:
def create_summary_table(dict_of_models, LME):
    table = texttable.Texttable()
    table.set_cols_align(["c", "c", "c", "c", "c", "c", "c"])
    table.set_cols_valign(["t", "t", "m", "m", "m", "m", "b"])

    table.header(["LME", "Model", "Scenario", "ID", "Var", "CMIP6 min", "CMIP6 max"])
    for key in dict_of_models.keys():
        model = dict_of_models[key]

        table.add_row(
            [
                LME,
                model["model_name"],
                model["model_scenario"],
                model["model_ensemble_id"],
                str(model["model_var"]),
                str(model["model_var"]),
                str(model["model_var"]),
            ]
        )

    table.set_cols_width([30, 30, 20, 20, 10, 10, 10])
    print(table.draw() + "\n")

In [None]:
def write_netcdf(ds: xr.Dataset, out_file: str) -> None:
    enc = {}

    for k in ds.data_vars:
        if ds[k].ndim < 2:
            continue

        enc[k] = {
            "zlib": True,
            "complevel": 3,
            "fletcher32": True,
            "chunksizes": tuple(map(lambda x: x // 2, ds[k].shape)),
        }

    ds.to_netcdf(out_file, format="NETCDF4", engine="netcdf4", encoding=enc)

In [None]:
%%time
scenarios=["ssp245","ssp585"]
member_range=10
frequency="A"

# Create objects to reuse functionality to open files on Google Cloud storage
io = CMIP6_IO()
config = Config_albedo()

#ensemble_ids = ["r10i1p1f1", "r4i1p1f1", "r10i1p2f1", "r3i1p2f1", "r2i1p1f2",#"r4i1p1f2","r2i1p1f1"]
period="1950-01-01-2099-12-16"
start_time="1950-01-01"
end_time="2099-12-16"

# FOR RTM we are using:
# "CMCC-ESM2": ["r1i1p1f1","r1i1p2f1"]
# "CanESM5":  ["r1i1p2f1","r2i1p2f1","r9i1p2f1","r10i1p2f1","r7i1p2f1"]
# "MPI-ESM1-2-LR": ["r10i1p1f1","r1i1p1f1","r4i1p1f1","r2i1p1f1"]
# "UKESM1-0-LL": ["r1i1p1f2","r2i1p1f2","r3i1p1f2","r4i1p1f2","r8i1p1f2"]
# "MPI-ESM1-2-HR": ["r1i1p1f1","r2i1p1f1"]
        
models=["CMCC-ESM2","CanESM5", "MPI-ESM1-2-LR", "UKESM1-0-LL", "MPI-ESM1-2-HR"]
ensemble_ids=[["r1i1p2f1"],
              ["r1i1p2f1","r2i1p2f1","r9i1p2f1","r10i1p2f1","r7i1p2f1"],
              ["r10i1p1f1","r1i1p1f1","r4i1p1f1","r2i1p1f1"],
              ["r1i1p1f2","r2i1p1f2","r3i1p1f2","r4i1p1f2","r8i1p1f2"],
              ["r1i1p1f1","r2i1p1f1"]]

ds_var_names=["velocity","chl", "clt", "sithick", "siconc", "tas","chl"]
ds_var_names=["chl", "clt"] 
ds_var_names = ["prw",
            "clt",
            "uas",
            "vas",
            "chl",
            "sithick",
            "siconc",
            "sisnthick",
            "sisnconc",
            "tas",
            "tos"]

LMES=['Barents Sea','Northern Bering - Chukchi Seas']
toz_list=[]
config.source_ids = models
ensemble_ids_flat = [item for it in ensemble_ids for item in it]

config.member_ids=ensemble_ids_flat

for experiment_id in ["ssp245"]:
    io.organize_cmip6_netcdf_files_into_datasets(config, experiment_id)

    for varname in ds_var_names:
        all_ds = []
        for model in io.models:
            for mids in model.member_ids:
                for ds_var in model.ocean_vars[mids]:
                    
                    if ds_var==varname:
                    
                        all_ds.append(model.ds_sets[mids][varname])
                        
                if len(all_ds) > 0:
                    ens = ensembles.create_ensemble(all_ds, multifile=False, backend='zarr')
                    dir = config.cmip6_netcdf_dir
                    ensemble_filename = io.format_netcdf_filename(dir, "ensemble", "", experiment_id, varname)
                    Path(ensemble_filename).parent.mkdir(parents=True, exist_ok=True)

                    write_netcdf(ens, ensemble_filename)
                    io.upload_to_gcs(ensemble_filename)
                    print("ens", ens, ensemble_filename)



In [None]:

    """
    for LME in LMES:
        df_list=[]
        models_dict={}
        create_maps=False
        # We loop over all of the scenarios, ensemble_ids, and models to create a
        # list of dataframes that we eventually concatenate together and plot
        for scenario in scenarios:
            ds_list=[]
            for model in models:
                for ensemble_id in ensemble_ids:
                    io.format_netcdf_filename(dir, model_name, member_id, current_experiment_id, key)
                    if var_name in ["TOZ"]:
                        fname = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/cmip6/light/ozone-absorption/TOZ_{}.nc".format(scenario)
                        fname2=None

                    elif var_name not in ["velocity", "TOZ"]:
                        fname = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/cmip6/light/{}/{}/CMIP6_{}_{}_{}_{}.nc".format(scenario,model,
                                                                                             model,
                                                                                             ensemble_id,
                                                                                             scenario,
                                                                                             var_name)
                        fname2=None

                    elif var_name in ["velocity"]:
                        fname = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/cmip6/light/{}/{}/CMIP6_{}_{}_{}_{}.nc".format(scenario,model,
                                                                                         model,
                                                                                         ensemble_id,
                                                                                         scenario,
                                                                                         "uas")
                        fname2 = "/Users/trondkr/Library/CloudStorage/Dropbox/NIVA/oceanography/cmip6/light/{}/{}/CMIP6_{}_{}_{}_{}.nc".format(scenario,model,
                                                                                         model,
                                                                                         ensemble_id,
                                                                                         scenario,
                                                                                             "vas")
                    key="{}_{}_{}_{}".format(model,ensemble_id,scenario,var_name)
                    if var_name in ["TOZ"]:
                        key=fname

                    if key not in models_dict.keys():

                        df, models_dict, ds_lme = get_area_averaged_ds(fname, model,scenario, ensemble_id,var_name, LME, create_maps, frequency, models_dict,fname2)
                        create_maps=False
                        if ds_lme is not None:
                          #  ds_lme=xr.where( ((ds_lme < 1.e-3)| (ds_lme > 1e3)), np.nan, ds_lme)
                           # ds_lme=xr.where(ds_lme < 1, np.nan, ds_lme)
                            ds_list.append(ds_lme)

                        if df is not None:

                            df_list.append(df)
                            if var_name in ["TOZ"]:
                                toz_list.append(df)
                            print("Created dataframe of file: {}".format(fname))

            if len(ds_list) > 0:
                ens = ensembles.create_ensemble(ds_list).load()
                ens.close()
                ens_stats = ensembles.ensemble_mean_std_max_min(ens)

                outfile = "Figures/{}_ensemble_{}_{}.png".format(var_name.capitalize(),scenario, LME)

              #  CMIP6_ridgeplot.ridgeplot("{}_mean".format(var_name),
              #                            None, outfile,
              #                                    glorys=False, depth_threshold=None,
              #                                    ds=ens_stats)


        if len(df_list) > 0:
            df = pd.concat(df_list)

            create_summary_table(models_dict, LME)
            df = df.reindex()
            df["date"]=df.index

            if os.path.exists("test.csv"):os.remove("test.csv")
            df.to_csv("test.csv")
            df = pd.read_csv('test.csv', parse_dates=['time', 'date'])
            
            
            sns.set(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}, font_scale=1.5)
            f=plt.figure(figsize=(16, 16))
            gs = f.add_gridspec(2, 1)
            ax = f.add_subplot(gs[0, 0])

            sns.set_palette(["#8172B3", "#64B5CD"])
            legend_on=True if var_name=="tas" else False

            b=sns.lineplot(ax=ax, data=df, x=df["date"], y=df["roll_mean"],
                         hue=df["model_scenario"],palette= ["#8172B3", "#64B5CD"],
                         alpha=.95, ci=95,linewidth=5, legend=legend_on)

            cycle_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
            df_ssp245=df[df["model_scenario"]=="ssp245"]
            ax.fill_between(df_ssp245["date"], df_ssp245["roll_mean"] - df_ssp245["roll_std"], df_ssp245["roll_mean"] + df_ssp245["roll_std"], color=cycle_colors[0], alpha=0.2)
            df_ssp585=df[df["model_scenario"]=="ssp585"]
            ax.fill_between(df_ssp585["date"], df_ssp585["roll_mean"] - df_ssp585["roll_std"], df_ssp585["roll_mean"] + df_ssp585["roll_std"], color=cycle_colors[1], alpha=0.2)
            print(df_ssp245.head())
            print("ssp245",((df_ssp245["roll_mean"][-1]-df_ssp245["roll_mean"][0])/df_ssp245["roll_mean"][0])*100.)
            print("ssp585",((df_ssp585["roll_mean"][-1]-df_ssp585["roll_mean"][0])/df_ssp585["roll_mean"][0])*100.)

            b.tick_params(labelsize=38)
            b.set_xlabel("",fontsize=34)
            b.set_ylabel("",fontsize=34)
            import matplotlib.dates as mdates
            if var_name=="tas":
                plt.legend(loc="upper left", frameon=False, fontsize=32)

            ax.xaxis.set_major_locator(mdates.YearLocator(base=10))
            ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
            plt.setp(ax.xaxis.get_majorticklabels(), rotation=-90)

            if not os.path.exists("Figures"):
                os.makedirs("Figures")
            plotfile="Figures/CMIP6_lightpaper_{}_{}.png".format(var_name,LME)
            print("Created figure {}".format(plotfile))
            plt.savefig(plotfile, dpi=300,
                        bbox_inches = 'tight')

            plt.show()

        df=pd.DataFrame()
        ds_list=[]
"""

In [None]:
# Calculate the relative change in chlorophyll production between 1980-2000 and
# and 2080-2099 for each model scenario.
df = pd.read_csv("test.csv", parse_dates=["time", "date"])

df_ssp245 = df[df["model_scenario"] == "ssp245"].dropna()
df_ssp585 = df[df["model_scenario"] == "ssp585"].dropna()
df_ssp585 = df_ssp585.reset_index()
clim245 = (df_ssp245["roll_mean"].loc["1980-01-01":"2000-01-01"]).mean(skipna=True)
clim585 = (df_ssp585["roll_mean"].loc["1980-01-01":"2000-01-01"]).mean(skipna=True)
std = (df_ssp245["roll_mean"].loc["2080-01-01":"2099-01-01"]).mean(skipna=True) + (
    df_ssp245["roll_std"].loc["2080-01-01":"2099-01-01"]
).mean(skipna=True)
std585 = (df_ssp585["roll_mean"].loc["2080-01-01":"2099-01-01"]).mean(skipna=True) + (
    df_ssp585["roll_std"].loc["2080-01-01":"2099-01-01"]
).mean(skipna=True)

df_ssp245["rel_change_std"] = ((std - float(clim245)) / float(clim245)) * 100.0
df_ssp245["rel_change"] = (
    (
        (df_ssp245["roll_mean"].loc["2080-01-01":"2099-01-01"]).mean(skipna=True)
        - float(clim245)
    )
    / float(clim245)
) * 100.0

df_ssp585["rel_change_std"] = ((std585 - float(clim245)) / float(clim245)) * 100.0
df_ssp585["rel_change"] = (
    (
        df_ssp585["roll_mean"].loc["2080-01-01":"2099-01-01"].mean(skipna=True)
        - float(clim245)
    )
    / float(clim245)
) * 100.0
print(df_ssp245[["time", "rel_change", "rel_change_std"]])
print(df_ssp585[["time", "rel_change", "rel_change_std"]])
sns.lineplot(y="roll_mean", x="time", data=df_ssp245)
sns.lineplot(y="roll_mean", x="time", data=df_ssp585)