In [1]:
import xarray as xr
import numpy as np
import os
from src.configs import Configs
from glob import glob

In [2]:
VAR = 'W'
GRIDSIZE = 50

In [7]:
configs = Configs('tropical_nw_pacific')
out_dir = configs.get_gsam_satfrac_sorted_var_dir(VAR, GRIDSIZE)
os.makedirs(out_dir, exist_ok=True)
surf_files = sorted(glob(configs.get_gsam_native_var_dir('2D')+'/*.nc'))
var_files = sorted(glob(configs.get_gsam_native_var_dir('W')+'/*.nc'))
surf_files_times = [f.split('/')[-1].split('_')[4] for f in surf_files]
var_file_times = [f.split('/')[-1].split('_')[4] for f in var_files]

In [8]:
def _compute_saturation_fraction(surf_fn):
    surf = xr.open_dataset(surf_fn)
    return(surf.PW/surf.PWS)

In [9]:
def _sort_grid_by_saturation_fraction(var_ds, sf_ds):
    stack_var = var_ds.stack(column=('lat', 'lon'))
    stack_sf = sf_ds.stack(column=('lat', 'lon'))
    sorted_var = stack_var.sortby(stack_sf)
    column_rank = np.linspace(0, 1, sorted_var.column.size)
    sorted_var = sorted_var.drop_vars(('lat', 'lon')).assign_coords({'column': column_rank})
    return(sorted_var)

In [10]:
for vi, var_file in enumerate(var_files):
    # get surface file corresponding to 3D file time
    file_idx = np.where(np.isin(surf_files_times, var_file_times[vi]))[0].item()
    surf_file = surf_files[file_idx]
    
    # compute saturation fraction
    satfrac = _compute_saturation_fraction(surf_file)

    # make list of slices to cut into
    lon_size, lat_size = satfrac.lon.size, satfrac.lat.size
    assert(lon_size%GRIDSIZE==lat_size%GRIDSIZE==0)

    lon_slices = [slice(i*GRIDSIZE, (i+1)*GRIDSIZE) for i in range(0, lon_size//GRIDSIZE)]
    lat_slices = [slice(i*GRIDSIZE, (i+1)*GRIDSIZE) for i in range(0, lat_size//GRIDSIZE)]

    for loni, lon_slice in enumerate(lon_slices):
        for lati, lat_slice in enumerate(lat_slices):
            slice_dict = {'lon': lon_slice, 'lat': lat_slice}
            grid_satfrac = satfrac.isel(slice_dict).squeeze('time')
            grid_var = xr.open_dataarray(var_file).squeeze('time')
            sorted_var = _sort_grid_by_saturation_fraction(grid_var, grid_satfrac)
            filename = f'/satfrac_sorted.lon_{loni}.lat_{lati}.{os.path.basename(var_file)}'
            
            sorted_var.to_netcdf(out_dir+filename)
            break
        break
    break