In [244]:
from sklearn import datasets
import os
import numpy as np
import cv2 # OpenCV
from sklearn.svm import SVC # SVM klasifikator
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier # KNN
import matplotlib.pyplot as plt
%matplotlib inline

In [245]:
def load_image(path):
    return cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)

def display_image(image):
    plt.figure()
    plt.imshow(image)

### HOG

In [246]:
def get_hog():
    img_size = (90, 90)
    nbins = 9
    cell_size = (8, 8)
    block_size = (3, 3)
    hog = cv2.HOGDescriptor(_winSize=(img_size[1] // cell_size[1] * cell_size[1],
                                      img_size[0] // cell_size[0] * cell_size[0]),
                            _blockSize=(block_size[1] * cell_size[1],
                                        block_size[0] * cell_size[0]),
                            _blockStride=(cell_size[1], cell_size[0]),
                            _cellSize=(cell_size[1], cell_size[0]),
                            _nbins=nbins)
    return hog

In [247]:
def get_features_and_labels(hog, train_dir):
    features=[]
    labels = []
    for _ in range(100):
        for img_name in os.listdir(train_dir):
            img_path = os.path.join(train_dir, img_name)
            img = load_image(img_path)
            features.append(hog.compute(img))
            if img_name[0]=='b':
                labels.append(img_name[2])
            elif img_name[0]=='w':
                labels.append(img_name[2].upper())
            elif img_name[0]=='e':
                labels.append(img_name[0])

        
    return np.array(features), np.array(labels)
    
    

In [248]:
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score

def train_classifier(x_train, y_train):
    
    print("Treniranje klasifikatora...")
    classifier = SVC(kernel='linear', probability=True)
    classifier.fit(x_train, y_train)
    
    return classifier

In [249]:
train_dir='../data/pictures/'

hog = get_hog()
x_train, y_train = get_features_and_labels(hog, train_dir)

classifier = train_classifier(x_train, y_train)
# print(classifier.predict_proba(hog.compute(load_image('../data/pictures/w_q0.png')).reshape(1, -1)))

Treniranje klasifikatora...


### HOUGH

In [250]:
def detect_lines(gray_img):
    
    edges_img = cv2.Canny(gray_img, 50, 150, apertureSize=3)
    # plt.imshow(edges_img, "gray")
   
    min_line_length = 200
    lines = cv2.HoughLinesP(image=edges_img, rho=1, theta=np.pi/180, threshold=10, lines=np.array([]),
                            minLineLength=min_line_length, maxLineGap=20)

    lines[:, :, 1]=gray_img.shape[0] - lines[:, :, 1]
    lines[:, :, 3]=gray_img.shape[0] - lines[:, :, 3]

    lines = np.vstack([lines, [
        [[0, 1, 719, 1]],
        [[0, 721, 719, 721]],
        [[1, 0, 1, 719]],
        [[721, 0, 721, 719]]
    ]])
    return lines
    

In [251]:
def check_fields_similarity(frame_field, next_frame_field, classifier, hog):
    frame_field=cv2.resize(frame_field, (90,90),interpolation=cv2.INTER_NEAREST)
    next_frame_field=cv2.resize(next_frame_field, (90,90),interpolation=cv2.INTER_NEAREST)
    f1=classifier.predict(hog.compute(frame_field).reshape(1, -1))
    f2=classifier.predict(hog.compute(next_frame_field).reshape(1, -1))
    return (f1[0], f2[0])
    

In [252]:
def invert(similarities):
    inverse=[]
    for i in range (8):
        for j in range(8):
            inverse.append(similarities[j*8+i])
    return inverse

In [253]:
def change_fem_row(row, idx, char):
    full_row=''
    for c in row:
        if '9'>c>'0':
            full_row+='e'*int(c)
        else:
            full_row+=c
    full_row=full_row[:idx]+char+full_row[idx+1:]
    row=''
    cons_e=0
    for c in full_row:
        if c!='e':
            if cons_e!=0:
                row+=str(cons_e)
                cons_e=0
            row+=c
        else:
            cons_e+=1
    if cons_e!=0:
        row+=str(cons_e)
    return row

In [254]:
def get_fem_format(fem, figure, start_iter, end_iter):
    # rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w
    fem_table=fem[:-2]
    next_move=chr(217-ord(fem[-1]))
    rows=fem_table.split('/')
    if figure in ['k','K'] and abs(start_iter//8-end_iter//8)>1:
        #king
        rows[start_iter%8]=change_fem_row(rows[start_iter%8], start_iter//8, 'e')
        rows[end_iter%8]=change_fem_row(rows[end_iter%8], end_iter//8, figure)
        #rook
        rows[start_iter%8]=change_fem_row(rows[start_iter%8], 0 if end_iter<start_iter else 7, 'e')
        rows[end_iter%8]=change_fem_row(rows[end_iter%8], end_iter//8+1 if end_iter<start_iter else end_iter//8-1, 'r' if figure=='k' else 'R' )
    else: 
        rows[start_iter%8]=change_fem_row(rows[start_iter%8], start_iter//8, 'e')
        rows[end_iter%8]=change_fem_row(rows[end_iter%8], end_iter//8, figure)
    new_fem=''
    for i in range(8):
        new_fem+=rows[i]
        if i!=7:
            new_fem+='/'
    new_fem+=' '+next_move
    return new_fem

In [255]:
def detect_moves(video_path,start_pos,lines,classifier,hog):
    cap = cv2.VideoCapture(video_path)
    cap.set(1,0)
    next_frame=start_pos
    # lines.append()
    moves=[]
    frames=0
    fem="rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w"
    while True:
        frame=next_frame
        grabbed, next_frame= cap.read()
        
        if not grabbed:
            break
        next_frame=cv2.cvtColor(next_frame, cv2.COLOR_BGR2RGB)

        frames+=1
        half_size=int(len(lines)/2)
        iter=0
        similarities=[]
        differences=[]
        for vert_line_ind in range(half_size-1):
            for hor_line_ind in range(half_size-1):
                x1=lines[half_size+hor_line_ind][0][0]
                x2=lines[half_size+hor_line_ind+1][0][0]
                y1=lines[vert_line_ind][0][1]
                y2=lines[vert_line_ind+1][0][1]
                f1, f2=check_fields_similarity(frame[x1:x2, y1:y2],next_frame[x1:x2, y1:y2], classifier, hog)
                similarities.append(f1==f2)
                if f1!=f2:
                    differences.append((f1,f2, iter))
                    
                iter+=1
        if len(differences)==2:
            if differences[0][1]=='e':
                figure=differences[0][0]
                start_iter=differences[0][2]
                start_field=chr(ord('a') + differences[0][2]//8) +''+str(8-differences[0][2]%8)
                end_iter=differences[1][2]
                end_field=chr(ord('a') + differences[1][2]//8) +''+str(8-differences[1][2]%8)
            elif differences[1][1]=='e':
                figure=differences[1][0]
                start_iter=differences[1][2]
                start_field=chr(ord('a') + differences[1][2]//8) +''+str(8-differences[1][2]%8)
                end_iter=differences[0][2]
                end_field=chr(ord('a') + differences[0][2]//8) +''+str(8-differences[0][2]%8)
        
        if len(differences)==4:
            for i in range(4):
                if differences[i][0] in ['k','K'] and differences[i][1]=='e':
                    figure=differences[i][0]
                    start_iter=differences[i][2]
                    start_field=chr(ord('a') + differences[i][2]//8) +''+str(8-differences[i][2]%8)
                elif differences[i][0]=='e' and differences[i][1] in ['k','K']:
                    end_iter=differences[i][2]
                    end_field=chr(ord('a') + differences[i][2]//8) +''+str(8-differences[i][2]%8)
            
        if 5>len(differences)>0:

            similarities=invert(similarities)
            fem=get_fem_format(fem, figure, start_iter, end_iter)
            moves.append((start_field+end_field, fem))
        

    return moves

        

In [256]:
start_pos_img=load_image('../data/start.png')
lines=detect_lines(start_pos_img)
lines = sorted(lines, key=lambda line: (line[0][0], line[0][1]))

detect_moves('../data/videos/game0.mp4',start_pos_img, lines, classifier, hog)





[('b2b3', 'rnbqkbnr/pppppppp/8/8/8/1P6/P1PPPPPP/RNBQKBNR b'),
 ('g8f6', 'rnbqkb1r/pppppppp/5n2/8/8/1P6/P1PPPPPP/RNBQKBNR w'),
 ('c1b2', 'rnbqkb1r/pppppppp/5n2/8/8/1P6/PBPPPPPP/RN1QKBNR b'),
 ('g7g6', 'rnbqkb1r/pppppp1p/5np1/8/8/1P6/PBPPPPPP/RN1QKBNR w'),
 ('e2e4', 'rnbqkb1r/pppppp1p/5np1/8/4P3/1P6/PBPP1PPP/RN1QKBNR b'),
 ('d7d6', 'rnbqkb1r/ppp1pp1p/3p1np1/8/4P3/1P6/PBPP1PPP/RN1QKBNR w'),
 ('f2f4', 'rnbqkb1r/ppp1pp1p/3p1np1/8/4PP2/1P6/PBPP2PP/RN1QKBNR b'),
 ('f8g7', 'rnbqk2r/ppp1ppbp/3p1np1/8/4PP2/1P6/PBPP2PP/RN1QKBNR w'),
 ('e4e5', 'rnbqk2r/ppp1ppbp/3p1np1/4P3/5P2/1P6/PBPP2PP/RN1QKBNR b'),
 ('d6e5', 'rnbqk2r/ppp1ppbp/5np1/4p3/5P2/1P6/PBPP2PP/RN1QKBNR w'),
 ('f4e5', 'rnbqk2r/ppp1ppbp/5np1/4P3/8/1P6/PBPP2PP/RN1QKBNR b'),
 ('f6d5', 'rnbqk2r/ppp1ppbp/6p1/3nP3/8/1P6/PBPP2PP/RN1QKBNR w'),
 ('f1c4', 'rnbqk2r/ppp1ppbp/6p1/3nP3/2B5/1P6/PBPP2PP/RN1QK1NR b'),
 ('e8g8', 'rnbq1rk1/ppp1ppbp/6p1/3nP3/2B5/1P6/PBPP2PP/RN1QK1NR w'),
 ('d1f3', 'rnbq1rk1/ppp1ppbp/6p1/3nP3/2B5/1P3Q2/PBPP2PP/RN2K1NR b'),
 (