# Some statistics to choose the input resolution for your InChI predictor

Analysis of train and test images (random sample) in the [Bristol-Myers Squibb â€“ Molecular Translation Competition](https://www.kaggle.com/c/bms-molecular-translation) to find a suitable input image ratio and resolution.

**Credits: I adapted the crop function from https://www.kaggle.com/markwijkhuizen/advanced-image-cleaning-and-tfrecord-generation (great TFRecord kernel!)**

Hi everyone!

I've seen different choices of the image resolution and w/h ratio so far, some use squares, some rectangles. I did this analysis to learn more about the images we are given, especially after they are cropped. Note that image width and height are swapped if height > width, for orginal versions as well as cropped versions.

**In the end you can find a summary with the fractions of images that need to be 'shrinked' after cropping for different input resolutions together with the mean 'shrink factor' and more statistics for each resolution.**

In this summary width / height ratios of around 2 seems to work best. What ratio and image resolution did you choose as input for your InChI prediction model? What were the reasons?

Feel free to comment below and / or leave a vote if you find this kernel helpful :)

In [None]:
DEBUG = False
IMAGE_NUM = 1000 if DEBUG else 1_300_000
EXAMPLE_NUM = 2 if DEBUG else 5

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
tqdm.pandas()

import cv2
import imageio
import os
import sys
import re
import seaborn as sns
import time
import random
import pickle
       
SEED = round(time.time())
print(f'SEED: {SEED}')
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [None]:
# sorting again after sampling seems to make the file access slightly faster for large IMAGE_NUM
train_ids = pd.read_csv('/kaggle/input/bms-molecular-translation/train_labels.csv', dtype={'image_id': 'string', 'InChI': 'string'}).sample(n=IMAGE_NUM).sort_values(by='image_id', ignore_index=True).image_id
test_ids = pd.read_csv('/kaggle/input/bms-molecular-translation/sample_submission.csv', usecols=['image_id'], dtype={'image_id': 'string'}).sample(n=IMAGE_NUM).sort_values(by='image_id', ignore_index=True).image_id

# Adapted crop function

The crop function from the original source above was adapted to ignore noise pixels and thus crop the real molecule structure only without removing the noise first.

In [None]:
def crop(img, contour_min_size=2, small_stuff_size=2, small_stuff_dist=5, pad_pixels=1, debug=False, my_figsize=(12,6), horizontal=True):
    
    # idea: pad with contour_min_size pixels just in case we cut off
    #       a small part of the structure that is separated by a missing pixel
    
    # rotate counter clockwise to get horizontal images
    h, w = img.shape
    if h > w:
        img = np.rot90(img)
    
    if debug:
        if horizontal:
            fig, ax = plt.subplots(1,2, figsize=my_figsize)
        else:
            fig, ax = plt.subplots(2,1, figsize=my_figsize)
        ax[0].imshow(img)
        ax[0].set_title(f'original image, shape: {img.shape}', size=16)
        
    _, thresh = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)[-2:]
    
    small_stuff = []
    
    x_min0, y_min0, x_max0, y_max0 = np.inf, np.inf, 0, 0
    for cnt in contours:
        if len(cnt) < contour_min_size:  # ignore contours under contour_min_size pixels
            continue
        x, y, w, h = cv2.boundingRect(cnt)
        if w <= small_stuff_size and h <= small_stuff_size:  # collect position of small contours starting with contour_min_size pixels
            small_stuff.append([x, y, x+w, y+h])
            continue
        x_min0 = min(x_min0, x)
        y_min0 = min(y_min0, y)
        x_max0 = max(x_max0, x + w)
        y_max0 = max(y_max0, y + h)
        
    x_min, y_min, x_max, y_max = x_min0, y_min0, x_max0, y_max0
    
    # enlarge the found crop box if it cuts out small stuff that is very close by
    for i in range(len(small_stuff)):
        if small_stuff[i][0] < x_min0 and small_stuff[i][0] + small_stuff_dist >= x_min0:
             x_min = small_stuff[i][0]
        if small_stuff[i][1] < y_min0 and small_stuff[i][1] + small_stuff_dist >= y_min0:
             y_min = small_stuff[i][1]
        if small_stuff[i][2] > x_max0 and small_stuff[i][2] - small_stuff_dist <= x_max0:
             x_max = small_stuff[i][2]
        if small_stuff[i][3] > y_max0 and small_stuff[i][3] - small_stuff_dist <= y_max0:
             y_max = small_stuff[i][3]
                             
    if pad_pixels > 0:  # make sure we get the crop within a valid range
        y_min = max(0, y_min-pad_pixels)
        y_max = min(img.shape[0], y_max+pad_pixels)
        x_min = max(0, x_min-pad_pixels)
        x_max = min(img.shape[1], x_max+pad_pixels)
        
    img_cropped = img[y_min:y_max, x_min:x_max]
    
    if debug:
        ax[1].imshow(img_cropped)
        ax[1].set_title(f'cropped image, shape: {img_cropped.shape}', size=16)
        plt.show()
    
    return img_cropped

