# Creation of TIR Landsat 8 Chips for Model Build

This notebook creates the Landsat 8 TIR Band 10 image chips for cement, steel, or landcover sites from the EarthAI catalog.

For final model build, we created:

1. 4 sets of steel plant chips for years 2020, 2019, 2018, and 2017
2. 3 sets of cement plant chips for years 2020, 2019, and 2018
3. 1 set of landcover chips for year 2020

Documentation on Landsat 8 L1TP:
https://prd-wret.s3.us-west-2.amazonaws.com/assets/palladium/production/atoms/files/LSDS-1656_%20Landsat_Collection1_L1_Product_Definition-v2.pdf

## Import required libraries

In [None]:
from earthai.all import *
import earthai.chipping.strategy as chp
import pyspark.sql.functions as F
import geopandas as gpd
import pandas as pd
import os
import shutil
import boto3

## Define input and output files and parameters

### Parameters

* `site_type` should be set to `'cement'`, `'steel'`, or `'landcover'`
* `chip_size` is the size of chips (length) to create (in pixels)
* `unmsk_frac` is the minimum threshold on the fraction of unmasked cells required to keep site in sample
* `year2` defines the year for layer 1 (thermal band, in January)
* `year1` defines the year for layers 2 and 3, (thermal band, in January and April, respectively)

The original model was hardcoded for `year2 = '2018'` and `year1 = '2017'`; without more testing it is recommended that `year2 = [year1 + 1]`.

In [None]:
site_type = 'landcover'

chip_size = 35 # 1.05 km for Landsat 8
unmsk_frac = 0.75

year2 = '2018'
year1 = '2017'

### Input files

* `site_geojson` is a GeoJSON specifying the site locations (Points)

In [None]:
if site_type == 'landcover':
    gjson_prefix = 'lc'
else:
    gjson_prefix = site_type
site_geojson = "../../resources/macro-loc-model-build/"+gjson_prefix+"_exact_china_v4.1.geojson"

### Output files and paths

* `output_path` defines directory to write chip GeoTIFFs to, and sub-folder on S3 where chips are stored
* `s3_path` defines S3 high-level folder for L8 TIR macro-localization data
* `filename_append` is appended to each chip file name
* `chip_extents_gjson` is an output GeoJSON file with chip metadata and tile extents

In [None]:
output_path = 'ALD_L8_TIR_'+site_type+'_chips_v4p1_'+year2+'_train3'
s3_path = 'L8-TIR-macro-localization-model-build3'

filename_append = 'v4p1_'+year2+'_L8_TIR'
chip_extents_gjson = '../../resources/macro-loc-model-build/'+output_path+'.geojson'

## Load site location point data

In [None]:
site_gdf = gpd.read_file(site_geojson)
print("Total count of sites: ", len(site_gdf))

## Get catalog of Landsat 8 scenes that intersect with sites

* Queries EarthAI Catalog to find L8 scenes that intersect with sites
* Returns all scenes for:
    * January Year 2
    * January Year 1
    * April Year 1
* Join back to site location data for chipping

Below, we do NOT impose a maximum cloud cover filter. Since sites are small, it's possible that a high-cloud
coverage scene is relatively clear over the small region we need. Will select thes highest quality scenes
after the masking steps below.

### January Year 2

In [None]:
site_cat_year2_01 = earth_ondemand.read_catalog(
    geo=site_gdf,
    start_datetime=year2+'-01-01', 
    end_datetime=year2+'-01-31',
    max_cloud_cover=100,
    collections='landsat8_l1tp')
site_cat_year2_01 = gpd.sjoin(site_gdf, site_cat_year2_01)

### January Year 1

In [None]:
site_cat_year1_01 = earth_ondemand.read_catalog(
    site_gdf,
    start_datetime=year1+'-01-01', 
    end_datetime=year1+'-01-31',
    max_cloud_cover=100,
    collections='landsat8_l1tp'
)
site_cat_year1_01 = gpd.sjoin(site_gdf, site_cat_year1_01)

### April Year 1

