In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import sys, time
import datetime
import pandas as pd

In [2]:
from tempfile import NamedTemporaryFile, TemporaryDirectory # Creating temporary Files/Dirs
import dask # Distributed data libary
from dask_jobqueue import SLURMCluster # Setting up distributed memories via slurm
from distributed import Client, progress, wait # Libaray to orchestrate distributed resources

In [3]:
# Set some user specific variables
account_name = 'bb1018'
partition = 'compute'
job_name = 'extractSim' # Job name that is submitted via sbatch
memory = '24GiB' # 64GiB, Max memory per node that is going to be used - this depends on the partition
cores = 48 # Max number of cores per that are reserved - also partition dependent
walltime = '07:00:00' #'12:00:00' # Walltime - also partition dependent

In [4]:
scratch_dir = '/scratch/b/b380873/' # Define the users scratch dir
# Create a temp directory where the output of distributed cluster will be written to, after this notebook
# is closed the temp directory will be closed
dask_scratch_dir = TemporaryDirectory(dir=scratch_dir, prefix=job_name)
cluster = SLURMCluster(memory=memory,
                       cores=cores,
                       project=account_name,
                       walltime=walltime,
                       queue=partition,
                       name=job_name,
                       processes=8,
                       scheduler_options={'dashboard_address': ':12435'},
                       local_directory=dask_scratch_dir.name,
                       job_extra=[f'-J {job_name}', 
                                  f'-D {dask_scratch_dir.name}',
                                  f'--begin=now',
                                  f'--output={dask_scratch_dir.name}/LOG_cluster.%j.o',
                                  f'--output={dask_scratch_dir.name}/LOG_cluster.%j.o'
                                 ],
                       interface='ib0')

In [5]:
print(cluster.job_script())

#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -p compute
#SBATCH -A bb1018
#SBATCH -n 1
#SBATCH --cpus-per-task=48
#SBATCH --mem=24G
#SBATCH -t 07:00:00
#SBATCH -J extractSim
#SBATCH -D /scratch/b/b380873/extractSimtnq2g90a
#SBATCH --begin=now
#SBATCH --output=/scratch/b/b380873/extractSimtnq2g90a/LOG_cluster.%j.o
#SBATCH --output=/scratch/b/b380873/extractSimtnq2g90a/LOG_cluster.%j.o

JOB_ID=${SLURM_JOB_ID%;*}

/pf/b/b380459/conda-envs/Nawdex-Hackathon/bin/python3 -m distributed.cli.dask_worker tcp://10.50.40.21:45187 --nthreads 6 --nprocs 8 --memory-limit 3.22GB --name name --nanny --death-timeout 60 --local-directory /scratch/b/b380873/extractSimtnq2g90a --interface ib0



In [6]:
cluster.scale(jobs=1)
cluster

