# Imports and Functions

In [2]:
from datacube import Datacube
cdc = Datacube(config='/g/data/u46/users/ext547/ewater/cambodia_cube/cambodia.conf', app = "Polygon drill")
from datacube_stats.statistics import GeoMedian
from datacube.storage import masking
from datacube.storage.masking import mask_to_dict
from datacube.storage.storage import write_dataset_to_netcdf
from datacube.utils import geometry

import fiona
import rasterio.features
import geopandas as gpd

import numpy as np
import xarray as xr
import pickle

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.animation as animation
import matplotlib.patheffects as PathEffects
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
%matplotlib inline

from skimage import exposure

import calendar

#Import external dea-notebooks functions using relative link to Scripts directory
import sys
import os.path
sys.path.append('/g/data/u46/users/sc0554/dea-notebooks/10_Scripts/')
import DEAPlotting

In [None]:
def LoadAreaOfInterest(study_area):
    '''
    Firstly, LoadAreaOfInterest checks whether a pickle, that contains an xarray of nbar data, 
    is saved in the output folder. If there is no pickle, the function searches the 
    "AreaOfInterest" function to gain lat/lon information for that particular study_area. If 
    there is no lat/lon coordinates in the AreaOfInterest function, then an error is returned. 
    If the coordinates are found, nbar data is loaded and masked. Finally data from all 
    sensors are concatenated together into one xarray.
   
    Last modified: March 2018
    
    Author: Erin Telfer
    
    Inputs: 
    study_area - the name of the study area
    '''
    
    study_area=study_area.lower().replace(" ","")
    pickle_location=('{0}{1}.pkl'.format(output_folder,study_area))
    try:
        nbar_clean= pickle.load( open(pickle_location, "rb" ) )
        print("Nbar pickle has been found on file")
        print("Nbar pickle has been loaded")
        return(nbar_clean)
    
    except FileNotFoundError:
        try:
            print("No {0}.pkl file found on file".format(study_area))
            print("Location information from the AreaOfInterest function has been read")

            lat_min, lat_max, lon_min, lon_max = AreaOfInterest(study_area)
            
            print("Loading Cambodia Cube data")
            sensor_clean = {}

            #define wavelengths/bands of interest
            bands_of_interest = ['green',
                                 'swir1', 
                                 'nir',
                                 'pixel_qa',
                                 #'blue',
                                 #'swir2',
                                 #'swir1'
                                 ]

            #query is created
            query = {'time': (start_of_epoch, end_of_epoch),}
            query['x'] = (lon_min, lon_max)
            query['y'] = (lat_max, lat_min)
            query['crs'] = 'EPSG:4326'

            for sensor in sensors: #loop through specified
                sensor_nbar = cdc.load(product= sensor+'_usgs_sr_scene',
                                       measurements = bands_of_interest,group_by='solar_day',
                                       **query) #load nbar , dask_chunks = {'x':200, 'y':200}
                #retrieve the projection information before masking/sorting
                crs = sensor_nbar.crs
                crswkt = sensor_nbar.crs.wkt
                affine = sensor_nbar.affine
                #assign pq data variable
                sensor_pq= sensor_nbar.pixel_qa
                #create and use quality and cloud masks
                mask_components = {'cloud_shadow': 'no_cloud_shadow',
                           'cloud': 'no_cloud',}
                quality_mask = masking.make_mask(sensor_pq, **mask_components)
                good_data = quality_mask.loc[start_of_epoch:end_of_epoch]
                sensor_nbar2 = sensor_nbar.where(good_data)
                del (sensor_nbar)

                #calculate the percentage cloud free for each scene
                cloud_free = masking.make_mask(sensor_pq,
                                               cloud_shadow= 'no_cloud_shadow',cloud= 'no_cloud')
                mostly_cloud_free = cloud_free.mean(dim=('x','y')) >= cloud_free_threshold
                
                del(cloud_free)
                #discard data that does not meet the cloud_free_threshold
                mostly_good = sensor_nbar2.where(mostly_cloud_free).dropna(dim='time', 
                                                                           how='all')
                nodata_mask=mostly_good.mean(dim=('x','y')) >= -9998
                mostly_good=mostly_good.drop('pixel_qa')
                mostly_good=mostly_good.where(nodata_mask).dropna(dim='time',
                                                               how='all') 
                del(sensor_nbar2)
                #assign masked data to array
                sensor_clean[sensor] = mostly_good

                print('loaded %s' % sensor) 
            print('ls load complete')


            #data from different sensors are joined together and sorted so that observations are sorted by time rather than sensor
            nbar_clean = xr.concat(sensor_clean.values(), 'time')
            nbar_clean = nbar_clean.sortby('time')
            nbar_clean.attrs['crs'] = crs
            nbar_clean.attrs['affin|e'] = affine          
                    
            print("saving nbar data as {0}.pkl".format(study_area))

            pickle.dump(nbar_clean, open(pickle_location,"wb")) #save nbar as pickle
            return nbar_clean
        except TypeError:
            print("please add lat/lon details to AreaOfInterest function")

