# USGS Water Balance Model: Create cloud-optimized output
Data from a 2.5 arc minute CONUS model from 1895 to 2020
The provided files were fixed-width ASCII, with year and date in the first two columns, and the data in the rest of the columns. The raster data from each time step is written to a single row, with only the non-missing values written.  There is one file for each variable. There is also a separate CSV file that contains the lon,lat locations for each column of data.

To parallelize the workflow, we split the original files (`tmean.monthly.all.gz`, `prcp.monthly.all.gz`) into many smaller text files using `split`, choosing the number of lines to match the desired number of time steps in the chunked output.  
```
#!/bin/bash
for var in prcp tmean
do
  mkdir $var
  zcat $var.monthly.all.gz | split -l 120 --numeric-suffixes - $var/$var
done

```

In [None]:
import fsspec 
import xarray as xr
import pandas as pd
import numpy as np
import datetime as dt
import hvplot.xarray
from dask.distributed import Client
import dask

In [None]:
fs = fsspec.filesystem('')

In [None]:
inpath = '/scratch/mike/'
outpath = '/scratch/mike/wbm.zarr'

In [None]:
fs.ls(inpath)

#### Read the station locations

In [None]:
df = pd.read_csv(f'{inpath}/LatLongs.csv')

Determine the i,j locations on the grid corresponding to the given lon,lat point:

In [None]:
ii = np.round((df['X']-df['X'].min())/(2.5/60)).astype('int')
jj = np.round((df['Y']-df['Y'].min())/(2.5/60)).astype('int')

In [None]:
nx = max(ii)+1
ny = max(jj)+1
print(nx,ny)

In [None]:
lon = np.linspace(df['X'].min(), df['X'].max(),nx) 
lat = np.linspace(df['Y'].min(), df['Y'].max(),ny) 

#### Create the empty Zarr dataset to fill with chunks

In [None]:
dates = pd.date_range(start='1895-01-01 00:00',end='2021-01-01 00:00', freq='M')

In [None]:
nt = len(dates)
print(nt)

In [None]:
chunk_lon = 700
chunk_lat = 300
chunk_time = 120

In [None]:
fs.ls(f'{inpath}/gzfiles/')

In [None]:
d = dask.array.zeros((nt,ny,nx), chunks=(chunk_time, chunk_lat, chunk_lon), dtype='float32')

In [None]:
ds0 = xr.Dataset(
        {
            "prcp": (['time', 'lat', 'lon'], d),
            "tmean": (['time', 'lat', 'lon'], d),
            "aet": (['time', 'lat', 'lon'], d),
            "pet": (['time', 'lat', 'lon'], d),
            "rain": (['time', 'lat', 'lon'], d),
            "runoff": (['time', 'lat', 'lon'], d),
            "snow": (['time', 'lat', 'lon'], d),
            "soilstorage": (['time', 'lat', 'lon'], d),
            "swe": (['time', 'lat', 'lon'], d)
        },
        coords={
            "lon": (["lon"], lon),
            "lat": (["lat"], lat),
            "time": dates
        },
    )

In [None]:
ds0.to_zarr(outpath, mode='w', compute=False, consolidated=True)

In [None]:
def write_chunk(var, f, istart):
    a = np.loadtxt(f, dtype='float32')
    year = a[:,0].astype('int')
    mon = a[:,1].astype('int')
    t = [np.datetime64(dt.datetime(year[k],mon[k],1)) for k in range(len(mon))]
    data = a[:,2:]
    [nt, nr] = data.shape
    b = np.nan * np.zeros((nt,ny,nx), dtype='float32')
    for k in range(nr):
        b[:, jj[k], ii[k]] = data[:,k]
    da = xr.DataArray(data=b, dims=['time','lat','lon'], 
                  coords=dict(
                      lon=('lon',lon),
                      lat=('lat',lat),
                      time=('time',t)
                    ))
    ds = da.to_dataset(name=var)
    ds = ds.chunk(chunks={'time':chunk_time, 'lat':chunk_lat, 'lon':chunk_lon})
    ds.drop(['lon','lat']).to_zarr(outpath, region={'time':slice(istart,istart+nt)})

In [None]:
client = Client()

In [None]:
%%time
tasks=[]
for var in ['tmean','prcp','aet','pet','rain','runoff','snow','soilstorage','swe']:
    flist = fs.glob(f'/scratch/mike/gzfiles/{var}/{var}??')
    i = 0
    for f in flist:
        print(f)
        istart=i*chunk_time
        tasks.append(dask.delayed(write_chunk)(var, f, istart))
        i = i + 1

In [None]:
%%time
dask.compute(tasks, scheduler='processes', num_workers=4)

#### Let's see what we produced!

In [None]:
ds2 = xr.open_dataset(outpath, engine='zarr', chunks={})

In [None]:
ds2

In [None]:
ds2.tmean.sel(time='1925-01-01').hvplot.quadmesh(x='lon',y='lat', geo=True, tiles='OSM', cmap='turbo', rasterize=True, alpha=0.7)

In [None]:
ds2.prcp.hvplot.quadmesh(x='lon',y='lat', geo=True, tiles='OSM', cmap='turbo', rasterize=True, alpha=0.7)

In [None]:
ds2.tmean

In [None]:
ds2.tmean.sel(lon=-90.,lat=35.,method='nearest').plot()