In [None]:
import os
# These are the default AWS configurations for the Analysis Sandbox.
# that are set in the environmnet variables. 
aws_default_config = {
    #'AWS_NO_SIGN_REQUEST': 'YES', 
    'AWS_SECRET_ACCESS_KEY': 'fake',
    'AWS_ACCESS_KEY_ID': 'fake',
}

# To access public bucket, need to remove the AWS credentials in 
# the environment variables or the following error will occur.
# PermissionError: The AWS Access Key Id you provided does not exist in our records.

for key in aws_default_config.keys():
    if key in os.environ:
        del os.environ[key]

In [None]:
import logging
import os
import queue
from threading import Thread

import click
import datacube
import fsspec
from odc.dscache import create_cache
from odc.dscache.apps.slurpy import EOS, qmap
from odc.dscache.tools import (
    bin_dataset_stream,
    dataset_count,
    db_connect,
    dictionary_from_product_list,
    mk_raw2ds,
    ordered_dss,
    raw_dataset_stream,
)
from odc.dscache.tools.tiling import parse_gridspec_with_name
from odc.stats.model import DateTimeRange
from tqdm import tqdm

from deafrica_conflux.cli.logs import logging_setup
from deafrica_conflux.hopper import bin_solar_day, persist
from deafrica_conflux.io import (
    check_dir_exists,
    check_file_exists,
    check_if_s3_uri,
    find_geotiff_files,
)
from deafrica_conflux.text import parse_tile_ids

In [None]:
verbose = 1
# Grid name africa_{10|20|30|60}
grid_name = "africa_30"
# Datacube product to search datasets for.
product = "wofs_ls"
# Only extract datasets for a given time range," "Example '2020-05--P1M' month of May 2020
temporal_range = "2023-03--P3M"
# Compression setting for zstandard 1-fast, 9+ good but slow
complevel = 6
# Path to the directory containing the polygons raster files.
polygons_rasters_directory = "s3://deafrica-waterbodies-dev/waterbodies/v0.0.2/senegal_basin/conflux/historical_extent_rasters"
# Regular expression for filename matching when searching for the polygons raster files.
pattern = ".*"
# Overwrite existing cache file.
overwrite = True
# Directory to write the cache file to.
output_directory = "s3://deafrica-waterbodies-dev/waterbodies/v0.0.2/senegal_basin/conflux/dbs"

In [None]:
# Set up logger.
logging_setup(verbose)
_log = logging.getLogger(__name__)

In [None]:
# Support pathlib Paths.
polygons_rasters_directory = str(polygons_rasters_directory)
output_directory = str(output_directory)

In [None]:
# Create the output directory if it does not exist.
is_s3 = check_if_s3_uri(output_directory)
if is_s3:
    fs = fsspec.filesystem("s3")
else:
    fs = fsspec.filesystem("file")

if not check_dir_exists(output_directory):
    fs.makedirs(output_directory, exist_ok=True)
    _log.info(f"Created directory {output_directory}")

if not check_dir_exists(polygons_rasters_directory):
    _log.error(f"Directory {polygons_rasters_directory} does not exist!")
    raise FileNotFoundError(f"Directory {polygons_rasters_directory} does not exist!)")

In [None]:
# Validate the product
products = [product]
# Connect to the datacube.
dc = datacube.Datacube()
# Get all products.
all_products = {p.name: p for p in dc.index.products.get_all()}
if len(products) == 0:
    raise ValueError("Have to supply at least one product")
else:
    for p in products:
        if p not in all_products:
            raise ValueError(f"No such product found: {p}")

In [None]:
# Parse the temporal range.
temporal_range_ = DateTimeRange(temporal_range)

output_db_fn = f"{product}_{temporal_range_.short}.db"
output_db_fp = os.path.join(output_directory, output_db_fn)

# Check if the output file exists.
if check_file_exists(output_db_fp):
    if overwrite:
        fs.delete(output_db_fp, recursive=True)
        _log.info(f"Deleted {output_db_fp}")
        # Delete the local file created before uploading to s3.
        if is_s3:
            if check_file_exists(output_db_fn):
                fsspec.filesystem("file").delete(output_db_fn)
                _log.info(f"Deleted local file created before uploading to s3 {output_db_fn}")
    else:
        raise FileExistsError(f"{output_db_fp} exists!")

