# THIS FILE CQN BE DELETED IF THE .PY FILE WORKS

In [None]:
import xarray as xr
import glob
import pandas as pd
from scipy.interpolate import griddata
import numpy as np

from inversion_sst_gp import utils, simulate_obs

In [2]:
# Load himawari grid (load a himawari sst file)
path_himawari_file = "1_preproc_data/proc_data/himawari_case_1.nc"
def load_himawari_grid(path_himawari_file):
    ds = xr.open_dataset(path_himawari_file)
    ds_grid = ds.coords.to_dataset()
    ds_grid = ds_grid.drop_vars(['time','tstep'])
    return ds_grid

ds_grid = load_himawari_grid(path_himawari_file)

In [13]:
# Set domain
lonlims = (115, 118)
latlims = (-15.5, -12.5)


In [14]:
def prep_osse(ds_o, ds_grid, datet):
    # interpolate non gridded ocean model output to rectangular 

    # get idx
    idx = np.where(ds_o.time == datet)[0][0]

    # non-gridded parameters
    T_ug = ds_o.T.isel(time=idx).values  # Temperature at the current time step
    T_ug_p = ds_o.T.isel(time=idx-1).values  # Temperature at the previous time step
    T_ug_n = ds_o.T.isel(time=idx+1).values  # Temperature at the next time step
    u_ug = ds_o.u.isel(time=idx).values  # Velocity component u at the current time step
    v_ug = ds_o.v.isel(time=idx).values  # Velocity component v at the current time step
    time_p = ds_o.time.isel(time=idx-1).values # Time at previous time step
    time = ds_o.time.isel(time=idx).values # Time at time step
    time_n = ds_o.time.isel(time=idx+1).values # Time at next time step
    lon_ug = ds_o.lon.values
    lat_ug = ds_o.lat.values

    # temporal differences and derivatives for non-gridded points
    tstep = (time_n - time_p)/np.timedelta64(1,'s') / 2  # Time-step duration
    dT_ug = (T_ug_n - T_ug_p) / 2  # Temporal difference in temperature
    dTdt_ug = dT_ug / tstep  # Temporal derivative of temperature

    # get grid values
    LON = ds_grid.LON.values
    LAT = ds_grid.LAT.values        
    lon = ds_grid.lon.values
    lat = ds_grid.lat.values
    X = ds_grid.X.values
    Y = ds_grid.Y.values

    # interpolate all variables
    points = np.stack([lon_ug, lat_ug]).T  # Format for griddata() function
    T = griddata(points, T_ug, (LON, LAT), method='cubic')
    dTdt = griddata(points, dTdt_ug, (LON, LAT), method='cubic')
    u = griddata(points, u_ug, (LON, LAT), method='cubic')
    v = griddata(points, v_ug, (LON, LAT), method='cubic')

    # calculate
    dTdx, dTdy = utils.finite_difference_2d(X, Y, T) 
    S = dTdt + u * dTdx + v * dTdy

    # create gridded dataset
    ds_data = xr.Dataset(
        data_vars=dict(
            T=(['lat', 'lon'], T),
            dTdt=(['lat', 'lon'], dTdt),
            dTdx=(['lat', 'lon'], dTdx),
            dTdy=(['lat', 'lon'], dTdy),
            u=(['lat', 'lon'], u),
            v=(['lat', 'lon'], v),
            S=(['lat', 'lon'], S),
        ),
        coords=dict(
            lon=(['lon'], lon),
            lat=(['lat'], lat),
            time=time,
            tstep=tstep,
        )
    )
    ds = xr.merge([ds_grid, ds_data])
    return ds

In [17]:
# load full dataset
data_dir = '/mnt/c/users/23513098/OneDrive - The University of Western Australia/Linux/Python/Current/SSC_suntans/datasets'  # Directory containing the dataset
data_name = 'SUNTANS_CROP_lon_114.3_118.7_lat_-15.7_-12.3'
ds_osse_full = xr.open_mfdataset(glob.glob(f'{data_dir}/{data_name}*'))

# select single snapshot
datet_osse = np.datetime64('2014-02-19T18:00:00')
datet_str_osse = pd.to_datetime(datet_osse).strftime("%Y-%m-%d_%H-%M-%S")
tstep = 3600

In [20]:
ds_osse_full.coords

Coordinates:
  * time     (time) datetime64[ns] 41kB 2013-12-01 ... 2014-07-01
    lon      (Nc) float64 128kB dask.array<chunksize=(16037,), meta=np.ndarray>
    lat      (Nc) float64 128kB dask.array<chunksize=(16037,), meta=np.ndarray>

