# Introduction to xarray with Dask Parallization for the Earth Sciences

Xarray functionality with Dask parallization only works with the correct syntax. This notebook provides an example.

In [1]:
import s3fs
import requests

import xarray as xr
from dask.distributed import Client, LocalCluster

import matplotlib.pyplot as plt
%matplotlib inline

## Connect to AWS S3 File System and Locate MUR 1 km File Paths

In [38]:
creds = requests.get('https://archive.podaac.earthdata.nasa.gov/s3credentials').json()
fs_s3 = s3fs.S3FileSystem(
    anon=False,
    key=creds['accessKeyId'],
    secret=creds['secretAccessKey'], 
    token=creds['sessionToken'],
    client_kwargs={'region_name':'us-west-2'}
    )
s3path = "s3://podaac-ops-cumulus-protected/MUR-JPL-L4-GLOB-v4.1/"
fns = fs_s3.glob(s3path+"*.nc")

In [39]:
# Check that S3 connection was successful:
print("total files found = ",len(fns))
print("Example filename: ", fns[0])

total files found =  7714
Example filename:  podaac-ops-cumulus-protected/MUR-JPL-L4-GLOB-v4.1/20020601090000-JPL-L4_GHRSST-SSTfnd-MUR-GLOB-v02.0-fv04.1.nc


## Compute the global mean for a single file

Load and inspect the file

In [8]:
s3_file_obj = fs_s3.open(fns[0], mode='rb')
data = xr.open_dataset(s3_file_obj)
data 

In [11]:
%%time
print(data['analysed_sst'].mean().values)
data.close()

286.64944
CPU times: user 1.46 s, sys: 678 ms, total: 2.14 s
Wall time: 2.14 s


## Compute the global mean of the first 10 files in serial with a for-loop

This is an example of how not to use xarray, especially if trying to implement parallel computing with it:

In [6]:
%%time

globalmeansst = []
time = []
for f in fns[:10]:
    s3_file_object = fs_s3.open(f, mode='rb')
    data = xr.open_dataset(s3_file_object)
    time.append(data['time'].values)
    globalmeansst.append(data['analysed_sst'].mean().values)

CPU times: user 1min 31s, sys: 29.8 s, total: 2min 1s
Wall time: 4min 33s


In [9]:
globalmeansst

[<xarray.DataArray 'analysed_sst' ()>
 array(286.75073, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.78214, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.78564, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.7816, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.7549, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.72464, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.70187, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.67288, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.65546, dtype=float32),
 <xarray.DataArray 'analysed_sst' ()>
 array(286.64944, dtype=float32)]

## Compute global mean of the first 10 files using xarray's `open_mfdataset()` functionality

This type of syntax will allow us to use xarray with dask parallelization. As a first step, we do not implement parallelization, but show how `open_mfdataset` is used to perform the same task as above without the for-loop.

In [41]:
# Gather data from all 10 files:
s3_file_objects = [ fs_s3.open(fns[i], mode='rb') for i in range(10) ]
data_mf = xr.open_mfdataset(s3_file_objects)
data_mf

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 2.41 GiB Shape (10, 17999, 36000) (1, 17999, 36000) Dask graph 10 chunks in 21 graph layers Data type float32 numpy.ndarray",36000  17999  10,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 2.41 GiB Shape (10, 17999, 36000) (1, 17999, 36000) Dask graph 10 chunks in 21 graph layers Data type float32 numpy.ndarray",36000  17999  10,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 2.41 GiB Shape (10, 17999, 36000) (1, 17999, 36000) Dask graph 10 chunks in 21 graph layers Data type float32 numpy.ndarray",36000  17999  10,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 2.41 GiB Shape (10, 17999, 36000) (1, 17999, 36000) Dask graph 10 chunks in 21 graph layers Data type float32 numpy.ndarray",36000  17999  10,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,2.41 GiB
Shape,"(10, 17999, 36000)","(1, 17999, 36000)"
Dask graph,10 chunks in 21 graph layers,10 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Then compute mean over the spatial coordinates, which will give us the global mean for each time step:

In [None]:
%%time

data_mf['analysed_sst'].mean(dim=['lat','lon']).load()

In [46]:
del data_mf

## Compute global mean of the first 10 files using xarray's `open_mfdataset()` functionality and dask's parallel processing

For simple built in functions such as `xarray`'s `mean()`, the switch over to parallel processing is straightforward. All that needs to be done is adding a few lines before the data processing code, which starts a computing cluster and a 'client'. You can specify things like number of workers if you are familiar, otherwise, they will start up with some default settings.

In [44]:
cluster = LocalCluster()
client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 48,Total memory: 373.71 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:46275,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 48
Started: Just now,Total memory: 373.71 GiB

0,1
Comm: tcp://127.0.0.1:42825,Total threads: 6
Dashboard: http://127.0.0.1:44823/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:38117,
Local directory: /tmp/dask-scratch-space/worker-s6lff1xk,Local directory: /tmp/dask-scratch-space/worker-s6lff1xk

