# Library

In [1]:
import torch
import numpy as np
import cv2

from model.model import TESeg_ResNet50
from inference.inference import Get_Instances_from_BinaryMask_DistMap
from inference.inference import convert_Tensor_normalize_image, ColorMapNumpyArray, Overlay

In [2]:
sigmoid_op = torch.nn.Sigmoid()

# Process

In [None]:
# Load Model
# download pretrained model at https://drive.google.com/file/d/18NPePHqafuJVyY9JjAzZJU1OKHKTxaNR/view?usp=sharing
model_state_dict = torch.load("teseg_res50_modelStateDict.pth")

model = TESeg_ResNet50()
model.load_state_dict( model_state_dict )
model.cuda()
model.eval()

TESeg_ResNet50(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
   

In [4]:
# Load Image
raw_img = np.load("test_img.npy")
raw_img_color = cv2.cvtColor(raw_img, cv2.COLOR_GRAY2BGR)
input_tensor = convert_Tensor_normalize_image(raw_img_color).unsqueeze(0) # 1,3,512,512

In [5]:
# Inference
with torch.no_grad():
    pred_binary_mask, pred_ellipse_dist = model(input_tensor.cuda()) # pred_binary_mask (1,1,512,512); pred_ellipse_dist (1,1,512,512)

In [6]:
# Post-processing

pred_binary_mask_cpu_numpy = (sigmoid_op(pred_binary_mask) >= 0.5).float().cpu().numpy()[0,0,:]
pred_ellipse_dist_cpu_numpy = pred_ellipse_dist.float().cpu().numpy()[0,0,:]
watershed_result, _ = Get_Instances_from_BinaryMask_DistMap(pred_binary_mask_cpu_numpy, pred_ellipse_dist_cpu_numpy)

In [7]:
# overlay
color_inst, tmp_mask = ColorMapNumpyArray(watershed_result)
overlay_inst = Overlay( color_inst, tmp_mask, raw_img_color )

# colorful ellipse dist:
ellipse_dist_color, _ = ColorMapNumpyArray(pred_ellipse_dist_cpu_numpy)

# 3-channel binary mask
pred_binary_mask_cpu_numpy_color = cv2.cvtColor( (pred_binary_mask_cpu_numpy * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR )

# Save result: raw image, onverlay instance, binary mask, ellipse dist, instance map, 

strip = np.zeros( (512,5,3), dtype= np.uint8)
strip[:, :, :] = 255
    
concat = np.concatenate([raw_img_color, strip, 
                             overlay_inst, strip,
                             pred_binary_mask_cpu_numpy_color, strip,
                             ellipse_dist_color, strip, 
                             color_inst], axis=1)
cv2.imwrite("result.jpg", concat)

True