# MODIS - DTW (Dynamic Time Warping) Clustering

## Install packages

In [None]:
!pip install zkyhaxpy rasterio utm geopandas
!pip install tslearn
!pip install gcsfs
!pip install ipython-autotime
%load_ext autotime

## Import libraries

In [None]:
## for all ##
from zkyhaxpy import io_tools, pd_tools, np_tools, console_tools, timer_tools, json_tools, dict_tools, colab_tools, gis_tools
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import os
import matplotlib.pyplot as plt
import seaborn as sns
import logging
import re
import rasterio
import geopandas as gpd
import statsmodels.api as sm

from tslearn.clustering import TimeSeriesKMeans

In [None]:
colab_tools.mount_drive()
colab_tools.authen_gcp()

## Define paths

In [None]:
folder_dtw_viz = '/content/drive/MyDrive/!UNBDH2022-Multiverse-Of-Data/unbdh2022_multiverse_of_data/viz'
folder_dtw_result = '/content/drive/MyDrive/!UNBDH2022-Multiverse-Of-Data/unbdh2022_multiverse_of_data/result'
io_tools.create_folders(folder_dtw_viz, folder_dtw_result)

In [None]:
folder_modis_ndvi_pixval = '/temp/modis/ndvi_pixval'
io_tools.create_folders(folder_modis_ndvi_pixval)
!gsutil -m cp -r -n gs://unbdh2022-multiverseofdata-dev/modis/ndvi_pixval /temp/modis


In [None]:
list_files = io_tools.get_list_files_re(folder_modis_ndvi_pixval)
list_df_pixval = []
for path_file in tqdm(list_files):
    df_pixval_tmp = pd.read_parquet(path_file).sample(5000, random_state=0)
    list_df_pixval.append(df_pixval_tmp)

df_pixval = pd.concat(list_df_pixval)
df_pixval = df_pixval.dropna().copy()

In [None]:
df_pixval.shape

# Finding best n_clusters using elbow method

In [None]:

# df_pixval_tmp = df_pixval.sample(1000, random_state=0).copy()
# X = df_pixval_tmp.values
# dict_km = {}
# dict_inertia = {}
# list_inertia = []
# for n_clusters in range(2, 31):
#     print(f'Fitting for {n_clusters}')
#     km = TimeSeriesKMeans(n_clusters=n_clusters, metric="dtw", max_iter=5,
#                           max_iter_barycenter=5, n_jobs=os.cpu_count() - 1,
#                           random_state=0, verbose=0, dtw_inertia=True).fit(X)
                          
#     dict_km[n_clusters] = km
#     list_inertia.append({'n_clusters':n_clusters, 'inertia':km.inertia_})

# df_inertia = pd.DataFrame(list_inertia)    
# df_inertia = df_inertia.set_index('n_clusters')

# df_inertia.plot()
# plt.ylim(0, 0.15)

## Fit DTW

In [None]:
folder_model = '/temp/models'
io_tools.create_folders(folder_model)
!gsutil cp -r -n gs://unbdh2022-multiverseofdata-dev/models /temp


In [None]:

best_n_clusters = 8
model_id = f'dtw_all_years_k{best_n_clusters}'
path_model = f'/temp/models/{model_id}.pkl'
if os.path.exists(path_model):
    km, params = io_tools.read_pickle(path_model)
else:

    print(f'Fitting for {model_id}')
    X = df_pixval.values
    km = TimeSeriesKMeans(n_clusters=best_n_clusters, metric="dtw", max_iter=5,
                            max_iter_barycenter=5, n_jobs=os.cpu_count() - 1,
                            random_state=0, verbose=0, dtw_inertia=False).fit(X)

    io_tools.write_pickle((km, km.get_params()), path_model)
       
!gsutil cp -r -n /temp/models gs://unbdh2022-multiverseofdata-dev


## Get cluster_id

In [None]:
from google.cloud import storage

