In [None]:
# Load modules
import datacube
import os
import sys
import warnings
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.image as mpimg
from datacube.utils import geometry
from datacube.utils.geometry import CRS
from matplotlib import pyplot as plt
import geopandas as gp
import fiona
import glob
from datacube import helpers
import rasterio
import sklearn
import graphviz 
import pdb
import sklearn
# Import external functions from dea-notebooks using relative link to 10_Scripts
sys.path.append('/g/data/u46/users/sc0554/dea-notebooks/Scripts')
from dea_classificationtools import get_training_data_for_shp
from dea_plotting import display_map

## Extract training data

In [None]:
shp_list = glob.glob('/g/data1a/r78/LCCS_Aberystwyth/training_data/2015/*.shp')
out_train = []
for shp_num, path in enumerate(shp_list):
    print("[{:02}/{:02}]: {}".format(shp_num+1, len(shp_list), path))
    try:
        column_names = dea_classificationtools.get_training_data_for_shp(path, out_train, product = 'ls8_nbart_tmad_annual', 
                                                                     time = ('2015-01-01', '2015-12-31'), 
                                                                     crs = 'EPSG:3577', field='classnum')
    except Exception as e:
        print("Failed to extract data: {}".format(e))
    print("\n extracted pixels")
    
model_input = np.vstack(out_train)
print(model_input.shape)
# np.savetxt("train_input_tmad.txt", model_input, header = ' '.join(column_names), fmt = '%.4f')

## Train model

In [None]:
model_input = np.loadtxt('train_input.txt', delimiter = " ", skiprows=1)

In [None]:
### RANDOM FOREST
from sklearn.ensemble import RandomForestClassifier
# Initialise classifier
model = RandomForestClassifier(n_estimators=100, verbose=2, n_jobs=-1)
# Fit classifier add "==215" to make a single class prediction.
model = model.fit(model_input[:,1:], model_input[:,0])

In [None]:
### Decision tree
from sklearn import tree
# Initialise classifier
model = tree.DecisionTreeClassifier(random_state=0, max_depth=5)
# Fit classifier add "==215" to make a single class prediction.
model = model.fit(model_input[:,1:], model_input[:,0])

## Evaluate model

In [None]:
model.score(model_input[:,1:], model_input[:,0])

In [None]:
feature_names = list(data.data_vars)
print(feature_names)
target_names = np.array(('Natural Terrestrial Vegetated', 'Artificial Surface', 'Natural Surface', 'Artificail Water', 'Natural Water'))
print(target_names)

In [None]:
# Plots the structure of the tree
plt.figure(figsize=(25,8))
sklearn.tree.plot_tree(model) 

In [None]:
predict_out = model.predict(model_input[:,1:])

# Prediction

In [None]:
# load the area you want to predict land cover here

# Lake Eyre
# x = (550000, 600000)
# y = (-3000000, -2950000)
# x = (-1000000, -950000)
# y = (-3400000, -3350000)
# x = (-1200000, -1299850)
# y = (-3600000, -3500125)

# # Coorong
# x = (600000, 700000)
# y = (-3950000, -3850000)

# Kakadu
x = (0,100000)
y = (-1350000,-1250000)

query = {'time': ('2015-01-01', '2015-02-01')}
query['x'] = (x[0], x[1])
query['y'] = (y[0], y[1])
query['crs'] = 'EPSG:3577'

In [None]:
display_map(x, y, crs="EPSG:3577")

In [None]:
new_data = dc.load(product=product, group_by='solar_day', **query)

In [None]:
new_data.blue.isel(time=0).plot()

In [None]:
new_data = calculate_indices(new_data, 'BUI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'BSI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'BSI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'NBI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'EVI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'NDWI', collection='ga_ls_2')
new_data = calculate_indices(new_data, 'MSAVI', collection='ga_ls_2')
# new_data = new_data.drop(bands)

In [None]:
predicted = dea_classificationtools.predict_xr(model, new_data)

In [None]:
out = predicted.isel(time=0).transpose()
out = out.to_dataset(name="LCCS_L3")
out.attrs['crs']=geometry.CRS(data.crs)
# out = out.isel(time=0)

In [None]:
helpers.write_geotiff('dtreekak.tif', out)