<a href="https://colab.research.google.com/github/phytometrics/arabidopsis_leaf_stomata_quantification/blob/main/Easy_custom_inference_from_google_drive_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Easy code execution version with Google Drive

# Google Drive Activation

In [1]:
# Authorization window popsup. Must finish the procedure to continue the process.
from google.colab import drive
drive.mount('gdrive')

Mounted at gdrive


# Preparation (do not unfold this cell block. Press execute button below once.)

In [2]:
!pip install onnxruntime-gpu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.13.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (115.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 KB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 KB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, onnxruntime-gpu
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-gpu-1.13.1


In [22]:
import seaborn as sns
import os
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from skimage.measure import label, regionprops, find_contours
import onnxruntime
import gdown
import copy
from tqdm import tqdm
import time

## Get model weights

In [4]:
det = "1HjZJzRjs5NXlXchbRqOGybdSXx--HZNF"
seg = "1Hhr68lZsycMmFkQlZ7WphK6cru29oZgN"
!gdown {det}
!gdown {seg}

Downloading...
From: https://drive.google.com/uc?id=1HjZJzRjs5NXlXchbRqOGybdSXx--HZNF
To: /content/221121_micro_yolox_s1920.onnx
100% 35.8M/35.8M [00:00<00:00, 85.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Hhr68lZsycMmFkQlZ7WphK6cru29oZgN
To: /content/221121_micro_seg.onnx
100% 26.7M/26.7M [00:00<00:00, 32.1MB/s]


## Pipeline


In [5]:
# the codes in this cell are adopted/customized based on the following sites. Authors do not claim any rights regarding this code cell
# https://github.com/Kazuhito00/YOLOX-ONNX-TFLite-Sample
# https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/visualize.py
# both of the original repository are licenced with APACHE LICENSE 2.0

class YoloxONNX(object):
    def __init__(
        self,
        model_path="",
        input_shape=(1920, 1920),
        nms_th=0.45,
        score_th=0.1,
        with_p6=False,
        providers=['CUDAExecutionProvider', 'CPUExecutionProvider'],
    ):
        self.input_shape = input_shape
        self.model_path = model_path
        self.nms_th = nms_th
        self.score_th = score_th
        self.with_p6 = with_p6
        self.onnx_session = onnxruntime.InferenceSession(
            self.model_path,
            providers=[providers[0]],
        )

        self.input_name = self.onnx_session.get_inputs()[0].name
        self.output_name = self.onnx_session.get_outputs()[0].name

    def inference(self, image):
        temp_image = copy.deepcopy(image)
        image_height, image_width = image.shape[0], image.shape[1]
        image, ratio = self._preprocess(temp_image, self.input_shape)
        results = self.onnx_session.run(
            None,
            {self.input_name: image[None, :, :, :]},
        )
        bboxes, scores, class_ids = self._postprocess(
            results[0],
            self.input_shape,
            ratio,
            self.nms_th,
            self.score_th,
            image_width,
            image_height,
            p6=self.with_p6,
        )

        return bboxes, scores, class_ids

    def _preprocess(self, image, input_size, swap=(2, 0, 1)):
        if len(image.shape) == 3:
            padded_image = np.ones(
                (input_size[0], input_size[1], 3), dtype=np.uint8) * 114
        else:
            padded_image = np.ones(input_size, dtype=np.uint8) * 114

        ratio = min(input_size[0] / image.shape[0],
                    input_size[1] / image.shape[1])
        resized_image = cv2.resize(
            image,
            (int(image.shape[1] * ratio), int(image.shape[0] * ratio)),
            interpolation=cv2.INTER_LINEAR,
        )
        resized_image = resized_image.astype(np.uint8)

        padded_image[:int(image.shape[0] * ratio), :int(image.shape[1] *
                                                        ratio)] = resized_image
        padded_image = padded_image.transpose(swap)
        padded_image = np.ascontiguousarray(padded_image, dtype=np.float32)

        return padded_image, ratio

    def _postprocess(
        self,
        outputs,
        img_size,
        ratio,
        nms_th,
        score_th,
        max_width,
        max_height,
        p6=False,
    ):
        grids = []
        expanded_strides = []

        if not p6:
            strides = [8, 16, 32]
        else:
            strides = [8, 16, 32, 64]

        hsizes = [img_size[0] // stride for stride in strides]
        wsizes = [img_size[1] // stride for stride in strides]

        for hsize, wsize, stride in zip(hsizes, wsizes, strides):
            xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
            grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
            grids.append(grid)
            shape = grid.shape[:2]
            expanded_strides.append(np.full((*shape, 1), stride))

        grids = np.concatenate(grids, 1)
        expanded_strides = np.concatenate(expanded_strides, 1)
        outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
        outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides

        predictions = outputs[0]
        boxes = predictions[:, :4]
        scores = predictions[:, 4:5] * predictions[:, 5:]

        boxes_xyxy = np.ones_like(boxes)
        boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
        boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
        boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
        boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
        boxes_xyxy /= ratio

        dets = self._multiclass_nms(
            boxes_xyxy,
            scores,
            nms_thr=nms_th,
            score_thr=score_th,
        )

        bboxes, scores, class_ids = [], [], []
        if dets is not None:
            bboxes, scores, class_ids = dets[:, :4], dets[:, 4], dets[:, 5]
            for bbox in bboxes:
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])
                bbox[2] = min(bbox[2], max_width)
                bbox[3] = min(bbox[3], max_height)

        return bboxes, scores, class_ids

    def _nms(self, boxes, scores, nms_thr):
        x1 = boxes[:, 0]
        y1 = boxes[:, 1]
        x2 = boxes[:, 2]
        y2 = boxes[:, 3]

        areas = (x2 - x1 + 1) * (y2 - y1 + 1)
        order = scores.argsort()[::-1]

        keep = []
        while order.size > 0:
            i = order[0]
            keep.append(i)
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])

            w = np.maximum(0.0, xx2 - xx1 + 1)
            h = np.maximum(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (areas[i] + areas[order[1:]] - inter)

            inds = np.where(ovr <= nms_thr)[0]
            order = order[inds + 1]

        return keep

    def _multiclass_nms(
        self,
        boxes,
        scores,
        nms_thr,
        score_thr,
        class_agnostic=True,
    ):
        if class_agnostic:
            nms_method = self._multiclass_nms_class_agnostic
        else:
            nms_method = self._multiclass_nms_class_aware

        return nms_method(boxes, scores, nms_thr, score_thr)

    def _multiclass_nms_class_aware(self, boxes, scores, nms_thr, score_thr):
        final_dets = []
        num_classes = scores.shape[1]

        for cls_ind in range(num_classes):
            cls_scores = scores[:, cls_ind]
            valid_score_mask = cls_scores > score_thr

            if valid_score_mask.sum() == 0:
                continue
            else:
                valid_scores = cls_scores[valid_score_mask]
                valid_boxes = boxes[valid_score_mask]
                keep = self._nms(valid_boxes, valid_scores, nms_thr)
                if len(keep) > 0:
                    cls_inds = np.ones((len(keep), 1)) * cls_ind
                    dets = np.concatenate(
                        [
                            valid_boxes[keep], valid_scores[keep, None],
                            cls_inds
                        ],
                        1,
                    )
                    final_dets.append(dets)

        if len(final_dets) == 0:
            return None

        return np.concatenate(final_dets, 0)

    def _multiclass_nms_class_agnostic(self, boxes, scores, nms_thr,
                                       score_thr):
        cls_inds = scores.argmax(1)
        cls_scores = scores[np.arange(len(cls_inds)), cls_inds]

        valid_score_mask = cls_scores > score_thr

        if valid_score_mask.sum() == 0:
            return None

        valid_scores = cls_scores[valid_score_mask]
        valid_boxes = boxes[valid_score_mask]
        valid_cls_inds = cls_inds[valid_score_mask]
        keep = self._nms(valid_boxes, valid_scores, nms_thr)

        dets = None
        if keep:
            dets = np.concatenate([
                valid_boxes[keep],
                valid_scores[keep, None],
                valid_cls_inds[keep, None],
            ], 1)

        return dets
def annotate(img, boxes, scores, cls_ids, conf=0.5, class_names=None, text=True, CROP_SIZE=64):
    for i in range(len(boxes)):
        box = boxes[i]
        cls_id = int(cls_ids[i])
        score = scores[i]
        if score < conf:
            continue
        x0 = int(box[0])
        y0 = int(box[1])
        x1 = int(box[2])
        y1 = int(box[3])

        xmin, ymin, xmax, ymax = x0, y0, x1, y1
        xcenter = int((xmin + xmax) / 2)
        ycenter = int((ymin + ymax) / 2)
        new_xmin = xcenter - CROP_SIZE // 2
        new_ymin = ycenter - CROP_SIZE // 2
        new_xmax = xcenter + CROP_SIZE // 2
        new_ymax = ycenter + CROP_SIZE // 2
        new_xmin = np.max([0, new_xmin]) 
        new_ymin = np.max([0, new_ymin])
        new_xmax = np.min([img.shape[1], new_xmax])
        new_ymax = np.min([img.shape[0], new_ymax])

        x0, y0, x1, y1 = new_xmin, new_ymin, new_xmax, new_ymax

        color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
        text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
        txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
        font = cv2.FONT_HERSHEY_SIMPLEX

        txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
        cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)

        txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
        if text:
            cv2.rectangle(
                img,
                (x0, y0 + 1),
                (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
                txt_bk_color,
                -1
            )
            cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)

    return img


