# CMIP6 Statistics and Plots

**Following steps are included in this script:**

1. Load netCDF files
2. Compute statistics
3. Plot statistics

In [None]:
# ========== Packages ==========
import xarray as xr
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import dask
from matplotlib import rcParams
import math
import multiprocessing as mp
from cftime import DatetimeNoLeap
import glob
from scipy.stats import pearsonr
from matplotlib import colors

%matplotlib inline

rcParams["mathtext.default"] = 'regular'

### Functions

In [None]:
def compute_statistic_single(ds, statistic, dimension, yearly_mean=False):
    if dimension == "time":
        stat_ds = getattr(ds, statistic)("time", keep_attrs=True, skipna=True)
        stat_ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
    if dimension == "space":
        # Assign the period attribute before grouping by year
        ds.attrs['period'] = [str(ds.time.dt.year[0].values), str(ds.time.dt.year[-1].values)]
        
        if yearly_mean:
            ds = ds.groupby('time.year').mean('time', keep_attrs=True, skipna=True)
            ds.attrs['mean'] = 'yearly mean'
            
        #get the weights, apply on data, and compute statistic
        weights = np.cos(np.deg2rad(ds.lat))
        weights.name = "weights"
        ds_weighted = ds.weighted(weights)
        stat_ds = getattr(ds_weighted, statistic)(("lon", "lat"), keep_attrs=True, skipna=True)
    
    stat_ds.attrs['statistic'] = statistic
    stat_ds.attrs['statistic_dimension'] = dimension

    return stat_ds

In [None]:
def select_period(ds_dict, start_year=None, end_year=None):
    '''
    Helper function to select periods.
    
    Parameters:
    ds_dict (dict): Dictionary with xarray datasets.
    start_year (int): The start year of the period.
    end_year (int): The end year of the period.
    """
    '''
    start_year = DatetimeNoLeap(start_year, 1, 16, 12, 0, 0, 0,has_year_zero=True) # 16th of January of start year
    end_year = DatetimeNoLeap(end_year, 12, 16, 12, 0, 0, 0, has_year_zero=True) # 16th of December of end year
    ds_dict = {k: v.sel(time=slice(start_year, end_year)) for k, v in ds_dict.items()}
    
    return ds_dict

