In [None]:
# Make sure you are under the cper environment

# Import modules
import numpy as np
import xarray as xr
import dask.array as da
from dask.distributed import Client
from dask import delayed
from time import sleep

In [None]:
# Use the distributed scheduler to form a local cluster
# 4 works, 1 thread (CPU) per work
my_client = Client(n_workers=4, threads_per_worker=1)

# Show information of the local cluster
my_client.cluster

In [None]:
# Define two functions
def fun1(x):
    sleep(1)
    return x + 1

def fun2(x, y):
    sleep(1)
    return x + y

In [None]:
%%time

# This takes three seconds to run because we call each
# function sequentially, one after the other
x = fun1(1)
y = fun1(2)
z = fun2(x, y)

In [None]:
%%time

# This runs immediately, all it does is build a graph
x = delayed(fun1)(1)
y = delayed(fun1)(2)
z = delayed(fun2)(x, y)

In [None]:
%%time

# This actually runs our computation using a local cluster
z.compute()

In [None]:
# z is a lazy delayed object
z

In [None]:
# Look at the task graph for z
z.visualize()

In [None]:
# Make a simple list
#data = [1, 2, 3, 4]
data = [1, 2, 3, 4, 5, 6, 7, 8]

In [None]:
%%time

# Loop element one by one
# Sequential code
results = []

for i in data:
    temp = fun1(i)
    results.append(temp)

total = sum(results)

In [None]:
total

In [None]:
%%time

# Parallel code 
results = []

for i in data:
    temp = delayed(fun1)(i)
    results.append(temp)
    
total = delayed(sum)(results)

# Let's see what type of thing total is
print("Before computing:", total)

# Compute
result = total.compute()

# After it's computed
print("After computing :", result)  

In [None]:
# Look at the task graph for total
total.visualize()

In [None]:
# Load the first file with xarray
ds_first = xr.open_dataset('aviso_2015/dt_global_allsat_madt_h_20150101_20150914.nc')

# Check the data
ds_first

In [None]:
# Use open_mfdataset to load all the nc files
ds = xr.open_mfdataset('aviso_2015/*.nc')

# Check data object
# Notice that the values are not displayed
ds

In [None]:
# Get sea surface height
ssh = ds.adt

# Check the data, this is a dask array
ssh

In [None]:
# Plot the first day's data
ssh[0]

In [None]:
# Plot the first day's data
ssh[0].plot()

In [None]:
# Compute annual mean ssh
ssh_2015_mean = ssh.mean(dim='time')

# Need to use load function
ssh_2015_mean.load()

In [None]:
# Plot annual mean
ssh_2015_mean.plot()

In [None]:
# Compute daily anomalies
ssh_anom = ssh - ssh_2015_mean

# Compute variance in daily anomalies
ssh_variance = (ssh_anom**2).mean(dim=('time'))

# This is an empty object
ssh_variance

In [None]:
# Compute the value
ssh_variance.load()

In [None]:
# Call plot function
ssh_variance.plot()

In [None]:
# Close the client (local cluster)
my_client.close()