# IV. Zonal statistics, spectral signatures and vegetation index

---
**Author(s):** Quentin Yeche, Kenji Ose, Dino Ienco - [UMR TETIS](https://umr-tetis.fr) / [INRAE](https://www.inrae.fr/)

---

## 1. Introduction

Here are presented two solutions for computing zonal statistics on Sentinel-2 image with polygon vector file. We will try two librairies : `xrspatial` and `rasterstats`.

## 2. Import libraries

As usual, we import all the required Python libraries.

In [None]:
# STAC access
import pystac_client
import planetary_computer

# (geo)dataframes
import pandas as pd
import geopandas as gpd

# xarrays
import rioxarray
import xarray as xr

from rasterio import features

# library for turning STAC objects into xarrays
import stackstac

# visualization
from matplotlib import pyplot as plt

# miscellanous
import numpy as np
from IPython.display import display

## 3. Creating a `DataArray` from STAC object

### 3.1. Getting a Sentinel-2 STAC Item 

As a practical use case let's consider that we have identified the STAC Item we're interested in (see [this notebook](Joensuu_01-STAC.ipynb) for a refresher), and we also have an area of interest defined as a bounding box.

In [None]:
# Access to Planetary Computer API
root_catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)

item_id = 'S2A_MSIL2A_20201213T104441_R008_T31TEJ_20201214T083443'
item = root_catalog.get_collection("sentinel-2-l2a").get_item(item_id)

### 3.2. Loading Sentinel-2 image

We stack the item's assest, corresponding to spectral bands, into a `rioxarray` with a resampling at a spatial resolution of 10m.

we also collect information that will be useful for the following processing steps.

In [None]:
# bounding box expressed in Lat/Lon
aoi_bounds = (3.875107329166124, 43.48641456618909, 4.118824575734205, 43.71739887308995)

# bands of interest
boi = ['B02','B03','B04','B05','B06','B07','B08','B11','B12']

FILL_VALUE = 2**16-1

ds = stackstac.stack(
                item,
                assets = boi,
                resolution=10,
                dtype="uint16",
                fill_value=FILL_VALUE,
                bounds_latlon=aoi_bounds,
                    )

s2_ref_crs = ds.crs
s2_ref_trf = ds.transform
s2_ref_shape = (ds['x'].size, ds['y'].size)

print(f'- S2 CRS: {s2_ref_crs}')
print(f'- S2 affine transform: \n{s2_ref_trf}')
print(f'- S2 XY dimension: {s2_ref_shape}')

## 4. Loading of polygon vector file

### 4.1. Conversion into data array

In order to compute zonal statistics, first we have to convert the file into a labeled raster. Labels must be of integer type.

The vector file, named `sample.geojson`, has an attribute table with the following information:
- **fid**: unique ID [integer]
- **geometry**: coordinates of entity's polygon <[list]
- **landcover**: label [string]

In [None]:
field = gpd.read_file('sample.geojson')
field_to_raster_crs = field.to_crs(s2_ref_crs)
geom = field_to_raster_crs[['geometry', 'fid']].values.tolist()

field_cropped_raster = features.rasterize(geom, out_shape=s2_ref_shape, fill=0, transform=s2_ref_trf)
field_cropped_raster_xarr = xr.DataArray(field_cropped_raster)

### 4.2. Displaying the labeled image

In [None]:
import matplotlib.pyplot as plt

plt.imshow(field_cropped_raster_xarr)

## 5. Computing of zonal statistics 

### 5.1. Solution 1: with `xrspatial`

Here, we compute statistics based on each Sentinel-2 bands and merge the results into the vector file attribute table.

#### 5.1.1. Creating a dedicated function

First, we create a function, named `s2_zonal`, that calls `xrspatial.zonal_stats`.

In [None]:
from xrspatial import zonal_stats as xrspatial_zs
import time

def s2_zonal(ds, band, field_cropped_raster_xarr):
    s2_band = ds.sel(band=band).squeeze('time').values
    s2_band_xarr = xr.DataArray(s2_band)
    sign_spectral = xrspatial_zs(field_cropped_raster_xarr, 
                                 s2_band_xarr, 
                                 stats_funcs = ['count','min','mean','max'], nodata_values = 0)
    
    out_names = dict()
    for i in sign_spectral.columns:
        out_names[i] = '{}_{}'.format(band, i)
    sign_spectral.rename(columns = out_names, inplace = True)

    return sign_spectral

#### 5.1.2. Calculating summary statistics of each vector entity

In [None]:
final = field.copy()

start_time = time.time()
#for band in boi:
#    df = s2_zonal(ds, band, field_cropped_raster_xarr)
#    final = final.merge(df, left_on='fid', right_on='{}_zone'.format(band))
end_time = time.time()

