# Imports

In [None]:
%matplotlib inline

import xarray as xr
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
from matplotlib import axes
import numpy as np

import plotly.express as px

import seaborn as sns
sns.set_theme(style="whitegrid")
# sns.set(rc={"figure.dpi":200,})
# sns.set(rc={"figure.dpi":300, 'savefig.dpi':300})
sns.set_context('notebook')

import warnings
warnings.simplefilter('ignore')

from sklearn.preprocessing import minmax_scale

import random
import time

# Function

In [None]:
def mask_land(mask_df, data_df):
    data_df = data_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    data_df = data_df.loc[data_df['tmask'] == 1] ##--> comment to keep the land points
    return data_df

In [None]:
def round_nav_lat(df):
    '''
    Round up the coordinates to 2 decimal places
    '''
    df['nav_lat'] = df['nav_lat'].apply(lambda x:round(x,2))
    df['nav_lon'] = df['nav_lon'].apply(lambda x:round(x,2))
    return df

In [None]:
def get_year_month(df, yrmonth):
    df['time_counter'] = df['time_counter'].astype("string")
    df = df.loc[df['time_counter'].str.contains(yrmonth, case=False)]
    return df.reset_index()

In [None]:
def get_mean_over_lat_lon_yrmonth(df, yrmonth, col_name):
    '''
    Select a particular Month
    Get the mean for each pair of coordinate
    '''
    df['time_counter'] = df['time_counter'].astype("string")
    df = df.loc[df['time_counter'].str.contains(yrmonth, case=False)]
    df = df.groupby(['nav_lat','nav_lon'], as_index=False)[col_name].mean()
    return df

In [None]:
def get_mean_over_lat_lon(df, col_name):
    '''
    Get the mean for each pair of coordinate
    '''
    df = df.groupby(['nav_lat','nav_lon','time_counter'], as_index=False)[col_name].mean()
    return df

In [None]:
def remove_land_points(df, mask_df=None):

    if mask_df is None:
        path_mask = '../dataset/mesh_mask.nc' 
        mask = xr.open_dataset(path_mask,chunks={"z":1, "y":100, "x":100}).tmask.isel(z=0).squeeze()
        mask_df = mask.to_dataframe()

    df = df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    df = df.loc[sst_df['tmask'] == 1]
    
    return df

In [None]:
'''
https://stackoverflow.com/questions/45122849/best-way-to-separate-range-into-n-equal-ranges-in-python 
'''

def make_cells(bound_min,bound_max, cell_size):
    step = abs(bound_min-bound_max) / cell_size
    return [(round(step*i), round(step*(i+1))) for i in range(cell_size)]

# Data Load

## Data path

In [None]:
## path for ocean model output and land mask on local machine

path = '../ocean_model/output/' 

## OR, Load the model from HLRN
path = "/scratch/usr/shklvn09/SCRATCH/ORCA025.L46.LIM2vp.CFCSF6.MOPS.JRA.LP04-KLP002.hind/"

path_mask = '../ocean_model/mask/mesh_mask_orca_025.nc'

## Load mask

In [None]:
path_mask = f"{path}/mask/mesh_mask_orca_025.nc" 
mask = xr.open_dataset(path_mask,chunks={"z":1, "y":100, "x":100}).tmask.isel(z=0).squeeze()
e1t = xr.open_dataset(path_mask,chunks={"z":1, "y":100, "x":100}).e1t.squeeze()
e2t = xr.open_dataset(path_mask,chunks={"z":1, "y":100, "x":100}).e2t.squeeze()
mask_df = mask.to_dataframe()
mask_df['e1t'] = e1t.to_dataframe()['e1t']
mask_df['e2t'] = e2t.to_dataframe()['e2t']
mask_df

## Load ocean model output

In [None]:
years = range(1958,2019)

In [None]:
%%time

## Loop over 61 years

