In [None]:
import matplotlib.pyplot as plt
import os
import random
import cv2
from scipy import ndimage
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns

datadir = "jellybean-data"
if not os.path.exists(datadir):
    import sys
    !{sys.executable} -m pip install gitpython
    from git import Repo
    Repo.clone_from(r"git@github.com:srvanderplas/jellybean_data.git", r"jellybean-data")

In [None]:
## list the images in here
images = os.listdir(datadir)

In [None]:
# take a random flavor
sample_path = os.path.join(datadir, random.sample(images,1)[0])
sample_path

In [None]:
## list all images in the folder
samp_images = os.listdir(sample_path)

In [None]:
# sample a random image
flavor = random.sample(samp_images,1)[0]
sample_path = os.path.join(sample_path, flavor)
sample_path

In [None]:
sample_path
sample_image = plt.imread(sample_path)

In [None]:
plt.imshow(sample_image)

In [None]:
sample_image.shape

In [None]:
from skimage.color import rgb2gray

In [None]:
gray_mask = rgb2gray(sample_image)

In [None]:
plt.imshow(gray_mask, "gray")
plt.show()

In [None]:
gray_mask = cv2.convertScaleAbs(gray_mask*255)

In [None]:
from skimage import exposure

img_eq = exposure.equalize_hist(gray_mask)

In [None]:
plt.imshow(img_eq, "gray")
plt.show()

In [None]:
from skimage.filters import threshold_otsu

In [None]:
# get the otsu thresholding
img_threshold = threshold_otsu(img_eq)

In [None]:
# threshold the image
binary = img_eq > img_threshold

In [None]:
plt.imshow(binary, "gray")
plt.show()

In [None]:
binary = 1- binary

In [None]:
plt.imshow(binary, "gray")
plt.show()

In [None]:
binary = cv2.convertScaleAbs(binary*255.0)

In [None]:
plt.imshow(binary, "gray")

In [None]:
fill_holes = ndimage.binary_fill_holes(binary)

In [None]:
plt.imshow(fill_holes, "gray")

In [None]:
fill_holes = cv2.convertScaleAbs(fill_holes*255.0)

In [None]:
plt.imshow(fill_holes, "gray")

In [None]:
ret, markers = cv2.connectedComponents(fill_holes)

