Imports

In [1]:
from PIL import Image
print(Image.__file__)
#import Image
#print(Image.__file__)

D:\University\FYP\cDCGAN\venv\lib\site-packages\PIL\Image.py


In [2]:
import os
import numpy as np
from collections import Counter
from tqdm import tqdm
import shutil
import sys



import tensorflow as tf
from keras import layers
from keras.models import Model, Sequential
from keras_preprocessing.image import load_img
from keras.applications.xception import preprocess_input 
from keras.applications.xception import Xception 

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from k_means_constrained import KMeansConstrained

In [3]:
def blockPrint():
    sys.stdout = open(os.devnull, 'w')
blockPrint()

In [4]:
ROOT_DIR = "data/crc7k/norm"
TARGET_DIR = "data/crc7k/decomposed"
INPUT_SHAPE = (299,299,3)

In [5]:
def get_model():
    model = Xception(weights='imagenet', include_top=True, input_shape=INPUT_SHAPE)
#     model = Sequential()
#     model.add(base)
#     model.add(layers.GlobalAveragePooling2D())
#     model.add(layers.Dense(1024, activation='relu')) 
#     model.add(layers.Dense(8, activation='softmax'))
#     model.load_weights("xception_weights.h5")
    return model

In [6]:
def extract_features(img_path, extractor):
    img = load_img(img_path, target_size=(299, 299))
    #img = img.resize((299, 299))
    img = np.asarray(img)
    img = preprocess_input(img)
    img = np.expand_dims(img, axis=0)
    assert img.shape == (1, 299, 299, 3)
    features = extractor.predict(img).reshape(-1)
    assert features.shape == (2048,)
    return features

In [7]:
def get_cluster_labels(folder_path, extractor, 
                  n_components, n_clusters, random_state):
    feat_dict = {}
    files = os.listdir(folder_path)
    for file in tqdm(files):
        file_path = os.path.join(folder_path, file)
        feat_dict[file] = extract_features(file_path, extractor)
    
    # recreate list from keys to make sure ordering is parallel
    fnames = np.array(list(feat_dict.keys()))
    features = np.array(list(feat_dict.values()))
    
    pca = PCA(n_components=n_components, random_state=random_state)
    pca.fit(features)
    features_t = pca.transform(features)
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
    kmeans.fit(features_t)
    labels = kmeans.labels_
    
#     kmeans_c = KMeansConstrained(n_clusters=2, size_min=250, random_state=123)
#     kmeans_c.fit(features_t)
#     labels = kmeans_c.labels_
    
    file_label_dict = dict(zip(fnames, labels))
    return file_label_dict

In [8]:
def write_new_classes(source_folder_path, extractor, 
                     n_components=0.95, n_clusters=2, random_state=123):
    source_folder = source_folder_path.split('\\')[-1]
    target_folder_path = os.path.join(TARGET_DIR, source_folder)
    new_folders = [f'{target_folder_path}_CLUSTER_{i}' 
                   for i in range(n_clusters)]
    for folder in new_folders:
        if not os.path.exists(folder):
            os.mkdir(folder)
    cluster_labels = get_cluster_labels(
        source_folder_path, extractor, n_components, n_clusters, random_state)
    for fname in cluster_labels.keys():
        label = cluster_labels[fname]
        src = os.path.join(source_folder_path, fname)
        dst = os.path.join(new_folders[label], fname)
        
        shutil.copyfile(src, dst)

In [9]:
def main():
    model = get_model()
    extractor = Model(inputs=model.inputs, outputs=model.layers[-2].output)
    extractor.summary()
    if not os.path.exists(TARGET_DIR):
        os.mkdir(TARGET_DIR)
    for folder in os.listdir(ROOT_DIR):
        folder_path = os.path.join(ROOT_DIR, folder)
        #print(f"Working folder {folder}")
        write_new_classes(folder_path, extractor)
    print("Finished")

In [10]:
main()

100%|██████████| 1338/1338 [01:16<00:00, 17.45it/s]
100%|██████████| 847/847 [00:46<00:00, 18.25it/s]
100%|██████████| 339/339 [00:18<00:00, 17.95it/s]
100%|██████████| 634/634 [00:35<00:00, 17.86it/s]
100%|██████████| 1035/1035 [00:56<00:00, 18.41it/s]
100%|██████████| 592/592 [00:32<00:00, 18.41it/s]
100%|██████████| 741/741 [00:37<00:00, 19.73it/s]
100%|██████████| 421/421 [00:23<00:00, 17.72it/s]
100%|██████████| 1233/1233 [01:08<00:00, 18.10it/s]


In [11]:
for folder in os.listdir(TARGET_DIR):
    folder_path = os.path.join(TARGET_DIR, folder)
    num_files = len(os.listdir(folder_path))
    print(folder, num_files)