for yr in years:
    
    print(f"\n---> Running {yr}")
    
    ## extract SST
    file_sst = glob(f"{path}*sosstsst.nc")
    sst_xr = xr.open_dataset(file_sst[0], chunks={"z":1, "y":100, "x":100})
    
    ## extract DIC pre-ind
    file_dicp = glob(f"{path}*DICP_k1.nc")
    dicp_xr = xr.open_dataset(file_dicp[0], chunks={"z":1, "y":100, "x":100})

    ## extract ALK
    file_alk = glob(f"{path}*ALK_k1.nc")
    alk_xr = xr.open_dataset(file_alk[0], chunks={"z":1, "y":100, "x":100})
    
    ## extract SAL
    file_sal = glob(f"{path}/demo/ORCA025*{yr}*sosaline*.nc")
    sal_xr = xr.open_dataset(file_sal[0], chunks={"z":1, "y":100, "x":100})
    
#     ## extract ice-fraction
#     file_icemod = glob(f"{path}/demo/ORCA025*{yr}*icemod*.nc")
#     icemod_xr = xr.open_dataset(file_icemod[0], chunks={"z":1, "y":100, "x":100}) 
    
    ## extract mixed layer depth
    file_mld = glob(f"{path}/demo/ORCA025*{yr}*somxl010*.nc")
    mld_xr = xr.open_dataset(file_mld[0], chunks={"z":1, "y":100, "x":100}) 
    
    ## extract fco2 pre-ind  
    file_fco2 = glob(f"{path}/demo/ORCA025*{yr}*fco2_pre.nc")
    fco2_xr = xr.open_dataset(file_fco2[0], chunks={"z":1, "y":100, "x":100})
    
    ## extract the variables
    sst_df = sst_xr.sosstsst.squeeze().to_dataframe().reset_index(level=['time_counter'])
    dicp_df = dicp_xr.DICP.squeeze().to_dataframe().reset_index(level=['time_counter'])
    alk_df = alk_xr.ALK.squeeze().to_dataframe().reset_index(level=['time_counter'])
    sal_df = sal_xr.sosaline.squeeze().to_dataframe().reset_index(level=['time_counter'])
    mld_df = mld_xr.somxl010.squeeze().to_dataframe().reset_index(level=['time_counter'])
    fco2_df = fco2_xr.fco2_pre.squeeze().to_dataframe().reset_index(level=['time_counter'])
    
    
    ## apply mask
    sst_df = sst_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    sst_df = sst_df.loc[sst_df['tmask'] == 1]

    dicp_df = dicp_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    dicp_df = dicp_df.loc[dicp_df['tmask'] == 1]

    alk_df = alk_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    alk_df = alk_df.loc[alk_df['tmask'] == 1]

    sal_df = sal_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    sal_df = sal_df.loc[sal_df['tmask'] == 1]
    
    mld_df = mld_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    mld_df = mld_df.loc[mld_df['tmask'] == 1]
    
    fco2_df = fco2_df.merge(mask_df, left_index=True, right_on=['y', 'x'])
    fco2_df = fco2_df.loc[fco2_df['tmask'] == 1]
    
    ## put the varibales in one table
    sosstsst_df = get_mean_over_lat_lon(df=sst_df, col_name='sosstsst')
    e1t_df = get_mean_over_lat_lon(df=sst_df, col_name='e1t')
    e2t_df = get_mean_over_lat_lon(df=sst_df, col_name='e2t')
    
    sal_df = get_mean_over_lat_lon(df=sal_df, col_name='sosaline')
    mld_df = get_mean_over_lat_lon(df=mld_df, col_name='somxl010')
    dicp_df = get_mean_over_lat_lon(df=dicp_df, col_name='DICP')
    alk_df = get_mean_over_lat_lon(df=alk_df, col_name='ALK')
    fco2_df = get_mean_over_lat_lon(df=fco2_df, col_name='fco2_pre')
    
    data_df = sosstsst_df.rename(columns={"sosstsst": "SST",})
    data_df['DICP'] = dicp_df['DICP']
    data_df['ALK'] = alk_df['ALK']
    data_df['fco2_pre'] = fco2_df['fco2_pre']
    data_df['SAL'] = sal_df['sosaline']
    data_df['MLD'] = mld_df['somxl010']
    data_df['e1t'] = e1t_df['e1t']
    data_df['e2t'] = e2t_df['e2t']
    
    data_df.to_pickle(f"../carbon_data_preprocessed/ocean_data_{yr}_df.pkl")
    
    print(f"\n---> {yr} processed.\n")
    print()
#     time.sleep(2)