VBox(children=(HTML(value='<h2>extractSim</h2>'), HBox(children=(HTML(value='\n<div>\n  <style scoped>\n    .d…

In [7]:
dask_client = Client(cluster)
dask_client

0,1
Client  Scheduler: tcp://10.50.40.21:45187  Dashboard: http://10.50.40.21:8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [9]:
# Which simulation and variable do you want to look at? How many synthetic trajectories to generate?
global sim_acronym
global n

# Over which lat-lon and altitude interval should we extract data?
global ll_interval
global alt_interval

sim_acronym = '0V2M0A0R'
n = 10
ll_interval = 0.75
alt_interval = 1

In [10]:
# Input a np.datetime64 and round it to the nearest 10 minutes.
def timeround10(dt):
    dt_not_np = pd.to_datetime(dt)
    b = round(dt_not_np.minute,-1)
    if b == 60:
        return_time = datetime.datetime(2017, 8, 8, dt_not_np.hour + 1, 0)
    else:
        return_time = datetime.datetime(2017, 8, 8, dt_not_np.hour, int(b))
    return return_time

In [11]:
# Extract <n> random indices along all dimensions of the variable <var>
# I feel that there is a more Pythonic way of doing this but it works for now
def randIndx(var):
    dim_rand = np.empty((n, 4), dtype='int')
    t_size, alt_size, lat_size, lon_size = var.shape
    dim_rand[:, 0] = np.random.randint(low=0, high=t_size, size=n)
    dim_rand[:, 1] = np.random.randint(low=0, high=alt_size, size=n)
    dim_rand[:, 2] = np.random.randint(low=0, high=lat_size, size=n)
    dim_rand[:, 3] = np.random.randint(low=0, high=lon_size, size=n)
    return dim_rand

In [12]:
# Input the existing syn_traj Dataset along with time, pressure, lat, and lon from the flight track
def extractSim(syn_traj, flight_time, flight_pressure, flight_lat, flight_lon):
    # Translate the variable key to its corresponding file
    sim_dir = '/work/bb1018/b380873/model_output/ICON/'
    var_ICON = xr.open_dataset(sim_dir + 'ICON_3D_F10MIN_icon_tropic_0V2M0A0R_PL2.nc')

    # Find the nearest whole 10-min time.
    flight_time_approx = timeround10(flight_time.values)

    # Construct the time window to extract.
    early_time = flight_time_approx - datetime.timedelta(minutes=30)
    late_time = flight_time_approx + datetime.timedelta(minutes=30)

    # Find the indices for levels above and below the closest match.
    basedir = '/work/bb1018/b380873/tropic_vis/remapping/'
    sim_pressures = np.loadtxt(basedir + 'PMEAN_48-72.txt')
    i = np.argmin(np.abs(flight_pressure - sim_pressures))
    if i < 1 or i > 117:
        raise Exception('Flight pressure outside of simulation range.')
    var_ICON = var_ICON.isel( plev=slice(i-alt_interval, i+alt_interval+1) )

    # Define the lat-lon interval to extract
    var_ICON = var_ICON.sel( time=slice(early_time, late_time),
                             lat=slice(flight_lat-ll_interval, flight_lat+ll_interval),
                             lon=slice(flight_lon-ll_interval, flight_lon+ll_interval) )

    dim_rand = randIndx(var_ICON['qv'])
    for k, dims in enumerate(dim_rand):
        for v in syn_traj.variables:
            if v != 'ntraj' and v != 'time':
                syn_traj[v].loc[dict(ntraj=k+1, time=flight_time)] = var_ICON[v].isel(time=dims[0], plev=dims[1], lat=dims[2], lon=dims[3])
                    
    return syn_traj

In [13]:
# Load the observational data
basedir = '/work/bb1018/b380873/tropic_vis/obs/'
fi = basedir + 'stratoclim2017.geophysika.0808_1.filtered_per_sec.nc'
Stratoclim = xr.open_dataset(fi)
flight_times = Stratoclim['time']

# <j> is the first iteration for which there are ICON high-resolution values available.
j = 1942 # 1342
tt = flight_times.shape[0] - j

# Load the ICON values
ICON = xr.open_dataset('/work/bb1018/b380873/model_output/ICON/ICON_3D_F10MIN_icon_tropic_0V2M0A0R_PL2.nc')

# Initiate the synthetic trajectory Dataset
syn_traj = xr.Dataset( data_vars=dict(
                            temp=( ["time", "ntraj"], np.empty([tt, n]) ),
                            omega=( ["time", "ntraj"], np.empty([tt, n]) ),
                            air_pressure=( ["time", "ntraj"], np.empty([tt, n]) ),
                            qv=( ["time", "ntraj"], np.empty([tt, n]) ),
                            qc=( ["time", "ntraj"], np.empty([tt, n]) ),
                            qi=( ["time", "ntraj"], np.empty([tt, n]) ),
                            qs=( ["time", "ntraj"], np.empty([tt, n]) ),
                            qg=( ["time", "ntraj"], np.empty([tt, n]) ),
                        ),
                       coords=dict(
                           time=flight_times[j:], ntraj=np.arange(1, n+1))
                     )

# Set the variable attributes as in the standard ICON output file.
for v in syn_traj.variables:
    if v != 'time' and v!= 'ntraj':
        syn_traj[v].attrs["long_name"] = ICON[v].long_name
        syn_traj[v].attrs["units"] = ICON[v].units
        syn_traj[v].attrs["standard_name"] = ICON[v].standard_name

syn_traj['ntraj'].attrs["long_name"] = 'Trajectory ID'

In [14]:
%%time

for flight_iter, flight_time in enumerate(flight_times[j:]):
    if flight_iter%500 == 0:
        print(flight_iter)
    flight_pressure = Stratoclim['BEST:PRESS'].sel(time=flight_time).values*100 # [Pa]
    flight_lat = Stratoclim['BEST:LAT'].sel(time=flight_time).values
    flight_lon = Stratoclim['BEST:LON'].sel(time=flight_time).values

    # Based on the flight values, load the relevant chunk of simulations
    syn_traj = dask.delayed(extractSim)(syn_traj, flight_time, flight_pressure, flight_lat, flight_lon)

syn_traj.to_netcdf(path='/work/bb1018/b380873/model_output/ICON/ICON_syn_traj2.nc')

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500
6000
6500
7000
7500
8000
8500
9000
9500
10000
10500
CPU times: user 37.6 s, sys: 664 ms, total: 38.3 s
Wall time: 45.3 s


Delayed('to_netcdf-b797c291-3591-4972-b338-f208ce725e32')

In [None]:
%%time
syn_traj.compute()

  (<xarray.Dataset>
Dimensions:       (ntraj: 10, ti ... dtype=float32))
Consider scattering large objects ahead of time
with client.scatter to reduce scheduler burden and 
keep data on workers

    future = client.submit(func, big_data)    # bad

    big_future = client.scatter(big_data)     # good
    future = client.submit(func, big_future)  # good
