In [1]:
import xgboost as xgb
import pickle
from sklearn.model_selection import cross_val_score
import pandas as pd
import numpy as np
import cv2
import sys
import os
from skimage.metrics import structural_similarity, normalized_root_mse, adapted_rand_error, hausdorff_distance, peak_signal_noise_ratio
from compare_images import compare_images

In [2]:
csv_path = 'train_data.csv'
template_path = 'template'

In [None]:
templates = [(1+i, f, cv2.imread(os.path.join(template_path, f))) for i, f in enumerate(os.listdir(template_path)) if os.path.isfile(os.path.join(template_path, f))]
def get_row(path, label):
    cols = []
    after = cv2.imread(path)
    for _, template_name, template_image in templates:
        comparison = compare_images(template_image, after)
        cols.extend((
            comparison['similarity'], 
            comparison['mse'],
            comparison['adapted_rand_error_are'],
            comparison['adapted_rand_error_prec'],
            comparison['adapted_rand_error_rec'],
            comparison['hausdorff_distance'],
            comparison['psnr'],
        ))
    cols.append(label)
    return cols

rows = []
for root, dirs, files in os.walk('train'):
    if root == 'train':
        continue
    if os.path.basename(root) == 'garbage':
        label = 0
    else:
        label = [t[0] for t in templates if os.path.basename(root) == t[1]][0]

    for f in files:
        rows.append(get_row(os.path.join(root, f), label))

columns = []
for t in templates:
    columns.extend((
        t[1] + '_similarity',
        t[1] + '_mse',
        t[1] + '_adapted_rand_error_are',
        t[1] + '_adapted_rand_error_prec',
        t[1] + '_adapted_rand_error_rec',
        t[1] + '_hausdorff_distance',
        t[1] + '_hpsnr',
    ))
columns.append('label')
df = pd.DataFrame(rows, columns=columns)
df.to_csv(csv_path)

In [None]:
df = pd.read_csv(csv_path)

In [None]:
X = df.drop(['label', 'Unnamed: 0'], axis=1)
y = df['label'].astype(int)

In [None]:
classifier = xgb.XGBClassifier(use_label_encoder=False)
print(cross_val_score(classifier, X, y.values, cv=3))

In [None]:
classifier.fit(X, y)
pickle.dump(classifier, open(f'classifier-{csv_path}.pickle', "wb"))