# Test application of ML classifier to data

Takes pickled traing sklearn model and applies to data from DEA

Model is trained using data extracted to a CSV file, can be downloaded from: https://rsg.pml.ac.uk/shared_files/dac/train_input_geomedian_tmad_subset.txt.gz

The model is trained using something like the code below:

```python
import pickle
import numpy
from sklearn.ensemble import RandomForestClassifier

# Read in text file
model_input = numpy.loadtxt("train_input_geomedian_tmad.txt", skiprows=1)

# Headers are
# classnum blue green red nir swir1 swir2 BUI BSI NBI EVI NDWI MSAVI sdev edev bcdev
column_names = 'classnum blue green red nir swir1 swir2 BUI BSI NBI EVI NDWI MSAVI sdev edev bcdev'.split()

# Set up model
model = RandomForestClassifier(n_estimators=100, n_jobs=-1, max_depth=10, max_features=3, verbose=2, oob_score=True)

# Train model
classifier = model.fit(model_input[:,[3, 4, 5, 6, 13, 14, 15]], model_input[:,0])

# Pickle model
with open('model_pickle.pickle', 'wb') as f:
    pickle.dump(classifier,f)

```


In [None]:
import os
import pickle
import sys

import datacube
from datacube import helpers
from datacube.utils import geometry
from matplotlib import pyplot
import numpy
import sklearn
import xarray
import yaml


# Load in modules from repos
sys.path.append('/home/jovyan/development/LCCS/decision_tree')
import dea_classificationtools

sys.path.append('/home/jovyan/development/dea-notebooks/Scripts')
from dea_bandindices import calculate_indices
from dea_plotting import display_map

sys.path.append('/home/jovyan/development/livingearth_lccs')
from le_lccs.le_classification import lccs_l3

In [None]:
# Set up working dir
working_dir = '/home/jovyan/cultivated_classification'

In [None]:
# Load in pickled data
with open(os.path.join(working_dir, 'model_pickle.pickle'), 'rb') as f:
    classifier = pickle.load(f)
    
model_variable_names = 'blue green red nir swir1 swir2 BUI BSI NBI EVI NDWI MSAVI sdev edev bcdev'.split()

# Prediction

In [None]:
def run_classification_for_site(site_name, model_variable_names=None):
    """
    Function to run the classification for a given site.
    
    Gets bounds of site from yaml file
    """
    # Specify site

    # Read in config file with site bounds
    with open("au_test_sites.yaml", "r") as f:
        config = yaml.safe_load(f)
    
    # Get bounds
    x = (config[site_name]["min_x"],config[site_name]["max_x"])
    y = (config[site_name]["max_y"],config[site_name]["min_y"])

    query = {'time': ('2015-01-01', '2015-02-01')}
    query['x'] = x
    query['y'] = y
    query['crs'] = 'EPSG:3577'
    query['resolution'] = (-100, 100)
    
    dc = datacube.Datacube(app = 'classifiers')

    geomedian_data = dc.load(product='ls8_nbart_geomedian_annual', group_by='solar_day',
                             dask_chunks={'x' : 1000, 'y' : 1000}, **query)
    
    
    geomedian_data = calculate_indices(geomedian_data, 'BUI', collection='ga_ls_2')
    geomedian_data = calculate_indices(geomedian_data, 'BSI', collection='ga_ls_2')
    geomedian_data = calculate_indices(geomedian_data, 'NBI', collection='ga_ls_2')
    geomedian_data = calculate_indices(geomedian_data, 'EVI', collection='ga_ls_2')
    geomedian_data = calculate_indices(geomedian_data, 'NDWI', collection='ga_ls_2')
    geomedian_data = calculate_indices(geomedian_data, 'MSAVI', collection='ga_ls_2')
    
    mads_data = dc.load(product='ls8_nbart_tmad_annual', group_by='solar_day',
                        dask_chunks={'x' : 1000, 'y' : 1000}, **query)

    # Join geomedian + mads
    new_data = xarray.merge([geomedian_data, mads_data])
    
    # Subset to just use the variable names from the model
    if model_variable_names is not None:
        new_data = new_data[model_variable_names]
        
    predicted = dea_classificationtools.predict_xr(classifier, new_data, progress=True)
    
    return predicted

### Loop thorough all site and export a geotiff

In [None]:
with open("au_test_sites.yaml", "r") as f:
    config = yaml.safe_load(f)

In [None]:
for i, site_name in enumerate(config.keys()):
    
    print("[{:02}/{:02}] {}".format(i+1, len(config.keys()),site_name))
    
    predicted = run_classification_for_site(site_name, model_variable_names=model_variable_names)
    
    # Get only cultivated layer
    cultivated = predicted.where(predicted == 111)

    out = cultivated.isel(time=0).transpose()
    out = cultivated.to_dataset(name="cultivated")
    out.attrs['crs']=geometry.CRS(geomedian_data.crs)
    out = out.isel(time=0)
    
    helpers.write_geotiff(os.path.join(working_dir, '{}_cultivated.tif'.format(site_name.lower().replace(' ','_'))),
                          out)

## Colour and plot the classified data

In [None]:
red, green, blue, alpha = lccs_l3.colour_lccs_level3(predicted[0])

In [None]:
pyplot.imshow(numpy.dstack([red, green, blue, alpha]))