# 1. Initializations

## 1.1 General imports

In [None]:
### extra
from PIL import Image

### data management
# import pandas as pd
import numpy as np

### régression
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift, estimate_bandwidth

### graphical matplotlib basics
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# for jupyter notebook management
%matplotlib inline

## 1.2 General dataframe functions

In [None]:
# import smartcheck.dataframe_common as dfc
import smartcheck.paths as pth

## 1.3 General classification functions

In [None]:
def get_clusters_centroids(X, quantile, n_samples):
    bandwidth = estimate_bandwidth(X, quantile = quantile, n_samples = n_samples)
    cluster = MeanShift(bandwidth = bandwidth)
    cluster.fit(X)
    labels = cluster.labels_
    centroids = cluster.cluster_centers_
    return centroids, labels

# 2. Loading and Data Quality

## 2.1 Loading of data sets and general exploration

In [None]:
# Chargement
bird_img_path = pth.get_full_path("smartcheck\\resources\\bird_small.png")
print("File Full Path:",bird_img_path)
bird_img = plt.imread(bird_img_path)

In [None]:
# Visualisation
print("Dimensions de l'image:", bird_img.shape)
plt.imshow(bird_img);

## 2.2 Data quality refinement

In [None]:
# Remaniement des dimensions de l'image en 2D 
# suppression de la transparence alpha (RGB Alpha -> RGB) pour diminuer le nombre de variables d'entrainement
bird_img = bird_img[:, :, :3]
bird_rs = np.reshape(bird_img, (bird_img.shape[0]*bird_img.shape[1], bird_img.shape[2]))

In [None]:
# Visualisation
print("Dimensions de l'image:", bird_img.shape)
plt.imshow(bird_img);

# 2. Data Clustering

## 2.1 General Analysis

In [None]:
# Visualisation brute des données
# R, G, B pour les axes
r, g, b = bird_rs[:, 0], bird_rs[:, 1], bird_rs[:, 2]
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(r, g, b, c=bird_rs, marker='o')
ax.view_init(elev=10, azim=90)
ax.set_xlabel('Red')
ax.set_ylabel('Green')
ax.set_zlabel('Blue')
ax.set_title('RGB image Scatter 3D')
plt.show()

## 2.2 Agglomerative Clustering (CAH : Classification Ascendante Hiérarchique )

In [None]:
# Récupération des information des clusters et visualisation /!\ TRES COUTEUX
centroids, labels = get_clusters_centroids(bird_rs, 0.1, 200)

In [None]:
# Application de la compression mathématique aux données
print(centroids.shape, centroids)
print(labels.shape, labels)
bird_rs_zip = np.zeros(bird_rs.shape)
for i in range(len(bird_rs_zip)):
    bird_rs_zip[i] = centroids[labels[i]]

In [None]:
# reconstitution de l'image d'origine
bird_img_zip = np.reshape(bird_rs_zip, (bird_img.shape[0], bird_img.shape[1], bird_img.shape[2]))
plt.figure()
plt.subplot(121)
plt.imshow(bird_img)
plt.title('Image originale')
plt.subplot(122)
plt.imshow(bird_img_zip)
plt.title('Image reconstruite')
plt.show()