In [None]:
def compute_statistic(ds_dict, statistic, dimension, start_year=None, end_year=None, yearly_mean=True):
    """
    Computes the specified statistic for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        statistic (str): The statistic to compute, which can be one of 'mean', 'std', 'min', 'var', or 'median'.
        dimension (str): The dimension to compute over, which can be 'time' or 'space'.
        start_year (str, optional): The start year of the period to compute the statistic over.
        end_year (str, optional): The end year of the period to compute the statistic over.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if statistic not in ["mean", "std", "min", "max", "var", "median"]:
        raise ValueError(f"Invalid statistic '{statistic}' specified.")
    if dimension not in ["time", "space"]:
        raise ValueError(f"Invalid dimension '{dimension}' specified.")

    # Select period
    if start_year is not None and end_year is not None:
        ds_dict = select_period(ds_dict, start_year=start_year, end_year=end_year)
        
    # Use multiprocessing to compute the statistic for each dataset in parallel
    with mp.Pool() as pool:
        results = pool.starmap(compute_statistic_single, [(ds, statistic, dimension, yearly_mean) for ds in ds_dict.values()])

    return dict(zip(ds_dict.keys(), results))

In [None]:
def check_args_and_get_info(ds_dict, variable):
    # Check the validity of input arguments
    if not isinstance(ds_dict, dict):
        raise TypeError("ds_dict must be a dictionary of xarray datasets.")
    if not all(isinstance(ds, xr.Dataset) for ds in ds_dict.values()):
        raise TypeError("All values in ds_dict must be xarray datasets.")
    if not isinstance(variable, str):
        raise TypeError('variable must be a string.')
        
    # Dictionary to store plot titles for each statistic
    titles = {"mean": "Mean", "std": "Standard deviation of yearly means", "min": "Minimum", "max": "Maximum", "median": "Median", "time": "Time", "space": "Space"}
    freq = {"mon": "Monthly"}
   
    # Data information
    var_long_name = ds_dict[list(ds_dict.keys())[0]][variable].long_name
    period = f"{ds_dict[list(ds_dict.keys())[0]].attrs['period'][0]}-{ds_dict[list(ds_dict.keys())[0]].attrs['period'][1]}"
    experiment_id =  ds_dict[list(ds_dict.keys())[0]].experiment_id
    unit = ds_dict[list(ds_dict.keys())[0]][variable].units
    statistic_dim = ds_dict[list(ds_dict.keys())[0]].statistic_dimension
    statistic = ds_dict[list(ds_dict.keys())[0]].attrs['statistic']
    frequency = freq[ds_dict[list(ds_dict.keys())[0]].frequency]

    return var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency

In [None]:
def plot_line_statistic(ds_dict, variable, log_scale=False, add_regression=True, save_fig=False, file_format='png', smooth_window=1):
    """
    Plots a line plot of the specified statistic of the given variable for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variable (str): The name of the variable to plot.
        log_scale (bool): If True, plot the data on a log scale. Default is False.
        add_regression (bool): If Ture, compute regression and plot it. Default is True.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.
        smooth_window (int): Window size for rolling mean smoothing. Default is 1 (no smoothing).

    Returns:
        str: The file path where the figure was saved.
    """
    
    # Check arguments and get info
    var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles = check_args_and_get_info(ds_dict, variable)
    
    fig, ax = plt.subplots(figsize=(30, 15))
        
    # Initialize a list to store the DataArrays from each dataset
    data_list = []

    # Loop over datasets and plot the requested statistic in a single figure
    for i, (name, ds) in enumerate(ds_dict.items()):
        if variable not in ds:
            print(f"Variable '{variable}' not found in dataset '{name}', skipping.")
            continue

        data_to_plot = ds[variable].squeeze()

        # Apply log-scale
        if log_scale:
            data_to_plot = np.log10(data_to_plot)

        # Apply smoothing
        if smooth_window > 1:
            data_to_plot = data_to_plot.rolling(time=smooth_window, center=True).mean()

        data_list.append(data_to_plot)  # Append the transformed data to the list

        # Plot the data and get the color of the line
        #sns.lineplot(x='time', y=data_to_plot, ax=ax, label=None)
        data_lines = data_to_plot.plot.line(x='year', ax=ax, label=None)
        #data_lines = data_to_plot.plot.line(x='time', ax=ax, label=None)
        data_color = data_lines[0].get_color()

        if add_regression:

            # Create regression model
            #trend = data_to_plot.polyfit(dim='time', deg=1)
            trend = data_to_plot.polyfit(dim='year', deg=1)

            # Fit regression model to data and add it to plot 
            regression_line = xr.polyval(ds['year'], trend).squeeze()
            #regression_line = xr.polyval(ds['time'], trend).squeeze()
            #regression_line['polyfit_coefficients'].plot(x='time', ax=ax, color=data_color)
            #regression_line['polyfit_coefficients'].plot(x='year', ax=ax, linestyle='--', color=data_color)
            #sns.lineplot(x='time', y=regression_line['polyfit_coefficients'], ax=ax, color=data_color)

            # Calculate percentage of change
            first_value = regression_line['polyfit_coefficients'].isel(year=0).item()
            #first_value = regression_line['polyfit_coefficients'].isel(time=0).item()
            last_value = regression_line['polyfit_coefficients'].isel(year=-1).item()
            #last_value = regression_line['polyfit_coefficients'].isel(time=-1).item()
            percentage_change = ((last_value - first_value) / first_value) * 100

            # Add the percentage change to the legend label
            ax.plot([], [], color=data_color, label=f"{name} ({percentage_change:.2f}%)")


    # Set the x and y axis labels
    ax.set_xlabel('Year')
    if log_scale:
        ax.set_ylabel(f"{var_long_name} [{unit}] - log-scale")
    else:
        ax.set_ylabel(f"{var_long_name} [{unit}]")

    # Calculate the ensemble mean
    ensemble_mean = xr.concat(data_list, dim='dataset').mean(dim='dataset')

    if add_regression:

        # Create regression model
        #trend = ensemble_mean.polyfit(dim='time', deg=1)
        trend = ensemble_mean.polyfit(dim='year', deg=1)

        # Fit regression model to data and add it to plot 
        regression_line = xr.polyval(ds['year'], trend).squeeze()
        #regression_line = xr.polyval(ds['time'], trend).squeeze()
        #regression_line['polyfit_coefficients'].plot(x='year', ax=ax, linestyle='--', color=data_color)
        #regression_line['polyfit_coefficients'].plot(x='time', ax=ax, color=data_color)
        #sns.lineplot(x='time', y=regression_line['polyfit_coefficients'], ax=ax, color=black)

        # Calculate percentage of change
        #first_value = regression_line['polyfit_coefficients'].isel(time=0).item()
        #last_value = regression_line['polyfit_coefficients'].isel(time=-1).item()
        first_value = regression_line['polyfit_coefficients'].isel(year=0).item()
        last_value = regression_line['polyfit_coefficients'].isel(year=-1).item()
        ens_percentage_change = ((last_value - first_value) / first_value) * 100

    # Plot the ensemble mean with a different line style and/or color
    #ensemble_mean.plot.line(x='time', ax=ax, linestyle='--', color='black', label=f"Ensemble mean ({ens_percentage_change:.2f}%)") 
    ensemble_mean.plot.line(x='year', ax=ax, linestyle='--', color='black', label=f"Ensemble mean ({ens_percentage_change:.2f}%)") 
    #sns.lineplot(x='time', y=ensemble_mean, ax=ax, linestyle='--', color='black', label=f"Ensemble mean ({ens_percentage_change:.2f}%)") 

    # Add a legend
    ax.legend()
    ax.grid(True)
    
    # Set figure title with first and last year of dataset
    if smooth_window>1:
        fig.suptitle(f"{titles[statistic_dim]} {titles[statistic]} of {var_long_name} ({period}) ({smooth_window}-year running mean)", fontsize=26, y=1.0)
    else:
        fig.suptitle(f"{titles[statistic_dim]} {titles[statistic]} of {var_long_name} ({period})", fontsize=26, y=1.0)
    
    # Show plot
    plt.show()
    
    # Safe figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, statistic_dim, statistic, 'line_plots')
        os.makedirs(savepath, exist_ok=True)

        if smooth_window>1:
            if log_scale:
                filename = f'{statistic_dim}.{statistic}.{period}.{variable}.{ds.experiment_id}.{smooth_window}-year_running_mean.log_scale.{file_format}'
            else:
                filename = f'{statistic_dim}.{statistic}.{period}.{variable}.{ds.experiment_id}.{smooth_window}-year_running_mean.{file_format}'
        else:
            if log_scale:
                filename = f'{statistic_dim}.{statistic}.{period}.{variable}.{ds.experiment_id}.log_scale.{file_format}'
            else:
                filename = f'{statistic_dim}.{statistic}.{period}.{variable}.{ds.experiment_id}.{file_format}'


        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

In [None]:
def plot_sm_profile(ds_depth, save_path='../results/CMIP6/historical/', save_name='soil_moisture_profile.png', save_fig=False, xlim_bound=3, ylim_bound=1000):
    """
    Plots soil moisture profile.

    Args:
        ds_depth (dict): A dictionary of xarray datasets with depth and mean mean soil water content per layer (mrsol).
        save_path (string): Path saved figure. 
        save_name (string): Name of saved figure.
        save_fig (bool): If True, save the figure to a file. Default is False.
        xlim_bound (float): A value to set the max for the x-axis. Default is 3.
        ylim_bound (float): A value to set the max for the y-axis. Default is 1000.
    """
    
    fig, ax = plt.subplots(figsize=(30, 15))

    plt.xlim(0, xlim_bound)
    plt.ylim(0, ylim_bound)

    # Define the marker size for the plot
    marker_size = 150

    for i, (name, ds) in enumerate(ds_depth.items()):

        data_to_plot = ds.squeeze()
        data_lines = ax.plot(data_to_plot['depth'], data_to_plot.variable, linestyle='--', label=f"{name}")
        data_color = data_lines[0].get_color()
        data_markers = data_to_plot.plot.scatter(x='depth', y='variable', s=marker_size, c=data_color, ax=ax, label=None)

    plt.legend(fontsize=20)

    if save_fig:
        fig.savefig(f'{save_path}{save_name}', dpi=300)

In [None]:
def soil_moisture_profile(ds_dict, var='mrsol', save_path='../results/CMIP6/historical/', save_name='soil_moisture_profile.png', plot_fig=True, save_fig=False, xlim_bound=3, ylim_bound=1000):
    """
    Plots soil moisture profile.

    Args:
        ds_depth (dict): A dictionary of xarray datasets for computing the and mean soil water content per layer (mrsol).
        save_path (string): Path saved figure. Default is '../results/CMIP6/historical/'.
        save_name (string): Name of saved figure.
        plot_fig (bool): If True, plot the figure. Default is True.
        save_fig (bool): If True, save the figure to a file. Default is False. plot_fig has to be True as well to save figure.
        xlim_bound (float): A value to set the max for the x-axis. Default is 3.
        ylim_bound (float): A value to set the max for the y-axis. Default is 1000.

    Returns:
        dict: A dictionary with computed statistic for each dataset.
    """
    
    ds_depth = {}
    
    for i, (name, ds) in enumerate(ds_dict.items()):
        
        mean_time = getattr(ds[var], 'mean')("time", keep_attrs=True, skipna=True)
        mean_time_space = getattr(mean_time, 'mean')(("lon", "lat"), keep_attrs=True, skipna=True)
        ds_depth[ds.source_id] = mean_time_space
    
    if plot_fig:
        plot_sm_profile(ds_depth, save_path=save_path, save_name=save_name, save_fig=save_fig, xlim_bound=xlim_bound, ylim_bound=ylim_bound)

    return ds_depth

In [None]:
# ========= Helper function to open the dataset ========
def open_dataset(filename):
    ds = xr.open_dataset(filename)
    return ds

In [None]:
# Define a helper function to open and merge datasets
def open_and_merge_datasets(folder, model, experiment_id, variables):
    filepaths = []
    for var in variables:
        path = f'../../data/CMIP6/{experiment_id}/{folder}/{var}'
        fp = glob.glob(os.path.join(path, f'CMIP.{model}.{experiment_id}.{var}_regridded.nc'))
        if fp:
            filepaths.append(fp[0])
        else:
            #print(f"No file found for variable '{var}' in model '{model}'.")
            print(fp)

    datasets = [xr.open_dataset(fp) for fp in filepaths]
    ds = xr.merge(datasets)
    return ds

In [None]:
def compute_ensemble(ds_dict):
    # Combine all datasets into one larger dataset
    combined = xr.concat(ds_dict.values(), dim='ensemble')
    # Compute the ensemble metric
    ds_dict['Ensemble mean'] = getattr(combined, 'mean')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    # Preserve variable attributes from the original dataset
    for var in ds_dict['Ensemble mean'].variables:
        ds_dict['Ensemble mean'][var].attrs = ds_dict[list(ds_dict.keys())[0]][var].attrs
    
    ds_dict['Ensemble mean'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "mean", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": ds_dict[list(ds_dict.keys())[0]].attrs['experiment'], 
                           "source_id" : f"Ensemble mean"} 
    
    ds_dict['Ensemble median'] = getattr(combined, 'median')(dim='ensemble')#, skipna=True) # use getattr to call method by string name
    
    ds_dict['Ensemble median'].attrs = {"period" : ds_dict[list(ds_dict.keys())[0]].attrs['period'],
                           "statistic" : "median", # use variable metric here
                           "statistic_dimension" : "time",
                           "experiment_id": ds_dict[list(ds_dict.keys())[0]].attrs['experiment'], 
                           "source_id" : f"Ensemble median"} 
    
    return ds_dict

In [None]:
def plot_map_statistic(ds_dict, variable, n_cols=4, cbar_min=0, cbar_max=0.75, cmap='viridis', save_fig=False, log_scale=False, file_format='png'):
    """
    Plots a map of the specified statistic of the given variable for each dataset in the dictionary.

    Args:
        ds_dict (dict): A dictionary of xarray datasets, where each key is the name of the dataset
            and each value is the dataset itself.
        variable (str): The name of the variable to plot.
        n_cols (int): The number of columns for the subplots. Default is 4.
        cbar_min (float): A value to set vmin by multiplying with the variables minimum value across the dataset. Default is 0.
        cbar_max (float): A value to set vmax by multiplying with the variables maximum value across the dataset. Default is 0.75.
        cmap (str): The name of the colormap to use for the plot. Default is 'viridis'.
        save_fig (bool): If True, save the figure to a file. Default is False.
        file_format (str): The format of the saved figure. Default is 'png'.

    Returns:
        str: The file path where the figure was saved.
    """
    # Check arguments and get info
    var_long_name, period, unit, statistic_dim, statistic, experiment_id, titles, frequency = check_args_and_get_info(ds_dict, variable) 

    # Calculate vmin and vmax
    temp_dim_ds = xr.concat([ds[variable] for ds in ds_dict.values() if variable in ds], dim='temp_dim', coords='minimal')
    print(f"Min: {round(float(temp_dim_ds.min()))}")
    print(f"Max: {round(float(temp_dim_ds.max()))}")

    vmin = round(float(temp_dim_ds.min())) * cbar_min
    vmax = round(float(temp_dim_ds.max()), -int(math.floor(math.log10(abs(float(temp_dim_ds.max())))))) * cbar_max

    # Number of datasets that contain the variable
    n_datasets_with_var = sum([1 for ds in ds_dict.values() if variable in ds])

    # Compute the required number of rows for the subplots
    n_rows = np.ceil(n_datasets_with_var / n_cols).astype(int)

    # Create a figure
    fig = plt.figure(figsize=(12 * n_cols, 6.5 * n_rows))

    # Create a GridSpec with specified width and height space
    gs = plt.GridSpec(n_rows, n_cols, figure=fig, wspace=0.1, hspace=0.1)

    axes = []
    for i in range(n_rows):
        for j in range(n_cols):
            ax = fig.add_subplot(gs[i, j], projection=ccrs.Robinson())
            axes.append(ax)

    subplot_counter = 0
    # Loop over datasets and plot the requested statistic
    for name, ds in ds_dict.items():
        if variable not in ds:
            print(f"Variable '{variable}' not found in dataset '{name}', skipping.")
            continue

        # plot the variable in subplot
        ax = axes[subplot_counter]
        data_to_plot = ds[variable]
        
        # Add condition for log_scale colorbar
        if log_scale:
            norm = colors.LogNorm(vmin=vmin, vmax=vmax)
        else:
            norm = None
            
        im = data_to_plot.plot(ax=ax, vmin=-1, vmax=1, cmap=cmap, extend='max', transform=ccrs.PlateCarree(), add_colorbar=False, norm=norm)
        ax.set_title(name, fontsize=18)
        ax.coastlines()

        subplot_counter += 1

    # Add a common colorbar at the bottom of the plots
    cbar = fig.colorbar(im, ax=axes, orientation='vertical', fraction=0.1, pad=0.03, aspect=20, shrink=0.5)
    # Set tick size
    cbar.ax.tick_params(labelsize=20)  # Adjust size as needed
    # Set colorbar label
    cbar.set_label(f"{var_long_name} [{unit}]", size=26)  # Adjust size as needed
    
    # Set figure title 
    fig.suptitle(f"{titles[statistic_dim]} {titles[statistic]} of {frequency} {var_long_name} ({period})", fontsize=30, y=0.95)
    
    # Show plot
    plt.show()

    # Save figure
    if save_fig:
        savepath = os.path.join('../..', 'results', 'CMIP6', experiment_id, statistic_dim, statistic, 'maps')
        os.makedirs(savepath, exist_ok=True)
        filename = f'{statistic_dim}.{statistic}.{period}.{variable}.{experiment_id}.{file_format}'
        filepath = os.path.join(savepath, filename)
        fig.savefig(filepath, dpi=300)
    else:
        filepath = 'Figure not saved. If you want to save the figure add save_fig=True to the function call'
        
    return filepath

### 1. Load netCDF files

In [None]:
# ========= Define period, models and path ==============
variable=['pr', 'tran', 'mrro']#['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'mrso_1m', 'mrso_2m', 'tran', 'lai', 'gpp', 'wue', 'EI']
experiment_id = 'historical'
source_id = ['TaiESM1', 'BCC-CSM2-MR',  'CanESM5', 'CNRM-CM6-1', 'CNRM-ESM2-1', 'IPSL-CM6A-LR', 'UKESM1-0-LL', 'MPI-ESM1-2-LR', 'CESM2-WACCM', 'NorESM2-MM']#, 'Ensemble mean', 'Ensemble median'] #
folder='preprocessed'

# ========= Use Dask to parallelize computations ==========
dask.config.set(scheduler='processes')

# Create dictionary using a dictionary comprehension and Dask
ds_dict = dask.compute({model: open_and_merge_datasets(folder, model, experiment_id, variable) for model in source_id})[0]

In [None]:
# ============= Have a look into the data ==============
print(ds_dict.keys())
ds_stat[list(ds_dict.keys())[3]].nbwfp.plot()

### 2. Compute statistics

In [None]:
# ========= Compute statistic for plot ===============
ds_stat = compute_statistic(ds_dict, 'mean', 'time', start_year=1985, end_year=2014, yearly_mean=False)

In [None]:
# ======== Compute Ensemble mean/median ==============
ds_stat = compute_ensemble(ds_stat)

In [None]:
for model, ds in ds_stat.items():
    nbwfp = (ds['mrro'] - ds['tran']) / ds['pr']
    
    # Replace infinite values with NaN
    nbwfp = xr.where(np.isinf(nbwfp), float('nan'), nbwfp)

    ds['nbwfp'] = nbwfp
    ds['nbwfp'].attrs = {'long_name': 'Net Blue Water Flux / Precipitation',
                         'units': ''}

### Plot Maps

In [None]:
variable=['tas', 'pr', 'vpd', 'evspsbl', 'mrro', 'lmrso_1m', 'lmrso_2m', 'mrso_1m', 'mrso_2m', 'tran', 'lai', 'gpp', 'wue', 'EI']

In [None]:
# ========= Plot the computed data ===================
#tas - 'coolwarm' (1 - 1) // pr - 'Blues' (1 - 0.05) // vpd - 'YlOrBr' (1 - 1)// evspsbl - 'BuPu' (1 - 0.65)// mrro - 'RdPu' (0 - 0.1)
#lmrso_1m / lmrso_2m - 'YlGnBu' (0 - 0.5) // tran - 'YlGn'  (0 - 0.65) // lai - 'YlGn'  (0 - 0.65) // gpp - 'YlGn' (0 - 0.65)
#wue - YlGn (0 - 0.1) // EI - YlGn (1 - 0.8) // nbwfp - PiYG 

plot_map_statistic(ds_stat, 'nbwfp', n_cols=4, cbar_min=1, cbar_max=1, cmap='PiYG_r', save_fig=True, log_scale=False, file_format='png')

### Plot line plot

In [None]:
plot_line_statistic(ds_stat, variable, log_scale=False, add_regression=True, save_fig=False, file_format='png', smooth_window=1)