# Creation of RGB Sentinel-2 Chips for Model Build

This notebook creates the Sentinel-2 RGB image chips for cement, steel, or landcover sites from the EarthAI catalog.

* Site locations in China
* Sentinel-2, red, green, and blue bands
* Chips are 3-km on a side

## Import required libraries

In [None]:
from earthai.all import *
import earthai.chipping.strategy as chp
import pyspark.sql.functions as pys

import geopandas as gpd
import pandas as pd

import os
import shutil
import boto3

## Create Spark Session

* Set number of partitions on par with the number of catalog items per scene

In [None]:
partitions = 250
spark = create_earthai_spark_session(**{
    "spark.default.parallelism": partitions,
    "spark.sql.shuffle.partitions": partitions,
})

## Define input and output files and parameters

### Parameters

* `site_type` should be set to `'cement'`, `'steel'`, or `'landcover'`
* `year` defines the year of selected scenes
* `chip_size` is the size of chips (length) to create (in pixels)
* `max_cc` is the maximum cloud coverage to use in eod query filter (in percent)

In [None]:
site_type = 'landcover'
year = '2020'

chip_size = 300 # 3 km for Sentinel-2
max_cc = 5

### Input files

* `site_geojson` is a GeoJSON specifying the chip centers

In [None]:
site_geojson = '../../resources/macro-loc-model-build4/'+site_type+'_chip_cntr_china_v4.1_s2.geojson'

### Output files and paths

* `output_path` defines directory to write chip GeoTIFFs to, and sub-folder on S3 where chips are stored
* `s3_path` defines S3 high-level folder for L8 TIR macro-localization data
* `filename_append` is appended to each chip file name
* `chip_extents_gjson` is an output GeoJSON file with chip metadata and tile extents

In [None]:
output_path = 'ALD_S2_RGB_'+site_type+'_chips_v4p1_'+year+'_train4'
s3_path = 'S2-RGB-macro-localization-model-build4'

filename_append = 'v4p1_'+year+'_S2_RGB'
chip_extents_gjson = '../../resources/macro-loc-model-build4/'+output_path+'.geojson'

In [None]:
if os.path.exists('/scratch/'+output_path):
    shutil.rmtree('/scratch/'+output_path)

## Define EOD Catalog Read and Chipping Functions

### Get catalog of Sentinel-2 scenes that intersect with chip centroids

Queries EarthAI Catalog to find S2 scenes that intersect with chip centroids. Returns scenes/datetimes from May - August in specified `year`, limited to scenes with less than `max_cc`.

In [None]:
def eod_read_catalog(geom, year, max_cc=100):
    
    # Start/end date formatting
    start_date = year+'-05-01'
    end_date = year+'-08-31'
    
    # Query catalog
    site_cat = earth_ondemand.read_catalog(
        geo=geom,
        start_datetime=start_date,
        end_datetime=end_date,
        max_cloud_cover=max_cc,
        collections='sentinel2_l2a'
        )
    if len(site_cat) > 0:
        return(site_cat)
        
    else:
        return([])

### Create Image Chips

* Read and create image chips for specified chip centers
* Select data from highest quality scene

In [None]:
def create_chips(site_cat, chip_size=300, site_type=site_type):
    
    # Uses centroid-centered chipping to create same-size chips
    # Grabs red, green, and blue bands
    # Filter out chips smaller than chip_size x chips_size
    # Rename columns
    # Filter out chips with NoData cells
    # Normalize and convert data bands to uint16
    if site_type=='landcover':
        site_chip_all = spark.read.chip(site_cat, ['B04_10m','B03_10m','B02_10m'],
                                    chipping_strategy=chp.CentroidCentered(chip_size, chip_size)) \
                         .select('tile_id', 
                                 'eod_grid_id', 'id', 'datetime', 'eo_cloud_cover', 
                                 'B04_10m', 'B03_10m', 'B02_10m') \
                         .withColumn('tile_dims', rf_dimensions('B04_10m')) \
                         .filter((pys.col('tile_dims').rows == chip_size) & 
                                 (pys.col('tile_dims').cols == chip_size)) \
                         .withColumnRenamed('eod_grid_id', 'scene_id') \
                         .withColumnRenamed('eo_cloud_cover', 'scene_cloud_pnt') \
                         .withColumn('Red_uint16', rf_convert_cell_type('B04_10m', 'uint16')) \
                         .withColumn('nodata_cell_cnt', rf_no_data_cells('Red_uint16')) \
                         .filter(pys.col('nodata_cell_cnt') == 0) \
                         .withColumn('Red', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B04_10m', 'uint16')), 65535), 'uint16')) \
                         .withColumn('Green', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B03_10m', 'uint16')), 65535), 'uint16')) \
                         .withColumn('Blue', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B02_10m', 'uint16')), 65535), 'uint16')) \
                         .drop('B04_10m', 'B03_10m', 'B02_10m', 'tile_dims', 'Red_uint16', 'nodata_cell_cnt') \
                         .cache()
    else:
        site_chip_all = spark.read.chip(site_cat, ['B04_10m','B03_10m','B02_10m'],
                                    chipping_strategy=chp.CentroidCentered(chip_size, chip_size)) \
                         .select('uid', 'tile_id', 'dist_m', 
                                 'eod_grid_id', 'id', 'datetime', 'eo_cloud_cover', 
                                 'B04_10m', 'B03_10m', 'B02_10m') \
                         .withColumn('tile_dims', rf_dimensions('B04_10m')) \
                         .filter((pys.col('tile_dims').rows == chip_size) & 
                                 (pys.col('tile_dims').cols == chip_size)) \
                         .withColumnRenamed('eod_grid_id', 'scene_id') \
                         .withColumnRenamed('eo_cloud_cover', 'scene_cloud_pnt') \
                         .withColumn('Red_uint16', rf_convert_cell_type('B04_10m', 'uint16')) \
                         .withColumn('nodata_cell_cnt', rf_no_data_cells('Red_uint16')) \
                         .filter(pys.col('nodata_cell_cnt') == 0) \
                         .withColumn('Red', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B04_10m', 'uint16')), 65535), 'uint16')) \
                         .withColumn('Green', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B03_10m', 'uint16')), 65535), 'uint16')) \
                         .withColumn('Blue', 
                                 rf_convert_cell_type(
                                     rf_local_multiply(
                                         rf_rescale(rf_convert_cell_type('B02_10m', 'uint16')), 65535), 'uint16')) \
                         .drop('B04_10m', 'B03_10m', 'B02_10m', 'tile_dims', 'Red_uint16', 'nodata_cell_cnt') \
                         .cache()
    
    return(site_chip_all)

