In [79]:
from sklearn.datasets import load_sample_image
from sklearn.feature_extraction import image
from skimage import io, color

import numpy as np
import matplotlib.pyplot as plt

from os import listdir

In [80]:
######################################## FUNCTIONS ########################################

##### Import / Filtres
def extract_patches(path, size, patch_count) :
    img = io.imread(path)
    patches = image.extract_patches_2d(img, (size,size),max_patches=patch_count)
    return patches

def folder_patches(path, size = 100, patch_count = 50) :
    patches = np.ndarray((0,size,size))
    for file in listdir(path) :
        new_patches = extract_patches(path+"/"+file, size, patch_count)
        patches = np.concatenate((patches, new_patches))
    return patches

def filter_patches(patches,seuil,size) :
    out = np.ndarray((0,size,size))
    for i in range(len(patches)) :

        if np.mean(patches[i]) > seuil :
            out = np.concatenate([out, [patches[i]]])
    
    return out

##### Affichage / Export
def patch_show(patch) :
    io.imshow(patch)
    # les images de bases sont en unsigned 16 bits donc les pixels vont de 0 --> 65535
    histogramme, classes = np.histogram(patch, bins=65536)
    print(histogramme)

    plt.figure()
    plt.xlabel("pixel value")
    plt.ylabel("pixel count")
    plt.yscale('log')
    plt.plot(classes[0:-1], histogramme)
    plt.show()
    
def patches_stats(patches) :
    print("Il y a %s patches" %(len(patches)))
    
    averages = np.mean(patches,axis=(1,2))
    #np.mean (on choisit les images qui ont une moyenne assez élevée ?)
    #np.var (on choisit les images qui ont une forte variance, donc pas que tout blanc tout noir ?)
    histogramme, classes = np.histogram(averages, bins=65536)
    #print(histogramme)

    plt.figure()
    plt.xlabel("average pixel value of a patch")
    plt.ylabel("pixel count")
    plt.yscale('log')
    plt.plot(classes[0:-1], histogramme)
    plt.show()

def save_patches(patches, path, name) :
    for i in range(len(patches)) :
        true_path = path+"/"+name+"_"+str(i)+".png"
        io.imsave(true_path,(patches[i]/256).astype('uint8'))



In [82]:
##### Gestion
def gestion(category, name, size, threshold, v = False) :

    patches = folder_patches('./Data/Raw/' + category, size=size,patch_count=1000)
    if v : patches_stats(patches)
    patches = filter_patches(patches,threshold, size)
    if v : patches_stats(patches)
    save_patches(patches, './Data/Patches/' + category, name)

######################################## MAIN ########################################
SIZE = 50
THRESHOLD = 5000

gestion ("Good", "good_patch", SIZE, THRESHOLD)
gestion ("Bad", "bad_patch", SIZE, THRESHOLD)