In [1]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
from glob import glob

In [2]:
# Dicrease color
def dic_color(img):
    img //= 63
    img = img * 64 + 32
    return img

In [17]:
# Database
def get_DB():
    # get image paths
    train = glob("../dataset/train_*")
    train.sort()

    # prepare database
    db = np.zeros((len(train), 13), dtype=np.int32)

    pdb = []
    # each image
    for i, path in enumerate(train):
        img = dic_color(cv2.imread(path))
        
        # get histogram
        for j in range(4):
            db[i, j] = len(np.where(img[..., 0] == (64 * j + 32))[0])
            db[i, j+4] = len(np.where(img[..., 1] == (64 * j + 32))[0])
            db[i, j+8] = len(np.where(img[..., 2] == (64 * j + 32))[0])

        # get class
        if 'akahara' in path:
            cls = 0
        elif 'madara' in path:
            cls = 1

        # store class label
        db[i, -1] = cls

        pdb.append(path)

    return db, pdb

In [18]:
# test
def test_DB(db, pdb, N=3):
    # get test image path
    test = glob("../dataset/test_*")
    test.sort()

    accuracy_N = 0.

    # each image
    for path in test:
        # read image
        img = dic_color(cv2.imread(path))

        # get histogram
        hist = np.zeros(12, dtype=np.int32)
        for j in range(4):
            hist[j] = len(np.where(img[..., 0] == (64 * j + 32))[0])
            hist[j+4] = len(np.where(img[..., 1] == (64 * j + 32))[0])
            hist[j+8] = len(np.where(img[..., 2] == (64 * j + 32))[0])

        # get histogram difference
        difs = np.abs(db[:, :12] - hist)
        difs = np.sum(difs, axis=1)

        # get top N
        pred_i = np.argsort(difs)[:N]

        # predict class index
        pred = db[pred_i, -1]

        # get class label
        if len(pred[pred == 0]) > len(pred[pred == 1]):
            pl = "akahara"
        else:
            pl = 'madara'

        print(path, "is similar >> ", end='')
        for i in pred_i:
            print(pdb[i], end=', ')
        print("|Pred >>", pl)

        # count accuracy
        gt = "akahara" if "akahara" in path else "madara"
        if gt == pl:
            accuracy_N += 1.

    accuracy = accuracy_N / len(test)
    print("Accuracy >>", accuracy, "({}/{})".format(int(accuracy_N), len(test)))

In [19]:
db, pdb = get_DB()
test_DB(db, pdb)

../dataset/test_akahara_1.jpg is similar >> ../dataset/train_akahara_3.jpg, ../dataset/train_akahara_2.jpg, ../dataset/train_akahara_4.jpg, |Pred >> akahara
../dataset/test_akahara_2.jpg is similar >> ../dataset/train_akahara_1.jpg, ../dataset/train_akahara_2.jpg, ../dataset/train_akahara_4.jpg, |Pred >> akahara
../dataset/test_madara_1.jpg is similar >> ../dataset/train_madara_2.jpg, ../dataset/train_madara_4.jpg, ../dataset/train_madara_3.jpg, |Pred >> madara
../dataset/test_madara_2.jpg is similar >> ../dataset/train_akahara_2.jpg, ../dataset/train_madara_3.jpg, ../dataset/train_madara_2.jpg, |Pred >> madara
Accuracy >> 1.0 (4/4)