## Load site location point data

In [None]:
site_gdf = gpd.read_file(site_geojson)
print("Total count of sites: ", len(site_gdf))

## Get Catalog Covering All Chips

* All Chips in specified year, from May to August, with less than specified cloud coverage
* Determine unique scene ids

In [None]:
site_cat_all = eod_read_catalog(site_gdf, year, max_cc=max_cc)

In [None]:
scene_ids = site_cat_all.eod_grid_id.unique().tolist()
scene_ids.sort()
print("Total Number of Unique Scene Ids: ", len(scene_ids))

## Create Chips

* Loops over scene id's to speed up process
* Finds best scene to create unique chip per scene
* Writes chips to GeoTIFFs
* Creates GeoJSON file with chip extents and metadata

In [None]:
# Loop over scene ids
for scene_id in scene_ids:
    
    # Limit catalog to scenes matching scene id
    # Join to chip sites
    site_cat = site_cat_all[site_cat_all.eod_grid_id == scene_id]
    site_cat = gpd.sjoin(site_gdf, site_cat)
    
    # Create chips for all scenes
    site_chips = create_chips(site_cat, chip_size=chip_size)
    chp_cnt = site_chips.count()
    
    if (chp_cnt > 0):
        
        # For each tile_id, find the scene with the least cloud coverage
        chpinf_pdf = site_chips.select('tile_id', 'id', 'scene_cloud_pnt').toPandas()
        site_mincc_pdf = chpinf_pdf.sort_values('scene_cloud_pnt') \
                                   .groupby(['tile_id']).first() \
                                   .drop('scene_cloud_pnt', axis=1) \
                                   .reset_index()
        
        # Join to RasterFrame to find unique chip per tile_id
        site_mincc_sdf = spark.createDataFrame(site_mincc_pdf) \
                              .withColumnRenamed('tile_id', 'tile_id2') \
                              .withColumnRenamed('id', 'id2')
        site_chips_unq = site_chips.join(site_mincc_sdf, 
                                         (site_chips.tile_id == site_mincc_sdf.tile_id2) & \
                                         (site_chips.id == site_mincc_sdf.id2)) \
                                   .drop('tile_id2', 'id2') \
                                   .withColumn('file_path_name', 
                                               pys.concat_ws('_', pys.col('scene_id'), pys.col('tile_id'), 
                                                             lit(site_type), lit(filename_append))) \
                                   .cache()
        
        # Write chips to GeoTIFFs
        site_chips_unq.write.chip('/scratch/'+output_path, filenameCol='file_path_name', 
                                  catalog=False)
        
        # Write out Vector File of Tile Extents and Metadata
        site_chips_pdf = site_chips_unq.withColumn('tile_extent',
                                                   st_reproject(st_geometry(rf_extent('Red')),
                                                                rf_crs('Red'), lit('EPSG:4326'))) \
                                       .drop('Red', 'Green', 'Blue') \
                                       .toPandas()
        site_chips_gdf = gpd.GeoDataFrame(site_chips_pdf.drop('tile_extent', axis=1),
                                          geometry=site_chips_pdf.tile_extent,
                                          crs='EPSG:4326')
        
        # Append to growing GeoDataFrame
        if 'site_chip_ext_gdf' in locals():
            site_chip_ext_gdf = pd.concat([site_chip_ext_gdf, site_chips_gdf], 
                                          ignore_index=True)
        else:
            site_chip_ext_gdf = site_chips_gdf
            
    print('Done creating chips for scene ', scene_id, '(', \
          scene_ids.index(scene_id)+1, ' out of ', len(scene_ids), ')')

## Write out tile extents to GeoJSON

In [None]:
site_chip_ext_gdf.to_file(chip_extents_gjson, driver='GeoJSON')

## Tar GeoTIFFs and Upload to S3 bucket

In [None]:
#unix_code = 'tar -C /scratch -cvf '+output_path+'.tar '+output_path
#os.system(unix_code)
!tar -C /scratch -cvf {output_path+'.tar '} {output_path}

In [None]:
s3 = boto3.resource('s3')
bucket = s3.Bucket('sfi-shared-assets')

bucket.upload_file(output_path+'.tar', 
                   s3_path+'/'+output_path+'.tar')

## Clean Up Temporary Files

In [None]:
if os.path.exists('/scratch/'+output_path):
    shutil.rmtree('/scratch/'+output_path)

In [None]:
os.remove(output_path+'.tar')