In [11]:
for TEST_TYPE in ['measurement_error', 'sparse_cloud', 'dense_cloud','time_24h', 'time_1h']:
    

    if TEST_TYPE =='time_24h':
        val_range = np.arange(0,100)*24*tstep
        dataset_name = "suntans_24h"
    elif TEST_TYPE =='time_1h':
        val_range = np.arange(0,49)*tstep
        dataset_name = "suntans_1h"
    elif TEST_TYPE =='measurement_error':
        val_range = np.arange(0,0.016,0.001)
        dataset_name = "suntans_measurement_error"
    elif TEST_TYPE =="sparse_cloud":
        val_range = np.linspace(0,.75,26)
        dataset_name = "suntans_sparse_cloud"
    elif TEST_TYPE == "dense_cloud":
        val_range = np.linspace(0,.75,26)
        dataset_name = "suntans_dense_cloud"
        

    Nrange = len(val_range)
    Ny, Nx = len(ds_grid.lat), len(ds_grid.lon)

    Toc = np.empty((Nrange,Ny,Nx))
    uc = np.empty((Nrange,Ny,Nx))
    vc = np.empty((Nrange,Ny,Nx))
    Sc = np.empty((Nrange,Ny,Nx))
    dTds1oc = np.empty((Nrange,Ny,Nx))
    dTds2oc = np.empty((Nrange,Ny,Nx))
    dTdtoc = np.empty((Nrange,Ny,Nx))
    maskcc = np.empty((Nrange,Ny,Nx))
    np.random.seed(0) # set seed

    flag_time = (TEST_TYPE == 'time_24h' or TEST_TYPE == 'time_1h')
    if not flag_time:
        ds_osse = prep_osse(ds_osse_full, ds_grid, datet_osse)

    for i,val in enumerate(val_range):
        if flag_time:
           ds_osse = prep_osse(ds_osse_full, ds_grid, datet_osse+np.timedelta64(val))
        
        ut = ds_osse.u.values
        vt = ds_osse.v.values
        Tt = ds_osse.T.values
        St = ds_osse.S.values
        dTdtt = ds_osse.dTdt.values
        X = ds_osse.X.values
        Y = ds_osse.Y.values

        uc[i,:,:] = ut
        vc[i,:,:] = vt
        Sc[i,:,:] = St
        if TEST_TYPE == 'measurement_error':
            Toc[i,:,:], dTds1oc[i,:,:], dTds2oc[i,:,:], dTdtoc[i,:,:], maskcc[i,:,:] = simulate_obs.ModifyData(Tt, dTdtt, tstep, X, Y).noise(val).convert_to_input()
        elif TEST_TYPE == 'sparse_cloud':
             Toc[i,:,:], dTds1oc[i,:,:], dTds2oc[i,:,:], dTdtoc[i,:,:], maskcc[i,:,:] = simulate_obs.ModifyData(Tt, dTdtt, tstep, X, Y).sparse_cloud(val).convert_to_input()
        elif TEST_TYPE == 'dense_cloud':
             Toc[i,:,:], dTds1oc[i,:,:], dTds2oc[i,:,:], dTdtoc[i,:,:], maskcc[i,:,:] = simulate_obs.ModifyData(Tt, dTdtt, tstep, X, Y).circ_cloud(val).convert_to_input()
        else:
            Toc[i,:,:], dTds1oc[i,:,:], dTds2oc[i,:,:], dTdtoc[i,:,:], maskcc[i,:,:] = simulate_obs.ModifyData(Tt, dTdtt, tstep, X, Y).convert_to_input()

    if flag_time:
        xr.Dataset(
            {
                'T': (['time', 'lat', 'lon'], Toc),
                'dTdt': (['time', 'lat', 'lon'], dTdtoc),
                'u': (['time', 'lat', 'lon'], uc),
                'v': (['time', 'lat', 'lon'], vc),
                'S': (['time', 'lat', 'lon'], Sc)
            },
            coords={
                "time": [datet_osse+np.timedelta64(val) for val in val_range],
                "time_step": tstep,
                "lat": ds_grid.lat.values,  
                "lon": ds_grid.lon.values
            }
        ).to_netcdf(f'./data/{dataset_name}.nc')
    else:
        var_dict = {"measurement_error": "noise", "sparse_cloud": "coverage_sparse", "dense_cloud": "coverage_dense"}
        var_name = var_dict[TEST_TYPE]
        ds = xr.Dataset(
            {
                'T': ([var_name, 'lat', 'lon'], Toc),
                'dTdt': ([var_name, 'lat', 'lon'], dTdtoc),
                'u': ([var_name, 'lat', 'lon'], uc),
                'v': ([var_name, 'lat', 'lon'], vc),
                'S': ([var_name, 'lat', 'lon'], Sc)
            },
            coords={
                var_name: val_range,
                "time": datet_osse,
                "time_step": tstep,
                "lat": ds_grid.lat.values,  
                "lon": ds_grid.lon.values
            }
        ).to_netcdf(f'./data/{dataset_name}.nc')
        

KeyboardInterrupt: 