Creating Training Set Data

Code based on https://towardsdatascience.com/board-game-image-recognition-using-neural-networks-116fc876dafa

In [None]:
!gdown --id 1F_goGWD-dFSEtUlKxC7Y8GbKx7K7cKRM
!unzip -q renders-1k.zip
!mkdir -p -- training_renders_cropped/all training_renders_cropped/bb training_renders_cropped/bk training_renders_cropped/bn training_renders_cropped/bp training_renders_cropped/bq training_renders_cropped/br training_renders_cropped/wb training_renders_cropped/wk training_renders_cropped/wn training_renders_cropped/wp training_renders_cropped/wq training_renders_cropped/wr training_renders_cropped/blank

Downloading...
From: https://drive.google.com/uc?id=1F_goGWD-dFSEtUlKxC7Y8GbKx7K7cKRM
To: /content/renders-1k.zip
100% 939M/939M [00:10<00:00, 88.0MB/s]


In [None]:
import glob
import math
import cv2
import numpy as np
import scipy.spatial as spatial
import scipy.cluster as cluster
from collections import defaultdict
from statistics import mean
from google.colab.patches import cv2_imshow
from numpy.polynomial import polynomial as P
import os
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.cluster import AgglomerativeClustering, DBSCAN
import typing
import matplotlib.pyplot as plt
import json

In [None]:
def print_lines(img, lines, color=(0,0,255)):
  for line in lines:
    rho = line[0]
    theta = line[1]
    a = np.cos(theta)
    b = np.sin(theta)
    x0 = a*rho
    y0 = b*rho
    x1 = int(x0 + 1000*(-b))
    y1 = int(y0 + 1000*(a))
    x2 = int(x0 - 1000*(-b))
    y2 = int(y0 - 1000*(a))
    cv2.line(img,(x1,y1),(x2,y2),color,2)
  cv2_imshow(img)

def print_points(img, points, c=(255,0,0)):
  for point in points:
    img = cv2.circle(img, point, radius=5, color=c, thickness=-1)
  cv2_imshow(img)

# Read image and do lite image processing
def read_img(file):
    img = cv2.imread(str(file), 1)

    W = 1000
    height, width, depth = img.shape
    imgScale = W / width
    newX, newY = img.shape[1] * imgScale, img.shape[0] * imgScale
    img = cv2.resize(img, (int(newX), int(newY)))
    
    blur = cv2.bilateralFilter(img,9,75,75)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    (thresh, im_bw) = cv2.threshold(gray, 80, 255, cv2.THRESH_BINARY_INV)
    #cv2_imshow(im_bw)
    #cv2_imshow(img)
    
    return img, im_bw

# Canny edge detection
def canny_edge(img, sigma=0.33):
    v = np.median(img)
    lower = int(max(0, (1.0 - sigma) * v))
    upper = int(min(255, (1.0 + sigma) * v))
    edges = cv2.Canny(img, lower, upper)

    #cv2_imshow(edges)
    return edges

# Hough line detection
def hough_line(edges, img, min_line_length=100, max_line_gap=30, threshold = 90):
    lines = cv2.HoughLines(edges, 1.5, np.pi / 110, threshold, min_line_length, max_line_gap)
    
    lines = np.reshape(lines, (-1, 2))
    return lines

def merge_lines(lines, img, rho_thresh=14):
    strong_lines = np.zeros_like(lines)
    n2 = 0

    for n1 in range (len(lines)):
      line = lines[n1]

      if line[0] < 0:
        line[0] *= -1
        line[1] -= np.pi

      if not ((line[1] >= (4/9*np.pi) and line[1] <= (5/9*np.pi)) or (line[1] >= (-np.pi/9) and line[1] <= (np.pi/9))):
        continue

      rho = line[0]
      theta = line[1]
      if n1 == 0:
        strong_lines[n2] = lines[n1]
        n2 = n2 + 1
        continue

      closeness_rho = np.isclose(rho,strong_lines[0:n2,0],atol = rho_thresh)

      #closeness_theta = np.isclose(theta,strong_lines[0:n2,1],atol = np.pi/12)
      similar_rho = np.any(closeness_rho)
      #similar_theta = np.any(closeness_theta)
      #similar = (similar_rho and similar_theta)

      if not similar_rho: 
        strong_lines[n2] = line
        n2 += 1

    strong_lines = strong_lines[0:n2]
    
    return strong_lines