In [None]:
site_cat_year1_04 = earth_ondemand.read_catalog(
    site_gdf,
    start_datetime=year1+'-04-01', 
    end_datetime=year1+'-04-30',
    max_cloud_cover=100,
    collections='landsat8_l1tp'
)
site_cat_year1_04 = gpd.sjoin(site_gdf, site_cat_year1_04)

## Initialize Spark

Set the number of partitions to be proportional to catalog size.

In [None]:
partitions = round(len(site_cat_year2_01) / 4)
spark = create_earthai_spark_session(**{
    "spark.default.parallelism": partitions,
    "spark.sql.shuffle.partitions": partitions,
})

## Read and create image chips for sites

* Uses chip reader to create uniform, same-sized chips covering all sites
* Filter out blank chips at edge of scenes
* Handle rare edge case where returned chip is less than specified size (when reach edge of a scene)

### January Year 2

In [None]:
site_chip_year2_01 = spark.read.chip(site_cat_year2_01, ['BQA'],
                                    chipping_strategy=chp.CentroidCentered(chip_size)) \
                         .select('uid', 'id', 'BQA') \
                         .withColumn('mask', rf_make_constant_tile(1, chip_size, chip_size, 'uint16')) \
                         .withColumn('tot_cell_count', rf_data_cells('BQA')) \
                         .filter(F.col('tot_cell_count') == chip_size*chip_size) \
                         .withColumn('BQA_min', rf_tile_min('BQA')) \
                         .filter(F.col('BQA_min') > 1.0) \
                         .repartition(partitions, 'uid', 'id')

### January Year 1

In [None]:
site_chip_year1_01 = spark.read.chip(site_cat_year1_01, ['BQA'],
                                    chipping_strategy=chp.CentroidCentered(chip_size)) \
                         .select('uid', 'id', 'BQA') \
                         .withColumn('mask', rf_make_constant_tile(1, chip_size, chip_size, 'uint16')) \
                         .withColumn('tot_cell_count', rf_data_cells('BQA')) \
                         .filter(F.col('tot_cell_count') == chip_size*chip_size) \
                         .withColumn('BQA_min', rf_tile_min('BQA')) \
                         .filter(F.col('BQA_min') > 1.0) \
                         .repartition(partitions, 'uid', 'id')

### April Year 1

In [None]:
site_chip_year1_04 = spark.read.chip(site_cat_year1_04, ['BQA'],
                                    chipping_strategy=chp.CentroidCentered(chip_size)) \
                         .select('uid', 'id', 'BQA') \
                         .withColumn('mask', rf_make_constant_tile(1, chip_size, chip_size, 'uint16')) \
                         .withColumn('tot_cell_count', rf_data_cells('BQA')) \
                         .filter(F.col('tot_cell_count') == chip_size*chip_size) \
                         .withColumn('BQA_min', rf_tile_min('BQA')) \
                         .filter(F.col('BQA_min') > 1.0) \
                         .repartition(partitions, 'uid', 'id')

## Select highest quality chips per site

* Mask chips by QA band and compute count of unmasked cells
* Remove chips with less than a minimum fraction of unmasked cells
* For each site, keep the chip with the highest number of unmasked cells

### Mask by QA band

* Landsat 8 Collection 1 Tier 1 QA band description: https://www.usgs.gov/land-resources/nli/landsat/landsat-collection-1-level-1-quality-assessment-band?qt-science_support_page_related_con=0#qt-science_support_page_related_con
* In order to apply a mask, the tile must have a NoData defined. Landsat 8 measurement bands have a cell type of uint16raw, which indicates that there is no NoData value defined. The first line of the code below sets the cell types to uint16, whose NoData value is 0. This will cause any zero-valued cells in the measurement band to be considered NoData. In Landsat 8, these areas correspond to the BQA fill areas.
* Remove chips with less than minimum threshold of unmasked cells

#### January Year 2