In [4]:
#Define function to define the coordinates for the study area#Define 
def AreaOfInterest(study_area):
    if study_area == 'phumsrahkaev':
        lat_min = 13.000 #down
        lat_max = 13.100 #up
        lon_min = 103.300 #left
        lon_max = 103.400 #right  
    elif study_area == 'outapaong':
        lat_min = 12.600 #down
        lat_max = 12.800 #up
        lon_min = 103.600 #left
        lon_max = 103.800 #right
    elif study_area == 'mondulkiri':
        lat_min = 12.863 #down
        lat_max = 13.663 #up
        lon_min = 106.350 #left
        lon_max = 107.236 #right
    elif study_area == 'krongstungtreng':
        lat_min = 13.181 #down
        lat_max = 13.681 #up
        lon_min = 105.781 #left
        lon_max = 106.381 #right
    elif study_area == 'kaohnheaek':
        lat_min = 13.000 #down
        lat_max = 13.100 #up
        lon_min = 107.000 #left
        lon_max = 107.100 #right
    elif study_area == 'neakleoang':
        lat_min = 11.246 #down
        lat_max = 11.532 #up
        lon_min = 105.141 #left
        lon_max = 105.380 #right
    elif study_area == 'tonlesaplake':
        lat_min = 13.020 #down
        lat_max = 13.120 #up
        lon_min = 103.740 #left
        lon_max = 103.840 #right
    elif study_area == 'maximum_extent':
        lat_min = 9.25 #down
        lat_max = 15.25 #up
        lon_min = 101.75 #left
        lon_max = 108.25 #right
    elif study_area == 'kampongchhnang':
        lat_min = 12 #down
        lat_max = 12.25 #up
        lon_min = 104.75 #left
        lon_max = 105.0 #right        
    else:
        print('FileNotFoundError')
    return (lat_min, lat_max, lon_min, lon_max)

In [5]:
def write_your_netcdf(data, dataset_name, filename, crs):

    """
    This function turns an xarray dataarray into a dataset so we can write it to netcdf. 
    It adds on a crs definition from the original array. data = your xarray dataset, dataset_name 
    is a string describing your variable
    
    Last modified: May 2018
    Author: Bex Dunn    
    """ 
   
    #turn array into dataset so we can write the netcdf
    if isinstance(data,xr.DataArray):
        dataset= data.to_dataset(name=dataset_name)
    elif isinstance(data,xr.Dataset):
        dataset = data
    else:
        print('your data might be the wrong type, it is: '+type(data))
    #grab our crs attributes to write a spatially-referenced netcdf
    dataset.attrs['crs'] = crs

    try:
        write_dataset_to_netcdf(dataset, filename)
    except RuntimeError as err:
        print("RuntimeError: {0}".format(err))    

