# Download STAC Assets

> For this demo we will require a few spatial libraries that can be easily installed via pip install. We will be using gdal, rasterio, pystac and databricks-mosaic for data download and data manipulation. We will use Microsoft [Planetary Computer](https://planetarycomputer.microsoft.com/) as the [STAC](https://stacspec.org/en) source of the raster data. __Note: Because we are using the free tier of MPC, downloads might be throttled.__ 

---
__Last Update:__ 18 JAN 2024 [Mosaic 0.3.14]

## Setup

<p/>

1. Import Databricks columnar functions (including H3) for DBR with `from pyspark.databricks.sql.functions import *`
2. To use Databricks Labs [Mosaic](https://databrickslabs.github.io/mosaic/index.html) library for geospatial data engineering, analysis, and visualization functionality:
  * Configure Init Script to install GDAL on your cluster [[1](https://databrickslabs.github.io/mosaic/usage/install-gdal.html)]
  * Install with `%pip install databricks-mosaic`
  * Import and use with the following:
  ```
  import mosaic as mos
  mos.enable_mosaic(spark, dbutils)
  mos.enable_gdal(spark)
  ```
<p/>

3. To use [KeplerGl](https://kepler.gl/) OSS library for map layer rendering:
  * Already installed with Mosaic, use `%%mosaic_kepler` magic [[Mosaic Docs](https://databrickslabs.github.io/mosaicusage/kepler.html)]
  * Import with `from keplergl import KeplerGl` to use directly

If you have trouble with Volume access:

* For Mosaic 0.3 series (< DBR 13)     - you can copy resources to DBFS as a workaround
* For Mosaic 0.4 series (DBR 13.3 LTS) - you will need to either copy resources to DBFS or setup for Unity Catalog +Shared Access which will involve your workspace admin. Instructions, as updated, will be [here](https://databrickslabsgithub.io/mosaic/usage/install-gdal.html).

The search and download phase was run on AWS [m5d.xlarge](https://www.databricks.com/product/pricing/product-pricing/instance-types) instances (2-16 workers auto-scaling for up to 64 concurrent downloads).

### Imports + Config

In [0]:
%pip install --quiet 'databricks-mosaic<0.4,>=0.3'
%pip install --quiet rasterio==1.3.5 gdal==3.4.3 pystac pystac_client planetary_computer tenacity rich pandas==1.5.3

In [0]:
# -- configure AQE for more compute heavy operations
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 512)

# -- import databricks + delta + spark functions
from delta.tables import *
from pyspark.databricks.sql import functions as dbf
from pyspark.sql import functions as F
from pyspark.sql.functions import col
from pyspark.sql.types import *

# -- setup mosaic
import mosaic as mos

mos.enable_mosaic(spark, dbutils)
mos.enable_gdal(spark)

# -- other imports
from datetime import datetime
import library
import os
import pathlib
import planetary_computer
import pystac_client
import requests
import warnings

warnings.simplefilter("ignore")

In [0]:
mos.__version__

In [0]:
%reload_ext autoreload
%autoreload 2
%reload_ext library

### Databricks Catalog + Schema

> This is for writing out table(s).

In [0]:
# adjust to your preferred catalog + schema
catalog_name = "geospatial_docs"
db_name = "eo_alaska"

sql(f"""USE CATALOG {catalog_name}""")

# uncomment to cleanup prior
# sql(f"""DROP DATABASE IF EXISTS {db_name} CASCADE""")

sql(f"""CREATE DATABASE IF NOT EXISTS {db_name}""")
sql(f"""USE DATABASE {db_name}""")

### Data `ETL_DIR`

In [0]:
# Adjust this path to suit your needs...
user_name = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

ETL_DIR = f"/home/{user_name}/stac/eo-series"
ETL_DIR_FUSE = f"/dbfs/{ETL_DIR}"

os.environ['ETL_DIR'] = ETL_DIR
os.environ['ETL_DIR_FUSE'] = ETL_DIR_FUSE

# dbutils.fs.rm(ETL_DIR, True) # <- uncomment to clean out
dbutils.fs.mkdirs(ETL_DIR)
print(f"...ETL_DIR: '{ETL_DIR}', ETL_DIR_FUSE: '{ETL_DIR_FUSE}' (create)")

## Alaska STAC Asset Download

> We can easily extract the download links for items of interest. In this case, we will grab data within Alaska. Note: due to limitations in the [free tier for Planetary Computer](https://planetarycomputer.microsoft.com/docs/concepts/sas/), we will not attempt to get all available data within our time range. 

In [0]:
# Set EO_DIR to data/alaska subfolder of ETL_DIR
EO_DIR = f"{ETL_DIR}/data/alaska"
EO_DIR_FUSE = f"/dbfs{EO_DIR}"

os.environ['EO_DIR'] = EO_DIR
os.environ['EO_DIR_FUSE'] = EO_DIR_FUSE

# dbutils.fs.rm(EO_DIR, True) # <- uncomment to clean out
dbutils.fs.mkdirs(EO_DIR)
print(f"...EO_DIR: '{EO_DIR}', EO_DIR_FUSE: '{EO_DIR_FUSE}' (create)")

@udf(returnType=IntegerType())
def file_size(file_path):
  """
  Return file_size or null.
  - must exist and be a file
  """
  import os

  if os.path.exists(file_path) and os.path.isfile(file_path):
    return os.path.getsize(file_path)
  else:
    return None


@udf(returnType=StringType())
def timestamp_filename(dt):
  """
  Convert a passed timestamp to a filename friendly output.
  - return looks like 20230131-092030
  """
  from datetime import datetime

  if dt is None:
    return None
  return dt.strftime(library.FILENAME_TIMESTAMP_FORMAT)

def get_now_formatted():
  """
  Use for last update.
  - this is same as used in `timestamp_filename`
  """
  return datetime.now().strftime(library.FILENAME_TIMESTAMP_FORMAT)

def download_band(
    eod_items, band_name, is_append_mode, tbl_prefix="band", eo_dir=EO_DIR, repartition_factor=5,
    do_clean_files=False, do_download=True, do_table_write=True
  ):
  """
  Download band into table.
  - sets the 'last_update'
  - assumes databricks catalog and schema already set
  - default is append mode vs overwrite
  - default is 'do_table_write=True'
  - default is 'do_download=True'
  - default is 'do_clean_files=False'
  - filenames are '{band_name}_{item_id}.tif'
  Returns dataframe either from table or generated.
  !!! If you do not write, the returned dataframe will have not yet been executed !!!

  Notes:
  [a] It can take some time to clean files; this will be everything in the <band_name> dir
  [b] It can take some time to download files per band, especially in MPC free tier
  [c] If not doing table write, then the dataframe will not have forced execution on every row,
      you will have to manage that as the caller
  [d] You can change the table prefix to be more isolated for a given analytic / time filter
  """
  _eod_items = eod_items.filter(f"asset.name == '{band_name}'")
  orig_repart_num = spark.conf.get("spark.sql.shuffle.partitions")
  repart_num = round(_eod_items.count() / repartition_factor)
  spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", False) # <- option-2: just tweak partition management
  spark.conf.set("spark.sql.shuffle.partitions", repart_num)
  print(f"\t...shuffle partitions to {repart_num} for this operation.")
  try:
    last_updated = get_now_formatted()

    tbl_name = f"{tbl_prefix}_{band_name}"
    eo_dir_band = f"{eo_dir}/{band_name}"

    # [1] what are we downloading?
    #     - see that asset_name is static
    #     - band name will be used below
    to_download = (
      _eod_items
        .repartition(repart_num)
        .groupBy("item_id", "timestamp")
          .agg(
            F.sort_array(F.collect_set("h3")).alias("h3_set"),
            *[F.first(cn).alias(cn) for cn in eod_items.columns if cn not in ["item_id", "timestamp", "h3", "geojson"]]
          )
          .withColumn(
            "band_name",
            F.lit(band_name)
          )
          .withColumn(
            "out_dir_fuse",
            F.lit(f"/dbfs{eo_dir_band}")
          )
          .withColumn(
            "out_filename",
            F.concat(
              col("band_name"), F.lit("_"), col("item_id"),
              F.lit("_"), timestamp_filename("timestamp"), F.lit(".tif")
            )
        )
        .withColumn(
          "last_update",
          F.lit(last_updated)
        )  
    )

    # [3] clean files?
    if do_clean_files:
      dbutils.fs.rm(eo_dir_band, True)

    # [4] do download?
    if do_download:
      to_download = (
        to_download
          .withColumn(
            "out_file_path", 
            library.download_asset(
              "item_id",
              "band_name",
              "out_dir_fuse",
              "out_filename"
              )
          )
          .withColumn(
            "out_file_sz",
            file_size("out_file_path")
          )
          .withColumn(
            "is_out_file_valid",
            F
              .when(F.isnull("out_file_sz"), F.lit(False))
              .when(col("out_file_sz") > F.lit(library.FILE_SIZE_THRESHOLD), F.lit(True))
              .otherwise(F.lit(False))
          )
      )
    else:
      to_download = (
        to_download
          .withColumn(
            "out_file_path", 
            F.concat(col("out_dir_fuse"), F.lit("/"), col("out_filename")) # <- path set manually
          )
          .withColumn(
            "out_file_sz",
            F.lit(None).cast("integer") # <- sz unknown
          )
          .withColumn(
            "is_out_file_valid",
            F.lit(None).cast("boolean") # <- validity unknown
          )
      )
  
    # [5] write table?
    # [6] append mode?
    if do_table_write:
      write_mode = "append"
      if not is_append_mode:
        write_mode = "overwrite"
        spark.sql(f"DROP TABLE IF EXISTS {tbl_name}")

      (
        to_download
          .write
            .mode(write_mode)
          .saveAsTable(tbl_name)
      )

      # - return dataframe of written table
      #   !!! this is fully executed !!!
      return spark.table(tbl_name)
    else:
      # - return unwritten dataframe
      #   !!! this is not executed yet !!!
      return to_download
  finally:
    # print(f"...setting shuffle partitions back to {orig_repart_num}")
    spark.conf.set("spark.sql.shuffle.partitions", orig_repart_num)

In [0]:
CELL_ASSET_DIR_FUSE = None
CELL_ASSET_DIR = None
LAST_UPDATED = None

# - previous write?
for r in os.listdir(EO_DIR_FUSE):
  if r.startswith("cell_assets_"):
    LAST_UPDATED = r.split('_')[-1]
    CELL_ASSET_DIR_FUSE = f"{EO_DIR_FUSE}/{r}"
    CELL_ASSET_DIR = CELL_ASSET_DIR_FUSE.replace('/dbfs','')
    break

os.environ['CELL_ASSET_DIR'] = CELL_ASSET_DIR
os.environ['CELL_ASSET_DIR_FUSE'] = CELL_ASSET_DIR_FUSE

print(f"LAST_UPDATED: '{LAST_UPDATED}' [from working location] ->")
print(f"\tCELL_ASSET_DIR: '{CELL_ASSET_DIR}'")
print(f"\tCELL_ASSET_DIR_FUSE: '{CELL_ASSET_DIR_FUSE}'")

In [0]:
ls -ls $EO_DIR_FUSE

In [0]:
eod_item_df = spark.read.load(CELL_ASSET_DIR)
print(f"count? {eod_item_df.count():,}")
eod_item_df.limit(10).display()

_Notice that some assets overlap more than one h3 cellid._ __Function `download_band` consolidates to unique 'item_id' values vs focus on h3 cell(s) to avoid repeated download requests.__ 

In [0]:
# - notice multipe h3 cells for some item ids
display (
  eod_item_df
    .filter("item_id == 'S2A_MSIL2A_20210601T204021_R014_T07VEK_20210602T071624'")
    .filter(
        f"asset.name == '{bands[0]}'"
      ) 
    .orderBy("h3")
)

### Download bands for items 

> This will generate a table per band in the specified catalog and schema (set earlier in the notebook). __Note: We are invoking with `do_clean_files=False` to avoid wiping out already downloaded files; also, passing `False` for 'is_append_mode' param on the table generation side, but you can pass `True` to change.__

#### First - a Dry-Run

> Add all the columns that are added in the "live" execution, but we are specifying no actual execution (sanity check). __Note: we have nulls for `out_file_sz` (size) and `is_out_file_valid` since the files were not actually downloaded or checked.__

In [0]:
# - First, a do-nothing dry-run for sanity check...
# !!! NOTICE: do_clean_files, do_download, and do_table-write all FALSE !!!
b_df_dry = download_band(eod_item_df, bands[0], False, do_clean_files=False, do_download=False, do_table_write=False)
display(b_df_dry.limit(10))

#### Download bands of interest

> For this example series, we focus on B04 (red), B03 (green), B02 (blue), and B08 (nir). __You can easily download all / more.__

__Download Just 'B02'__

> We have `do_clean_files=False` to not overwrite any existing data (for repeated execution). The 'band_b02' metadata table is set to be overwritten with `append_mode` set to `False`. __Note: you can adjust this to append vs overwrite.__ Also, it is ok to interrupt and restart execution as files are first checked to see if they have already been downloaded to avoid unnecessary IOPS.

In [0]:
download_band(eod_item_df, 'B02', False, do_clean_files=False, do_download=True, do_table_write=True)

_Look at the band table generated for B02 [blue]._

In [0]:
%sql SELECT * from band_b02 limit 10

_Look at a couple of the band GeoTIFFs._

In [0]:
ex_bands = [t[0] for t in spark.table('band_b02').select('out_file_path').limit(2).collect()]
display(
  spark.table("band_b02")
  .where(
    col("out_file_path").isin(ex_bands)
  )
)

In [0]:
for b in ex_bands:
  library.plot_file(b)

In [0]:
print(f"""b02 total count? {sql("select format_number(count(1),0) from band_b02").first()[0]}""")
print(f"""b02 valid count? {sql("select format_number(count(1),0) from band_b02 where is_out_file_valid").first()[0]}""")
print(f"""b02 false count? {sql("select format_number(count(1),0) from band_b02 where is_out_file_valid = False").first()[0]}""")
print(f"""b02 null count?  {sql("select format_number(count(1),0) from band_b02 where is_out_file_valid is null").first()[0]}""")

#### Optional: Fix Missing Data

> As a result of being gated in the free tier to Planetary Computure, a number of attempts to download band data might have resulted in an message versus the actual data (no failure condition provided). Here is what that might look like:

```
<?xml version="1.0" encoding="utf-8"?><Error><Code>AuthenticationFailed</Code><Message>Server failed to authenticate the request. Make sure the value of Authorization header is formed correctly including the signature.
RequestId:bf21b919-d01e-002f-4e00-2ae765000000
Time:2023-12-08T18:04:30.9365583Z</Message><AuthenticationErrorDetail>Signature not valid in the specified time frame: Start [Thu, 07 Dec 2023 17:18:15 GMT] - Expiry [Fri, 08 Dec 2023 18:03:15 GMT] - Current [Fri, 08 Dec 2023 18:04:30 GMT]</AuthenticationErrorDetail></Error>
```

The size is around 550 bytes, so we can test for this and smartly retry.

__Note:__ We are using Delta Lake MERGE support to udate a given band table [[1](https://docs.databricks.com/en/delta/merge.html#language-python)].

__Since 'B02' (blue) now has all data, let's turn to 'B03' (green).__

> Initially, we set `do_download=False` and `do_table_write` to demonstrate how this table can be filled in with a subsequent call to `download_missing_assets(...)` or `update_assets(...)`. __Note; 'out_file_sz' and 'is_out_file_valid' are both set to `Null` because we have not yet calculated this information (e.g. there may be some pre-existing files).__

In [0]:
# - example of a table write without file download
display(
  download_band(eod_item_df, 'B03', False, do_clean_files=False, do_download=False, do_table_write=True)
    .limit(10)
)

In [0]:
%sql 
-- notice the table was written
-- but no files downloaded
select * from band_b03 limit 1

_See all for field 'is_out_file_valid' are `null`._

In [0]:
print(f"""b03 total count? {sql("select format_number(count(1),0) from band_b03").first()[0]}""")
print(f"""b03 valid count? {sql("select format_number(count(1),0) from band_b03 where is_out_file_valid = True").first()[0]}""")
print(f"""b03 false count? {sql("select format_number(count(1),0) from band_b03 where is_out_file_valid = False").first()[0]}""")
print(f"""b03 null count?  {sql("select format_number(count(1),0) from band_b03 where is_out_file_valid is null").first()[0]}""")

In [0]:
def update_assets(update_df, band_tbl_name):
  """
  Test the out
  - This expects an existing band tbl_name generated from `download_band(...)`
  - This expects an update_df conforming to what is generated by `download_band(...)`
    to include from `download_missing_assets(...)`
  Returns a dataframe filtered to the merged data (from the band table).
  """
  from datetime import datetime
  
  last_updated = get_now_formatted()
  
  # [1] udf for download_asset
  # [2] re-calc size
  # [3] re-calc valid
  df = (
    update_df
      .drop(
        "last_update",
        "out_file_path", 
        "out_file_sz",
        "is_out_file_valid"
      )
      .withColumn(
        "last_update",
        F.lit(last_updated)
      )
      .withColumn(
        "out_file_path", 
        library.download_asset(
          "item_id",
          "band_name",
          "out_dir_fuse",
          "out_filename"
        )
      )
      .withColumn(
        "out_file_sz",
        file_size("out_file_path")
      )
      .withColumn(
        "is_out_file_valid",
        F
          .when(F.isnull("out_file_sz"), F.lit(False)) # <- null to False
          .when(col("out_file_sz") > F.lit(library.FILE_SIZE_THRESHOLD), F.lit(True))
          .otherwise(F.lit(False))
      )
  )

  # [4] merge changes back to original table
  delta_tbl_eod = DeltaTable.forName(spark, band_tbl_name)

  (
    delta_tbl_eod.alias("eod")
      .merge(
        df.alias("updates"),
        "eod.item_id = updates.item_id"
      ) 
      .whenMatchedUpdate(
        set = {
          "last_update": "updates.last_update",
          "out_file_path": "updates.out_file_path", 
          "out_file_sz": "updates.out_file_sz",
          "is_out_file_valid": "updates.is_out_file_valid"
        }
      )
    .execute()
  ) 
  
  # [5] return the changes
  # - from the 'band_tbl_name'
  return (
    spark.table(band_tbl_name)
      .filter(f"last_update == '{last_updated}'")
  )

def download_missing_assets(
    band_tbl_name, where_clause=None, do_dry_run=False
  ):
  """
  Download missing assets for band (from pre-existing table) and update the table.
  - Columns updated are 'out_file_sz', 'is_out_file_valid',
      and 'last_update'
  - Optional: 'where_clause' can be used to filter the table
  - Missing is based on 'is_out_file_valid' being False or Null
  - Assumes databricks catalog and schema already set
  - default is to update vs dry-run
  Returns the updated or dry run dataframe.
  """
  # - df from table name
  #   filter?
  df = spark.table(band_tbl_name)
  if where_clause is not None:
    df = df.filter(where_clause)
  
  # - filter 'is_out_file_valid' is False or Null
  df = df.filter(
    (col("is_out_file_valid") == False) |
    (F.isnull("is_out_file_valid"))
  )
  
  if not do_dry_run:
    # - handle missing
    return update_assets(df, band_tbl_name)
  else:
    # - just df for dry-run
    #   useful to test 'where_clase'
    return df

_A dry-run of a single item..._

In [0]:
display(
  download_missing_assets(
    "band_b03", 
    where_clause="item_id = 'S2A_MSIL2A_20210628T221531_R115_T03VXC_20210630T063438'", 
    do_dry_run=True
  )
)

_Actual run of a single item..._

In [0]:
display(
  download_missing_assets(
    "band_b03", 
    where_clause="item_id = 'S2A_MSIL2A_20210628T221531_R115_T03VXC_20210630T063438'", 
    do_dry_run=False 
  )
)

_The `where_clause` param is optional in `download_missing_assets(...)`, when not specified, all data where 'is_out_file_valid' is not True will be tested and (re)downloaded as needed._ __Note: merges can be a more expensive, so do testing to see which option (including the one below), meets your needs.__

### Download Any / All Other Bands

> This call will get all the bands; it can be rerun to download any missing files also. __Note: `append_mode` is `False`, meaning it will overwrite the current band table, but we set `do_clean_files=False` so it will just download new / missing files.__ 

_This downloads ~10K GeoTIFFs per band for the state of Alaska. The operation can take some time, especially depending on (1) the size of your cluster, (2) whether starting fresh, and (3) how impacted you are by the free tier limits / throttling._ Also, sometimes an executor on the cluster might hang due to the nature of these longer running jobs (and the phased up delays from throttling). In the even that a task is hung up for say 15+ minutes, e.g. 511/512 tasks having completed, it is ok to interrupt and restart execution as files are first checked to see if they have already been downloaded, so the operation can recover without to much duplicated time / processing.

In [0]:
# - uncomment when ready
for band in ['B02', 'B03', 'B04', 'B08']:
    print(f"working on band '{band}'...")
    download_band(eod_item_df, band, False, do_clean_files=False, do_download=True, do_table_write=True)

display(dbutils.fs.ls(f"{EO_DIR}"))

_Verify we have all data for bands._

In [0]:
for band in ['B02', 'B03', 'B04', 'B08']:
  print(f"::: '{band}' :::")
  df_band = spark.table(f"band_{band}")
  print(f"""\ttotal?            {df_band.count():,}""")
  print(f"""\tis valid?         {df_band.filter(F.expr("is_out_file_valid = True")).count():,}""")
  print(f"""\tnot valid?        {df_band.filter(F.expr("is_out_file_valid = False")).count():,}""")
  print(f"""\tunknown validity? {df_band.filter(F.expr("is_out_file_valid is null")).count():,}""")
  print("")

_See all the tables generated._

In [0]:
%sql show tables

_Look at the directories where the bands were downloaded._

In [0]:
display(dbutils.fs.ls(EO_DIR))