In [None]:
import cgc
import logging
import time
import sys
import dask.array as da
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from osgeo import gdal
from cgc.kmeans import Kmeans
from cgc.coclustering import Coclustering

In [None]:
cgc.__version__  # print clustering-geodata-cubes version

In [None]:
# logging configuration
logging.basicConfig(format='%(asctime)s | %(levelname)s : %(message)s', 
                    level=logging.INFO, stream=sys.stdout)

In [None]:
# Manual input
dir_tiff = Path('/mnt/c/Users/OuKu/Developments/phenology/data/Bloom/Europe/')
dir_output = Path('/mnt/c/Users/OuKu/Developments/phenology/test_cocluster/')
load_pattern = '195[0-2].tif'
# load_pattern = '*.tif'
band_id = 3 # 4th band
k = 10  # num clusters in rows
l = 2  # num clusters in columns
errobj, niters, nruns, epsilon = 0.00001, 1, 1, 10e-8
kmean_n_clusters = 3
kmean_max_iter = 500

In [None]:
# Load first image, get th mask to select non-nan values
h_tif = gdal.Open([x for x in dir_tiff.iterdir()][0].as_posix())
img = h_tif.ReadAsArray(0, 0, h_tif.RasterXSize, h_tif.RasterYSize)[band_id]
img = img.reshape(-1, 1)
mask = np.where(np.isnan(img)==False)[0]


In [None]:
# Load the geotiffs
Z = np.empty((len(mask),0))
for f_tiff in dir_tiff.glob(load_pattern):
    h_tif = gdal.Open(f_tiff.as_posix())
    img = h_tif.ReadAsArray(0, 0, h_tif.RasterXSize, h_tif.RasterYSize)[band_id]
    img = img.reshape(-1, 1)
    img = img[mask]
    Z = np.append(Z, img, axis=1)
Z = Z.astype('float64')

In [None]:
# Co-clustering
cc = Coclustering(Z, k, l, errobj, niters, nruns, epsilon)
cc.run_with_threads(nthreads=1)

In [None]:
# Kmean
km = Kmeans(Z=Z,
            row_clusters=cc.row_clusters,
            col_clusters=cc.col_clusters,
            n_row_clusters=k,
            n_col_clusters=l,
            kmean_n_clusters=kmean_n_clusters,
            kmean_max_iter=kmean_max_iter)
km.compute()
km.cl_mean_centroids

In [None]:
# Export Plots
# Temporal cluster
plt.plot(range(0,len(cc.col_clusters)),cc.col_clusters)
plt.ylabel('Cluster')
plt.xlabel('Years')
plt.savefig((dir_output/'Temporal clusters').as_posix())

# Spatial cluster
# Reconstruct
R = np.empty(h_tif.RasterXSize*h_tif.RasterYSize)
R[:] = np.nan
R[mask] = cc.row_clusters
R = R.reshape(h_tif.RasterYSize, h_tif.RasterXSize)
plt.imshow(R)
plt.ylabel('Yaxis')
plt.xlabel('Xaxis')
plt.savefig((dir_output/'spatial_clusters.png').as_posix())

# #Plot Spatial Grid
# fig, ax = plt.subplots(int(np.floor(C.shape[1]/2)),
#                         int(C.shape[1]-np.floor(C.shape[1]/2)))
# empty_string_labels = ['']
# min_val = np.min(ircc)
# max_val = np.max(ircc)
# colorbar_th = np.round(np.linspace(np.ceil(min_val),
#                                     np.floor(max_val), 10))
# for c, a in zip(range(0, C.shape[1]), ax.flatten()):
#     temp_cl = np.where(C[:, c])[0]
#     pr = np.unique(ircc[:, temp_cl], axis=1)
#     spatial_group = np.reshape(pr, (block_ysize, block_xsize))
#     fig1 = a.imshow(spatial_group, interpolation="None", vmin=min_val,
#                     vmax=max_val)
#     a.set_title('Temp. cl ' + str(c+1), fontsize=16)
#     a.grid(True)
#     a.set_xticklabels(empty_string_labels)
#     a.set_yticklabels(empty_string_labels)
# cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
# cb = fig.colorbar(fig1, cax=cbar_ax, ticks=colorbar_th)
# cb.set_label('DOY', fontsize=16)
# cb.ax.tick_params(labelsize=18)
# plt.savefig(fig_outputdir + 'Spatial_info.png', format='png',
#             transparent=True, bbox_inches="tight")
# plt.close(fig)