# Test Dask parallel compute with POCLOUD (US-WEST-2)
## Compute global mean ocean SSH from ECCO Version 4 Release 4

Written by Ian Fenty, revised by Jinbo Wang.

In [None]:
from dask.distributed import Client
# used the dask-labextension to start the LocalCluster
client = Client("tcp://127.0.0.1:65532")
client

In [None]:
import pandas as pd
import sys
import matplotlib as mpl

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from pathlib import Path

import matplotlib.pyplot as plt
import json
import time

from dask.distributed import get_worker
from dask import delayed

from pprint import pprint
import requests
import s3fs
import os
import warnings
warnings.filterwarnings('ignore')

In [None]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from itertools import repeat
from os.path import expanduser, basename, isfile, isdir, join

## Subroutines

In [None]:
def compute_GMSL(SSH, grid_area, total_grid_cell_area):
    ''' Compute the global mean '''
    GMSL = (SSH * grid_area).sum(dim=['latitude','longitude']) / total_grid_cell_area
    GMSL = GMSL.compute()
    return GMSL
def compute_GMSL_trend(GMSL):
    ''' linear fit of a time series '''
    trend_params = GMSL.polyfit(dim="time", deg=1, full=True)
    GMSL_trend = xr.polyval(coord=GMSL.time, coeffs=trend_params.polyfit_coefficients)

    return GMSL_trend, trend_params
def download(source, target, redownload_existing=False):
    ''' download data into local folder'''
    print(target)
    print(source)
    if not os.path.isfile(target) or redownload_existing==True:
        !wget --quiet --continue --output-document $target $source
    else:
        print('not re-downloading')
    return target
def download_file(url: str, out: str, force: bool=False):
    """
    url (str): the HTTPS url from which the file will download
    out (str): the local path into which the file will download
    force (bool): download even if the file exists locally already
    """
    if not isdir(out):
        raise Exception(f"Output directory doesnt exist! ({out})")
    
    target_file = join(out, basename(url))
    
    # if the file has already been downloaded, skip    
    if isfile(target_file) and force is False:
        print('file exists, and force=False, not re-downloading')
        return 0
    
    with requests.get(url) as r:
        if not r.status_code // 100 == 2: 
            raise Exception(r.text)
            return 0
        else:
            with open(target_file, 'wb') as f:
                total_size_in_bytes= int(r.headers.get('content-length', 0))
                for chunk in r.iter_content(chunk_size=1024):
                    if chunk:
                        f.write(chunk)

                return total_size_in_bytes
# download a list of files
def download_files(dls):
    start_time = time.time()

    # use 12 threads for concurrent downloads
    with ThreadPoolExecutor(max_workers=12) as executor:
        results = list(tqdm(executor.map(download_file, dls, repeat(download_dir)), total=len(dls)))
    
        total_download_size_in_bytes = np.sum(np.array(results))
        total_time = time.time() - start_time

        print('\n=====================================')
        print(f'total downloaded: {np.round(total_download_size_in_bytes/1e6,2)} Mb')
        print(f'avg download speed: {np.round(total_download_size_in_bytes/1e6/total_time,2)} Mb/s')

## Define local disk directories

Modify the home path in pth_hm. 

In [None]:
# output directory

# change pth_hm to your folder
pth_hm='/home/jpluser/Dask_test/'

output_dir=Path(pth_hm+'/ECCO_global_mean_TS')
output_dir.mkdir(exist_ok=True)

ECCO_grid_dir =Path('/ECCO_grids/')
ECCO_grid_dir.mkdir(exist_ok=True)

# staging directory
download_dir=Path('/ECCO_global_mean_TS/tmp_dl')
download_dir.mkdir(exist_ok=True)

## Connect S3 file system

Get keys, pass credentials

In [None]:
%%capture
import requests

def store_aws_keys(endpoint: str="https://archive.podaac.earthdata.nasa.gov/s3credentials"):    
    with requests.get(endpoint, "w") as r:
        accessKeyId, secretAccessKey, sessionToken, expiration = list(r.json().values())

    creds ={}
    creds['AccessKeyId'] = accessKeyId
    creds['SecretAccessKey'] = secretAccessKey
    creds['SessionToken'] = sessionToken
    creds['expiration'] = expiration
    
    return creds



In [None]:
def refresh_s3():
    creds = store_aws_keys()
    s3 = s3fs.S3FileSystem(
        key=creds['AccessKeyId'],
        secret=creds['SecretAccessKey'],
        token=creds['SessionToken'],
        client_kwargs={'region_name':'us-west-2'},
    )
    print(f"\nThe current session token expires at {creds['expiration']}.\n")
    return s3


## Download ECCO grid geometry to local disk