In [None]:
## detect contours
## and remove the biggest one
contours,hierarchy =  cv2.findContours(fill_holes,cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

In [None]:
largest_contours = sorted(contours, key=cv2.contourArea)[-20:-1]

In [None]:
len(largest_contours)

In [None]:
# temp1 = np.zeros((fill_holes.shape[0], fill_holes.shape[1]))

In [None]:
# cv2.drawContours(temp1, largest_contours, -1, (255,255,255), -1)

In [None]:
# plt.imshow(temp1, "gray")


In [None]:
# img = mask_sample_bounded.reshape(mask_sample_bounded.shape[0]*mask_sample_bounded.shape[1],3)

In [None]:
temp1 = np.zeros((fill_holes.shape[0], fill_holes.shape[1]))
catch_img = []
for cnt in tqdm(largest_contours): 
    area = cv2.contourArea(cnt)
#     temp = markers == cnt
#     temp = cv2.convertScaleAbs(temp*255.0)
    if (area > 10**5) & (area < 10**6):
        print(area)
        x,y,w,h = cv2.boundingRect(cnt)
        mask_sample_bounded = sample_image[y:y+h,x:x+w, :]
        plt.imshow(mask_sample_bounded)
        plt.show()
        cv2.drawContours(temp1, [cnt], -1, (255,255,255), -1)
        img = mask_sample_bounded.reshape(mask_sample_bounded.shape[0]*mask_sample_bounded.shape[1],3)
        mean_rgb = img.mean(0)
        std_rgb = img.std(0)
        mean_by_std = mean_rgb/std_rgb
        all_catch = [mean_rgb, std_rgb, mean_by_std]
        all_catch = [it for item in all_catch for it in item]
        catch_img.append(all_catch)
catch_img_df = pd.DataFrame(catch_img)
catch_img_df.columns = ["r_mean", "g_mean", "b_mean", "r_std", "g_std", "b_std", "r_mean_by_std", 
                       "g_mean_by_std", "b_mean_by_std"]
catch_img_df["flavor"] = flavor

In [None]:
plt.imshow(temp1, "gray")
plt.show()

In [None]:
def get_features_image(sample_path): 
    flavor = sample_path.split("/")[-2]
    sample_image = plt.imread(sample_path)
    gray_mask = rgb2gray(sample_image)
    gray_mask = cv2.convertScaleAbs(gray_mask*255)
    img_eq = exposure.equalize_hist(gray_mask)
    # get the otsu thresholding
    img_threshold = threshold_otsu(img_eq)
    # threshold the image
    binary = img_eq > img_threshold
    binary = 1- binary
    binary = cv2.convertScaleAbs(binary*255.0)
    fill_holes = ndimage.binary_fill_holes(binary)
    fill_holes = cv2.convertScaleAbs(fill_holes*255.0)
    ## detect contours
    ## and remove the biggest one
    contours,hierarchy =  cv2.findContours(fill_holes,cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    largest_contours = sorted(contours, key=cv2.contourArea)[-20:-1]
    
    if len(largest_contours) != 19:
        return []
#     img = sample_image.reshape(sample_image.shape[0]*sample_image.shape[1],3)

    temp1 = np.zeros((fill_holes.shape[0], fill_holes.shape[1]))
    catch_img = []
    for cnt in largest_contours: 
        area = cv2.contourArea(cnt)
#     temp = markers == cnt
#     temp = cv2.convertScaleAbs(temp*255.0)
        if (area > 10**5) & (area < 10**6):
#             print(area)
            x,y,w,h = cv2.boundingRect(cnt)
            mask_sample_bounded = sample_image[y:y+h,x:x+w, :]
#             plt.imshow(mask_sample_bounded)
#             plt.show()
            cv2.drawContours(temp1, [cnt], -1, (255,255,255), -1)
            img = mask_sample_bounded.reshape(mask_sample_bounded.shape[0]*mask_sample_bounded.shape[1],3)
            mean_rgb = img.mean(0)
            std_rgb = img.std(0)
            mean_by_std = mean_rgb/std_rgb
            all_catch = [mean_rgb, std_rgb, mean_by_std]
            all_catch = [it for item in all_catch for it in item]
            catch_img.append(all_catch)
            
    catch_img_df = pd.DataFrame(catch_img)
#     if len(catch_img) !=
    catch_img_df.columns = ["r_mean", "g_mean", "b_mean", "r_std", "g_std", "b_std", "r_mean_by_std", 
                       "g_mean_by_std", "b_mean_by_std"]
    catch_img_df["flavor"] = flavor

    return catch_img_df

In [None]:
catch_img_df = get_features_image(sample_path)

In [None]:
catch_img_df.head()

In [None]:
out_path =  "./"

In [None]:
# out_path

In [None]:
flavors = os.listdir(datadir)
flavors.remove(".git")
flavors.remove("LICENSE")
flavors

In [None]:
import random

In [None]:
# flavors = random.sample(flavors, 20)

In [None]:
sample_paths = []

for flv in flavors:
    sample_path1 = os.path.join(datadir, flv)
    # print(sample_path1)
    samp_images = os.listdir(sample_path1)
    for imgs in samp_images:
        sample_path = os.path.join(sample_path1, imgs)
        sample_paths.append(sample_path)

In [None]:
len(sample_paths)

In [None]:
from joblib import Parallel, delayed

In [None]:
catch_all_dfs = Parallel(n_jobs=7, verbose = 6, 
                        backend = "loky")(delayed(get_features_image)(i) for i in sample_paths)

In [None]:
# catch_all_dfs = []

# for sample_path in tqdm(sample_paths): 
#     if len(get_features_image(sample_path)) > 0:
#         catch_all_dfs.append(get_features_image(sample_path))

In [None]:
catch_all_dfs_1 = [item for item in catch_all_dfs if len(item) > 0]

In [None]:
catch_all_dfs_1 = pd.concat(catch_all_dfs_1)

In [None]:
current_dir

In [None]:
out_path =  os.path.join(current_dir, "../", "Case_Study_and_Misc")

In [None]:
catch_all_dfs_1.to_csv(os.path.join(out_path, "all.csv"), index = False)

In [None]:
catch_all_dfs_1.isnull().sum()

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
top_k = catch_all_dfs_1["flavor"].value_counts()[:20].index

In [None]:
catch_all_dfs_1 = catch_all_dfs_1[catch_all_dfs_1["flavor"].isin(top_k)]

In [None]:
x_train, x_test, y_train, y_test = train_test_split(catch_all_dfs_1.iloc[:,:-1], catch_all_dfs_1.iloc[:,-1], 
                                    test_size = 0.3, stratify = catch_all_dfs_1.iloc[:,-1])

In [None]:
count_train = pd.DataFrame(y_train).value_counts().reset_index()
count_train.columns = ["flavor", "counts"]

In [None]:
count_train

In [None]:
plt.rcParams["font.weight"] = "bold"
plt.figure(figsize = (10,5))
sns.barplot(data = count_train, x = "flavor", y = "counts")
plt.xticks(rotation = 90)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Flavor", weight = "bold", fontsize = 20)
plt.ylabel("Frequency", weight = "bold", fontsize = 20)
# plt.legend(prop={'size': 15})
plt.show()

In [None]:
from sklearn.model_selection import GridSearchCV

from sklearn.ensemble import RandomForestClassifier


In [None]:
rf = RandomForestClassifier(n_jobs=6)

parameters = {'n_estimators':[100]}

In [None]:
gs = GridSearchCV(estimator=rf, cv=5, n_jobs=6, scoring="accuracy", param_grid = parameters, verbose = 5)

In [None]:
gs.fit(x_train, y_train)

gs.best_score_

In [None]:
gs.best_estimator_

In [None]:
y_test_pred = gs.predict(x_test)

In [None]:
np.mean(y_test_pred == y_test)

In [None]:
from sklearn.metrics import confusion_matrix

In [None]:
cf = confusion_matrix(y_test, y_test_pred, normalize = "true", labels = top_k)

In [None]:
df_cf = pd.DataFrame(cf, columns=top_k, index = top_k)

In [None]:
plt.rcParams["font.weight"] = "bold"
plt.figure(figsize = (10,5))
sns.barplot(data = count_train, x = "flavor", y = "counts")
plt.xticks(rotation = 90)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Flavor", weight = "bold", fontsize = 20)
plt.ylabel("Frequency", weight = "bold", fontsize = 20)
# plt.legend(prop={'size': 15})
plt.show()

In [None]:
plt.rcParams["font.weight"] = "bold"
plt.figure(figsize = (10,10))
sns.heatmap(df_cf, cmap = "RdBu_r")
plt.xlabel("Predicted Flavor", weight = "bold", fontsize = 20)
plt.ylabel("True Flavor", weight = "bold", fontsize = 20)
plt.show()

In [None]:
rf = gs.best_estimator_

In [None]:
rf.fit(x_train, y_train)

In [None]:
rf.feature_importances_

In [None]:
rf.feature_names_in_

In [None]:
feat_score = zip(rf.feature_names_in_, rf.feature_importances_)

In [None]:
feat_score_df = pd.DataFrame(list(feat_score))

In [None]:
feat_score_df.columns = ["feature_name", "importance_score"]

In [None]:
imp_df = feat_score_df.sort_values("importance_score", ascending = False).reset_index(drop = True)

In [None]:
plt.rcParams["font.weight"] = "bold"
plt.figure(figsize = (10,5))
sns.barplot(data = imp_df, x = "feature_name", y = "importance_score")
plt.xticks(rotation = 90)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("Feature Name", weight = "bold", fontsize = 20)
plt.ylabel("Importance Score", weight = "bold", fontsize = 20)
plt.show()