_COLORS = np.array(
    [
        1., 0., 0., # open red
        0., 0., 1., # close blue
        # R, G, B
    ]
).astype(np.float32).reshape(-1, 3)

In [6]:
class Segmentation(object):
    def __init__(self,
                 model_path="",
                 input_shape=(64, 64),
                 providers=['CPUExecutionProvider'],
                 ):
        self.input_shape = input_shape
        self.onnx_session = onnxruntime.InferenceSession(
            model_path,
            providers=providers,
        )

        self.input_name = self.onnx_session.get_inputs()[0].name
        self.output_name = self.onnx_session.get_outputs()[0].name

    def inference(self, image):
        image = image.transpose((2, 0, 1)).astype("float32")
        result = self.onnx_session.run(
            None,
            {self.input_name: image[None, :, :, :]},
        )[0][0][0]
        return result


def crop_stomata(image, dets, CROP_SIZE):
    stomatas = []
    offsets = []

    for det in dets:
        xmin, ymin, xmax, ymax = det[0], det[1], det[2], det[3]
        xcenter = int((xmin + xmax) / 2)
        ycenter = int((ymin + ymax) / 2)
        new_xmin = xcenter - CROP_SIZE // 2
        new_ymin = ycenter - CROP_SIZE // 2
        new_xmax = xcenter + CROP_SIZE // 2
        new_ymax = ycenter + CROP_SIZE // 2
        new_xmin = np.max([0, new_xmin])
        new_ymin = np.max([0, new_ymin])
        new_xmax = np.min([image.shape[1], new_xmax])
        new_ymax = np.min([image.shape[0], new_ymax])
        stomata = image[new_ymin:new_ymax, new_xmin:new_xmax, :]
        if stomata.shape != (CROP_SIZE, CROP_SIZE, 3):
            daishi = np.zeros((CROP_SIZE, CROP_SIZE, 3), dtype=np.uint8)
            daishi[0: stomata.shape[0], 0:stomata.shape[1], :] = stomata
            stomata = daishi
        else:
            pass
        stomatas.append(stomata)
        offsets.append([new_xmin, new_ymin])
    stomatas = np.array(stomatas)
    offsets = np.array(offsets)
    return stomatas, offsets