In [None]:
ECCO_grid_filename = 'GRID_GEOMETRY_ECCO_V4r4_latlon_0p50deg.nc'
ECCO_grid_url = "https://archive.podaac.earthdata.nasa.gov/podaac-ops-cumulus-protected/ECCO_L4_GEOMETRY_05DEG_V4R4/"

source = ECCO_grid_url + ECCO_grid_filename
target = ECCO_grid_dir / ECCO_grid_filename

local_grid_fname = download(source, target)
ecco_grid = xr.open_dataset(local_grid_fname)
ecco_grid.load()
print(ecco_grid.data_vars)

## Calculate ECCO grid cell volumes and total ocean volume


In [None]:
# area is grid cell area * land/ocean mask
# volume is grid cell thickness (drF) * area (rA) * partial cell factors (hFacC) * land/ocean mask (maskC)

grid_cell_area = ecco_grid.area * ecco_grid.maskC.isel(Z=0)
grid_cell_vol = ecco_grid.drF * ecco_grid.area * ecco_grid.hFacC * ecco_grid.maskC

total_grid_cell_area= grid_cell_area.sum()
total_grid_cell_area.name = 'Total ECCO ocean area'

total_grid_cell_vol = grid_cell_vol.sum()
total_grid_cell_vol.name = 'Total ECCO ocean volume'

print(f'total grid cell area  {total_grid_cell_area.values/1e9:0.3g} billion km$^2$')
print(f'total grid cell volume  {total_grid_cell_vol.values/1e9:0.3g} billion km$^3$')

In [None]:
grid_cell_area.plot()

## Find S3 Addresses to ECCO Fields

In [None]:
# PO.DAAC's 'short name' is an identifier for the dataset
ShortName = 'ECCO_L4_SSH_05DEG_MONTHLY_V4R4'

In [None]:
# Ask PODAAC for the collection id using the 'short name'
response = requests.get(
    url='https://cmr.earthdata.nasa.gov/search/collections.umm_json', 
    params={'provider': "POCLOUD",
            'ShortName': ShortName,
            'page_size': 1}
)

ummc = response.json()['items'][0]
ccid = ummc['meta']['concept-id']
print(f'collection id: {ccid}')

In [None]:
# glob to find the NetCDF files associated with this collection id
year = '*'

start_time = time.time()

ss = "podaac-ops-cumulus-protected/" + ShortName + '/*'+ str(year) + '*.nc'

s3 = refresh_s3()
ECCO_s3_files = s3.glob(ss)

print(f'time to find urls: { time.time() - start_time} s\n')

# make a list of just the filenames
ECCO_files =[]
for f in ECCO_s3_files:
    ECCO_files.append(f.split('/')[-1])

pprint(ECCO_files[0])
pprint(ECCO_files[-1])
pprint(ECCO_s3_files[0])
pprint(ECCO_s3_files[-1])

In [None]:
# convert list of s3 files to urls
ECCO_s3_files_as_http = ['https://archive.podaac.earthdata.nasa.gov/' + f for f in ECCO_s3_files]

# Method \#1: Calculate using S3 Direct Access fields

In [None]:
download_files(ECCO_s3_files_as_http)

### Direct File Access within EC2, parallel=True (Success)

In [None]:
start_time = time.time()
local_files = np.sort(list(download_dir.glob('*nc')))

#Total number of granules for the monthly fields of 26 years is 312

num_granules = 312

ECCO_SSH_ds = xr.open_mfdataset(
    paths=local_files[:num_granules],
    coords='minimal', 
    compat='override', 
    data_vars='minimal',
    decode_cf=True,
    join='left',
    parallel = True
)
ECCO_SSH_ds.close()

tt = time.time() - start_time

print(f'open time = {tt:0.3g} s')
print(f'open time per granule (n={num_granules}) = {tt/num_granules:0.3g} s \n') 

In [None]:
%%time 
GMSL = compute_GMSL(ECCO_SSH_ds.SSH, ecco_grid.area, total_grid_cell_area)
GMSL_trend, trend_params = compute_GMSL_trend(GMSL)

In [None]:
GMSL.plot()
GMSL_trend.plot(color='r')
plt.grid()

In [None]:
# rough trend: final - initial / length of time series in years
GMSL_rough_trend = 1000*(GMSL_trend[-1]-GMSL_trend[0])/(len(GMSL.time)/12)
print(f'{np.round(GMSL_rough_trend.values,3)} mm/yr')

# Method \#2 Direct S3 Access, parallel=False

parallel must be false else open_mfdataset hangs

In [None]:
## read num_granules number of files
num_granules = 32

# update s3 credentials
s3 = refresh_s3()

# open each file using s3
paths=[s3.open(f) for f in ECCO_s3_files[:num_granules]]

