In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from PIL import Image
import matplotlib.pyplot as plt
import ast
import os

In [None]:
# Define the root data directory
DATA_DIR = "/kaggle/input/hpa-single-cell-image-classification"

# Define the paths to the training and testing image folders respectively
TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train")
TEST_IMG_DIR = os.path.join(DATA_DIR, "test")

# Capture all the relevant full image paths
TRAIN_IMG_PATHS = sorted([os.path.join(TRAIN_IMG_DIR, f_name) for f_name in os.listdir(TRAIN_IMG_DIR)])
TEST_IMG_PATHS = sorted([os.path.join(TEST_IMG_DIR, f_name) for f_name in os.listdir(TEST_IMG_DIR)])

print(f"\n... Recall that 4 training images compose one example (R,G,B,Y) ...")
print(f"... \t– i.e. The first 4 training files are:")
for path in [x.rsplit('/',1)[1] for x in TRAIN_IMG_PATHS[:4]]: print(f"... \t\t– {path}")
print(f"\n... The number of training images is {len(TRAIN_IMG_PATHS)} i.e. {len(TRAIN_IMG_PATHS)//4} 4-channel images ...")
print(f"... The number of testing images is {len(TEST_IMG_PATHS)} i.e. {len(TEST_IMG_PATHS)//4} 4-channel images ...")

PATH_TO_NEW_TEST_CSV = "/kaggle/input/hpa-test-dataframe-with-cellwise-rle/test_data_w_rles (2).csv"
test_df = pd.read_csv(PATH_TO_NEW_TEST_CSV)
test_df.cell_masks = test_df.cell_masks.apply(lambda x: ast.literal_eval(x))
test_df.head()

In [None]:
def rle_encode(img, mask_val=1):
    """
    Turns our masks into RLE encoding to easily store them
    and feed them into models later on
    https://en.wikipedia.org/wiki/Run-length_encoding
    
    Args:
        img (np.array): Segmentation array
        mask_val (int): Which value to use to create the RLE
        
    Returns:
        RLE string
    
    """
    dots = np.where(img.T.flatten() == mask_val)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
        
    return ' '.join([str(x) for x in run_lengths])


def rle_decode(rle_string, height, width):
    """ Convert RLE sttring into a binary mask 
    
    Args:
        rle_string (rle_string): Run length encoding containing 
            segmentation mask information
        height (int): Height of the original image the map comes from
        width (int): Width of the original image the map comes from
    
    Returns:
        Numpy array of the binary segmentation mask for a given cell
    """
    rows,cols = height,width
    rle_numbers = [int(num_string) for num_string in rle_string.split(' ')]
    rle_pairs = np.array(rle_numbers).reshape(-1,2)
    img = np.zeros(rows*cols,dtype=np.uint8)
    for index,length in rle_pairs:
        index -= 1
        img[index:index+length] = 255
    img = img.reshape(cols,rows)
    img = img.T
    return img

def load_image(img_id, img_dir):
    """ Load An Image Using ID and Directory Path - Composes 4 Individual Images """
    rgby = [
        np.asarray(Image.open(os.path.join(img_dir, img_id+f"_{c}.png")), np.uint8) \
        for c in ["red", "green", "blue", "yellow"]
    ]
    return np.stack(rgby, axis=-1)


def plot_rgb(arr, figsize=(12,12)):
    """ Plot 3 Channel Microscopy Image """
    plt.figure(figsize=figsize)
    plt.title(f"RGB Composite Image", fontweight="bold")
    plt.imshow(arr)
    plt.axis(False)
    plt.show()
    
def plot_rgby(arr, figsize=(20,6), title=None, plot_merged=True):
    """ Plot 4 Channels Side by Side """
    if plot_merged:
        n_images=5 
    else:
        n_images=4
    plt.figure(figsize=figsize)
    if type(title) == str:
        plt.suptitle(title, fontsize=20, fontweight="bold")

    for i, c in enumerate(["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus", "Yellow – Endoplasmic Reticulum"]):
        ch_arr = np.zeros_like(arr[..., :-1])        
        if c in ["Red Channel – Microtubles", "Green Channel – Protein of Interest", "Blue - Nucleus"]:
            ch_arr[..., i] = arr[..., i]
        else:
            ch_arr[..., 0] = arr[..., i]
            ch_arr[..., 1] = arr[..., i]
        plt.subplot(1,n_images,i+1)
        plt.title(f"{c.title()}", fontweight="bold")
        plt.imshow(ch_arr)
        plt.axis(False)
        
    if plot_merged:
        plt.subplot(1,n_images,n_images)
        plt.title(f"Merged RGBY into RGB", fontweight="bold")
        plt.imshow(convert_rgby_to_rgb(arr))
        plt.axis(False)
        
    plt.tight_layout(rect=[0, 0.2, 1, 0.97])
    plt.show()
    
    
def grab_contours(cnts):
    # if the length the contours tuple returned by cv2.findContours
    # is '2' then we are using either OpenCV v2.4, v4-beta, or
    # v4-official
    if len(cnts) == 2:
        cnts = cnts[0]

    # if the length of the contours tuple is '3' then we are using
    # either OpenCV v3, v4-pre, or v4-alpha
    elif len(cnts) == 3:
        cnts = cnts[1]

    # otherwise OpenCV has changed their cv2.findContours return
    # signature yet again and I have no idea WTH is going on
    else:
        raise Exception(("Contours tuple must have length 2 or 3, "
            "otherwise OpenCV changed their cv2.findContours return "
            "signature yet again. Refer to OpenCV's documentation "
            "in that case"))

    # return the actual contours array
    return cnts

    
def convert_rgby_to_rgb(arr):
    """ Convert a 4 channel (RGBY) image to a 3 channel RGB image.
    
    Advice From Competition Host/User: lnhtrang

    For annotation (by experts) and for the model, I guess we agree that individual 
    channels with full range px values are better. 
    In annotation, we toggled the channels. 
    For visualization purpose only, you can try blending the channels. 
    For example, 
        - red = red + yellow
        - green = green + yellow/2
        - blue=blue.
        
    Args:
        arr (numpy array): The RGBY, 4 channel numpy array for a given image
    
    Returns:
        RGB Image
    """
    
    rgb_arr = np.zeros_like(arr[..., :-1])
    rgb_arr[..., 0] = arr[..., 0]
    rgb_arr[..., 1] = arr[..., 1]+arr[..., 3]/2
    rgb_arr[..., 2] = arr[..., 2]
    
    return rgb_arr

In [None]:
demo_id = "0040581b-f1f2-4fbe-b043-b6bfea5404bb"
rgby_image = load_image(demo_id, TEST_IMG_DIR)
plot_rgby(rgby_image, title="DEMO IMAGE")
demo_rles = test_df[test_df.ID==demo_id].cell_masks.to_list()[0]
master_mask = sum([rle_decode(rle, rgby_image.shape[0], rgby_image.shape[1]) for rle in demo_rles])
plt.figure(figsize=(7,7))
plt.imshow(master_mask, cmap='gray')
plt.show()