duration = (end_time - start_time)
minutes, seconds =  divmod(duration, 60)
print(f'duration: {int(minutes)} min {seconds:.2f} sec')

#### 5.1.3. Displaying head of output table

In [None]:
cols = [c for c in final.columns if c.lower()[3:] != '_zone']
finalb = final[cols]
finalb.head()

### 5.2 Solution 2: with `rasterstats`

#### 5.2.1. Installation of `rasterstats` library

Here, we will use another package, named `rasterstats`. As it is not installed by default, we add it in our environment.

In [None]:
!pip install rasterstats

#### 5.2.2. Creating a dedicated function

First, we create a function, named `s2_zonal2`, that calls `rasterstats.zonal_stats`.

In [None]:
from rasterstats import zonal_stats as rasterstats_zs

def s2_zonal2(stac_item, band, geodf, geodf_id):
    fid = field_to_raster_crs[geodf_id]
    zs = rasterstats_zs(geodf, stac_item.assets[band].href, stats="count min mean max median")
    sign_spectral = pd.DataFrame(zs)
    sign_spectral = pd.concat([fid, sign_spectral], axis=1)
        
    out_names = dict()
    for i in sign_spectral.columns:
        out_names[i] = '{}_{}'.format(band, i)
    sign_spectral.rename(columns = out_names, inplace = True)

    return sign_spectral

#### 5.2.3. Calculating summary statistics of each vector entity

In [None]:
final2 = field.copy()

start_time = time.time()
for band in boi:
    df = s2_zonal2(item, band, field_to_raster_crs, "fid")
    final2 = final2.merge(df, left_on='fid', right_on='{}_fid'.format(band))
end_time = time.time()

duration = (end_time - start_time)
minutes, seconds =  divmod(duration, 60)
print(f'duration: {int(minutes)} min {seconds:.2f} sec')

#### 5.2.4. Displaying head of output table

In [None]:
cols = [c for c in final2.columns if c.lower()[3:] != '_fid']
final2b = final2[cols]
final2b.head()

## 6. Spectral signatures

Now we have the summary statistics for several landcover types, we can plot their spectral signatures.

In [None]:
df_spectral = final2b.set_index('landcover').T
df_spectral = df_spectral.reset_index()

df_spectral['band'] = df_spectral['index'].str[:3]
df_spectral['stat'] = df_spectral['index'].str[4:]

df_spectral2 = df_spectral.set_index('band')
df_spectral2 = df_spectral2.drop(['fid', 'geo'])
df_spectral2 = df_spectral2.drop(['index'], axis=1)

test = pd.concat([df_spectral2.urban01, df_spectral2.urban02, df_spectral2.urban03])
print(test.shape)

#test2['radio'] = ['blue', 'green', 'red', 'rededge1', 'rededge2', 'rededge3', 'nir', 'swir1', 'swir2' ]

s2_mean = df_spectral2[df_spectral2['stat']=='mean'].drop(['stat'], axis=1)
s2_min = df_spectral2[df_spectral2['stat']=='min'].drop(['stat'], axis=1)
s2_max = df_spectral2[df_spectral2['stat']=='max'].drop(['stat'], axis=1)

ax = s2_mean['forest03'].plot()
s2_min['forest03'].plot(ax=ax, c='gray')
s2_max['forest03'].plot(ax=ax, c='gray')

ax = s2_mean.plot()
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))



# 7. Band math and vegetation index 

## 7.1. NDVI principles

The NDVI (Normalized Difference Vegetation Index) is a vegetation index based on the difference between red and near infrared (nIR) values. Its formula is as follows:

$$NDVI = {nIR - Red \over nIR + Red}$$

This index exploits the spectral signature of the vegetation which is very particular, because it shows a very marked peak in the near infrared, and a lower reflectance in the red. This index is very effective in determining the presence of vegetation, but it can also be used to to evaluate the importance of the vegetation biomass as well as the intensity of the photosynthesis activity.

In [None]:
nir, red = ds.sel(band="B08").astype('float'), ds.sel(band="B04").astype('float')
ndvi = (nir-red)/(nir+red)

## 7.2. Plotting NDVI

In [None]:
plt.imshow(ndvi.squeeze(), cmap="RdYlGn")
plt.colorbar()
plt.title('NDVI')
plt.show()

### 7.3. Using of Jupyter widgets

It is possible to add elements (*slider*, *progress bar*, *checkbox*, *radio buttons*, etc.) to interact with the data visualization. To do this, load the `ipywidgets` library.

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual

def threshold(seuil):
    seuil = np.where(ndvi.squeeze()>seuil, 1, 0)
    plt.imshow(seuil, cmap="viridis", interpolation='nearest')#, cmap=plt.cm.gray)
    plt.colorbar()
    plt.show()

interact(threshold, seuil = widgets.FloatSlider(min=-1, max=1, step=0.001))