In [1]:
import xarray as xr
from dask.distributed import Client
import time
import datetime as dt
import warnings
warnings.filterwarnings('ignore')
import sys
import gc
import numpy as np
sys.setrecursionlimit(100000)

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.colors as colors
from matplotlib import path
import scipy.io as sio
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter
import cartopy.feature as cfeature
import seaborn as sns
import cmaps
import seapy
from scipy.spatial import KDTree
import dask

In [2]:
#define your root paths
myobsroot = '/Volumes/TO_1/roms4dvar_ecs/i4dvar_outputs/INSITU_OBS/'
mynlroot = '/Volumes/WD_3/outputs_SCORRECTION/'
mydasstroot =  '/Volumes/WD_3/roms4dvar_ecs/i4dvar_outputs/'
#define your workspace 
nl_workspace = 'outputs_201205/'
dasst_workspace = 'workspace_sstbgqc/'
obs_workspace = ''
obs_file='geopolar_sst_2012to14_offshore.nc'
nl_files = "ocean_ecs_his_00*.nc"
dasst_files = "STORAGE/posterior/ocean_ecs_fwd_*.nc"

In [12]:
#define your target data duration

start_date = '2012-05-01-12H'
end_date = '2013-05-31-12H'

start_datetime = dt.datetime.strptime(start_date,"%Y-%m-%d-%HH")
end_datetime = dt.datetime.strptime(end_date,"%Y-%m-%d-%HH")
data_len = (end_datetime-start_datetime).days+1 

#define your target variable
obs_var = 6
# 2d situation
# model_var = 'temp_sur' 
#3d situation
model_var = 'temp'

# setting chunk size
x_chunk = int(262/2)
y_chunk = int(362/2)
z_chunk = 10

print('''duration to be validated: from %s to %s, total of %i days.
         target obs variable: %s
         target model variable: %s'''
      %(start_date,end_date,data_len,obs_var,model_var))



duration to be validated: from 2012-05-01-12H to 2013-05-31-12H, total of 396 days.
         target obs variable: 6
         target model variable: temp


In [13]:
# loading observation data
Obs_ds = xr.open_dataset(myobsroot+obs_workspace+obs_file,
                           engine='netcdf4',
                           # chunks={'longitude':260,'latitude':210},
                         
                           )

In [14]:
# extracting cruise observation data with specific time range 
start_obstime = (start_datetime - dt.datetime(1970,1,1)).total_seconds()/3600/24
end_obstime = (end_datetime - dt.datetime(1970,1,1)).total_seconds()/3600/24

this_range = np.where( 
                                (Obs_ds.obs_time.data >= start_obstime) &
                                (Obs_ds.obs_time.data <= end_obstime)  &
                                (Obs_ds.obs_provenance == 355)&
                                (Obs_ds.obs_type == obs_var)
                        )


Obs_ds = Obs_ds.isel(datum=this_range[0]).copy()
timestamp = [ dt.timedelta(itime)+dt.datetime(1970,1,1) for itime in Obs_ds.obs_time.data]


In [15]:
Obs_ds = Obs_ds.assign_coords(datum=timestamp)

In [16]:
# loading forward sst
start = time.time()
nl_ds = xr.open_mfdataset(mynlroot+nl_workspace+nl_files,
                                      engine='netcdf4',coords='minimal',
                                      parallel=True,
                                      # chunks={'eta_rho':y_chunk,'xi_rho':x_chunk,
                                      #      's_rho':z_chunk, # only 3d needed
                                      #      'eta_u':y_chunk,'xi_u':x_chunk,\
                                      #      'eta_v':y_chunk,'xi_v':x_chunk,\
                                      #      'eta_psi':y_chunk,'xi_psi':x_chunk,},
                                       )#.chunk(dict(ocean_time=-1))
end = time.time()
print('loading costing %f min'%((end-start)/60))


loading costing 4.857848 min


