# VII. Image classification with ML algorithm

---
**Author(s):** Kenji Ose, Dino Ienco, Quentin Yeche - [UMR TETIS](https://umr-tetis.fr) / [INRAE](https://www.inrae.fr/)

---

## 1. Introduction

We will present a way to classify satellite images according to a land cover typology. Here we present only a few basic principles. This is not a course on machine learning techniques or best practices.

## 2. Import libraries

As usual, we import all the required Python libraries. The new one is `sklearn` (*scikit-learn*), a package that provides simple and efficient tools for predictive data analysis.

In [None]:
import pystac_client
import planetary_computer
import rasterio as rio
import stackstac
import geogif
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import colors as mplcol

from IPython.display import Image
import rich.table

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import numpy as np

## 3. Getting a Sentinel-2 image

Here are some codes you already know. We want to load one Sentinel-2 image.


In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

sel_item = catalog.get_collection("sentinel-2-l2a").get_item("S2B_MSIL2A_20230226T103919_R008_T31TEJ_20230228T091155")

In [None]:
table = rich.table.Table("Asset Key", "Description")
for asset_key, asset in sel_item.assets.items():
    table.add_row(asset_key, asset.title)

table

In [None]:
print("RGB preview of Sentine-2 image")
print("---")
Image(url=sel_item.assets["rendered_preview"].href, width=250)

## 4. Looking for training dataset

For this example, we choose a land-use/land-cover map generated by [Impact Observatory](https://www.impactobservatory.com/). A real reference, with information acquired in the field or from expert knowledge (e.g. photo-interpretation) would be preferable.  


In [None]:
bbox = [3.698959,43.501749,4.015503,43.687239]

search = catalog.search(
                collections=["io-lulc-9-class"],
                bbox=bbox,
                sortby="datetime"
                )

    # convert stac catalog into item collection
items_lulc = search.item_collection()
for i in items_lulc: print(i)

We get a list of six items for our area of interest. We keep the last item corresponding to the year 2022, which is the same as our Sentinel-2 image we want to classify.

In [None]:
sel_lulc = catalog.get_collection("io-lulc-9-class").get_item("31T-2022")
Image(url=sel_lulc.assets["rendered_preview"].href, width=250)

## 5. Data preparation

### 5.1. Converting `item` to `xarray`

We convert the satellite image and the labeled dataset to `xarray`.

In [None]:
#import rasterio as rio

band_list = ['B01','B02','B03','B04','B05','B06','B07','B08','B09','B11','B12']
ds_sentinel = stackstac.stack(sel_item, 
                     assets=band_list, 
                     epsg=2154, 
                     resolution=20,
                     bounds_latlon=bbox)

ds_lulc = stackstac.stack(sel_lulc, 
                     assets=['data'],
                     epsg=2154, 
                     resolution=20,
                     bounds_latlon=bbox)
labels = list(np.unique(ds_lulc.values))
print(f"lulc unique label values: {labels}")

### 5.2. Displaying the land use/cover (lulc) layer

In [None]:
# labels and related colors
legend = {1:'Water', 2:'Trees', 4:'Flooded vegetation', 
          5:'Crops', 7:'Built area', 8:'Bare ground', 11:'Rangeland'}
colors = {1:'blue', 2:'darkgreen', 4:'lightgreen', 
          5:'gold', 7:'darkred', 8:'tan', 11:'tomato'}
colors = list(colors.values()) 
legend = list(legend.values())

# plot
fig = plt.figure()
# colorbar configuration
cmap=matplotlib.colors.ListedColormap(colors)
# ticks position
labels.append(12.0)
pos = [(labels[labels.index(i)+1]+i)/2 for i in labels if i<labels[-1]]
norm= mplcol.BoundaryNorm(labels,len(labels))
# plotting
img = plt.imshow(ds_lulc.squeeze(), cmap=cmap, norm=norm, 
                 aspect='equal', interpolation='nearest')
cbar = fig.colorbar(img)
cbar.set_ticks(pos)
cbar.set_ticklabels(legend)
# plot title
plt.title('io-lulc-9-class')
plt.show()

## 6. Classification with ML methods

To classify satellite images, we need to split the reference dataset into two sub-datasets: 

- training dataset: to train the ML model, 
- test dataset: to assess the performances of the trained model.

Some lines of code presented below could be quite tricky to understand, they aim to reshape input data into the format expected by Scikit-learn functions.

### 6.1. Reshaping input data

Here, it's like playing with a rubik's cube... We move the axes of our xarray in order to make their shape conform to the specifications of the scikit-learn functions.

<img src="https://upload.wikimedia.org/wikipedia/commons/1/10/Rubiks_cube.jpg?20051109150753" width="30%" alt="datacube nightmare">


In [None]:
# loading reference date
arr_lulc = ds_lulc.to_numpy()
y_data = arr_lulc

# loading and reshaping of satellite date
print("-----\nReshaping of satellite data array\n------")
arr_sentinel = ds_sentinel.to_numpy()
print(f"shape_step01: {arr_sentinel.shape}")
x = np.moveaxis(arr_sentinel.squeeze(), 0, -1)
print(f"shape_step02: {x.shape}")
X_data = x.reshape(-1, 11)
print(f"shape_step03: {X_data.shape}")
print("===\n")


### 6.2. Data normalization

The scikit-learn library provides several ready-to-use methods to perform data normalization. Here, we choose the `Standard Scaler` method in order to convert values between -1 and +1.

In [None]:
# Satellite date normalization
scaler = StandardScaler().fit(X_data)
X_scaled = scaler.transform(X_data)

### 6.3. Training and test datasets

In [None]:
# Split data
print("-----\nSplitting input data into training/test datasets\n-----")
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_data.ravel(), 
                                                    test_size=0.70, stratify = y_data.ravel())
print(f"X_train Shape: {X_train.shape}\nX_test Shape: {X_test.shape}")
print(f"y_train Shape: {y_train.shape}\ny_test Shape:{y_test.shape}")

### 6.4. Training a Random Forest model

Once training and test data are ready, we can train the ML model with the `fit` function of *Scikit-learn*. Here we choose the **Random Forest** algorithm. 

In [None]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(n_estimators=10)
rf.fit(X_train, y_train)

### 6.5. Inference

Here, we apply with the `predict` function of *Scikit-learn* the trained model on the satellite image in order to predict landcover type for all pixels.


In [None]:
# Predict the labels of test data
rf_pred = rf.predict(X_test)

### 6.6. Assessment

Now, we assess the prediction against test dataset. Here, the scores are not very good. It should be interesting to improve the model by changing the training sample strategy, by testing other model's parameters.

In [None]:
print(f"Accuracy: {accuracy_score(y_test, rf_pred)*100}")
print(classification_report(y_test, rf_pred))


## 7. Plotting results

In [None]:
output_img = X_scaled.reshape((1046, 1290, 11))[:,:,0:3]
output = rf.predict(X_scaled).reshape((1046, 1290))

f = plt.figure(figsize=[12.8, 9.6])

f.add_subplot(2,2, 1)
plt.imshow((output_img))
plt.title('S2 image')

f.add_subplot(2,2, 2)
plt.imshow(output, cmap=cmap, norm=norm, 
           aspect='equal', interpolation='nearest')
plt.title('prediction')

for ax in f.get_axes():
    ax.label_outer()

plt.show(block=True)