# Separate line into horizontal and vertical
def h_v_lines(lines):
    h_lines, v_lines = [], []
    for rho, theta in lines:
        if theta < np.pi / 4 or theta > np.pi - np.pi / 4:
            v_lines.append([rho, theta])
        else:
            h_lines.append([rho, theta])
    return h_lines, v_lines

# Find the intersections of the lines
def line_intersections(h_lines, v_lines):
    points = []
    for r_h, t_h in h_lines:
        for r_v, t_v in v_lines:
            a = np.array([[np.cos(t_h), np.sin(t_h)], [np.cos(t_v), np.sin(t_v)]])
            b = np.array([r_h, r_v])
            inter_point = np.linalg.solve(a, b)
            points.append(inter_point)
    return np.array(points)


# Hierarchical cluster (by euclidean distance) intersection points
def cluster_points(points, thresh=15):
    dists = spatial.distance.pdist(points)
    single_linkage = cluster.hierarchy.single(dists)
    flat_clusters = cluster.hierarchy.fcluster(single_linkage, thresh, 'distance')
    cluster_dict = defaultdict(list)
    for i in range(len(flat_clusters)):
        cluster_dict[flat_clusters[i]].append(points[i])
    cluster_values = cluster_dict.values()
    clusters = map(lambda arr: (np.mean(np.array(arr)[:, 0]), np.mean(np.array(arr)[:, 1])), cluster_values)
    return sorted(list(clusters), key=lambda k: [k[1], k[0]])


# Average the y value in each row and augment original point
def augment_points(points, img):
    discard = False
    points_shape = list(np.shape(points))
    augmented_points = []

    if int(points_shape[0] / 11) != 11: discard = True

    for row in range(int(points_shape[0] / 11)):
        start = row * 11
        end = (row * 11) + 10
        rw_points = points[start:end + 1]
        rw_y = []
        rw_x = []
        for point in rw_points:
            x, y = point
            rw_y.append(y)
            rw_x.append(x)
        y_mean = mean(rw_y)
        if len(rw_x) != 11: discard = True
        for i in range(len(rw_x)):
            point = (rw_x[i], y_mean)
            augmented_points.append(point)
    augmented_points = sorted(augmented_points, key=lambda k: [k[1], k[0]])
    return augmented_points, discard

def get_coord_distance(p1, p2):
    coordinate_distance = math.sqrt( ((int(p1[0])-int(p2[0]))**2)+((int(p1[1])-int(p2[1]))**2) )
    return coordinate_distance

def augment_fen(file_name):
    fen = file_name.split("/")[2]
    fen = fen.split(" ")[0]
    expanded_fen = ""
    for c in fen:
        if c.isalpha() or c == '-': 
            expanded_fen += c
        elif c.isdigit():
            n = ord(c) - 48
            for i in range(n):
                expanded_fen += '0'
   
    expanded_fen = expanded_fen.split("-")
    aug_fen = [""] * len(expanded_fen)
    for col in expanded_fen:
        for i in range(len(aug_fen)):
            aug_fen[i] += col[i]

    for i in range(len(aug_fen)):
        aug_fen[i] = aug_fen[i][::-1]
    
    return aug_fen

def augment_fen2(file_name):
    num = file_name.split("/")[2]
    num = num.split(".")[0]

    f = open('./test-1k/' + str(num) + '.json')
    data = json.load(f)
    fen = data['fen']
    expanded_fen = ""
    for c in fen:
        if c.isalpha(): 
            expanded_fen += c
        elif c.isdigit():
            n = ord(c) - 48
            for i in range(n):
                expanded_fen += '0'
        elif c == '/':
            expanded_fen += '-'
    
    expanded_fen = expanded_fen.split("-")

    white = data['white_turn']

    if not white:
        return expanded_fen
    elif white:
        aug_fen = expanded_fen
        for i in range(len(aug_fen)):
            aug_fen[i] = aug_fen[i][::-1]
        return aug_fen[::-1]

