# Transfer Speeds - Chunks

## Imports & Client Initialization

In [2]:
import dask.array as dsa
import numpy as np
from contextlib import contextmanager
import time
import dask
import intake
import xarray as xr
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
from matplotlib.ticker import MaxNLocator
import matplotlib.colors
import pandas as pd
from scipy.stats import sem
import tiledb

In [3]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(threads_per_worker=2)
client = Client(cluster)
cluster.scheduler

Perhaps you already have a cluster running?
Hosting the HTTP server on port 36225 instead


## Benchmarking Setup

The cell below is only used for benchmarking reads.

In [4]:
class DevNullStore:
    def __init__(self):
        pass
    def __setitem__(*args, **kwargs):
        pass

null_store = DevNullStore()

############################################################################################################################

class DiagnosticTimer:
    def __init__(self):
        self.diagnostics = []
        self.names = []
        
    @contextmanager
    def time(self, **kwargs):
        tic = time.time()
        yield
        toc = time.time()
        kwargs["runtime"] = toc - tic
        self.diagnostics.append(kwargs)
        
    def dataframe(self):
        return pd.DataFrame(self.diagnostics)
    
diag_timer = DiagnosticTimer()

############################################################################################################################

def name(fileType, daf): 
    globals()[f"df_{fileType}"] = daf
    diag_timer.names.append(globals()[f"df_{fileType}"])
    
    global df, da
    del df, da
    diag_timer.diagnostics = []
    
############################################################################################################################   

def total_nthreads():
    return sum([v for v in client.nthreads().values()])

def total_ncores():
    return sum([v for v in client.ncores().values()])

def total_workers():
    return len(client.ncores())

############################################################################################################################

class mainLoop:
    def errorCalc(self, df0):
        global tests
        newVals = []
        info = []
        thrPut = df0['throughput_Mbps']
        rTime = df0['runtime']
        for i in np.linspace(0, len(thrPut)-tests, int(len(thrPut)/tests), dtype='int'):
            means = thrPut[slice(i,(i+tests))].mean()
            runtime = rTime[slice(i,(i+tests))].mean()
            errors = sem(thrPut[slice(i,(i+tests))])
            error_kwargs = dict(runtime = runtime, throughput_Mbps = means, errors = errors)
            info.append(df0.iloc[i, 0:7])
            newVals.append(error_kwargs)
        
        df1 = pd.DataFrame(info, index=range(len(info)))
        df2 = pd.DataFrame(newVals)
        df = pd.concat([df1, df2], axis=1)
        return df

    def loop(self, da, diag_kwargs):
        global tests, max_workers, worker_step
        for nworkers in np.flip(np.arange(max_workers, 0, -worker_step)):
            cluster.scale(nworkers)
            time.sleep(10)
            client.wait_for_workers(nworkers)
            print('Number of Workers:', nworkers)
            for i in range(tests):
                with diag_timer.time(nworkers=total_workers(), nthreads=total_nthreads(), ncores=total_ncores(),
                                     **diag_kwargs):
                    future = dsa.store(da, null_store, lock=False, compute=False)
                    dask.compute(future, retries=5)
                del future
        
        df = diag_timer.dataframe()
        df['throughput_Mbps'] = da.nbytes / 1e6 / df['runtime']
        if i != 0:
            df = self.errorCalc(df)
        return df

mainLoop = mainLoop()

---------------

## Perform Benchmarking

In [5]:
# Loop Parameters
tests = 5
max_workers = 8
worker_step = 1

# Data Location (DO NOT CHANGE THESE)
root = 'gs://cloud-data-benchmarks/'
data = 'slp.1948-2009.100MB'
variable = 'SLP'

### With Dask

#### Write Script

This bucket does not have public write access. I attempted to use the `commits` consolidation and vacuuming feature, but I do not have the most up to date version of TileDB-Py installed in my VM image.

In [None]:
config = tiledb.Config()
config['vfs.gcs.project_id'] = 'modular-magpie-167320' # Input your project ID here
ctx = tiledb.Ctx(config)

uri = root + data + '.tldb'

with diag_timer.time(conversionType='netcdf2tldb'):
    ds = intake.open_netcdf(root + data + '.nc').to_dask()
    da = ds[variable]
    internal_chunks = da.encoding['chunksizes']
    coords = da.dims
    da = da.chunk(chunks=dict(zip(coords, internal_chunks))).data