In [6]:
#Define function to create subplots of all scenes within an array as subplots
def one_band_image_subplots(ds, num_cols, figsize = [10,40], left  = 0.125, 
                              right = 0.9, bottom = 0.1, top = 0.9, 
                              wspace = 0.2, hspace = 0.4):
    '''
    one_band_image_subplots takes a dataset with one band and multiple time steps, 
    and plots them in image. 
    Last modified: March 2018
    Author: Mike Barnes
    Modified by: Claire Krause and Erin Telfer
    
    Inputs: 
    ds -   Dataset containing the bands to be plotted
    num_cols - number of columns for the subplot
    
    Optional:
    figsize - dimensions for the output figure
    left  - the space on the left side of the subplots of the figure
    right - the space on the right side of the subplots of the figure
    bottom - the space on the bottom of the subplots of the figure
    top - the space on the top of the subplots of the figure
    wspace - the amount of width reserved for blank space between subplots
    hspace - the amount of height reserved for white space between subplots
    '''
    # Find the number of rows/columns we need, based on the number of time steps in ds
    fig = plt.figure(figsize = figsize)
    timesteps = ds.time.size
    num_rows = int(np.ceil(timesteps/num_cols))
    fig, axes = plt.subplots(num_rows, num_cols, figsize = figsize)
    fig.subplots_adjust(left  = left, right = right, bottom = bottom, top = top, 
                        wspace = wspace, hspace = hspace)
    try: #loop through all scenes, prepare imagery and create subplots
        for i, ax in enumerate(fig.axes):
            image_ds = ds.rainfall.isel(time =i)
            ax.set_title(str(image_ds.time.values)[0:10])
            ax.imshow(image_ds, interpolation = 'nearest') #plot image as subplot
    except IndexError: #if there are an odd number of plots, this code will allow plotting of images
        fig.delaxes(ax)
        plt.draw() 

# Query datacube for Landsat data and Calculate Geomedians 

**Read in SPEI data**

In [18]:
cambodia_spei = xr.open_dataset('/g/data/u46/users/sc0554/drought_indices_cambodia/climate_indices_output/cambodia_spei.nc')
cambodia_masked_spei = xr.open_dataset('/g/data/u46/users/sc0554/drought_indices_cambodia/climate_indices_output/cambodia_masked_spei.nc')

** Calculate quantiles **

In [19]:
cambodia_mean_spei_03 = cambodia_masked_spei.spei_gamma_03.mean(dim=('latitude', 'longitude'))
cambodia_mean_spei_06 = cambodia_masked_spei.spei_gamma_06.mean(dim=('latitude', 'longitude'))

In [22]:
spei_quantiles = cambodia_mean_spei_03.quantile([0,0.25, 0.5, 0.75, 1], dim = ['time'], keep_attrs = True, )
spei_quantiles = spei_quantiles.values.tolist()
# spei_quantiles = [item for sublist in spei_quantiles for item in sublist]
# spei_quantiles = [item for sublist in spei_quantiles for item in sublist]
print(spei_quantiles)

[-2.199103593826294, -0.48335041105747223, 0.03167547658085823, 0.495951808989048, 2.0509073734283447]


In [23]:
# Quartiles

spei_q1 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < spei_quantiles[1]) &
                                 (cambodia_mean_spei_03 >= spei_quantiles[0]), drop=True)
spei_q2 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < spei_quantiles[2]) &
                                 (cambodia_mean_spei_03 >= spei_quantiles[1]), drop=True)
spei_q3 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < spei_quantiles[3]) &
                                 (cambodia_mean_spei_03 >= spei_quantiles[2]), drop=True)
spei_q4 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 <= spei_quantiles[4]) &
                                 (cambodia_mean_spei_03 >= spei_quantiles[3]), drop=True)

In [115]:
# Hard boundaries