In [None]:
# Create the query to find the datasets.
query = {"time": (temporal_range_.start, temporal_range_.end)}
_log.info(f"Query: {query}")

In [None]:
_log.info("Getting dataset counts")
counts = {p: dataset_count(dc.index, product=p, **query) for p in products}

n_total = 0
for p, c in counts.items():
    _log.info(f"..{p}: {c:8,d}")
    n_total += c

if n_total == 0:
    raise ValueError("No datasets found")

In [None]:
_log.info("Training compression dictionary...")
zdict = dictionary_from_product_list(dc, products, samples_per_product=50, query=query)
_log.info("Done")

In [None]:
if is_s3:
    cache = create_cache(output_db_fn, zdict=zdict, complevel=complevel, truncate=True)
else:
    cache = create_cache(output_db_fp, zdict=zdict, complevel=complevel, truncate=True)

In [None]:
raw2ds = mk_raw2ds(all_products)

def db_task(products, conn, q):
    for p in products:
        if len(query) == 0:
            dss = map(raw2ds, raw_dataset_stream(p, conn))
        else:
            dss = ordered_dss(dc, product=p, **query)

        for ds in dss:
            q.put(ds)
    q.put(EOS)

conn = db_connect()
q = queue.Queue(maxsize=10_000)
db_thread = Thread(target=db_task, args=(products, conn, q))
db_thread.start()

In [None]:
dss = qmap(lambda ds: ds, q, eos_marker=EOS)
dss = cache.tee(dss)

cells = {}
grid, gridspec = parse_gridspec_with_name(grid_name)
cache.add_grid(gridspec, grid)

cfg = dict(grid=grid)
cache.append_info_dict("stats/", dict(config=cfg))

dss = bin_dataset_stream(gridspec, dss, cells, persist=persist)

In [None]:
label = f"Processing {n_total:8,d} {product} datasets"
with tqdm(dss, desc=label, total=n_total) as dss:
    for _ in dss:
        pass

In [None]:
# Find the required tiles.
_log.info(f"Total bins: {len(cells):d}")
_log.info("Filtering bins by required tiles...")
geotiff_files = find_geotiff_files(path=polygons_rasters_directory, pattern=pattern, verbose=False)

tiles_ids = [parse_tile_ids(file) for file in geotiff_files]
_log.info(f"Found {len(tiles_ids)} tiles.")
_log.debug(f"Tile ids: {tiles_ids}")

# Filter cells by tile ids.
cells = {k: v for k, v in cells.items() if k in tiles_ids}
_log.info(f"Total bins: {len(cells):d}")

In [None]:
_log.info("For each bin, group datasets by solar day.")
tasks = bin_solar_day(cells)
_log.info(f"Total tasks: {len(tasks)}")

In [None]:
_log.info("Removing duplicate source uuids...")
# Duplicates occur when queried datasets are captured around UTC midnight
# and around weekly boundary
tasks = {k: set(dss) for k, dss in tasks.items()}
tasks_uuid = {k: [ds.id for ds in dss] for k, dss in tasks.items()}
all_ids = set()
for k, dss in tasks_uuid.items():
    all_ids.update(dss)
_log.info(f"Total of {len(all_ids):,d} unique dataset IDs after filtering.")

In [None]:
label = f"Saving {len(tasks)} tasks to disk"
with tqdm(tasks_uuid.items(), desc=label, total=len(tasks_uuid)) as groups:
    for group in groups:
        cache.add_grid_tile(grid, group[0], group[1])

db_thread.join()
cache.close()

In [None]:
if is_s3:
    fs.upload(output_db_fn, output_db_fp, recursive=False)
    fsspec.filesystem("file").delete(output_db_fn)

_log.info(f"Cache file written to {output_db_fp}")

In [None]:
# pylint:disable=too-many-locals
csv_path = os.path.join(output_directory, f"{product}_{temporal_range_.short}.csv")
with fs.open(csv_path, "wt", encoding="utf8") as f:
    f.write('"T","X","Y","datasets","days"\n')
    for p, x, y in sorted(tasks):
        dss = tasks[(p, x, y)]
        n_dss = len(dss)
        n_days = len(set(ds.time.date() for ds in dss))
        line = f'"{p}", {x:+05d}, {y:+05d}, {n_dss:4d}, {n_days:4d}\n'
        f.write(line)
        
_log.info(f"Written summary to {csv_path}")