def get_bucket_and_file_name_from_uri(file_uri):
    assert(file_uri.startswith('gs://'))
    file_uri = file_uri.replace('gs://', '')
    bucket_name = file_uri.split('/')[0]
    file_name = '/'.join(file_uri.split('/')[1:])
    return (bucket_name, file_name)

def check_file_exists_gcs(file_uri):        
    bucket_name, file_name = get_bucket_and_file_name_from_uri(file_uri)
    storage_client = storage.Client()    
    bucket = storage_client.bucket(bucket_name)
    stats = storage.Blob(bucket=bucket, name=file_name).exists(storage_client)
    return stats

In [None]:
df_pixval_cluster_all_years.head()

In [None]:
list_files.sort()

# sample = 100000
sample = 'full'
list_df_pixval_cluster = []
list_df_pct_cluster = []

for path_file in tqdm(list_files):
    file_nm = os.path.basename(path_file)    
    file_nm = file_nm.replace('pixval', 'pixval_cluster')

    year = int(re.findall('\d{4}', file_nm)[0])
    if year==2022:
        print('Skip year 2022 because there is only 9 months of NDVI.')
        continue
    df_pixval_tmp = pd.read_parquet(path_file)
    if sample != 'full':
        df_pixval_tmp = df_pixval_tmp.sample(sample, random_state=0)
        path_out = f'gs://unbdh2022-multiverseofdata-dev/cluster/sample-{sample}/{file_nm}'
    else:
        path_out = f'gs://unbdh2022-multiverseofdata-dev/cluster/full/{file_nm}'

    if check_file_exists_gcs(path_out) == True:
        df_pixval_tmp = pd.read_parquet(path_out)
    else:    
        df_pixval_tmp['cluster_id'] = km.predict(df_pixval_tmp.values)
        df_pixval_tmp.to_parquet(path_out)
    df_pixval_tmp['year'] = year
    list_df_pixval_cluster.append(df_pixval_tmp)
        
    df_pct_cluster = pd.DataFrame((df_pixval_tmp.cluster_id.value_counts().sort_index() / len(df_pixval_tmp))).T
    df_pct_cluster['year'] = year
    df_pct_cluster['sample'] = sample        
    list_df_pct_cluster.append(df_pct_cluster)
        

        
df_pixval_cluster_all_years = pd.concat(list_df_pixval_cluster)
df_pixval_cluster_all_years.to_parquet(f'gs://unbdh2022-multiverseofdata-dev/cluster/df_pixval_cluster_all_years_sample-{sample}.parquet')


df_pct_cluster_all = pd.concat(list_df_pct_cluster)
df_pct_cluster_all = df_pct_cluster_all.rename(columns={cluster_id:f'cluster_{cluster_id}' for cluster_id in df_pct_cluster_all.columns if type(cluster_id)==int})
df_pct_cluster_all = df_pct_cluster_all.set_index('year')
df_pct_cluster_all.to_parquet(f'gs://unbdh2022-multiverseofdata-dev/cluster/df_pct_cluster_all_years_sample-{sample}.parquet')

df_pct_cluster_all.to_excel(os.path.join(folder_dtw_result, f'df_pct_cluster_all_years_sample-{sample}.parquet'))

In [None]:
n_cluster = len(df_pixval_cluster_all_years[f'cluster_id'] .unique())
for cluster_id, df_tmp in df_pixval_cluster_all_years.groupby([f'cluster_id']):    
    pct_cluster = len(df_tmp) / len(df_pixval_cluster_all_years)
    if len(df_tmp) > 100:
        df_tmp = df_tmp.sample(100, random_state=0).copy()
    fig = plt.figure()
    ax = fig.gca()
    for _, s_pixval in df_tmp.iterrows():
        s_pixval.iloc[:12].plot(ax=ax, alpha=0.25)    
    ax.plot(km.cluster_centers_[cluster_id], alpha=1, color='red')
    plt.title(f'cluster id: {cluster_id} ({pct_cluster * 100:0.1f}%)')
    plt.ylim(0, 1)
    plt.savefig(os.path.join(folder_dtw_viz, f'ndvi-curvature-cluster_{cluster_id}.jpg'))
    plt.show()
