# S3VT Landsat and Sentinel 2 validation of hotspots - working

## Description
This notebook demonstrates how to:
 
From a candidate latitude longitude and solar_day:
* determine if intersecting Landsat or Sentinel 2 ARD exists
* apply the platform specific tests to determine if hotspots were detected in the vicinity 5km of hotspot
* return number of pixel identified as hotspots
* save a boolean file labelled with solar date of acquisition
* as a secondary test perform a Normalized Burnt Ratio and return as a binary with solar date of acquisition
    * find canidate dates within a time range of source hotspot
        * find closest before date within tolerance (dNBR A)
        * find closest after date within tolerance (dNBR B)
        * candidate closest to source hotspot will be used for hotspot matching i.e. high resolution hotspot

Assumptions:
* reflectance values are scaled by 10000 i.e. 100% reflectance = 10000
 

### Load packages

In [None]:
%matplotlib inline
from pathlib import Path
import datacube
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
import sys
import xarray as xr
import geopandas as gpd
sys.path.insert(0, '..')
import src.hotspot_utils as util
import random
import logging
import os
from shapely import wkt
from geopy.distance import distance

import rioxarray

In [None]:
logging.basicConfig(
    format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S',
    stream=sys.stdout,
)
logging.disable(logging.CRITICAL)

_LOG = logging.getLogger(__name__)

### Connect to the datacube

In [None]:
dc = datacube.Datacube(app='validating_hotspots')

### Be ignorant of the sensor

In [None]:
# configure sensor bands - #TODO implement sensor ignorance code here
sensor_ignorance = {'s2msi':{'0.433-0.453': 'nbart_coastal_aerosol',
                           '0.450-0.515': 'nbart_blue',
                           '0.525-0.600': 'nbart_green',
                           '0.630-0.680': 'nbart_red',
                           '0.845-0.885': 'nbart_nir_1',
                           '1.560-1.660': 'nbart_swir_2',
                           '2.100-2.300': 'nbart_swir_3',
                           'fmask': 'fmask'},
                   'lsoli': {'0.433-0.453': 'nbart_band01',
                           '0.450-0.515': 'nbart_band02',
                           '0.525-0.600': 'nbart_band03',
                           '0.630-0.680': 'nbart_band04',
                           '0.845-0.885': 'nbart_band05',
                           '1.560-1.660': 'nbart_band06',
                           '2.100-2.300': 'nbart_band07',
                           'fmask': 'fmask'}}

In [None]:
swaths = pd.read_pickle('../workdir_test1/swaths_154_113_20191101_20201008.pkl')
swath_gdf = swaths[swaths['geometry'].is_valid == True]
start_time_utc, end_time_utc = util.convert_solar_time_to_utc(154.0, 113.0, "20:00", "03:00")
hotspots_pkl_file = ('../workdir_test1/all_hotspots_154_113_20191101_20201008_2000_0300.pkl')
hotspots_gdf = pd.read_pickle(hotspots_pkl_file)

In [None]:
hotspots_gdf.set_index('datetime', inplace=True)#, drop=False
hotspots_gdf = hotspots_gdf.between_time(start_time_utc , end_time_utc)

In [None]:
results_list = []
product_availability = {}
for index, row in hotspots_gdf.resample("D", on='solar_night'):   
    for product in hotspots_gdf['satellite_sensor_product'].unique():
        results_list.append({'datetime': index, 'satellite_sensor_product': product, 'count': row[row['satellite_sensor_product'] == product].geometry.count()})

In [None]:
daily_hotspot_count = pd.DataFrame(results_list)

In [None]:
for index,product in enumerate(np.sort(daily_hotspot_count.satellite_sensor_product.unique(), axis = None)):
    print(index, product)
    daily_hotspot_count.loc[(daily_hotspot_count.satellite_sensor_product == product), 'available'] = \
    (daily_hotspot_count['count']/daily_hotspot_count['count'])*(index+1)

In [None]:
colours= {'AQUA_MODIS_LANDGATE': 'lightsteelblue',\
          'AQUA_MODIS_NASA6.03': 'lavender',\
          'NOAA 20_VIIRS_LANDGATE': 'indianred',\
          'NOAA 20_VIIRS_NASA2.0NRT': 'lightcoral',\
          'NOAA-19_AVHRR_LANDGATE': 'grey',\
          'SENTINEL_3A_SLSTR_ESA': 'darkgreen',\
          'SENTINEL_3A_SLSTR_EUMETSAT': 'seagreen',\
          'SENTINEL_3B_SLSTR_ESA': 'lime',\
          'SENTINEL_3B_SLSTR_EUMETSAT': 'limegreen',\
          'SUOMI NPP_VIIRS_LANDGATE': 'firebrick',\
          'SUOMI NPP_VIIRS_NASA1': 'darkred',\
          'TERRA_MODIS_LANDGATE': 'navy', \
          'TERRA_MODIS_NASA6.03': 'blue'} \