In [None]:
def check_cropping(image_id, folder='train', my_figsize=(12,6), horizontal=True):
    print(f'{folder}/{image_id}')
    file_path =  f'/kaggle/input/bms-molecular-translation/{folder}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
    img = 255 - cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
    img = crop(img, debug=True, my_figsize=my_figsize, horizontal=horizontal)

# Check cropped train images

In [None]:
dummy = [check_cropping(image_id, folder='train') for image_id in train_ids[:EXAMPLE_NUM]]

# Check cropped test images

In [None]:
dummy = [check_cropping(image_id, folder='test') for image_id in test_ids[:EXAMPLE_NUM]]

# Image analysis function

Image width and height are swapped if height > width, for orginal versions as well as cropped versions.

In [None]:
pd.set_option('display.float_format', lambda x: '%.2f' % x)

def analyse_img_sizes(image_ids, folder='train', plots=False, w_large=500, h_large=250, very_large_factor=1.5):
    ws = []
    hs = []
    ws_c = []
    hs_c = []
    fs = []
    for image_id in tqdm(image_ids):
        file_path =  f'/kaggle/input/bms-molecular-translation/{folder}/{image_id[0]}/{image_id[1]}/{image_id[2]}/{image_id}.png'
        file_size = os.path.getsize(file_path) 
        fs.append(file_size)
        img = 255 - cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)  # '255 -' need for cropping to work

        h, w = img.shape
        if h > w:
            h, w = w, h
        ws.append(w)
        hs.append(h)

        img_cropped = crop(img)
        h_c, w_c = img_cropped.shape
        if h_c > w_c:
            h_c, w_c = w_c, h_c
        ws_c.append(w_c)
        hs_c.append(h_c)

    img_info = pd.DataFrame({'image_id': image_ids, 'file_size': fs, 'width': ws, 'width_crop': ws_c, 'height': hs, 'height_crop': hs_c})
    
    img_info['area'] = img_info.width * img_info.height
    img_info['area_crop'] = img_info.width_crop * img_info.height_crop
    img_info['ratio'] = img_info.width / img_info.height
    img_info['ratio_crop'] = img_info.width_crop / img_info.height_crop
        
    img_info_large = img_info.loc[np.logical_or(img_info.width_crop > w_large, img_info.height_crop > h_large),:]
    
    img_info_very_large = img_info.loc[np.logical_or(img_info.width_crop > very_large_factor*w_large, img_info.height_crop > very_large_factor*h_large),:]
        
    print(f'statistics for all images')
    display(img_info.describe())
    print()
    print(f"statistics for 'large' images with cropped width > {w_large} or height > {h_large} ({len(img_info_large)/len(img_info)*100:.3}%):")
    display(img_info_large.describe())
    print()
    print(f"statistics for 'very large' images with cropped width > {very_large_factor*w_large} or height > {very_large_factor*h_large} ({len(img_info_very_large)/len(img_info)*100:.3}%):")
    display(img_info_very_large.describe())
    
    if plots:
        print()
        print(f"plots for 'large' and 'very large' images only")
        plot_info =  img_info_large
        sns.jointplot(data=plot_info, x='width', y='height', kind='hist')
        sns.jointplot(data=plot_info, x='file_size', y='area', kind='hist')
        sns.jointplot(data=plot_info, x='width', y='width_crop', kind='hist')
        sns.jointplot(data=plot_info, x='height', y='height_crop', kind='hist')
        sns.jointplot(data=plot_info, x='width_crop', y='height_crop', kind='hist')
        sns.jointplot(data=plot_info, x='ratio', y='ratio_crop', kind='hist')
        sns.jointplot(data=plot_info, x='area_crop', y='ratio_crop', kind='hist')

    return img_info

