In [5]:
import os
import numpy as np
import json 
import glob
from itertools import chain

import matplotlib.pyplot as plt
from PIL import Image

import cv2 
import easyocr

from patchify import patchify, unpatchify

In [6]:
ocrReader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory

In [7]:
filePath = '/home/shared/DARPA/training'

legendFolder = '/home/shared/DARPA/cutted_legend_label'

mapPath = os.path.join(filePath, 'CA_Dubakella.tif')
jsonPath = os.path.join(filePath, 'CA_Dubakella.json')

In [17]:
def cut_legend_from_patched_img(mapPath, jsonPath):
    """
    mapPath = os.path.join(filePath, 'VA_Lahore_bm.tif')
    jsonPath = os.path.join(filePath, 'VA_Lahore_bm.json')
    """
    # read-in map image and then patchify
    map_im =  cv2.imread(mapPath)
    map_im_dims = map_im.shape

    patch_dims = (256,256)
    patch_overlap = 32
    patch_step = patch_dims[1]-patch_overlap

    # To patchify, the (width - patch_width) mod step_size = 0
    shift_x = (map_im.shape[0]-patch_dims[0])%patch_step
    shift_y = (map_im.shape[1]-patch_dims[1])%patch_step
    shift_x_left = shift_x//2
    shift_x_right = shift_x - shift_x_left
    shift_y_left = shift_y//2
    shift_y_right = shift_y - shift_y_left

    shift_coord =  [shift_x_left, shift_x_right, shift_y_left, shift_y_right] # the number of pixels that are cutted at 4 directions

    map_im_cut = map_im[shift_x_left:map_im.shape[0]-shift_x_right, shift_y_left:map_im.shape[1]-shift_y_right,:]
    map_patchs = patchify(map_im_cut, (*patch_dims,3), patch_step)

    # read-in json legend 
    with open(jsonPath, 'r') as f:
        jsonData = json.load(f)

    polyLegendList = [x['label'].split('_')[0] for x in jsonData['shapes'] if x['label'].endswith('_poly')]

    # low and upper bound for black color in HSV
    lower_val = np.array([0,0,10])
    upper_val = np.array([256,256,100])

    found_Legend = []

    for i in range(map_patchs.shape[0]):
        for j in range(map_patchs.shape[1]):

            if len(found_Legend) == len(polyLegendList):
                break 

            # filter image with only black color
            hsv = cv2.cvtColor(map_patchs[i][j][0], cv2.COLOR_BGR2HSV)

            mask = cv2.inRange(hsv, lower_val, upper_val) # Threshold the HSV image to get only black colors
            res = cv2.bitwise_not(mask) # invert the mask to get black letters on white background

            ocrResult = ocrReader.readtext(res)

            if not(ocrResult):
                continue

            for corner_coord, legend_name, _ in ocrResult:
                if legend_name in polyLegendList and legend_name not in found_Legend:
                    found_Legend.append(legend_name)

                    shift_pixel  = 10
                    x_low, x_hi = int(max(0,corner_coord[0][0]-shift_pixel)), int(min(corner_coord[1][0]+shift_pixel, patch_dims[0]))
                    y_low, y_hi = int(max(0,corner_coord[0][1]-shift_pixel)), int(min(corner_coord[2][1]+shift_pixel, patch_dims[1]))

                    # print(y_low, y_hi, x_low, x_hi)
                    im_crop = map_patchs[i][j][0][y_low:y_hi, x_low:x_hi] # need to resize

                    im_crop_resize = cv2.resize(im_crop, dsize=patch_dims, interpolation=cv2.INTER_CUBIC)
                    im = Image.fromarray(im_crop_resize)


                    fileName = mapPath.split('/')[-1].split('.')[0]
                    file = fileName+'_'+legend_name+'_poly.png'

                    cv2.imwrite(os.path.join(legendFolder, file), im_crop_resize)

    # if not all legend can be recognised in OCR, then cut the legend from legend
    if len(found_Legend) != len(polyLegendList):
        diffLegend = list(set(polyLegendList) - set(found_Legend))

        for label_dict in jsonData['shapes']:
            legendname = label_dict['label'].split('_')[0]
            if legendname in diffLegend:

                found_Legend.append(legendname)

                point_coord = label_dict['points']

                if not point_coord: raise Exception("The provided legend does not exist: ", filename)
                flatten_list = list(chain.from_iterable(point_coord))

                if point_coord[0][0] >= point_coord[1][0] or point_coord[0][1] >= point_coord[1][1]:
                    print("Coordinate right is less than left:  ", legendname, point_coord)
                    x_low = int(min([x[0] for x in point_coord]))
                    x_hi = int(max([x[0] for x in point_coord]))
                    y_low = int(min([x[1] for x in point_coord]))
                    y_hi = int(max([x[1] for x in point_coord]))
                elif (len(flatten_list)!=4):
                    x_coord = [x[0] for x in point_coord]
                    y_coord = [x[1] for x in point_coord]
                    x_low, y_low, x_hi, y_hi = int(min(x_coord)), int(min(y_coord)), int(max(x_coord)), int(max(y_coord))
                    # print("Point Coordinates number is not 4: ", filename, legend)
                else: x_low, y_low, x_hi, y_hi = [int(x) for x in flatten_list]
                legend_coor =  [(x_low, y_low), (x_hi, y_hi)]
                shift_pixel  = 3
                im_crop = map_im[y_low+shift_pixel:y_hi-shift_pixel, x_low+shift_pixel:x_hi-shift_pixel] # need to resize
                im_crop_resize = cv2.resize(im_crop, dsize=patch_dims, interpolation=cv2.INTER_CUBIC)

                fileName = mapPath.split('/')[-1].split('.')[0]
                file = fileName+'_'+label_dict['label']+'.png'
                cv2.imwrite(os.path.join(legendFolder, file), im_crop_resize)

    if len(found_Legend) != len(polyLegendList):        
        print('something is wrong')

In [14]:
allJson = glob.glob(os.path.join(filePath,"*.json"))

In [17]:
for jsonPath in allJson[2:4]:
    mapPath = jsonPath.split('.')[0]+'.tif'
    print(mapPath)
    cut_legend_from_patched_img(mapPath, jsonPath)

/home/shared/DARPA/training/CO_HandiesPeak_451002_1955_24000_geo_mosaic.tif
/home/shared/DARPA/training/AK_Christian.tif