import matplotlib.pyplot as plt
import pandas as pd

# gca stands for 'get current axis'
ax = plt.gca()
ax.set_title('Hotspot data availability per product')

for product in (daily_hotspot_count.satellite_sensor_product.unique()):
    daily_hotspot_count[daily_hotspot_count['satellite_sensor_product'] == product].plot(kind='scatter',x='datetime',y='available',ax=ax, color=colours[product])

ax.set(ylabel='product')
ax.set(xlabel='date')
    #ax.set_yticklabels(daily_hotspot_count.satellite_sensor_product)
plt.yticks([1,2,3,4,5,6,7,8,9,10,11,12,13], colours.keys())


plt.savefig('data_availability.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
daily_hotspot_count.to_csv('dailyhotspotcount.csv')

In [None]:
hotspots_gdf.columns

In [None]:
hotspots_gdf["s2msi_rdnbr_gt_200"] = ""
hotspots_gdf["s2msi_pre_burn_time"] = ""
hotspots_gdf["s2msi_post_burn_time"] = ""
hotspots_gdf["s2msi_pre_burn_timedelta"] = ""
hotspots_gdf["s2msi_post_burn_timedelta"] = ""
hotspots_gdf["s2msi_pre_percent"] = ""
hotspots_gdf["s2msi_post_percent"] = ""
hotspots_gdf["s2msi_pre_hotspots"] = ""
hotspots_gdf["s2msi_post_hotspots"] = ""
hotspots_gdf["lsoli_rdnbr_gt_200"] = ""
hotspots_gdf["lsoli_pre_burn_time"] = ""
hotspots_gdf["lsoli_post_burn_time"] = ""
hotspots_gdf["lsoli_pre_burn_timedelta"] = ""
hotspots_gdf["lsoli_post_burn_timedelta"] = ""
hotspots_gdf["lsoli_pre_percent"] = ""
hotspots_gdf["lsoli_post_percent"] = ""
hotspots_gdf["lsoli_hotspot_percent"] = ""
hotspots_gdf["lsoli_pre_hotspots"] = ""
hotspots_gdf["lsoli_post_hotspots"] = ""

# Functions

In [None]:
# Buffer candidate hotspot with 5 kilometre radius (or .05 degrees will do)
def buffer_hotspot(lon, lat):
    ul_lon = lon - 0.05
    lr_lon = lon + 0.05
    ul_lat = lat + 0.05
    lr_lat = lat - 0.05
    return ((ul_lon, lr_lon), (ul_lat, lr_lat))

In [None]:
def buffer_date(firetime, days):
    prefire_date = (firetime - np.timedelta64(days, "D")).astype(str)
    postfire_date = (firetime + np.timedelta64(1, "D")).astype(str)
    return(prefire_date, postfire_date)

In [None]:
def get_measurement_list(product):
    measurement_list = []
    for i in dc.list_products().name:
        for j in dc.list_measurements().query('product == @i').name:
            if i == product:
                measurement_list.append([i, '--',j])
    return(measurement_list)

In [None]:
# potentially unambiguous active fire pixels
def get_candidates(ds):
    test1 = (((ds[sensor_ignorance[sensor]['2.100-2.300']] / ds[sensor_ignorance[sensor]['0.845-0.885']]) > 2.5) *
             ((ds[sensor_ignorance[sensor]['2.100-2.300']]  - ds[sensor_ignorance[sensor]['0.845-0.885']]) > 3000) *
             (ds[sensor_ignorance[sensor]['2.100-2.300']] > 5000))
    # Unambiguous fire pixels
    test2 = (((ds[sensor_ignorance[sensor]['1.560-1.660']] > 8000) *
              (ds[sensor_ignorance[sensor]['0.433-0.453']] < 2000)) *
             ((ds[sensor_ignorance[sensor]['0.845-0.885']] > 4000) +
              (ds[sensor_ignorance[sensor]['2.100-2.300']] < 1000)).clip(min=0, max=1))
    # other candidate fire pixels
    test3 = (((ds[sensor_ignorance[sensor]['2.100-2.300']]/ds[sensor_ignorance[sensor]['0.845-0.885']]) > 1.8)*
             (ds[sensor_ignorance[sensor]['2.100-2.300']]-ds[sensor_ignorance[sensor]['0.845-0.885']]  > 1700))
    unambiguous = (test1 + test2 + test3).clip(min=0, max=1)
    return(unambiguous)

In [None]:
def get_context_kernel_array(y, x, array):
    
    T, Y, X = array.shape

    ymin = y - 60
    ymax = y + 60
    xmin = x - 60
    xmax = x + 60
    
    if ymin < 0:
        ymin = 0
    
    if xmin < 0:
        xmin = 0

    if ymax > Y:
        ymax = Y
        
    if xmax > X:
        xmax = X

    try:
        outarray = array[0][:, ymin:ymax][xmin:xmax]
    except:
        outarray = np.nans((61,61), dtype=np.float64)
    
    return(outarray, (ymin, ymax, xmin, xmax))

In [None]:
def run_test6(ds):
    #6. ratio b7 b6 > 1.6
    return((ds[sensor_ignorance[sensor]['2.100-2.300']]/ds[sensor_ignorance[sensor]['1.560-1.660']]) > 1.6 )         
    

In [None]:
# Oceans test
#7. {b4 > b5 AND b5 > b6 AND b6 > b7 AND b1 - b7 < 0.2}
def run_test7(ds):
    test7 = ((ds[sensor_ignorance[sensor]['0.630-0.680']]>ds[sensor_ignorance[sensor]['0.845-0.885']])*
             (ds[sensor_ignorance[sensor]['0.845-0.885']]>ds[sensor_ignorance[sensor]['1.560-1.660']])*
             (ds[sensor_ignorance[sensor]['1.560-1.660']]>ds[sensor_ignorance[sensor]['2.100-2.300']])*
             ((ds[sensor_ignorance[sensor]['0.433-0.453']]-ds[sensor_ignorance[sensor]['2.100-2.300']]) < 2000))

    return(test7.clip(min=0, max=1))#.plot()


In [None]:
# Water bodies test - comment - seems like  bad test / smoke complications?
#AND
#8. {(b3 > b2)
def run_test8(ds):    
    test8 = (ds[sensor_ignorance[sensor]['0.525-0.600']]>ds[sensor_ignorance[sensor]['0.450-0.515']])

    return(test8.clip(min=0, max=1))

In [None]:
#OR
#9. (b1 > b2 AND b2 > b3 AND b3 < b4)}.

def run_test9(ds):
    test9 = ((ds[sensor_ignorance[sensor]['0.433-0.453']]>ds[sensor_ignorance[sensor]['0.450-0.515']]) *
            (ds[sensor_ignorance[sensor]['0.450-0.515']]>ds[sensor_ignorance[sensor]['0.525-0.600']])*
            (ds[sensor_ignorance[sensor]['0.525-0.600']]<ds[sensor_ignorance[sensor]['0.630-0.680']]))
  
    return(test9.clip(min=0, max=1))


In [None]:
def get_watermasks(ds):
    watermask=(run_test7(ds)+run_test8(ds)+run_test9(ds)).clip(min=0, max=1)
    return(watermask)

In [None]:
def get_hotspots(ds):
    # Find the candidates and perform context check
    candidates = get_candidates(ds)
    watermasks = get_watermasks(ds)
    indices = np.where(candidates.data == 1)
    swircandidates = (ds[sensor_ignorance[sensor]['2.100-2.300']].where(candidates.data == 0)).where(watermasks.data == 0)
    nircandidates = (ds[sensor_ignorance[sensor]['0.845-0.885']].where(candidates.data == 0)).where(watermasks.data == 0)

    test4 = (candidates*0)
    test5 = (candidates*0)

    index = 0

    while index < len(indices[1]):
        y = indices[1][index]
        x = indices[2][index]

        #4. ratio between b7 b5 > ratio b7 b5 + max[3x std ratio b7 and b5, 0.8 ]
        #AND
        #5. b7 > b7 + max[3x std b7, 0.08]
        #AND

        swirkernel = get_context_kernel_array(y,x,ds[sensor_ignorance[sensor]['2.100-2.300']].data)[0]
        nirkernel = get_context_kernel_array(y,x,ds[sensor_ignorance[sensor]['0.845-0.885']].data)[0]
        #uncommented the above on 2 November
        #swirkernel = get_context_kernel_array(y,x,swircandidates.data)[0]
        #nirkernel = get_context_kernel_array(y,x,nircandidates.data)[0] 

        swir = ds[sensor_ignorance[sensor]['2.100-2.300']].data[0][y][x]
        nir = ds[sensor_ignorance[sensor]['0.845-0.885']].data[0][y][x]

        test4.data[0][y][x] = ((swir/nir) > (np.nanmean(swirkernel/nirkernel) + max(3*np.nanstd(swirkernel/nirkernel), 0.8))) 
        test5.data[0][y][x] = (swir > (np.nanmean(swirkernel) + max(3*np.nanstd(swirkernel), 0.08)))

        #print(test4.data[0][y][x],(swir/nir), (np.nanmean(swirkernel/nirkernel) + max(3*np.nanstd(swirkernel/nirkernel), 0.8)))
        #print(test5.data[0][y][x], swir,(np.nanmean(swirkernel) + max(3*np.nanstd(swirkernel), 0.08)) )
        # Write values to new dimension
        #print(index, y, x, get_context_kernel_array(y,x,ds[sensor_ignorance[sensor]['2.100-2.300']].data)[1])
        index = index + 1
    test6 = run_test6(ds)
    t, y, z = np.where((candidates*(test4*test5*test6)).data == 1)
    hotspots = len(y)
    return(hotspots, (candidates*(test4*test5*test6)))
#(candidates*(test4*test5*test6)).plot()

In [None]:
def get_nbr(ds):
    swir = ds[sensor_ignorance[sensor]['2.100-2.300']]
    nir = ds[sensor_ignorance[sensor]['0.845-0.885']]
    return((nir - swir) / (swir + nir))

In [None]:
def get_rdnbr(pre_fire_image, post_fire_imag):
    # Revitalising dNBR from NSW Govt 
    postfire_nbr = get_nbr(post_fire_image)
    prefire_nbr = get_nbr(pre_fire_image)
    dnbr = (prefire_nbr[0] - postfire_nbr[0])
    # Scaling and offset as per NSW Govt algorithm
    #rdnbr = ((dnbr/(np.sqrt(np.abs(prefire_nbr[0]))))*1000)
    return((dnbr/(np.sqrt(np.abs(prefire_nbr[0]))))*1000)

In [None]:
def plot_list_rgb(image_list, label_list, fake_saturation):

    fig, ax = plt.subplots(nrows=1, ncols=len(image_list), figsize=(10,10))
    column = 0
    
        
    for image in image_list:
        rgb = image.to_array(dim='color')
        rgb = rgb.transpose(*(rgb.dims[1:]+rgb.dims[:1]))  # make 'color' the last dimension
        #rgb = rgb.where((rgb <= fake_saturation).all(dim='color'))  # mask out pixels where any band is 'saturated'
        rgb = (rgb-500)/fake_saturation  # scale to [0, 1] range for imshow       
        
        if len(image_list) == 1:
            ax.axis('off')
            ax.imshow(rgb[0])
            ax.set_title(str(image.time.values)[2:12]+" \n "+label_list[column],  wrap=True) 
        else:
            ax[column].axis('off')
            ax[column].imshow(rgb[0])
            if column < 3:
                ax[column].set_title(str(image.time.values)[2:12]+" \n "+label_list[column],  wrap=True) 
            else:
                ax[column].set_title(label_list[column],  wrap=True) 
            column = column+1
    plt.show()
    return

In [None]:
# Testing RGB
#plot_list_rgb([ds.isel(time=0), ds.isel(time=0), ds.isel(time=0), ds.isel(time=0)], ['label', 'label','label', 'label'], 3500)

In [None]:
# testing
#plot_list_rgb([pre_fire_image, post_fire_image, post_fire_image, post_fire_image], 3500)
#plot_list_rgb([post_fire_image], 3500)

In [None]:
def get_timedelta_from_ds(datacubeds, target_time):
    timedelta = {}
    index = 0
    for i in list(datacubeds.time.data):
        timedelta[index] = {"time": i, "delta": target_time - i}
        index = index+1
    return(pd.DataFrame.from_dict(timedelta).transpose())

In [None]:
def get_portion_imaged(image):
    t, y, x = np.where(image[sensor_ignorance[sensor]['2.100-2.300']] > 0)
    t1, y1, x1 = np.where(image[sensor_ignorance[sensor]['2.100-2.300']] == 0)
    return(len(y)/(len(y1)+len(y)), len(y))

In [None]:
def get_portion_above_threshold(image, threshold):
    y1, x1 = np.where(image)
    y, x = np.where(image > threshold)
    return(len(y)/(len(y1)+len(y)), len(y))

In [None]:
hotspots_gdf.reset_index(inplace=True)

In [None]:
# Check you're getting the UTC times in the correct interval
#for index, row in hotspots_gdf.resample("H", on='solar_night'):
#    print(index, len(row), row.datetime.max(), row.solar_night.max())

In [None]:
sensors_products = {'lsoli': ['ga_ls8c_ard_3'], 's2msi': ['s2a_ard_granule', 's2b_ard_granule']}

In [None]:
if not os.path.exists('TM_WORLD_BORDERS_SIMPL-0.3.zip'):
    !wget 'http://thematicmapping.org/downloads/TM_WORLD_BORDERS_SIMPL-0.3.zip'
australia = gpd.GeoDataFrame.from_file('zip://TM_WORLD_BORDERS_SIMPL-0.3.zip')
australia = australia[australia.NAME=='Australia']

In [None]:
# On a given day, where geometry of Landsat or Sentinel 2 from datacube intersect
# for a random set of hotspots, produce pre fire imagery, post fire imagery
# hotspots and relativised normalised difference burnt area imagery.
# Image selection is based on whether or not eligible pixels are above a percentage
# threshold of the sample area
join_results_list = []

tries = 0
for sensor in sensors_products.keys():
    measurements = []
    for measurement in sensor_ignorance[sensor]:
        measurements.append(sensor_ignorance[sensor][measurement])
    if sensor == 'lsoli':
        buffer_days = 17
    if sensor == 's2msi':
        buffer_days = 12
    
    output_path = Path(sensor)
    output_path.mkdir(parents=True, exist_ok=True)

    # Intersect hotspots with sensor footprint from datacube daily query
    
    dataset_list=[]

    for index_a, gdf_ra in hotspots_gdf.resample("D", on='solar_night'):

        # Get the most recent hotspot UTC time in the resampled period
        image_utc_time = gdf_ra.datetime.max()
        
        # Search datacube for a period 24hrs extending beyond the end of the last sample
        image_time_tuple = (str(image_utc_time), str(image_utc_time + np.timedelta64(24,'h')))

        xtuple = (113, 154)
        ytuple = (-10, -44)
        query = {
            'x': xtuple, 
            'y': ytuple,
            'time': (image_time_tuple),
            'measurements': measurements,
            'output_crs': 'EPSG:3577',
            'resolution': (-30, 30),
            'group_by': 'solar_day'
        }
        
        for product in sensors_products[sensor]:
            datasets = dc.find_datasets(product=product, **query)
            dataset_list.extend(datasets) 
        _LOG.info(f"{product} product selected")
        geometry_list = []
        for i in datasets:
            geometry_list.append(i.extent.to_crs('epsg:4326'))
        shapes = gpd.GeoDataFrame(gpd.GeoSeries(geometry_list))
        shapes.columns = ['geometry']
        shapes = shapes.set_crs('epsg:4326')
        _LOG.info(f"Shapes produced")
        

        
        # spatial join returning hotspots that intersect with datacube image bounds
        #join = gpd.sjoin(gdf_ra, shapes)
        join = gdf_ra.sjoin(shapes, how="inner")
        join = join.drop_duplicates()
        _LOG.info(f"Join successful")
        _LOG.info(f"Hotspot time span from {join.datetime.min()} to {join.datetime.max()}")
        _LOG.info(f"Datacube shapes time span from {image_time_tuple[0]} to {image_time_tuple[1]}")
        if len(join) > 0:
            fig, ax = plt.subplots(figsize=(2,2))

            ax.set_aspect('equal')
            ax.axis('off')
            australia.plot(ax=ax, color='white',  edgecolor='black')
            shapes.plot(ax=ax, color='black')
            join.plot(ax=ax, marker='o', color='red', markersize=5)
            plt.show()
        
        # for each hotspot intersecting with the follow day, run hotspot detection
        # Randomise selection of hotspots (maybe should randomise a set number for
        # each hotspot product

        choice = 0
        hotspot_index_list = []
        if len(join) > 0:
            while choice < 20:
                random_hotspot_index = random.choice(join.index)
                if random_hotspot_index not in hotspot_index_list:
                    hotspot_index_list.append(random_hotspot_index)
                choice = choice+1

        _LOG.info(f"{hotspot_index_list} random indexes selected")    

        for hotspot_index in hotspot_index_list:
            # Construct datacube query parameters from hotspot geometry

            hotspot_lat = (join[(join.index == hotspot_index)].latitude.values[0])
            hotspot_lon = (join[(join.index == hotspot_index)].longitude.values[0])
            xtuple, ytuple = buffer_hotspot(hotspot_lon, hotspot_lat)
            hotspot_utc_time = (join[(join.index == hotspot_index)].datetime.values[0])  
            hotspot_time_tuple = buffer_date(hotspot_utc_time, buffer_days)
            _LOG.info(f"image UTC time {image_utc_time}")
            _LOG.info(f"{[output_path,hotspot_index, xtuple, ytuple, hotspot_time_tuple, hotspot_utc_time]}")
            
            try:
                _LOG.info(f"Attempting run_hotspot for candidate datacube ....")
                #pre_fire_image, post_fire_image, pre_percent, post_percent = run_hotspots(sensor_ignorance, sensor, measurements, output_path,hotspot_index, xtuple, ytuple, hotspot_time_tuple, hotspot_utc_time)
                
                
                #### inexplicible failures so not using run_hotspots - presenting in line ####
                
                
                query = {
                    'x': xtuple, 
                    'y': ytuple,
                    'time': (hotspot_time_tuple),
                    'measurements': measurements,
                    'output_crs': 'EPSG:3577',
                    'resolution': (-30, 30),
                    'group_by': 'solar_day'
                }

                try:
                    dataset_list = []
                    # Query datacube to find intersecting images for each hotspot
                    for product in sensors_products[sensor]:
                        datasets = dc.find_datasets(product=product, **query)
                        dataset_list.extend(datasets)  

                    ds = dc.load(datasets=dataset_list, **query)

                    image_index = 0

                    pd_timediff = get_timedelta_from_ds(ds, hotspot_utc_time) 

                    # Initialise variables
                    pre_fire_candidate_delta = None
                    post_fire_candidate_delta = None
                    pre_fire_image = None
                    post_fire_image = None
                    _LOG.info(f"{pd_timediff}")
                    _LOG.info(f"Hotspot UTC date {hotspot_utc_time} and index {hotspot_index}")

                    while image_index < len(ds.time):
                        _LOG.info(f"{image_index+1} of {len(ds.time)} images being assessed ...")

                        # Load each image
                        image = ds.isel(time=[image_index])[([sensor_ignorance[sensor]['2.100-2.300'],sensor_ignorance[sensor]['0.845-0.885'],sensor_ignorance[sensor]['0.450-0.515']])]#,'fmask'] )]

                        # 1 == clear in FMask
                        #### EXPERIMENTING WITH CLOUD MASK ####

                        mask = ds.isel(time=[image_index])[('fmask')]
                        mask = (mask*mask.values==1)

                        image = image*mask
                        #######################################
                        # Add dataset time to filename

                        portion, valid_count = get_portion_imaged(image)

                        #pd_timediff[(pd_timediff.time != hotspot_utc_time)]
                        _LOG.info(f"{portion}, portion reported")

                        # Only use an image if at least 50% valid


                        if (portion > 0.50):

                            # If candidate on same day as hotspot, make it the post_fire_image
                            _LOG.info(f"{image_index+1} image of {len(pd_timediff)}")
                            if (image_index == pd_timediff[pd_timediff.delta.abs() == pd_timediff.delta.abs().min()].index):
                                _LOG.info(f"{image_index} - Post fire image found on hotspot day")
                                post_fire_candidate_delta = pd_timediff.iloc[[image_index]].delta.abs()
                                post_fire_image = image
                                post_index = image_index
                                _LOG.info(f"{len(pd_timediff[(pd_timediff.delta == pd_timediff.delta.abs())])} length of pd timediff")
                            else:
                                # For candidates not on same day as hotspot, they can be split into pre and post hotspot 

                                # Handle pre hotspot candidates here
                                if image_index in pd_timediff[(pd_timediff.delta == pd_timediff.delta.abs())].index:
                                    # Do this for the first pass on pre_fire candidates
                                    if pre_fire_candidate_delta is None:
                                        _LOG.info(f"Pre fire candidate found")
                                        pre_fire_candidate_delta = pd_timediff.iloc[[image_index]].delta.abs()
                                        pre_fire_image = image
                                        pre_index = image_index
                                    # Do this for subsequent passes on pre_fire candidates
                                    else:
                                        # If subsequent pre fire candidates have a shorter delta, use them
                                        if pre_fire_candidate_delta.values > pd_timediff.iloc[[image_index]].delta.abs().values:
                                            pre_fire_candidate_delta = pd_timediff.iloc[[image_index]].delta.abs()
                                            pre_fire_image = image
                                            pre_index = image_index
                                if image_index in pd_timediff[(pd_timediff.delta <  pd_timediff.delta.abs())].index:
                                    if post_fire_candidate_delta is None:
                                        # Consider post fire candidates
                                        if (image_index == pd_timediff[(pd_timediff.delta < pd_timediff.delta.abs())].index):
                                            post_fire_candidate_delta = pd_timediff.iloc[[image_index]].delta.abs()
                                            post_fire_image = image
                                            post_index = image_index
                                    else:
                                        if post_fire_candidate_delta.values > pd_timediff.iloc[[image_index]].delta.abs().values:
                                            post_fire_candidate_delta = pd_timediff.iloc[[image_index]].delta.abs()
                                            post_fire_image = image
                                            post_index = image_index

                        image_index = image_index + 1

                except:
                    result = '-'    

                # Write results to file
                if pre_fire_image is not None:
                    
                    _LOG.info(f"Pre fire index = {pre_index}")
                    pre_fire_image.attrs.pop('grid_mapping', None)
                    pre_fire_image.isel(time=0).rio.to_raster(output_path.joinpath(str(hotspot_index)+'_'+str(pre_fire_image.time[0].data)+'rgb.tif'))
                    pre_hotspots, pre_hotspot_array = get_hotspots(ds.isel(time=[pre_index]))
                    pre_hotspot_array.astype('int8').rio.to_raster(output_path.joinpath(str(hotspot_index)+'_'+str(pre_fire_image.time[0].data)+'hotspots.tif'))
                    portion, pre_valid_count = get_portion_imaged(pre_fire_image)
                else:
                    _LOG.info(f"No suitable Pre fire image found")

                if post_fire_image is not None:
                    
                    _LOG.info(f"Post fire index = {post_index}")
                    post_fire_image.attrs.pop('grid_mapping', None)
                    post_fire_image.isel(time=0).rio.to_raster(output_path.joinpath(str(hotspot_index)+'_'+str(post_fire_image.time[0].data)+'rgb.tif'))
                    post_hotspots, post_hotspot_array = get_hotspots(ds.isel(time=[post_index]))
                    post_hotspot_array.astype('int8').rio.to_raster(output_path.joinpath(str(hotspot_index)+'_'+str(post_fire_image.time[0].data)+'hotspots.tif'))
                    portion, post_valid_count = get_portion_imaged(post_fire_image)
                else:
                    _LOG.info(f"No suitable Post fire image found")
                
                
                
                ###########################

                
                
                _LOG.info(f"Hotspot result returned")
                rdnbr = get_rdnbr(pre_fire_image, post_fire_image)

                rdnbr.astype('int16').rio.to_raster(output_path.joinpath(str(hotspot_index)+'_RdNBR.tif'))
                
                xr_post_hotspots = xr.Dataset({'1': post_hotspot_array,
                                          '2': post_hotspot_array*0,
                                          '3': post_hotspot_array*0})*3500
                
                xr_rdnbr = xr.Dataset({'1': rdnbr*3,'2': rdnbr*3,'3': rdnbr*3})
                xr_rdnbr = xr_rdnbr.expand_dims({'time':1, })
                xr_rdnbr = xr_rdnbr.transpose('time', 'y', 'x')
                
                plot_list_rgb([pre_fire_image, post_fire_image, xr_post_hotspots, xr_rdnbr ],
                              ['pre fire RGB', 'post fire RGB', 'hotspots', 'RdNBR'], 3500)
                
                portion, valid_count = get_portion_above_threshold(rdnbr, 200)
                _LOG.info(f"{portion} rdnbr portion result")
                print(f"{hotspot_index}, {portion} rdnbr portion result")
                
                if (sensor == 's2msi'):  
                    join.loc[hotspot_index, "s2msi_rdnbr_gt_200"] = portion
                    join.loc[hotspot_index, "s2msi_pre_burn_time"] = pre_fire_image.time[0].data
                    join.loc[hotspot_index, "s2msi_post_burn_time"] = post_fire_image.time[0].data
                    join.loc[hotspot_index, "s2msi_pre_burn_timedelta"] = get_timedelta_from_ds(pre_fire_image, hotspot_utc_time).delta[0]
                    join.loc[hotspot_index, "s2msi_post_burn_timedelta"] = get_timedelta_from_ds(post_fire_image, hotspot_utc_time).delta[0]
                    join.loc[hotspot_index, "s2msi_pre_percent"] = pre_hotspots / pre_valid_count
                    join.loc[hotspot_index, "s2msi_post_percent"] = post_hotspots / post_valid_count
                    join.loc[hotspot_index, "s2msi_pre_hotspots"] = pre_hotspots
                    join.loc[hotspot_index, "s2msi_post_hotspots"] = post_hotspots
                    
                if (sensor == 'lsoli'):
                    join.loc[hotspot_index, "lsoli_rdnbr_gt_200"] = portion
                    join.loc[hotspot_index, "lsoli_pre_burn_time"] = pre_fire_image.time[0].data
                    join.loc[hotspot_index, "lsoli_post_burn_time"] = post_fire_image.time[0].data
                    join.loc[hotspot_index, "lsoli_pre_burn_timedelta"] = get_timedelta_from_ds(pre_fire_image, hotspot_utc_time).delta[0]
                    join.loc[hotspot_index, "lsoli_post_burn_timedelta"] = get_timedelta_from_ds(post_fire_image, hotspot_utc_time).delta[0]        
                    join.loc[hotspot_index, "lsoli_pre_percent"] = pre_hotspots / pre_valid_count
                    join.loc[hotspot_index, "lsoli_post_percent"] = post_hotspots / post_valid_count
                    join.loc[hotspot_index, "lsoli_pre_hotspots"] = pre_hotspots
                    join.loc[hotspot_index, "lsoli_post_hotspots"] = post_hotspots
                    
                # Collect results
                join_results_list.append(join.loc[hotspot_index])
    
                print(f"{hotspot_index} with {post_hotspots} {sensor} hotspots based on {str(join.loc[hotspot_index, 'satellite_sensor_product'])}")
            except:
                print('either pre, post or both images not available for: ', hotspot_index)
            
            
            # Uncomment the below for testing
            #tries = tries + 1
        
            #if tries > 1:

            #    break

In [None]:
# Gather results - for some reason at times two records are returned - #TODO debug

fixed_join_results_list = []
for i in join_results_list:
    if len(i) < 31:
        fixed_join_results_list.append(pd.DataFrame(i.squeeze()).iloc[[0]])
    else:
        fixed_join_results_list.append(pd.DataFrame(i.squeeze()).transpose())

In [None]:
s2msi_results_list = []
lsoli_results_list = []
for i in fixed_join_results_list:
    if i.s2msi_pre_burn_time.values == "":
        lsoli_results_list.append(i)
    else:
        s2msi_results_list.append(i)

In [None]:
s2msi_results_gpd = gpd.GeoDataFrame(pd.concat(s2msi_results_list))

In [None]:
lsoli_results_gpd = gpd.GeoDataFrame(pd.concat(lsoli_results_list))

In [None]:
import pickle
filehandler = open('fixed_join_results_list.pkl', 'wb') 
pickle.dump(fixed_join_results_list, filehandler)

In [None]:
for index, product in lsoli_results_gpd.groupby('satellite_sensor_product'):
    candidate_hotspots = len(product)
    highres_hotspots = len(product[product['lsoli_post_hotspots'] > 0])
    percent_confirmed = str((highres_hotspots/candidate_hotspots)*100)[:5]+"%"
    print(index, candidate_hotspots, highres_hotspots, percent_confirmed)

In [None]:
for index, product in s2msi_results_gpd.groupby('satellite_sensor_product'):
    candidate_hotspots = len(product)
    highres_hotspots = len(product[product['s2msi_post_hotspots'] > 0])
    percent_confirmed = str((highres_hotspots/candidate_hotspots)*100)[:5]+"%"
    print(index, candidate_hotspots, highres_hotspots, percent_confirmed)

In [None]:
filehandler = open('lsoli_results_gpd.pkl', 'wb') 
pickle.dump(lsoli_results_gpd, filehandler)
lsoli_results_gpd.to_csv('lsoli_results_gpd.csv')
filehandler = open('s2msi_results_gpd.pkl', 'wb') 
pickle.dump(s2msi_results_gpd, filehandler)
s2msi_results_gpd.to_csv('s2msi_results_gpd.csv')

In [None]:
fig, ax = plt.subplots(figsize=(8,8))

ax.set_aspect('equal')
ax.axis('off')
australia.plot(ax=ax, color='white',  edgecolor='black')
s2msi_results_gpd.plot(ax=ax, marker='o', color='blue', markersize=2)
lsoli_results_gpd.plot(ax=ax, marker='o', color='red', markersize=2)
plt.savefig('random_ls_s2.png')
plt.show()