In [None]:
import argparse
import json
import logging

import dask
import numpy as np
import xarray as xr

from dask.distributed import Client
import dask.config
import dask.array as da

In [None]:
import sys

In [None]:
sys.path.append('../src')

In [None]:
import helper_modules

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
with open("../src/conf/domain_config.json", "r") as j:
    domain_config = json.loads(j.read())

In [None]:
with open("../src/conf/attribute_config.json", "r") as j:
    attribute_config = json.loads(j.read())

In [None]:
with open("../src/conf/variable_config.json", "r") as j:
    variable_config = json.loads(j.read())

In [None]:
domain_config = domain_config['west_africa']

In [None]:
variable_config = {
    key: value
    for key, value in variable_config.items()
    if key in domain_config["variables"]
}

In [None]:
reg_dir_dict, glob_dir_dict = helper_modules.set_and_make_dirs(domain_config)

In [None]:
syr_calib = domain_config["syr_calib"]
eyr_calib = domain_config["eyr_calib"]

In [None]:
client, cluster = helper_modules.getCluster('cclake', 1, 40)
        
client.get_versions(check=True)
client.amm.start()
         
print(f"Dask dashboard available at {client.dashboard_link}")

In [None]:
client.close()
cluster.close()

In [None]:
client = Client(scheduler_file='/pd/home/lorenz-c/scheduler_test.json')

In [None]:
client.dashboard_link

In [None]:
raw_full, pp_full, refrcst_full, ref_full = helper_modules.set_input_files(domain_config, reg_dir_dict, 4, 2016, 'tp')

In [None]:
coords = helper_modules.get_coords_from_frcst(raw_full)

In [None]:
global_attributes = helper_modules.update_global_attributes(
    attribute_config, domain_config["bc_params"], coords, 'west_africa'
)

In [None]:
encoding = helper_modules.set_encoding(variable_config, coords)

In [None]:
ds_obs = xr.open_zarr(ref_full, consolidated=False)
ds_obs = xr.open_zarr(
    ref_full,
    chunks={"time": len(ds_obs.time), "lat": 1, "lon": 1},
    consolidated=False
    )
da_obs = ds_obs['tp'].persist()
#da_obs = da_obs.isel(lat=np.arange(100, 130), lon=np.arange(100, 130))

In [None]:
ds_mdl = xr.open_zarr(refrcst_full, consolidated=False)
ds_mdl = xr.open_zarr(
    refrcst_full,
    chunks={
       "time": len(ds_mdl.time),
       "ens": len(ds_mdl.ens),
       "lat": 1,
       "lon": 1
    },
    consolidated=False
    )
da_mdl = ds_mdl['tp']
#da_mdl = da_mdl.isel(lat=np.arange(100, 130), lon=np.arange(100, 130))

In [None]:
ds_pred = xr.open_dataset(raw_full)
ds_pred = xr.open_mfdataset(
    raw_full,
    chunks={
        "time": len(ds_pred.time),
        "ens": len(ds_pred.ens),
        "lat": 1,
        "lon": 1
     },
     parallel=True,
     engine="netcdf4",
)
da_pred = ds_pred['tp'].persist()

#da_pred = da_pred.isel(lat=np.arange(100, 130), lon=np.arange(100, 130))

In [None]:
da.from_delayed()

In [None]:
pred_out = da.zeros(shape=(len(da_pred.time), len(da_pred.lat), len(da_pred.lon), len(da_pred.ens)), chunks=(1, len(da_pred.lat), len(da_pred.lon), len(da_pred.ens)))

In [None]:
import importlib

In [None]:
import bc_module_v2

In [None]:
client.upload_file("bc_module_v2.py")

In [None]:
da_obs.isel(lon=0, lat=0)

In [None]:
import dask.array as da

In [None]:
for timestep in range(0, 10):
    
    intersection_day_obs, intersection_day_mdl = bc_module_v2.get_intersect_days(timestep, domain_config, da_obs.time, da_mdl.time, da_pred.time)
    
    da_obs_sub = da_obs.loc[dict(time=intersection_day_obs)]
    da_mdl_sub = da_mdl.loc[dict(time=intersection_day_mdl)]
    da_mdl_sub = da_mdl_sub.stack(ens_time=("ens", "time"), create_index=True)
    da_mdl_sub = da_mdl_sub.drop("time")
        
    da_pred_sub = da_pred.isel(time=timestep)
    
    out = xr.apply_ufunc(
            bc_module_v2.bc_module,
            da_pred_sub,
            da_obs_sub,
            da_mdl_sub,
            kwargs={
                "domain_config": domain_config,
                "precip": variable_config['tp']["isprecip"],
            },
            input_core_dims=[["ens"], ["time"], ["ens_time"]],
            output_core_dims=[["ens"]],
            vectorize=True,
            dask="parallelized",
            output_dtypes=[np.float64],
    )
    print(f"Timestep {timestep}")
    #pred_out[timestep, :, :, :] = da.from_array(out, chunks=[217, 298, 25])
    pred_out[timestep, :, :, :] = out


In [None]:
pred_out.persist()

In [None]:
da_out = xr.Dataset(
            data_vars=dict(tp=(["time", "lat", "lon", "ens"], pred_out)),       
            coords=dict(
                time=da_pred.time,
                ens=da_pred.ens,
                lon=da_pred.lon,
                lat=da_pred.lat
            ),
            attrs=dict(
            description="This is a small stupid test...",
            nits="And were going to kick some ass..."),
    )

In [None]:
da_out.to_netcdf('test.nc')

In [None]:
pred_out.to_zarr('test_1.zarr')

In [None]:
for i in range(0, len(da_pred.lat)):
    for j in range(0, len(da_pred.lon)):
        out = bc_module_v2.bc_module(da_pred, da_obs, da_mdl, i, j, 'fluff', domain_config, True)
        pred_out[:, :, i, j] = da.from_delayed(out, shape=[len(da_pred.time), len(da_pred.ens)], dtype=float)
        print(i, j)

In [None]:
pred_out.persist()

In [None]:
da_out = xr.Dataset(
            data_vars=dict(tp=(["time", "ens", "lat", "lon"], pred_out)),       
            coords=dict(
                time=da_pred.time,
                ens=da_pred.ens,
                lon=da_pred.lon,
                lat=da_pred.lat
            ),
            attrs=dict(
            description="This is a small stupid test...",
            nits="And were going to kick some ass..."),
    )

In [None]:
da_out.to_zarr('/bg/data/NCZarr/bcsd_test.zarr')