In [None]:
from pathlib import Path
import os
DATA_PATH = Path(os.getcwd()).parent / "data" / "interim"

In [None]:
import pandas

DEFECTS_ENCODING_PATH = DATA_PATH / "train.csv"
TRAINING_IMAGES_PATH = DATA_PATH / "train_images"

defects = pandas.read_csv(DEFECTS_ENCODING_PATH);
all_images_names = os.listdir(TRAINING_IMAGES_PATH)
print(defects.columns)
defects.head(6)

In [None]:
import cv2, matplotlib.pyplot as plt, numpy as np

def load_image(name):
    image_path = TRAINING_IMAGES_PATH / name
    img = cv2.imread(str(image_path))
    return img

def get_image_defects(image_name, defects):
    image_rows = defects[defects['ImageId_ClassId'].str.startswith(image_name)]
    # NOTE there are 4 defects types, there MUST be 4 rows
    assert len(image_rows) == 4
    
    result = {}
    for name, defect_encoding in image_rows.values:
        key = int(name.split('_')[1])
        if defect_encoding is np.nan:
            defect_encoding = None
        else:
            defect_encoding = [int(e) for e in defect_encoding.split(" ")]
        result[key] = defect_encoding
    return result

def convert_defects_to_bitmaps(image, defect_map):
    #converts defect map to equivalent map, but values are masks of image size, where 1's are defects of key class
    assert len(defect_map) == 4
    result = {}
    
    im_h, im_w, channels = image.shape
    for defect_id, encoded_mask in defect_map.items():  
        if encoded_mask:
            bitmap_flat = np.zeros(im_w*im_h)
            for i in range(0, len(encoded_mask), 2):
                start_index = encoded_mask[i]
               
                num_ones = encoded_mask[i+1]
            
                for j in range(start_index, start_index + num_ones):
                    bitmap_flat[j] = 1
            # NOTE ordering of pixels in encoding is top -> down, then left -> right, hence 'F'
            bitmap = np.reshape(bitmap_flat, (im_h, im_w), order='F')
            result[defect_id] = bitmap
        else:
            result[defect_id] = None
            
    return result

def get_sample(image_names, defects, defect_id):
    for i_n in image_names:
        i_defects = get_image_defects(i_n, defects)
        if i_defects[defect_id]:
            i = load_image(i_n)
            return i, convert_defects_to_bitmaps(i, i_defects)[defect_id]
        
img, bmp = get_sample(all_images_names, defects, 4)
plt.figure(figsize=(20,20))
plt.imshow(bmp)
plt.figure(figsize=(20,20))
plt.imshow(img)