In [17]:
# loading forward sst
start = time.time()
dasst_ds = xr.open_mfdataset(mydasstroot+dasst_workspace+dasst_files,
                                      engine='netcdf4',coords='minimal',
                                      parallel=True,
                                      # chunks={'eta_rho':y_chunk,'xi_rho':x_chunk,
                                      #      's_rho':z_chunk, # odassty 3d needed
                                      #      'eta_u':y_chunk,'xi_u':x_chunk,\
                                      #      'eta_v':y_chunk,'xi_v':x_chunk,\
                                      #      'eta_psi':y_chunk,'xi_psi':x_chunk,},
                                       )#.chunk(dict(ocean_time=-1))
end = time.time()
print('loading costing %f min'%((end-start)/60))


loading costing 4.805167 min


In [18]:
# extracting forward sst with specific time range 
# for forward there is no qck file ,so the model var has only 'temp'
# further we only want the surface data, so the s_rho = -1
nl_data = nl_ds[model_var].sel(ocean_time=slice(start_date,end_date)).isel(s_rho=-1)
# always drop the initial time of posterior since there is a jump
nl_data = nl_data.drop_duplicates(dim='ocean_time',keep='last')
# fwd_data


In [19]:
# extracting forward sst with specific time range 
# for forward there is no qck file ,so the model var has odassty 'temp'
# further we odassty want the surface data, so the s_rho = -1
dasst_data = dasst_ds[model_var].sel(ocean_time=slice(start_date,end_date)).isel(s_rho=-1)
# always drop the initial time of posterior since there is a jump
dasst_data = dasst_data.drop_duplicates(dim='ocean_time',keep='last')
# fwd_data


In [20]:
#calculate the daily mean 
start = time.time()
nl_dailymean = dask.compute(nl_data.resample(ocean_time='1d').mean())
# nl_dailymean = nl_dailymean.assign_coords({'ocean_time':Obs_modgrd.ocean_time.data})
end = time.time()
print('calculating costing %f min'%((end-start)/60))



calculating costing 15.943614 min


In [21]:
#calculate the daily mean 
start = time.time()
dasst_dailymean = dask.compute(dasst_data.resample(ocean_time='1d').mean())
# dasst_dailymean = dasst_dailymean.assign_coords({'ocean_time':Obs_modgrd.ocean_time.data})
end = time.time()
print('calculating costing %f min'%((end-start)/60))



calculating costing 34.578921 min


In [22]:
my_sst_color = sio.loadmat('/Volumes/TO_1/roms4dvar_ecs/i4dvar_outputs/'+
                           'LYG_rainbow.mat')['rainbow']
my_sst = LinearSegmentedColormap.from_list('sst',my_sst_color, N = 256)
my_div_color = np.array(  [
                 [0,0,123],
                [9,32,154],
                [22,58,179],
                [34,84,204],
                [47,109,230],
                [63,135,247],
                [95,160,248],
                [137,186,249],
                [182,213,251],
                [228,240,254],
                [255,255,255],
                [250,224,224],
                [242,164,162],
                [237,117,113],
                [235,76,67],
                [233,52,37],
                [212,45,31],
                [188,39,26],
                [164,33,21],
                [140,26,17],
                [117,20,12]
                ])/255
my_div = LinearSegmentedColormap.from_list('div',my_div_color, N = 256)
my_reds = LinearSegmentedColormap.from_list('div',my_div_color[10:], N = 256)
my_palette = sns.color_palette(my_div_color[5:-5])

In [23]:

@dask.delayed
def interpolate_data_roms(roms_data, obs_points):

    from scipy.interpolate import griddata
    # print(roms_data.ocean_time)
    lons = roms_data.lon_rho.data.flatten()
    lats = roms_data.lat_rho.data.flatten()
    values = roms_data.values  # 假设roms_data是二维数组，对应于网格点

    # 根据观测点的经纬度进行插值
    interpolation = griddata(
        (lons, lats),
        values.flatten(),
        (obs_points['obs_lon'], obs_points['obs_lat']),
        method='linear'
    )
    
    return interpolation

In [24]:

@dask.delayed
def fill_missing(lon_target,lat_target,data):
    target_kdtree = KDTree(np.c_[lon_target,lat_target])
    distance, indices = target_kdtree.query(np.c_[data['obs_lon'],data['obs_lat']])
    data_target = np.full_like(lon_target,np.nan)
    for i in range(len(data['obs_lon'])):
        idx=indices[i]
        data_target[idx] = data['obs_value'][i]
        
    return data_target