In [None]:
site_chip_year2_01 = site_chip_year2_01.withColumn('mask', # designated fill = yes
                                                 rf_mask_by_bit('mask', 'BQA', 0, 1)) \
                                     .withColumn('mask', # cloud = yes
                                                 rf_mask_by_bit('mask', 'BQA', 4, 1)) \
                                     .withColumn('mask', # cloud shadow conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 7, 2, [2, 3])) \
                                     .withColumn('mask', # cirrus conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 11, 2, [2, 3])) \
                                     .withColumn('unmsk_cell_count', rf_data_cells('mask')) \
                                     .filter(F.col('unmsk_cell_count') >= unmsk_frac*chip_size*chip_size)

#### January Year 1

In [None]:
site_chip_year1_01 = site_chip_year1_01.withColumn('mask', # designated fill = yes
                                                 rf_mask_by_bit('mask', 'BQA', 0, 1)) \
                                     .withColumn('mask', # cloud = yes
                                                 rf_mask_by_bit('mask', 'BQA', 4, 1)) \
                                     .withColumn('mask', # cloud shadow conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 7, 2, [2, 3])) \
                                     .withColumn('mask', # cirrus conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 11, 2, [2, 3])) \
                                     .withColumn('unmsk_cell_count', rf_data_cells('mask')) \
                                     .filter(F.col('unmsk_cell_count') >= unmsk_frac*chip_size*chip_size)

#### April Year 1

In [None]:
site_chip_year1_04 = site_chip_year1_04.withColumn('mask', # designated fill = yes
                                                 rf_mask_by_bit('mask', 'BQA', 0, 1)) \
                                     .withColumn('mask', # cloud = yes
                                                 rf_mask_by_bit('mask', 'BQA', 4, 1)) \
                                     .withColumn('mask', # cloud shadow conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 7, 2, [2, 3])) \
                                     .withColumn('mask', # cirrus conf is medium or high
                                                 rf_mask_by_bits('mask', 'BQA', 11, 2, [2, 3])) \
                                     .withColumn('unmsk_cell_count', rf_data_cells('mask')) \
                                     .filter(F.col('unmsk_cell_count') >= unmsk_frac*chip_size*chip_size)

### Select best-quality chips

For each date:
* Find the chip(s) with the highest number of unmasked cells
* If there's >1 chip (a tie) take the first record
* Read in thermal band for highest quality chip
* Normalize chips for min/max range of 0 to 65535

#### January Year 2

In [None]:
chpinf_year2_01_pdf = site_chip_year2_01.select('uid', 'id', 'unmsk_cell_count').toPandas()
chpinf_year2_01_pdf['grpid'] = chpinf_year2_01_pdf['uid']

In [None]:
site_year2_01_maxcnt = chpinf_year2_01_pdf.sort_values('unmsk_cell_count', ascending=False) \
                                        .groupby(['grpid']).first() \
                                        .drop('unmsk_cell_count', axis=1)
site_cat_year2_01 = site_cat_year2_01.merge(site_year2_01_maxcnt, on=['uid', 'id'], how='inner')

In [None]:
site_chip_year2_01_unq = spark.read.chip(site_cat_year2_01, ['B10'],
                                        chipping_strategy=chp.CentroidCentered(chip_size)) \
                             .select('uid', 'id', 'datetime', 'B10') \
                             .withColumn('B10_JY2', 
                                         rf_convert_cell_type(rf_local_multiply(rf_rescale(rf_convert_cell_type('B10', 'uint16')), 
                                                                                65535), 'uint16')) \
                             .drop('B10') \
                             .withColumnRenamed('id', 'id_JY2') \
                             .withColumnRenamed('datetime', 'datetime_JY2') \
                             .repartition(partitions, 'uid')

#### January Year 1

In [None]:
chpinf_year1_01_pdf = site_chip_year1_01.select('uid', 'id', 'unmsk_cell_count').toPandas()
chpinf_year1_01_pdf['grpid'] = chpinf_year1_01_pdf['uid']

In [None]:
site_year1_01_maxcnt = chpinf_year1_01_pdf.sort_values('unmsk_cell_count', ascending=False) \
                                        .groupby(['grpid']).first() \
                                        .drop('unmsk_cell_count', axis=1)
site_cat_year1_01 = site_cat_year1_01.merge(site_year1_01_maxcnt, on=['uid', 'id'], how='inner')

In [None]:
site_chip_year1_01_unq = spark.read.chip(site_cat_year1_01, ['B10'],
                                        chipping_strategy=chp.CentroidCentered(chip_size)) \
                             .select('uid', 'id', 'datetime', 'B10') \
                             .withColumn('B10_JY1', 
                                         rf_convert_cell_type(rf_local_multiply(rf_rescale(rf_convert_cell_type('B10', 'uint16')), 
                                                                                65535), 'uint16')) \
                             .drop('B10') \
                             .withColumnRenamed('id', 'id_JY1') \
                             .withColumnRenamed('datetime', 'datetime_JY1') \
                             .repartition(partitions, 'uid')

#### April Year 1

In [None]:
chpinf_year1_04_pdf = site_chip_year1_04.select('uid', 'id', 'unmsk_cell_count').toPandas()
chpinf_year1_04_pdf['grpid'] = chpinf_year1_04_pdf['uid']

In [None]:
site_year1_04_maxcnt = chpinf_year1_04_pdf.sort_values('unmsk_cell_count', ascending=False) \
                                        .groupby(['grpid']).first() \
                                        .drop('unmsk_cell_count', axis=1)
site_cat_year1_04 = site_cat_year1_04.merge(site_year1_04_maxcnt, on=['uid', 'id'], how='inner')

In [None]:
site_chip_year1_04_unq = spark.read.chip(site_cat_year1_04, ['B10'],
                                        chipping_strategy=chp.CentroidCentered(chip_size)) \
                             .select('uid', 'id', 'datetime', 'B10') \
                             .withColumn('B10_AY1', 
                                         rf_convert_cell_type(rf_local_multiply(rf_rescale(rf_convert_cell_type('B10', 'uint16')), 
                                                                                65535), 'uint16')) \
                             .drop('B10') \
                             .withColumnRenamed('id', 'id_AY1') \
                             .withColumnRenamed('datetime', 'datetime_AY1') \
                             .repartition(partitions, 'uid')

## Join TIR chips

* Join TIR chips at different dates into single RasterFrame
* Keep only sites where all three dates are included

In [None]:
site_chips_joined = site_chip_year2_01_unq.join(site_chip_year1_01_unq, on=['uid'], how='inner') \
                                         .join(site_chip_year1_04_unq, on=['uid'], how='inner')

## Write chips out as GeoTIFFs

* Writes chips to scratch directory
* Bundles output into tar file

In [None]:
if os.path.exists('/scratch/'+output_path):
    shutil.rmtree('/scratch/'+output_path)

In [None]:
site_chips_joined = site_chips_joined.withColumn('file_path_name', 
                                                 F.concat_ws('_', F.col('uid'), lit(site_type), lit(filename_append))) \
                                     .cache()

In [None]:
site_chips_joined.write.chip('/scratch/'+output_path, filenameCol='file_path_name', 
                             catalog=False)

In [None]:
unix_code = 'tar -C /scratch -cvf '+output_path+'.tar '+output_path
os.system(unix_code)

## Upload tar files to S3 bucket

In [None]:
s3 = boto3.resource('s3')
bucket = s3.Bucket('sfi-shared-assets')

bucket.upload_file(output_path+'.tar', 
                   s3_path+'/'+output_path+'.tar')

## Write out Vector File of Tile Extents and Metadata

* Serves as metadata catalog for chips
* Tile extent for just January Year 2 included for visualization

In [None]:
site_chips_geom_pdf = site_chips_joined.withColumn('tile_extent',
                                                   st_reproject(st_geometry(rf_extent('B10_JY2')),
                                                                rf_crs('B10_JY2'), lit('EPSG:4326'))) \
                                      .drop('B10_JY2', 'B10_JY1', 'B10_AY1') \
                                      .toPandas()

In [None]:
site_chips_geom_gdf = gpd.GeoDataFrame(site_chips_geom_pdf.drop('tile_extent', axis=1),
                                       geometry=site_chips_geom_pdf.tile_extent,
                                       crs='EPSG:4326')

In [None]:
site_chips_geom_gdf.to_file(chip_extents_gjson, driver='GeoJSON')

## Clean Up Temporary Files

In [None]:
if os.path.exists('/scratch/'+output_path):
    shutil.rmtree('/scratch/'+output_path)

In [None]:
os.remove(output_path+'.tar')