In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import pathlib
import statistics
%matplotlib inline

In [2]:
class ColorReduction:
    def __call__(self, img):
        if len(img.shape) == 3:
            return self.apply_3(img)
        if len(img.shape) == 2:
            return self.apply_2(img)
        return None
        
    # problem 84 の reference solution は、ここの処理間違ってそう
    def reduction_onepixel(self, value):
        if 0 <= value < 64:
            return 32
        elif 64 <= value < 128:
            return 96
        elif 128 <= value < 192:
            return 160
        elif 192 <= value < 256:
            return 224
        return -1
    
    def apply_3(self, img):
        H, W, ch = img.shape
        output_img = img.copy()
        for i in range(H):
            for j in range(W):
                for c in range(ch):
                    output_img[i, j, c] = self.reduction_onepixel(img[i, j, c])
        return output_img
    
    def apply_2(self, img):
        H, W = img.shape
        output_img = img.copy()
        for i in range(H):
            for j in range(W):
                output_img[i, j] = self.reduction_onepixel(img[i, j])
        return output_img

In [30]:
class KmeansImageRecognition:
    def __init__(self, parse_func, class_list):
        self.color_reduction = ColorReduction()
        self.reduced_valuemap = {
            32: 0,
            96: 1,
            160: 2,
            224: 3
        }
        self.parse_func = parse_func
        self.class_list = class_list
        self.num_classes = len(class_list)
         
    def _get_images(self, test_path):
        images, names = [], []
        file_list = sorted(glob.glob(test_path))
        for file in file_list:
            images.append(cv2.imread(file))
            names.append(file)
        images = np.array(images)
        names = np.array(names)
        return images, names
    
    def _get_hist(self, img):
        assert len(img.shape) == 3, "invalid img dimension: expected: 3, got: {}".format(img.shape)
        H, W, ch = img.shape
        
        hist = np.zeros((12))
        for i in range(H):
            for j in range(W):
                for c in range(ch):
                    cls = 4*c + self.reduced_valuemap[self.color_reduction.reduction_onepixel(img[i, j, c])]
                    hist[cls] += 1
        return hist
    
    def _get_hists(self, images):
        # create histograms
        hists = np.zeros((len(images), 12))
        for i in range(len(images)):
            hists[i] = self._get_hist(images[i])
        return hists
    
    def recognition(self, test_path, th=0.5, seed=4):
        np.random.seed(seed)
        images, names = self._get_images(test_path)
        hists = self._get_hists(images)
        N = len(images)
        classes = np.array([0 if np.random.random() < th else 1 for i in range(N)])
        
        # 何も属さないようなクラスは存在しないと仮定
        gs = np.zeros((self.num_classes, 12), dtype=np.float32)
        ns = np.zeros((self.num_classes))
        
        iteration = 0
        while True:
            iteration += 1
            print("iteration: {}, assigned: {}".format(iteration, classes))
            for i in range(self.num_classes):
                ns[i] = np.sum(classes == i)
                gs[i] = np.sum(hists[classes == i], axis=0).astype(np.float32) / ns[i]
                
            cont = False
            for i in range(N):
                new_class = np.argmin([np.linalg.norm(hists[i] - gs[j]) for j in range(self.num_classes)])
                cont |= (new_class != classes[i])
                classes[i] = new_class

            if not cont:
                break
                
        for i, file_name in enumerate(names):
            print("{} Pred: {}".format(
                    pathlib.Path(file_name).name,
                    classes[i]
                )
            )
        
    def problem_89(self, test_path, th=0.5, seed=4):
        self.recognition(test_path, th=th, seed=seed)

In [31]:
def parse_func(file_name):
    return file_name.split("_")[1]
class_list = np.array(["akahara", "madara"])

recog = KmeansImageRecognition(parse_func, class_list)
recog.problem_89("../dataset/train_*.jpg", th=0.3, seed=4)

iteration: 1, assigned: [1 1 1 1 1 0 1 0 0 1]
iteration: 2, assigned: [1 1 1 1 1 0 0 0 0 0]
train_akahara_1.jpg Pred: 1
train_akahara_2.jpg Pred: 1
train_akahara_3.jpg Pred: 1
train_akahara_4.jpg Pred: 1
train_akahara_5.jpg Pred: 1
train_madara_1.jpg Pred: 0
train_madara_2.jpg Pred: 0
train_madara_3.jpg Pred: 0
train_madara_4.jpg Pred: 0
train_madara_5.jpg Pred: 0
