In [None]:
import numpy as np
import pandas as pd
import xarray as xr

import cartopy.crs as ccrs
import cartopy.feature as cfeature

import matplotlib.pyplot as plt

%matplotlib inline

# To avoid warning messages
import warnings
warnings.filterwarnings('ignore')

## To apply land-sea mask

In [None]:
land_sea_mask = xr.open_dataset("/discover/nobackup/projects/gmao/advda/sakella/future_sst_fraci/gen_daily_clim_data/data/geos_fp_bcs_land_sea_mask.nc")
my_mask = land_sea_mask.land_mask.values

### Local functions

In [None]:
# file names: real data and daily climatology
def get_files_names(dates, data_path, file_pref, clim=False, file_suff=".nc"):
   files_to_read = []
   for idate in dates:

     if (clim==False): # real data
       ff = data_path + str(idate.year) + "/" +\
            file_pref + str(idate.year) + str(idate.month).zfill(2) + str(idate.day).zfill(2) +\
            file_suff
     else:
       ff = data_path + "/"+\
            file_pref + "0001" + str(idate.month).zfill(2) + str(idate.day).zfill(2) +\
            file_suff

     #print(ff)
     files_to_read.append(ff)
   return files_to_read

# We need to mask out land
def apply_mask( input_field, mask, tol=0.1):
  output_field = np.copy( input_field)
  output_field [mask<tol] = np.nan
  return output_field

# mask land
def mask_array(ds, iTime, vName='SST', mask=my_mask):
    arr = ds[vName].isel(time=iTime).values
    masked_arr=apply_mask(arr, mask)
    return masked_arr

# to write forecast stats
def write_stats(vName, var):
    f1 = vName + "_{}_{}.csv".format(exp_dates[0].strftime('%Y%m%d'), exp_dates[-1].strftime('%Y%m%d'))
    print("Writing out: ", f1)
    np.savetxt(f1, (var), delimiter=",", fmt='%1.4f')
    print("Done!")

# Unweighted mean and std dev
def unWeighted_mean_sdev(arr):
    mean_arr = np.nanmean(arr.flatten(), dtype=np.float64)
    sdev_arr =  np.nanstd(arr.flatten(), dtype=np.float64)
    return mean_arr, sdev_arr

def get_first_day(ds, classifacation='TS'):
    # first time it became Tropical Storm 'TS'
    id=np.where(ds.type==classifacation)[0][0]
    date0=ds.time[id]
    #print("{}/{}/{}".format(date0.year,str(date0.month).zfill(2),str(date0.day).zfill(2)))
    return date0

def make_data_arr(arr, mask=land_sea_mask):
    da = xr.DataArray(data=arr, 
                      coords={'lat': mask.lat,'lon': mask.lon}, 
                      dims=["lat", "lon"],
                      attrs=dict(description="See https://github.com/sanAkel/future_sst_fraci/blob/main/to_gen_new_files/SST_under_TC.ipynb"))
    return da

### Different methods to predict future BCs, see below for mathematical details

In [None]:
def forecast_bc(method, id, bc0, clim_bc, anomaly0):
    
    predicted_bc = np.zeros_like(bc0) # init to be safe!
    
    if method == "persist":
        predicted_bc = bc0 # persistence throughout the forecast
    elif method == "persist_init_anom":
        if (id==0):
            predicted_bc = bc0 # forecast start day
        else:
            predicted_bc = clim_bc + anomaly0
    elif method == "test3":
        if (id==0):
            predicted_bc = bc0 # forecast start day
        else:
            predicted_bc = clim_bc - anomaly0
    elif method == "test4":
        if (id==0):
            predicted_bc = bc0 # forecast start day
        else:
            predicted_bc = bc0 + anomaly0       
    else:
        print("Uknown method: {} for creating future BCs.".format(method))
        
    return predicted_bc    

## Read pre-processed storm info.

In [None]:
year = 2023
tc_name = 'franklin' 
ds_tc= xr.open_dataset(tc_name+ str(year)+'.nc')
#date_TS = get_first_day(ds_tc)
#
#print("\nHurricane:\t{} became tropical storm on:\t{}".format(tc_name.upper(), date_TS['time'].dt.strftime("%Y-%m-%d").values))

