# 数梯子之关键区域识别 

In [1]:
%reload_ext watermark
%reload_ext autoreload
%autoreload 2
%watermark -v -p numpy,sklearn,pandas
%watermark -v -p cv2,PIL,matplotlib
%watermark -v -p torch,torchvision,torchaudio
%matplotlib inline
%config InlineBackend.figure_format='retina'
%config IPCompleter.use_jedi = False

from IPython.display import display, Markdown, HTML, Javascript
display(HTML('<style>.container { width:%d%% !important; }</style>' % 80))

import sys, os, io, time, random, math
import json, base64, requests
import os.path as osp

def _IMPORT_(x):
    try:
        segs = x.split(' ')
        g = globals()
        if 'github.com' in segs[1]:
            uri = segs[1].replace('github.com', 'raw.githubusercontent.com')
            mod = uri.split('/')
            for s in ['main', 'master']:
                uri = 'https://' + '/'.join(mod[:-1]) + '/main/' + mod[-1] + '.py'
                x = requests.get(uri).text
                if x.status == 200:
                    break
        elif 'gitee.com' in segs[1]:
            mod = segs[1].split('/')
            for s in ['/raw/main/', '/raw/master/']:
                uri = 'https://' + '/'.join(mod[:3]) + s + '/'.join(mod[3:]) + '.py'
                x = requests.get(uri).text
                if x.status == 200:
                    break
        elif segs[1][0] == '/':
            with open(segs[1] + '.py') as fr:
                x = fr.read()
        exec(x, g)
    except:
        pass

