# Creating Chimera Training Samples

Once tiles have been collected in notebook two, we can use the FIA catalog generated in notebook one to sample data by point location and produce "chips" used for training. Each point in the catalog is associated with ground truth data collected by FIA. 

**Note: Some indexerrors are expected for larger chip sizes as only points are checked for intersection and samples may spans two tiles**

In [None]:
import os
import sys
sys.path.append('..')

import fiona.transform
import fsspec
import xarray as xr
from affine import Affine
import pandas as pd
import numpy as np
try:
    from dask_gateway import GatewayCluster
    clustenv='distributed'
except ModuleNotFoundError:
    clustenv='local'
    print('Using a local cluster...')

from utils.hls import catalog
from utils.hls import compute
from utils import get_logger

In [None]:
logger = get_logger('hls-point-sampling')
cluster_args = dict(
    workers=64, #int(.8*os.cpu_count()),
    worker_threads=3,
    worker_memory=16,
    scheduler_threads=4,
    scheduler_memory=8,
    clust_type=clustenv
)
code_path = '../utils'
checkpoint_path = 'checkpoints/sampling.txt'

In [None]:
# fill with your desired blob container for sample data collection
%store -r envdict
envdict['CHIP_BLOB_CONTAINER'] = ''
envdict['COL_ENV'] = clustenv
%store envdict

In [None]:
catalog_path = fsspec.get_mapper(
    f"az://fia/catalogs/fia_tiles.zarr",
    account_name=envdict['AZURE_STRG_ACCOUNT_NAME'],
    account_key=envdict['AZURE_STRG_ACCOUNT_KEY']
)
pt_catalog = catalog.HLSCatalog.from_zarr(catalog_path)

In [None]:
hls_pts = pt_catalog.xr_ds.where(pt_catalog.xr_ds['year'] >= 2019, drop=True)
hls_pts = hls_pts.to_dataframe()
jobs = hls_pts.groupby(['tile', 'year'])

In [None]:
ttdf = pd.read_csv('test_data/fd_test_tiles.csv')
test_tiles = ttdf['tile'].values

test_jobs = [job for job in jobs if job[0][0] in test_tiles]
col_inds = []
for ji, j in enumerate(test_jobs):
    ttdfrow=ttdf[ttdf['tile'] == j[0][0]].iloc[0]
    if j[0][1] == ttdfrow.year:
        col_inds.append(ji)
test_jobs = list(map(test_jobs.__getitem__,col_inds))

In [None]:
from dask.distributed import get_worker

def chip(ds, lat, lon, chip_size, metadata):
    CRS = "EPSG:4326"
    tfm = Affine(*ds.attrs['transform'])
    ([x], [y]) = fiona.transform.transform(
        CRS, ds.attrs['crs'], [lon], [lat]
    )
    x_idx, y_idx = [round(coord) for coord in ~tfm * (x, y)]

    half_chip = int(chip_size/2)
    try:
        return ds[dict(x=range(x_idx-half_chip, x_idx+half_chip), y=range(y_idx-half_chip, y_idx+half_chip))]
    except IndexError:
        get_worker().log_event("message", {"type": "IndexError", **metadata})
        return None
        

In [None]:
import dask

def chip_tile_year(
    job_id, job_df, chip_size, bands, account_name, chip_container, tile_container, account_key
):
    def sample_and_write(tl, row):
        sample = chip(
            tl,
            row['lat'],
            row['lon'],
            chip_size,
            metadata={'index': row['INDEX'], 'tile': row['tile'], 'year': row['year']}
        )
        if sample:
            output_zarr = fsspec.get_mapper(
                f"az://{chip_container}/{int(row['INDEX'])}-{row['tile']}.zarr",
                account_name=account_name,
                account_key=account_key
            )
            sample.chunk({'month': 12, 'x': 32, 'y': 32}).to_zarr(output_zarr, mode='w')
    band_names = [band.name for band in bands]
    tile, year = job_id
    input_zarr = fsspec.get_mapper(
        f"az://{tile_container}/{float(year)}/{tile}.zarr",
        account_name=account_name,
        account_key=account_key
    )
    try:
        ds = xr.open_zarr(input_zarr)[band_names].persist()
    except:
        errstr = f"az://{tile_container}/{float(year)}/{tile}.zarr"
        raise ValueError(errstr)
    samples = []
    for _, row in job_df.iterrows():
        samples.append(sample_and_write(ds, row))
    return job_id
    

In [None]:
compute.process_jobs(
    jobs=test_jobs,
    job_fn=chip_tile_year,
    checkpoint_path=checkpoint_path,
    logger=logger,
    cluster_args=cluster_args,
    code_path=code_path,
    concurrency=6,  # run 6 jobs at once
    cluster_restart_freq=42,  # restart after 42 jobs
    # chip_tile_year kwargs
    bands=pt_catalog.xr_ds.attrs['bands'],
    chip_size=32,
    account_name=envdict['AZURE_STRG_ACCOUNT_NAME'],
    chip_container=envdict['CHIP_BLOB_CONTAINER'],
    tile_container=envdict['TILE_BLOB_CONTAINER'],
    account_key=envdict['AZURE_STRG_ACCOUNT_KEY']
)