In [20]:
def analyze(yolox, unet, folder, file):
    _results = []
    path = os.path.join(folder,file)
    orig_image = cv2.imread(path)
    image = orig_image.copy()
    # detection requires BGR image
    bboxes, scores, class_ids = yolox.inference(image)
    orig_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = orig_image.copy()
    # segmentation requires  RGB image
    stomatas, offsets = crop_stomata(image, bboxes, 64)
    
    annotated1 = annotate(image, bboxes, scores, class_ids, conf=det_score_thresh, class_names=["open", "close"],text=False)
    annotated2 = annotated1.copy()
    if len(bboxes) != 0:
        for i, (stomata, bbox, score, class_id, offset) in enumerate(zip(stomatas, bboxes, scores, class_ids,offsets)):
            if score < det_score_thresh:
                continue
            if class_id == 1:  # 0 open, 1 close
                _results.append([folder, file,0,0,0])  # area, aperture, ratio
                continue
            else:
                mask = unet.inference(stomata)
                mask = mask > 0.5
                lbl = label(mask)
                props = regionprops(lbl)
                if len(props):
                    idx = np.argmax([x.area for x in props])
                    prop = props[idx]
                    area = prop.area
                    try:
                        ratio = prop.minor_axis_length / prop.major_axis_length
                        width = prop.minor_axis_length
                        length = prop.major_axis_length
                    except Exception as e:
                        ratio = np.nan
                        width = np.nan
                        length = np.nan
                    mask = lbl == idx + 1
                    contour = find_contours(mask)[0]
                    contour = np.array([[y+offset[1],x+offset[0]] for y,x in contour])
                    points = list(zip(contour[:,1],contour[:,0]))
                    points = np.array(points).reshape((-1, 1, 2)).astype(np.int32)
                    cv2.fillPoly(annotated2, pts=[points], color=(255,0,0))
                    _results.append([folder, file,area,width,length,ratio])
    return _results, orig_image, annotated1, annotated2