############################################################################################################################
    # TileDB Custom Schema Creation

    filter_list = tiledb.FilterList([tiledb.LZ4Filter(level=5)])

    dims = []
    for n in range(len(coords)):
        dim = tiledb.Dim(name=coords[n], domain=(0, ds[variable].encoding['original_shape'][n]-1),
                         tile=internal_chunks[n], dtype=np.uint64, filters=filter_list)
        dims.append(dim)

    attr = [tiledb.Attr(name=variable, dtype=np.float32, filters=filter_list)]
    dom = tiledb.Domain(dims)

    schema = tiledb.ArraySchema(domain=dom, attrs=attr, sparse=False)
    tiledb.Array.create(uri, schema)
    tdb_array = tiledb.open(uri, "w")
############################################################################################################################

    # The two configuration options were chosen as 2 to coincide with the threads_per_worker value set before
    da.to_tiledb(tdb_array, storage_options={"sm.compute_concurrency_level": 2, "sm.io_concurrency_level ": 2})

    config['sm.consolidation.mode'] = 'fragment_meta'
    ctx = tiledb.Ctx(config)
    tiledb.consolidate(uri, ctx=ctx)

In [None]:
write_results = diag_timer.dataframe()
diag_timer.diagnostics = [] # Clear write diagnostics for use in read throughput benchmarking
write_results

#### Read Script

In [None]:
tic1 = time.time()
da = dsa.from_tiledb(root + data + '.tldb')
toc1 = time.time()
connectTime = toc1 - tic1
chunksize = np.prod(da.chunksize) * da.dtype.itemsize
da

In [None]:
diag_kwargs = dict(nbytes=da.nbytes, chunksize=chunksize, format='TileDB Embedded', connectTime=connectTime)

df = mainLoop.loop(da, diag_kwargs)
name('tldb', df)
df_tldb

In [None]:
client.close()
cluster.close()

### Without Dask 

#### Write

In [None]:
ds = intake.open_netcdf(root + data + '.nc').to_dask()
data = ds[variable].data # This will simply output a numpy array of values, disregard the `.to_dask()` in the first line

#################################################Schema Creation############################################################

filter_list = tiledb.FilterList([tiledb.LZ4Filter(level=5)])

dims = []
for n in range(len(coords)):
    dim = tiledb.Dim(name=coords[n], domain=(0, ds[variable].encoding['original_shape'][n]-1),
                     tile=internal_chunks[n], dtype=np.uint64, filters=filter_list)
    dims.append(dim)

attr = [tiledb.Attr(name=variable, dtype=np.float32, filters=filter_list)]
dom = tiledb.Domain(dims)

schema = tiledb.ArraySchema(domain=dom, attrs=attr, sparse=False)
tiledb.Array.create(uri, schema)

############################################################################################################################

with tiledb.open(uri, "w") as tdb_array:
    tdb_array = data

#### Read

My entire throughput testing process relies on Dask to scale the number of workers up, so I will work on creating a benchmark for TileDB Embedded without the use of Dask.

-------------------------------------------------------------------------------------------------------------------------------

## Plot Throughput

Running the cell below will give you a throughput plot for the read results.

In [None]:
class errorPlot:
    def plot(self):
        x = self.df['nworkers']
        y = self.df['throughput_Mbps']
        error = self.df['errors']
        plt.errorbar(x, y, yerr=error, color=self.c, fmt='o', capsize=5, capthick=2)
        
    def errorCheck(self, daf, color):
        self.c = color
        self.df = daf
        try:
            self.plot()
        except:
            pass
        else:
            self.plot()
            
errorPlot = errorPlot()


color = cm.rainbow(np.linspace(0,1,len(diag_timer.names)))
legend = []
df_results = pd.concat(diag_timer.names, ignore_index=True)

for i in range(len(diag_timer.names)):
    legend.append(diag_timer.names[i]['format'][1])
    c = matplotlib.colors.to_hex(color[i,:], keep_alpha=True)
    
    if i == 0:
        ax = diag_timer.names[i].plot(x='nworkers', y='throughput_Mbps', kind='line', color=c, marker='o')
    else:
        diag_timer.names[i].plot(x='nworkers', y='throughput_Mbps', kind='line', color=c, ax=ax, marker='o')
        
    errorPlot.errorCheck(diag_timer.names[i], c) 
    plt.grid(True)
    plt.title('Cloud Data Read Speeds with Dask')
    plt.xlabel('Number of Parallel Reads')
    plt.ylabel('Throughput (Mbps)')
    plt.legend(legend, bbox_to_anchor=[1.25, 0.5], loc='center', title='Store Formats')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    #plt.yscale('symlog') ACTIVATE THIS LINE IF YOU ARE USING A LARGE AMOUNT OF WORKERS

In [None]:
df_results