start_date = str(ds_tc['time'][0].dt.strftime("%Y-%m-%d").values)
end_date = str(ds_tc['time'][-1].dt.strftime("%Y-%m-%d").values)
print("\nHurricane:\t{} originated on:\t{},\t dissipated on:\t{}.".format(tc_name.upper(), start_date, end_date))

## See `SST_ideas` notebook reg inputs

In [None]:
fcst_nDays, nfcst = [10, 13] # ?? forecasts following tropical storm categorization.

start_date, end_date = [start_date, end_date] # end_date must fit above.

data_path_real = "/discover/nobackup/projects/gmao/advda/sakella/future_sst_fraci/GMAO_OPS_bin_data/data/"
data_path_clim = "/discover/nobackup/projects/gmao/advda/sakella/future_sst_fraci/data/ncFiles/"

file_pref_real, file_suff = ["sst_ice_", ".nc"]
file_pref_clim, file_suff = ["daily_clim_mean_sst_fraci_", ".nc"]

# Select _forecast_ method
method = "persist" 
#method = "persist_init_anom"
#method = "test3"
#method = "test4"

vName = 'SST' # always SST, ice concentration is irrelevant in TC context.

## Dates of forecasts

In [None]:
# One forecast per day, since this is daily BCs.
exp_dates  = pd.date_range(str(start_date), end_date, freq='D')

## Initialize arrays

In [None]:
# With respect to real data -- remember, we _test_ in **hindcast** mode, so we know the _truth_.
error_real = np.zeros((ds_tc.time.shape[0], fcst_nDays, nfcst), dtype=np.float64)

# With respect to daily climatology
error_clim = np.zeros_like(error_real)

In [None]:
for ifcst in range(1, nfcst+1): # each forecast

  fcst_start_date = exp_dates[0] + pd.DateOffset(days=ifcst-1)
  fcst_dates = pd.date_range(start=fcst_start_date, periods=fcst_nDays)
  print("Forecast [{}] Dates: {}".format(ifcst,fcst_dates))

  files_names_real_data = get_files_names(fcst_dates, data_path_real, file_pref_real)
  clim_files_names      = get_files_names(fcst_dates, data_path_clim, file_pref_clim, clim=True)

  ds_real = xr.open_mfdataset(files_names_real_data)
  ds_clim = xr.open_mfdataset(clim_files_names, concat_dim='time', combine='nested', use_cftime=True)

  for id in range(0, fcst_nDays): # over each day of forecast
    real_bc = mask_array(ds_real, id); clim_bc = mask_array(ds_clim, id)

    # save initial BC (SST/ICE)
    if (id==0):
      bc0 = real_bc; anom0 = bc0 - clim_bc
   
    predicted_bc = forecast_bc(method, id, bc0, clim_bc, anom0)
       
    err_real=make_data_arr(predicted_bc-real_bc) # data array to ease selection along TC track
    err_clim=make_data_arr(clim_bc - real_bc)
    
    error_real[:,id,ifcst-1]=err_real.sel(lat=ds_tc.lat, lon=ds_tc.lon, method='nearest').values
    error_clim[:,id,ifcst-1]=err_clim.sel(lat=ds_tc.lat, lon=ds_tc.lon, method='nearest').values 

In [None]:
error_real.shape

In [None]:
ifcst=4

plt.figure( figsize=(16, 10))

for id in range(0, fcst_nDays):
    plt.subplot(3,4,id+1)
    plt.scatter(ds_tc.lon, ds_tc.lat, s=6, c=error_real[:,id,ifcst].squeeze(), cmap=plt.cm.bwr)
    plt.colorbar()
    plt.title(id)

## test

In [None]:
mean_real_error = np.zeros((ds_tc.time.shape[0], fcst_nDays), dtype=np.float64)
sdev_real_error = np.zeros((ds_tc.time.shape[0], fcst_nDays), dtype=np.float64)

for ifcst in range(0,nfcst):
    if (ifcst == 0):
        mean_real_error = error_real[:,:,ifcst].squeeze()
    else:
        mean_real_error = mean_real_error + error_real[:,:,ifcst].squeeze()

mean_real_error = mean_real_error/fcst_nDays