# Testing and Comparing USGS Collection 1 & Collection 2

## Import modules

In [None]:
%matplotlib inline
import datacube.utils.rio
import datacube
from odc.ui import DcViewer
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from datacube.storage.masking import make_mask, describe_flags_def, describe_variable_flags
from datacube_stats.statistics import GeoMedian
from datacube.helpers import write_geotiff

from Scripts import dea_plotting
from Scripts import dea_datahandling
from Scripts import dea_bandindices

## Connect to the datacube

In [None]:
dc = datacube.Datacube(app='test_collections')

#This will speed up loading data
datacube.utils.rio.set_default_rio_config(aws='auto', cloud_defaults=True)


### Inspect/view datasets

In [None]:
dc.list_products()

In [None]:
DcViewer(dc=dc,
        time= '201', width='500px',
        products=['usgs_ls8c_level2_2'],
        zoom=3)

## User inputs

In [None]:
lat = 24.50
lon = 33.0
buff = 0.15
time = ('2013-01', '2019-06')
res = (-30, 30)

### View the selected location

In [None]:
dea_plotting.display_map(x=(lon-buff, lon+buff+0.275), y=(lat+buff, lat-buff))

## Load data from both collections

In [None]:
# Create a query object
queryC1 = {
    'x': (lon-buff, lon+buff),
    'y': (lat+buff, lat-buff),
    'time': (time),
    'resolution': res,
    'group_by': 'solar_day'
}

queryC2 = {
    'x': (lon-buff, lon+buff++0.275),
    'y': (lat+buff, lat-buff),
    'time': (time),
    'resolution': res,
    'group_by': 'solar_day'
}

In [None]:
# find most common crs to allow
# native loading with correct crs
crs = dea_datahandling.mostcommon_crs(dc=dc,
                                          product='usgs_ls8c_level2_2',
                                          query=queryC2)

print(crs)

In [None]:
# col1 = dc.load(product='ls8_usgs_sr_scene',
#                **queryC1,
#                output_crs=crs,
#                align=(15, 15),
#                dask_chunks={})

col2 = dc.load(product='usgs_ls8c_level2_2',
               **queryC2,
               output_crs=crs,
               align=(15, 15),
               dask_chunks={})


### Cloud mask

In [None]:
col2

In [None]:
# valid_data_C1 = make_mask(col1.pixel_qa,
#                           cloud="no_cloud",
#                           cloud_shadow="no_cloud_shadow",
#                           nodata=False)

valid_data_C2 = make_mask(col2["quality_l2_aerosol"],
                          cloud_shadow="not_cloud_shadow",
                          cloud_or_cirrus="not_cloud_or_cirrus")

col1 = col1.where(valid_data_C1)
col2 = col2.where(valid_data_C2)

#drop the variables we don't care about
col1 = col1.drop(list(col1.data_vars)[7:])
col2 = col2.drop(list(col2.data_vars)[7:])

## Plot RGB

In [None]:
bands= ['red', 'green', 'blue']
t = 3

print("USGS Collection 1 left, USGS Collecton 2 right")
fig,ax=plt.subplots(1,2, figsize=(12,5))
col1[bands].isel(time=t).squeeze().to_array().plot.imshow(robust=True, ax=ax[0])
col2[bands].isel(time=t).squeeze().to_array().plot.imshow(robust=True, ax=ax[1])
plt.show()


## Scale Collection 2 values

In order to match collection 1

In [None]:
scaled_blue = (col2['blue'] * 2.75e-5 - 0.2) * 10000
scaled_green = (col2['green'] * 2.75e-5 - 0.2) * 10000
scaled_red = (col2['red'] * 2.75e-5 - 0.2) * 10000
scaled_nir = (col2['nir'] * 2.75e-5 - 0.2) * 10000
scaled_swir_1 = (col2['swir_1'] * 2.75e-5 - 0.2) * 10000
scaled_swir_2 = (col2['swir_2'] * 2.75e-5 - 0.2) * 10000

col2_scaled = xr.Dataset({
    'blue': scaled_blue,
    'red': scaled_red,
    'green': scaled_green, 
    'nir': scaled_nir, 
    'swir1': scaled_swir_1,
    'swir2': scaled_swir_2
                       }, coords={'time': col2.time, 'y': col2.y,'x': col2.x}, attrs = col2.attrs)

## Calculate annual geomedian

and then subtract C2 from C1

#### Compute the geomedians

This will be slow becasue the dask arrays are computed as well

In [None]:
geomedian_C1 = GeoMedian().compute(col1)
geomedian_C2 = GeoMedian().compute(col2_scaled)

### plot the C1 and C2 geomedians

In [None]:
bands= ['red', 'green', 'blue']
t = 3

print("USGS Collection 1 left, USGS Collecton 2 right")
fig,ax=plt.subplots(1,2, figsize=(12,5))
geomedian_C1[bands].to_array().plot.imshow(robust=True, ax=ax[0])
geomedian_C2[bands].to_array().plot.imshow(robust=True, ax=ax[1])
plt.show()

### subtract to find difference

In [None]:
#need to rename swir bands to match collection 1
geomedian_difference = geomedian_C1 - geomedian_C2
#add back in attributes because of xarray bug
geomedian_difference.attrs = geomedian_C2.attrs

#### write to geotiff

Kernel often crashes and this is slow to load, so export for easy reloading

In [None]:
write_geotiff('geomedian_difference.tif',geomedian_difference)

### plot the difference

In [None]:
fig, axes = plt.subplots(2,3,  figsize=(25,14))

for band, ax in zip(list(geomedian_difference.data_vars), axes.flatten()):
    geomedian_difference[band].plot(ax=ax, cmap='bwr', vmin=-750, vmax=750)
    ax.set_title(band)