start_time = time.time()

ECCO_SSH_ds = xr.open_mfdataset(
    paths=paths,
    coords='minimal', 
    compat='override', 
    data_vars='minimal',
    decode_cf=True,
    join='left',
    parallel=False
)
ECCO_SSH_ds.close()

tt = time.time() - start_time

print(f'open time = {tt:0.3g} s')
print(f'open time per granule (n={num_granules}) = {tt/num_granules:0.3g} s \n') 

## WITH ATTACHED DASK CLUSTER AND PARALLEL=FALSE
#===============================================
# 3 files  0.4 s (.14 s per)
# 12 files 1.7 (.14s per)
# 24 files 3.6 (0.15s per)
# 36 files 5.4 (0.15s per)
# ...
# 312 files 54 s (0.18s per)

In [None]:
# verify we got something good 
pprint(ECCO_SSH_ds.data_vars)
pprint(ECCO_SSH_ds.dims)

In [None]:
%%time 
GMSL = compute_GMSL(ECCO_SSH_ds['SSH'], ecco_grid.area, total_grid_cell_area)
GMSL_trend, GMSL_params = compute_GMSL_trend(GMSL)

In [None]:
GMSL.plot()
GMSL_trend.plot(color='r')
plt.grid()

# rough trend: (final - initial) / length of time series in years
GMSL_rough_trend = 1000*(GMSL_trend[-1]-GMSL_trend[0])/(len(GMSL.time)/12)
print(f'rough trend = {np.round(GMSL_rough_trend.values,3)} mm/yr')

# Method 3: Direct S3 Access, parallel=True

In [None]:
## READ THIS MANY FILES
num_granules = 12

# update s3 credentials
s3 = refresh_s3()

# open each file using s3
paths=[s3.open(f) for f in ECCO_s3_files[:num_granules]]

start_time = time.time()

# if the files in 'paths' were LOCAL this does work,
# when the files are on S3, crashes with error message:

#ValueError: did not find a match in any of xarray's currently installed IO backends ['netcdf4', 'h5netcdf', 'scipy']. Consider explicitly selecting one of the installed engines via the ``engine`` parameter, or installing additional IO dependencies, see:
#https://docs.xarray.dev/en/stable/getting-started-guide/installing.html
#https://docs.xarray.dev/en/stable/user-guide/io.html

ECCO_SSH_ds = xr.open_mfdataset(
    paths=paths,
    coords='minimal', 
    compat='override', 
    data_vars='minimal',
    decode_cf=True,
    join='left',
    parallel=True
)
ECCO_SSH_ds.close()

tt = time.time() - start_time

print(f'open time = {tt:0.3g} s')
print(f'open time per granule (n={num_granules}) = {tt/num_granules:0.3g} s \n') 

## WITH ATTACHED DASK CLUSTER AND PARALLEL=TRUE
#===============================================
# 3 files  0.4 s (.14 s per)
# 12 files 1.7 (.14s per)
# 24 files 3.6 (0.15s per)
# 36 files 5.4 (0.15s per)
# ...
# 312 files 54 s (0.18s per)

# Method 4: Dask delayed mode

In [None]:
def delayed_global_mean(fn, s3, ecco_grid_area, total_grid_cell_area):
    
    d_start_time = time.time()
    
# works when we open with open_dataset, dask client, and files on S3
    ECCO_SSH_ds = xr.open_dataset(s3.open(fn))    

# fails with open_mfdataset, dask client, and files on S3
#    ECCO_SSH_ds = xr.open_mfdataset(s3.open(fn))
        
    GMSL = compute_GMSL(ECCO_SSH_ds['SSH'], ecco_grid_area, total_grid_cell_area)
    worker_id = get_worker().id
    
    tt = time.time() - d_start_time
   
    return GMSL, tt

In [None]:
ecco_grid_area= ecco_grid.area
from dask import delayed,compute


In [None]:
start_time = time.time()

num_granules = 312
result=[]

for fn in ECCO_s3_files[0:num_granules]:
    result.append(delayed(delayed_global_mean)(fn, s3, \
                                               ecco_grid_area, total_grid_cell_area) )

tt = time.time() - start_time

print('append result')
print(f'append result time = {tt:0.3g} s')
 
print('calculate')

GMSL_delayed =np.array(compute(result)).squeeze()
tt = time.time() - start_time

print(f'calc time = {tt:0.3g} s')
print(f'calc time per granule (n={num_granules}) = {tt/num_granules:0.3g} s \n') 


# calculation timing
# ==================
#  64:  2.7s, 0.0416s per granule
# 128:  5.6s, 0.0432s per granule
# 312: 13  s, 0.0416s per granule