def find_base_len(img, points):
    num_list = []
    shape = list(np.shape(points))
    start_point = shape[0] - 14

    if int(shape[0] / 11) >= 8:
        range_num = 8
    else:
        range_num = int((shape[0] / 11) - 2)

    for row in range(range_num):
        start = start_point - (row * 11)
        end = (start_point - 8) - (row * 11)
        num_list.append(range(start, end, -1))  


    avg_base_len = 0

    for i in range(len(num_list)):
        row = num_list[i]
        for j in range(len(row)):
            s = row[j]
            bot_left, bot_right = points[s], points[s + 1]
            start_x, end_x = int(bot_left[0]), int(bot_right[0])
            
            base_len = abs(start_x - end_x)
            avg_base_len += base_len


    avg_base_len = avg_base_len / 64

    return round(avg_base_len)

# Crop board into separate images
def write_crop_images(img, points, counts, aug_fen, avg_base_len):
    num_list = []
    shape = list(np.shape(points))
    start_point = shape[0] - 14

    if int(shape[0] / 11) >= 8:
        range_num = 8
    else:
        range_num = int((shape[0] / 11) - 2)

    for row in range(range_num):
        start = start_point - (row * 11)
        end = (start_point - 8) - (row * 11)
        num_list.append(range(start, end, -1))

    #for row in num_list:
    for i in range(len(num_list)):
        row_fen = aug_fen[i]
        row = num_list[i]
        for j in range(len(row)):
            s = row[j]
            # ratio_h = 1.5
            
            x_dist = avg_base_len
            
            bot_left = points[s]
            start_x, start_y = int(bot_left[0]), int(bot_left[1] - (x_dist * 2))
            end_x, end_y = int(start_x + x_dist), int(start_y + 2*x_dist)

            #base_len = get_coord_distance(points[s], points[s + 1])
            #bot_left, bot_right = points[s], points[s + 1]
            #start_x, start_y = int(bot_left[0]), int(bot_left[1] - (base_len * 1.5))
            #end_x, end_y = int(bot_right[0]), int(bot_right[1])
            if start_y < 0:
                start_y = 0
            if start_x < 0:
                start_x = 0
            #print('y len: ' + str(end_y - start_y))
            #print('x len: ' + str(end_x - start_x))
            cropped = img[start_y: end_y, start_x: end_x]
            dims = (46, 92)
            cropped = cv2.resize(cropped, dims)
            #cv2_imshow(cropped)

            piece = row_fen[j]
            if (piece.islower()):
                piece = "b" + piece
            elif (piece.isupper()):
                piece = "w" + piece.lower()
            elif piece == '0':
                piece = "blank"
            counts[piece] += 1

            #print(piece)
            cv2.imwrite('./training_renders_cropped/' + str(piece) + '/' + str(piece) + '-' + str(counts[piece]) + '.jpeg', cropped)
            cv2.imwrite('./training_renders_cropped/all/' + str(piece) + '-' + str(counts[piece]) + '.jpeg', cropped)
            #print(folder_path + 'data' + str(img_count) + '.jpeg')
    return counts

In [None]:
# Create a list of image file names
img_filename_list = []
folder_name = './test-1k/*'
for path_name in glob.glob(folder_name):
    # file_name = re.search("[\w-]+\.\w+", path_name) (use if in same folder)
    img_filename_list.append(path_name)  # file_name.group()

In [None]:
counts = {
    'bp': 0,
    'br': 0,
    'bn': 0,
    'bb': 0,
    'bq': 0,
    'bk': 0,
    'wp': 0,
    'wr': 0,
    'wn': 0,
    'wb': 0,
    'wq': 0,
    'wk': 0,
    'blank': 0,
}

print_number = 0
image_number = 0

for file_name in img_filename_list:
    if not (file_name.endswith('.png') or file_name.endswith('.jpg')): continue
    print(file_name)
    img, gray_blur = read_img(file_name)
    edges = canny_edge(gray_blur)
    lines = hough_line(edges, img)
    h_lines, v_lines = h_v_lines(lines)
    h_lines = merge_lines(h_lines, img)
    v_lines = merge_lines(v_lines, img)
    intersection_points = line_intersections(h_lines, v_lines)
    points = cluster_points(intersection_points)
    points, discard = augment_points(points, img)

    augmented_fen = augment_fen(file_name)

    if not discard: 
      base_len = find_base_len(img, points)
      counts = write_crop_images(img, points, counts, augmented_fen, base_len)
      print_number += 1
      print('NUM_IMAGES_CROPPED: ' + str(print_number))
    image_number += 1
    print('NUM_IMAGES_PROCESSED: ' + str(image_number))
print(print_number)