def print_progress_bar(x):
    print('\r', end='')
    print('Progress: {}%:'.format(x), '%s%s' % ('▋'*(x//2), '.'*((100-x)//2)), end='')
    sys.stdout.flush()


CPython 3.6.9
IPython 7.16.1

numpy 1.19.4
sklearn 0.24.0
pandas 1.1.5
CPython 3.6.9
IPython 7.16.1

cv2 4.5.1
PIL 6.2.2
matplotlib 3.3.3
CPython 3.6.9
IPython 7.16.1

torch 1.8.0.dev20210103+cu101
torchvision 0.9.0.dev20210103+cu101
torchaudio not installed


In [2]:
###
### Common ###
###

_IMPORT_('import numpy as np')
_IMPORT_('import pandas as pd')
_IMPORT_('from tqdm.notebook import tqdm')

###
### Display ###
###

_IMPORT_('import cv2')
_IMPORT_('from PIL import Image')
_IMPORT_('from torchvision.utils import make_grid')
_IMPORT_('import matplotlib.pyplot as plt')
_IMPORT_('import plotly')
_IMPORT_('import plotly.graph_objects as go')
_IMPORT_('import ipywidgets as widgets')

# plotly.offline.init_notebook_mode(connected=False)

plt.rcParams['figure.figsize'] = (12.0, 8.0)

def show_table(headers, data, width=900):
    ncols = len(headers)
    width = int(width / ncols)
    lralign = []
    caption = []
    for item in headers:
        astr = ''
        if item[0] == ':':
            astr = ':'
            item = item[1:]
        astr += '---'
        if item[-1] == ':':
            astr += ':'
            item = item[:-1]
        lralign.append(astr)
        caption.append(item)
    captionstr = '|'.join(caption) + chr(10)
    lralignstr = '|'.join(lralign) + chr(10)
    imgholdstr = '|'.join(['<img width=%d/>' % width] * ncols) + chr(10)
    table = captionstr + lralignstr + imgholdstr
    is_dict = isinstance(data[0], dict)
    for row in data:
        if is_dict:
            table += '|'.join([f'{row[c]}' for c in caption]) + chr(10)
        else:
            table += '|'.join([f'{col}' for col in row]) + chr(10)
    return Markdown(table)

def show_video(vidsrc, width=None, height=None):
    W, H = '', ''
    if width:
        W = 'width=%d' % width
    if height:
        H = 'height=%d' % height
    if vidsrc.startswith('http'):
        data_url = vidsrc
    else:
        mp4 = open(vidsrc, 'rb').read()
        data_url = 'data:video/mp4;base64,' + base64.b64encode(mp4).decode()
    return HTML('<video %s %s controls src="%s" type="video/mp4"/>' % (W, H, data_url))

def show_image(imgsrc, width=None, height=None):
    if isinstance(imgsrc, np.ndarray):
        img = imgsrc
        if width or height:
            if width and height:
                size = (width, height)
            else:
                rate = img.shape[1] / img.shape[0]
                if width:
                    size = (width, int(width/rate))
                else:
                    size = (int(height*rate), height)
            img = cv2.resize(img, size)
            plt.figure(figsize=(3*int(size[0]/80+1), 3*int(size[1]/80+1)), dpi=80)
        plt.axis('off')
        if len(img.shape) > 2:
            plt.imshow(img);
        else:
            plt.imshow(img, cmap='gray');
        return

    W, H = '', ''
    if width:
        W = 'width=%d' % width
    if height:
        H = 'height=%d' % height
    if imgsrc.startswith('http'):
        data_url = imgsrc
    else:
        if len(imgsrc) > 2048:
            data_url = 'data:image/jpg;base64,' + imgsrc
        else:
            img = open(imgsrc, 'rb').read()
            data_url = 'data:image/jpg;base64,' + base64.b64encode(img).decode()
    return HTML('<img %s %s src="%s"/>' % (W, H, data_url))

def im_read(url, rgb=True, size=None):
    if url.startswith('http'):
        response = requests.get(url)
        if response:
            imgmat = np.frombuffer(response.content, dtype=np.uint8)
            img = cv2.imdecode(imgmat, cv2.IMREAD_COLOR)
        else:
            return None
    else:
        img = cv2.imread(url)
        
    if rgb:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if size:
        if isinstance(size, int):
            size = (size, size)
        img = cv2.resize(img, size, interpolation=cv2.INTER_AREA)
    return img

def img2bytes(x, width=None, height=None):
    if isinstance(x, bytes):
        return x

    if isinstance(x, str):
        if os.path.isfile(x):
            x = PIL.Image.open(x).convert('RGB')
        else:
            import cairosvg
            with io.BytesIO() as fw:
                cairosvg.svg2png(bytestring=x, write_to=fw,
                        output_width=width, output_height=height)
                return fw.getvalue()

    from matplotlib.figure import Figure
    if isinstance(x, Figure):
        with io.BytesIO() as fw:
            plt.savefig(fw)
            return fw.getvalue()

    from torch import Tensor
    from torchvision import transforms
    from PIL import Image
    if isinstance(x, Tensor):
        x = transforms.ToPILImage()(x)
    elif isinstance(x, np.ndarray):
        x = Image.fromarray(x.astype('uint8')).convert('RGB')

    if isinstance(x, Image.Image):
        if width and height:
            x = x.resize((width, height))
        with io.BytesIO() as fw:
            x.save(fw, format='PNG')
            return fw.getvalue()
    raise NotImplementedError(type(x))

def img2b64(x):
    return base64.b64encode(img2bytes(x)).decode()


In [3]:
from sklearn.cluster import KMeans
from collections import Counter

test_samples = ['./ladder_39.png', './632306671_40.jpg']

In [4]:
def extract_blackhole_rect(imgpath, thresh, kernel, iterations, iqr):
    img_rgb = im_read(imgpath)
    
    # gray and threshold
    img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY) 
    img_bin = cv2.threshold(img_gray, thresh, 255, cv2.THRESH_BINARY_INV)[1]
    
    # dilate
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel * 5, kernel))
    img_dilate = cv2.dilate(img_bin, kernel, iterations=iterations)
    
    # find contours
    contours = cv2.findContours(
        img_dilate,
        cv2.RETR_LIST,
        cv2.CHAIN_APPROX_SIMPLE)[0]
    
    # check width, height and area
    # np_data = np.array([(*cv2.boundingRect(c), cv2.contourArea(c)) for c in contours])
    np_data = np.array([cv2.boundingRect(c) for c in contours])
    np_data = np.column_stack((np_data, np_data[:, 2] * np_data[:, 3]))
    global g_data
    g_data = np_data
    Q1, medians, Q3 = np.percentile(np_data[:, 2:], [25, 50, 75], axis=0)
    IQR = Q3 - Q1
    upper_adjacent = np.clip(Q3 + IQR * iqr, Q3, np.max(np_data[:, 2:], axis=0))
    lower_adjacent = np.clip(Q1 - IQR * iqr, np.min(np_data[:, 2:], axis=0), Q1)
    outliers_mask = np.array([False] * len(np_data))

    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(20, 32));
    
    img = img_gray.copy()
    for c in contours:
        cv2.drawContours(img, contours, -1, (255, 255, 255), thickness=-1)
    
    for i, img in enumerate([img_bin, img_dilate, img]):
        axes[0][i].imshow(img, cmap='gray')

    for i, col, title in zip((0, 1, 2), (2, 3, 4), ('Width', 'Height', 'Area')):
        axes[1][i].yaxis.grid(True)

        axes[1][i].violinplot(np_data[:, col], showmeans=True, showmedians=False, showextrema=True)
        axes[1][i].set_title(f'Violin Plot of {title}')

        axes[1][i].scatter(1, medians[i], marker='o', color='white', s=40, zorder=3)
        axes[1][i].vlines(1, Q1[i], Q3[i], color='k', linestyle='-', lw=20)
        axes[1][i].vlines(1, lower_adjacent[i], upper_adjacent[i], color='k', linestyle='-', lw=3)
        axes[1][i].text(1, lower_adjacent[i], f'lower:{lower_adjacent[i]}', horizontalalignment='center')
        axes[1][i].text(1, upper_adjacent[i], f'upper:{upper_adjacent[i]}', horizontalalignment='center')

        axes[2][i].axis('off')
        axes[2][i].set_title(f'{lower_adjacent[i]} >) {title} (> {upper_adjacent[i]}')
        mask = np.logical_or(np_data[:, col] < lower_adjacent[i], np_data[:, col] > upper_adjacent[i])
        outliers_mask = np.logical_or(outliers_mask, mask)
        img = img_rgb.copy()
        for item in np_data[mask]:
            x1, y1 = int(item[0]), int(item[1])
            x2, y2 = int(item[0] + item[2]), int(item[1] + item[3])
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), thickness=-1)
        axes[2][i].imshow(img)

    valid_data = np_data[np.logical_not(outliers_mask)]
    model = KMeans(n_clusters=6, max_iter=100)
    clusters = model.fit_predict(valid_data[:, 0].reshape(-1, 1))
    for i, m in enumerate(Counter(clusters).most_common()[:3]):
        cluster_data = valid_data[clusters == m[0]]
        xmin, xmax = cluster_data[:, 0].min(), cluster_data[:, 0].max()
        xmin, xmax = int(xmin - 2 * medians[0]), int(xmax + 2 * medians[0])
        
        img = np.zeros(img_rgb.shape, dtype=np.uint8)
        img[:] = [158, 164, 158] 
        img[:, xmin:xmax, :] = img_rgb[:, xmin:xmax, :]
        
        for item in cluster_data:
            x1, y1 = int(item[0]), int(item[1])
            x2, y2 = int(item[0] + item[2]), int(item[1] + item[3])
            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), thickness=-1)
        axes[3][i].set_title(f'Count: {m[1]}')
        axes[3][i].imshow(img)
        axes[3][i].axis('off')

widgets.interact_manual(
    extract_blackhole_rect,
    imgpath=widgets.Dropdown(options=[(p.split('/')[-1][:-4], p) for p in test_samples]),
    thresh=widgets.IntSlider(min=1, max=60, value=15),
    kernel=widgets.IntSlider(min=1, max=16, value=3),
    iterations=widgets.IntSlider(min=1, max=5, value=1),
    iqr=widgets.FloatSlider(min=0.5, max=2.5, value=1.5)
);

interactive(children=(Dropdown(description='imgpath', options=(('ladder_39', './ladder_39.png'), ('632306671_4…