# Inference

In [8]:
# if you want the detection to be more sensitive, change the det_score_thresh from 0.5 to ex. 0.1
det_score_thresh = 0.5

yolox = YoloxONNX(model_path="/content/221121_micro_yolox_s1920.onnx",
                  input_shape=(1920,1920), 
                  nms_th=0.45,
                  score_th=det_score_thresh)
unet = Segmentation(model_path="/content/221121_micro_seg.onnx",
                            input_shape=(64,64))

In [9]:
# define the path of the folder containing images you want to analyze in google folder
# /content/gdrive/MyDrive is the path to the google drive top directory. change the FOLDERNAME to your foldername containing images
folder = "/content/gdrive/MyDrive/FOLDERNAME"

# change the EXT to other format in case you have a different extension image such as tiff or JPEG or .etc
EXT = ".jpg"

# filter files that are only image files/
files = os.listdir(folder)
files = [x for x in files if x.endswith(EXT)]
for file in files:
    print(file)

20190507_Disk_Light_ABA_18.jpg
20190507_Disk_Dark_DMSO_18.jpg
20190507_Disk_Light_ABA_05.jpg
20190507_Disk_Dark_DMSO_20.jpg
20190507_Disk_Light_DMSO_10.jpg
20190507_Disk_Light_DMSO_05.jpg
20190507_Disk_Dark_FC_09.jpg
20190507_Disk_Dark_FC_11.jpg


if the above code does not output the filename of the images you want to anlayze, something is wrong. double check before proceeding

In [28]:
%config InlineBackend.figure_formats = {'png', 'retina'}

timestr = time.strftime("%Y%m%d-%H%M%S")
out_folder = os.path.join("/content",timestr)

isExist = os.path.exists(out_folder)
if not isExist:
   os.makedirs(out_folder)

print("Result Saving in", out_folder)

results = []
for file in tqdm(files):
    _results, image, annotated1, annotated2 = analyze(yolox,unet,folder,file)
    results.extend(_results)
    plt.figure(figsize=(45,15))
    plt.subplot(1,3,1)
    plt.imshow(image)
    plt.subplot(1,3,2)
    plt.imshow(annotated1)
    plt.subplot(1,3,3)
    plt.imshow(annotated2)
    plt.savefig(os.path.join(out_folder,file))
    plt.close()
df = pd.DataFrame(results, columns=["folder","file","area","width","length","ratio"])
df.to_csv(os.path.join(out_folder,"result.csv"),header=True)
!zip -r -j {timestr+".zip"} {out_folder}
print("zip file saved to",timestr,".zip","right click and download at the left navigation")

Result Saving in /content/20230107-024337


100%|██████████| 8/8 [00:17<00:00,  2.24s/it]


  adding: 20190507_Disk_Dark_FC_09.jpg (deflated 4%)
  adding: 20190507_Disk_Light_ABA_05.jpg (deflated 4%)
  adding: 20190507_Disk_Dark_DMSO_20.jpg (deflated 4%)
  adding: result.csv (deflated 86%)
  adding: 20190507_Disk_Light_DMSO_10.jpg (deflated 4%)
  adding: 20190507_Disk_Dark_DMSO_18.jpg (deflated 4%)
  adding: 20190507_Disk_Dark_FC_11.jpg (deflated 4%)
  adding: 20190507_Disk_Light_ABA_18.jpg (deflated 4%)
  adding: 20190507_Disk_Light_DMSO_05.jpg (deflated 4%)
zip file saved to 20230107-024337 .zip right click and download at the left navigation


In [None]:
group = df.groupby(["folder","file"])
print("mean")
print(group.mean())
print("___")
print("std")
print(group.std())

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
palette = sns.color_palette([ '#F5AF98','#8CC1FF',], desat=1)
sns.set_palette(palette)
sns.boxplot(data=df, x="file",y="width", hue="folder", showfliers=False, palette=palette)
sns.stripplot(data=df,x="file",y="width", hue="folder", dodge=True, jitter=True, size=4, color='black', alpha=0.4)
handles, labels = ax.get_legend_handles_labels()
#l = plt.legend(handles[0:3], labels[0:3], bbox_to_anchor=(0,1), loc='upper left', borderaxespad=1)
l = plt.legend(handles[0:2], labels[0:2], bbox_to_anchor=(0,1), loc='upper left', borderaxespad=1)

plt.xticks(rotation=90)