In [None]:
import xarray as xr
import numpy as np
from datetime import datetime
import gcsfs
xr.__version__

## Set up Dask Cluster

In [None]:
from dask.distributed import Client
from dask_kubernetes import KubeCluster
cluster = KubeCluster(n_workers=1)
client = Client(cluster)
cluster

In [None]:
import distributed
from time import sleep, time

def get_nworkers(cores_per_worker=2):
    cl = distributed.get_client()
    ncores = sum(cl.ncores().values())
    return ncores // cores_per_worker

def block_until_scaled(desired_workers):
    cl = distributed.get_client()
    cl.restart()
    cl.cluster.scale(desired_workers)
    while get_nworkers() != desired_workers:
        sleep(5)

In [None]:
gc_path = 'pangeo-data/esgf_test/pr_Amon_GFDL-CM4_piControl_r1i1p1f1_gr1'
ds = xr.open_zarr(gcsfs.GCSMap(gc_path))
ds

## Benchmark Loading Speed

In [None]:
nworkers = [1, 2, 4, 8, 16]
tc = 120
rows = []
for nw in nworkers:
    block_until_scaled(nw)
    total_data_size = ds.pr.nbytes/1e6
    tic = time()
    pr_mean = ds.pr.mean(dim='time').load()
    runtime = time() - tic
    row = (datetime.now(), nw, tc, runtime, total_data_size)
    rows.append(row)
    print(', '.join([repr(r) for r in row]))

In [None]:
import pandas as pd
columns = ['timestamp', 'nworkers', 'chunksize', 'runtime', 'datasize']
df = pd.DataFrame(rows, columns=columns)
df

In [None]:
df.to_csv('benchmark_zarr.csv')