# Train image statistics 

In [None]:
train_img_info = analyse_img_sizes(train_ids, folder='train', plots=not DEBUG)

with open('train_img_info.pkl', 'wb') as handle:
    pickle.dump(train_img_info, handle)

# Test image statistics 

In [None]:
test_img_info = analyse_img_sizes(test_ids, folder='test', plots=not DEBUG)

with open('test_img_info.pkl', 'wb') as handle:
    pickle.dump(test_img_info, handle)

# Images with extremly low width or height

There are some images with extremly low height after cropping. Checking if crop function made a mistake... Seems legit.

In [None]:
def plot_extreme_images(img_info, folder='train', my_figsize = (20, 10)):
    img_info_width = img_info.sort_values(by='width_crop', ignore_index=True)[:EXAMPLE_NUM]
    img_info_height = img_info.sort_values(by='height_crop', ignore_index=True)[:EXAMPLE_NUM]
    
    print('very low height images (after swapping if height > width)')
    [check_cropping(image_id, folder=folder, my_figsize=my_figsize, horizontal=False) for image_id in img_info_height.image_id]
    
    print('very low width images (after swapping if height > width)')
    [check_cropping(image_id, folder=folder, my_figsize=my_figsize) for image_id in img_info_width.image_id]

plot_extreme_images(train_img_info)

In [None]:
plot_extreme_images(test_img_info, folder='test')

# Find best input resolution

In [None]:
input_ratios = [1, 1.25, 1.5, 1.75, 1.9, 2, 2.1, 2.25, 2.5]

def get_res(pixels, ratio):
    pixels = pixels**0.5
    ratio = ratio**0.5
    return (round(pixels*ratio), round(pixels/ratio))

base_pixels = 320*320
input_sizes = [get_res(base_pixels, r) for r in input_ratios]

base_pixels = 448*256
input_sizes += [get_res(base_pixels, r) for r in input_ratios]

base_pixels = 512*256
input_sizes += [get_res(base_pixels, r) for r in input_ratios]

base_pixels = 384*384
input_sizes += [get_res(base_pixels, r) for r in input_ratios]

pixels = [w*h for w, h in input_sizes]
input_ratios = [w/h for w, h in input_sizes]

def calc_shrink_factors(current_width, current_height, input_size):
    if current_width < input_size[0] and current_height < input_size[1]:
        return 1
    else:
        return max(input_size[0]/current_width, input_size[1]/current_height)

def check_resolutions(img_info):

    mean_shrink_factors = []  # mean shrink factor (largest of the two factors to decrease image width and/or height to fit the image into the input size, 1 if image fits already)
    rms_shrink_factors = []  # root mean square 
    mean_shrink_factors_over_1 = []
    rms_shrink_factors_over_1 = []
    fraction_shrinked = []

    for input_size in input_sizes:
        shrink_factors = np.array([calc_shrink_factors(train_img_info.width_crop[i], train_img_info.height_crop[i], input_size) for i in range(len(train_img_info))])
        mean_shrink_factors.append(np.mean(shrink_factors))
        rms_shrink_factors.append(np.mean(shrink_factors**2)**0.5)
        temp = shrink_factors>1
        fraction_shrinked.append(np.mean(temp))
        mean_shrink_factors_over_1.append(np.mean(shrink_factors[temp]))
        rms_shrink_factors_over_1.append(np.mean(shrink_factors[temp]**2)**0.5)
        
    return(pd.DataFrame({'resolution': input_sizes, 'pixels': pixels, 'input_ratio': input_ratios, 'frac_shrinked': fraction_shrinked, 
                        'mean_shr_factor': mean_shrink_factors, 'rms_shr_factor': rms_shrink_factors, 'mean_shr_fac_over_1': mean_shrink_factors_over_1, 
                        'rms_shr_fac_over_1': rms_shrink_factors_over_1}))

pd.set_option('display.float_format', lambda x: '%.3f' % x)

print('train images')
display(check_resolutions(train_img_info))

print('test images')
display(check_resolutions(test_img_info))