0,1
Comm: tcp://127.0.0.1:40239,Total threads: 6
Dashboard: http://127.0.0.1:43187/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:40333,
Local directory: /tmp/dask-scratch-space/worker-19dy0dab,Local directory: /tmp/dask-scratch-space/worker-19dy0dab

0,1
Comm: tcp://127.0.0.1:35845,Total threads: 6
Dashboard: http://127.0.0.1:45219/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:41239,
Local directory: /tmp/dask-scratch-space/worker-mj82mn1y,Local directory: /tmp/dask-scratch-space/worker-mj82mn1y

0,1
Comm: tcp://127.0.0.1:45007,Total threads: 6
Dashboard: http://127.0.0.1:38343/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:36585,
Local directory: /tmp/dask-scratch-space/worker-h3c9pjs8,Local directory: /tmp/dask-scratch-space/worker-h3c9pjs8

0,1
Comm: tcp://127.0.0.1:39807,Total threads: 6
Dashboard: http://127.0.0.1:34043/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:42559,
Local directory: /tmp/dask-scratch-space/worker-_m_sw5wb,Local directory: /tmp/dask-scratch-space/worker-_m_sw5wb

0,1
Comm: tcp://127.0.0.1:34309,Total threads: 6
Dashboard: http://127.0.0.1:43455/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:44521,
Local directory: /tmp/dask-scratch-space/worker-w6u2lprd,Local directory: /tmp/dask-scratch-space/worker-w6u2lprd

0,1
Comm: tcp://127.0.0.1:42369,Total threads: 6
Dashboard: http://127.0.0.1:46451/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:33609,
Local directory: /tmp/dask-scratch-space/worker-23tv4x3y,Local directory: /tmp/dask-scratch-space/worker-23tv4x3y

0,1
Comm: tcp://127.0.0.1:38193,Total threads: 6
Dashboard: http://127.0.0.1:39871/status,Memory: 46.71 GiB
Nanny: tcp://127.0.0.1:39565,
Local directory: /tmp/dask-scratch-space/worker-lb7doj43,Local directory: /tmp/dask-scratch-space/worker-lb7doj43


Then run the same processing code as in the previous section:

In [47]:
%%time

s3_file_objects = [ fs_s3.open(fns[i], mode='rb') for i in range(10) ]
data_mf = xr.open_mfdataset(s3_file_objects)
data_mf['analysed_sst'].mean(dim=['lat','lon']).load()

CPU times: user 3.58 s, sys: 746 ms, total: 4.33 s
Wall time: 55.5 s


In [48]:
del data_mf
client.close()
cluster.close()

## Global Mean of the First 10 Files with Additional Chunking

If you create a large number of workers such that each one does not have enough memory to process an entire file, you can specify smaller chunks:

In [None]:
cluster = LocalCluster()
client = Client(cluster)
client

In [49]:
s3_file_objects = [ fs_s3.open(fns[i], mode='rb') for i in range(10) ]
data_mf_chunked = xr.open_mfdataset(s3_file_objects, chunks={'time':1, 'lat':3000, 'lon':3000})
data_mf_chunked['analysed_sst']

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,34.33 MiB
Shape,"(10, 17999, 36000)","(1, 3000, 3000)"
Dask graph,720 chunks in 21 graph layers,720 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 24.14 GiB 34.33 MiB Shape (10, 17999, 36000) (1, 3000, 3000) Dask graph 720 chunks in 21 graph layers Data type float32 numpy.ndarray",36000  17999  10,

Unnamed: 0,Array,Chunk
Bytes,24.14 GiB,34.33 MiB
Shape,"(10, 17999, 36000)","(1, 3000, 3000)"
Dask graph,720 chunks in 21 graph layers,720 chunks in 21 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [30]:
%%time

data_mf_chunked['analysed_sst'].mean(dim=['lat','lon']).load()

CPU times: user 3.3 s, sys: 193 ms, total: 3.5 s
Wall time: 24.8 s


In [50]:
del data_mf_chunked
client.close()
cluster.close()

## Mean Over a Region of the Globe

Dask works with Xarray such that multiple functions can be strung together, and then a single `load()` function is called to computed everything. In this example, the mean over a region is taken by first subsetting all data to that region ("function 1") and then taking the mean over that region ("function 2").

In [75]:
cluster = LocalCluster()
client = Client(cluster)

In [76]:
s3_file_objects = [ fs_s3.open(fns[i], mode='rb') for i in range(10) ]
data_mf = xr.open_mfdataset(s3_file_objects)

In [72]:
results = data_mf['analysed_sst'].sel(lat=slice(-45,45), lon=slice(-180, -50)) # subset
results = results.mean(dim=['lat','lon']) # mean

In [73]:
%%time

regionalmean = results.load()
regionalmean

CPU times: user 833 ms, sys: 84.6 ms, total: 917 ms
Wall time: 12.3 s


In [77]:
del data_mf
client.close()
cluster.close()