# spei_q1 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < -1) &
#                                  (cambodia_mean_spei_03 >= -2.199104), drop=True)
# spei_q2 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < 0) &
#                                  (cambodia_mean_spei_03 >= -1), drop=True)
# spei_q3 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 < 1) &
#                                  (cambodia_mean_spei_03 >= 0), drop=True)
# spei_q4 = cambodia_mean_spei_03.where((cambodia_mean_spei_03 <= 2.050907) &
#                                  (cambodia_mean_spei_03 >= 1), drop=True)

In [116]:
spei_q1.time.to_netcdf("spei_q1_dates_hard.nc")
spei_q2.time.to_netcdf("spei_q2_dates_hard.nc")
spei_q3.time.to_netcdf("spei_q3_dates_hard.nc")
spei_q4.time.to_netcdf("spei_q4_dates_hard.nc")

In [None]:
#define Landsat sensors of interest
sensors = ['ls5','ls8']

cloud_free_threshold = 0.10 

#specify output folder
output_folder= '/g/data/u46/users/sc0554/spei_geomedians/'

In [None]:
nbar_clean=LoadAreaOfInterest(study_area)
# nbar_clean= pickle.load( open("/g/data/u46/users/sc0554/drought_indices_cambodia/spei_geomedians/kampongchhnang.pkl", "rb" ) )

##  Select Landsat data using SPEI Quartile dates

Select every landsat observation within a list of months

In [None]:
months = nbar_clean.time.astype('datetime64[M]')

In [None]:
spei_nbar_q1 = []
spei_nbar_q2 = []
spei_nbar_q3 = []
spei_nbar_q4 = []

for index, ds in nbar_clean.groupby(months):
    if index in spei_q1_dates.time.values.astype('datetime64[M]'):
        spei_nbar_q1.append(ds)
    if index in spei_q2_dates.time.values.astype('datetime64[M]'):
        spei_nbar_q2.append(ds)
    if index in spei_q3_dates.time.values.astype('datetime64[M]'):
        spei_nbar_q3.append(ds)
    if index in spei_q4_dates.time.values.astype('datetime64[M]'):
        spei_nbar_q4.append(ds)
        
spei_nbar_q1 = xr.concat(spei_nbar_q1, dim = 'time')
spei_nbar_q2 = xr.concat(spei_nbar_q2, dim = 'time')
spei_nbar_q3 = xr.concat(spei_nbar_q3, dim = 'time')
spei_nbar_q4 = xr.concat(spei_nbar_q4, dim = 'time')

In [None]:
#SPEI geomedians
spei_q1_geomedian = GeoMedian().compute(spei_q1)
spei_q2_geomedian = GeoMedian().compute(spei_q2)
spei_q3_geomedian = GeoMedian().compute(spei_q3)
spei_q4_geomedian = GeoMedian().compute(spei_q4)
    

In [None]:
from datacube_stats.statistics import GeoMedian
spei_nbar_q1_geomedian = GeoMedian().compute(spei_nbar_q1)
spei_nbar_q2_geomedian = GeoMedian().compute(spei_nbar_q2)
spei_nbar_q3_geomedian = GeoMedian().compute(spei_nbar_q3)
spei_nbar_q4_geomedian = GeoMedian().compute(spei_nbar_q4)

geomedians = [spei_nbar_q1_geomedian, spei_nbar_q2_geomedian, spei_nbar_q3_geomedian, spei_nbar_q4_geomedian]

In [None]:
write_your_netcdf(spei_nbar_q2_geomedian, 'spei_nbar_q2', '/g/data/u46/users/sc0554/drought_indices_cambodia/spei_geomedians/spei_nbar_q2', spei_nbar_q2_geomedian.crs)

In [None]:
#spei_nbar_q1_geomedian.to_netcdf("spei_nbar_q1_geomedian.nc")
# spei_nbar_q2_geomedian.to_netcdf("spei_nbar_q2_geomedian.nc")
# spei_nbar_q3_geomedian.to_netcdf("spei_nbar_q3_geomedian.nc")
# spei_nbar_q4_geomedian.to_netcdf("spei_nbar_q4_geomedian.nc")