In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import pathlib
%matplotlib inline

In [2]:
class ColorReduction:
    def __call__(self, img):
        if len(img.shape) == 3:
            return self.apply_3(img)
        if len(img.shape) == 2:
            return self.apply_2(img)
        return None
        
    # problem 84 の reference solution は、ここの処理間違ってそう
    def reduction_onepixel(self, value):
        if 0 <= value < 64:
            return 32
        elif 64 <= value < 128:
            return 96
        elif 128 <= value < 192:
            return 160
        elif 192 <= value < 256:
            return 224
        return -1
    
    def apply_3(self, img):
        H, W, ch = img.shape
        output_img = img.copy()
        for i in range(H):
            for j in range(W):
                for c in range(ch):
                    output_img[i, j, c] = self.reduction_onepixel(img[i, j, c])
        return output_img
    
    def apply_2(self, img):
        H, W = img.shape
        output_img = img.copy()
        for i in range(H):
            for j in range(W):
                output_img[i, j] = self.reduction_onepixel(img[i, j])
        return output_img

In [16]:
class TinyImageRecognition:
    def __init__(self, gt_path, parse_func):
        self.color_reduction = ColorReduction()
        self.reduced_valuemap = {
            32: 0,
            96: 1,
            160: 2,
            224: 3
        }
        self.gt_path = gt_path

        self.images, self.names, self.classes = self._get_images(parse_func)
        self.hists = self._get_hists()
         
    def _get_images(self, parse_func):
        images, names, classes = [], [], []
        file_list = sorted(glob.glob(self.gt_path + "/train_*.jpg"))
        for file in file_list:
            images.append(cv2.imread(file))
            names.append(file)
            classes.append(parse_func(pathlib.Path(file).name))
        return images, names, classes
    
    def _get_hist(self, img):
        assert len(img.shape) == 3, "invalid img dimension: expected: 3, got: {}".format(img.shape)
        H, W, ch = img.shape
        
        hist = np.zeros((12))
        for i in range(H):
            for j in range(W):
                for c in range(ch):
                    cls = 4*c + self.reduced_valuemap[self.color_reduction.reduction_onepixel(img[i, j, c])]
                    hist[cls] += 1
        return hist
    
    def _get_hists(self):
        # create histograms
        hists = np.zeros((len(self.images), 12))
        for i in range(len(self.images)):
            hists[i] = self._get_hist(self.images[i])
        return hists
    
    def nearest_neighbour(self, img):
        hist_test = self._get_hist(img)
        argmin = np.argmin(np.sum(np.abs(self.hists - hist_test), axis=1))
        return argmin
    
    def problem_84(self):
        plt.figure(figsize=(20, 10))
        for i in range(len(self.images)):
            plt.subplot(2, 5, i+1)
            plt.title(pathlib.Path(self.names[i]).name)
            plt.bar(np.arange(0, 12) + 1, self.hists[i])
            print(self.hists[i])
        plt.show()
        
    def problem_85(self, test_path):
        file_list = sorted(glob.glob(test_path + "/test_*.jpg"))
        for file in file_list:
            img = cv2.imread(file)
            nearest = self.nearest_neighbour(img)
            if nearest != -1:
                print("{} is similar >> {} Pred >> {}".format(
                        pathlib.Path(file).name,
                        pathlib.Path(self.names[nearest]).name,
                        self.classes[nearest]
                    )
                )

In [17]:
def parse_func(file_name):
    return file_name.split("_")[1]

recog = TinyImageRecognition("../dataset", parse_func)
recog.problem_85("../dataset")

test_akahara_1.jpg is similar >> train_akahara_3.jpg Pred >> akahara
test_akahara_2.jpg is similar >> train_akahara_1.jpg Pred >> akahara
test_madara_1.jpg is similar >> train_madara_2.jpg Pred >> madara
test_madara_2.jpg is similar >> train_akahara_2.jpg Pred >> akahara
