In [104]:
import torch
from models.inference_unet import *
import matplotlib.pyplot as plt
import cv2
from torchvision.io import read_image
from torchvision import transforms
from scipy import ndimage
import numpy as np
from skimage.color import rgb2lab

model = get_trained_unet(model_name='unet_vgg11_upsampling',
                         path='./weights/',
                         params='unet_params_104_lr_1e-05_h2o2.pt')

In [105]:
def get_inference(model,
                  image_path,
                  transform_hr=transforms.Compose([transforms.ToTensor(),transforms.Resize((512,512), antialias=True)]),
                  crop = 0.2,
                  device = None
                  ):
  
    # get devic
     #if not(device):
    device='cuda' if torch.cuda.is_available() else 'cpu'
    
    # getimage content
    img=Image.open(image_path)
    
    
    if crop :
        w, h = img.size
        left = int(w*crop)
        right = int(w - w*crop)
        top = int(h*crop)
        bottom = int(h - h*crop)
        img = img.crop((left,top,right, bottom))
        
    img_t=transform_hr(img)/255
    img_t=img_t.unsqueeze(0)
    del img

    # prediction
    model.eval()
    with torch.no_grad():
        mask_pred = model(img_t.to(device))
        
    _, mask_pred = torch.max(mask_pred, dim=1)
    
    mask = mask_pred.detach().cpu().squeeze(0).numpy()
    del mask_pred
    del model
    torch.cuda.empty_cache()

    return (img_t.squeeze(0).permute(1,2,0).numpy()*255, (mask*255).astype('uint8'))
    # convert single image and store in 

In [107]:
def squeeze_mask(image, mask, fac):
    '''
    reduces mask area by shortening ellipses mmajor and minor axis
    
    '''
    
    # fit BBoxes
    contours = cv2.findContours(mask,  cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    contours = contours[0] if len(contours) == 2 else contours[1]

    bboxes = []
    result = np.zeros(image.shape).astype('uint8')
    
    for i in contours:
        x,y,w,h = cv2.boundingRect(i)
        if w*h>30 : 
            bboxes.append([x,y,w,h])

    bboxes = np.array(bboxes)
    bboxes[:,0] = bboxes[:,0] + bboxes[:,2]/2
    bboxes[:,1] = bboxes[:,1] + bboxes[:,3]/2
    bboxes[:,2] = (1-fac)*bboxes[:,2]/2
    bboxes[:,3] = (1-fac)*bboxes[:,3]/2

    for el in bboxes:
        cv2.ellipse(result, (el[0],el[1]), (el[2],  el[3]), 0, 0,360, [255,255,255], -1)
    
    mask_orig = mask
    mask = result[:,:,0]

    return mask_orig, mask

In [101]:
def get_instance_masks(img, mask) -> np.ndarray:
    
    label_im, nb_labels = ndimage.label(mask) 

    instance_mask = np.zeros((*img.shape[:2],nb_labels))
    for i in range(nb_labels):

        mask_compare = np.full(np.shape(label_im), i+1) 
        separate_mask = np.equal(label_im, mask_compare).astype(int) 

        separate_mask[separate_mask == 1] = 1

        instance_mask[:,:,i] = separate_mask
            
    #assert instance_mask.shape[-1] == np.array(self.col_list).sum() , 'col_list does not match with instance mask dimension' 
    
    assert np.mod(instance_mask.shape[-1],8) == 0 , 'instance masks are not integer multiple of 8!' 
    
    col_list = [8]*(instance_mask.shape[-1]//8)
    
    print('col list modified : {}'.format(col_list))

    return instance_mask, col_list

def sort_instance_masks(instance_mask, col_list, img) :
    '''
    takes instance mask and group mask accoriding to col_list. Use bounding box estimates on each masks 
    for sorting.
    args :
     - instance_mask : ndarray representing num wells containing liquids
     - col_list : group columns 

    reuturns :
     - instance_mask_sorted : instance sorted and grouped
     - bboxes_sorted : sorted bounding boxes

    '''
    
    
    lst_bbox = [] # bbox for sorting
    for i in range(instance_mask.shape[2]):
        separate_mask = instance_mask[:,:,i]
        lst_bbox.append(np.array(get_bboxes(img,separate_mask.astype('uint8'))).ravel())

    dummy =[]
    dummy_ids =[]
    for idx, i in enumerate(lst_bbox):
        if i.size > 0 :
            if i[2]*i[3]>100 and i.size<5:
                dummy.append(i)
                dummy_ids.append(idx)

    lst_bbox = dummy
    instance_mask= instance_mask[:,:,dummy_ids]

    lst_bbox = np.array(lst_bbox)

    idx = [idx for idx, el in enumerate(lst_bbox) if el.any()]
    instance_mask = instance_mask[:,:,idx]
    lst_bbox = np.array([lst_bbox[k].tolist() for k in idx])

    # x-sort
    idx = lst_bbox[:,0].argsort()
    lst_bbox = lst_bbox[idx]

    instance_mask = instance_mask[:,:,idx]

    # ysort provided by user
    bboxes_sorted =[]
    instance_mask_sorted = []
    start = 0
    for num in col_list:
        bboxes_sorted.append(lst_bbox[start:start+num])
        instance_mask_sorted.append(instance_mask[:,:, start:start+num])
        start += num

    # sort y
    #bboxes_sorted2 =#lst_sorted2 =[]
    for i, (el,els) in enumerate(zip(bboxes_sorted,instance_mask_sorted)):
        idx = el[:,1].argsort()
        bboxes_sorted[i] = el[idx]
        instance_mask_sorted[i] = els[:,:,idx]

    #print(bboxes_sorted)

    #instance_mask_sorted = instance_mask_sorted
    #self.bboxes_sorted = bboxes_sorted

    return instance_mask_sorted, bboxes_sorted

In [102]:
def get_colors_from_patches(
                            img,
                            instance_mask_sorted,
                            mode = 'rgb',
                            background_rgb = [255,255,255],
                            background_std = np.array([1e-8,1e-8,1e-8]),
                            verbose = False):
    
    color_list =[]
    err_list = []
            
    errfn = lambda r,e1,b,e2 : np.sqrt((1/r)**2 * e1**2 + (1/b)**2 * e2**2)
    epsi = 1e-12

    img = (img * 255).astype('uint8')

    # lab processing
    if mode == 'lab':
        # lab processing
        if verbose : print('running lab mode')
        for i in instance_mask_sorted:
            mask = i
            dummy =[]
            dummy_err = []
            im1 = np.zeros(img.shape)
            for j in range(mask.shape[2]):
                im1[:,:,0]= mask[:,:,j] * img[:,:,0]
                im1[:,:,1]= mask[:,:,j] * img[:,:,1]
                im1[:,:,2]= mask[:,:,j] * img[:,:,2]

                lab = rgb2lab(im1[mask[:,:,j]>0].astype('uint8')).mean(axis = 0)
                err  = rgb2lab(im1[mask[:,:,j]>0].astype('uint8')).std(axis = 0)

                dummy.append(lab.tolist())
                dummy_err.append(err.tolist())

            color_list.append(dummy)
            err_list.append(dummy_err)

    else : # rgb-resolved processing
        if verbose : print('running rgb-resolved mode')

        for i in instance_mask_sorted:
            mask = i
            dummy =[]
            dummy_err = []
            im1 = np.zeros(img.shape)
            for j in range(mask.shape[2]):
                im1[:,:,0]= mask[:,:,j] * img[:,:,0]
                im1[:,:,1]= mask[:,:,j] * img[:,:,1]
                im1[:,:,2]= mask[:,:,j] * img[:,:,2]
                rgb = im1[mask[:,:,j]>0].mean(axis = 0)
                err = im1[mask[:,:,j]>0].std(axis = 0)

                dummy.append(np.log(background_rgb/(rgb+epsi)).tolist())
                dummy_err.append(errfn(rgb,err,background_rgb,background_std).tolist())

            color_list.append(dummy)
            err_list.append(dummy_err)  


    return color_list, err_list

In [103]:
# Note: only suitable for lab colour model, because it will lack some rgb information under mode 'rgb'
def analysis(image_path):
    """ Takes in an RGB image of wells
    and computes instance masks for filled wells.
    Then computes mean and std lab color for each
    well.

    out is rows x columns x LAB
    """
    model = get_trained_unet(model_name='unet_vgg11_upsampling',
                             path='./weights/',
                             params='unet_params_104_lr_1e-05_h2o2.pt')
    
    out = get_inference(model, image_path)
    
    img, mask = out
    
    orig_mask, squeezed_mask = squeeze_mask(img, mask, 0.4)
    
    instance_mask, col_list = get_instance_masks(img, squeezed_mask)
    
    im_sorted, bb_sorted = sort_instance_masks(instance_mask, col_list, img)

    color_list, err_list = get_colors_from_patches(img, im_sorted, 'lab')

    return np.array(color_list), np.array(err_list)

    

In [97]:
# mostly operation cell for chemists
import os
import glob
from pathlib import Path

# create a list to store all the images by giving the file path 
image_path = Path(r'C:/Users/scrc112/Desktop/work/yuan/20240718/first_trial')

image_path = glob.glob("*.jpg")

images =[]

for imagepath in glob.glob("*.jpg"):
    img=cv2.imread(str(imagepath)) 
    color = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    plt.imshow(color)
    images.append(img)

    
for img in image_path:
    mean, std = analysis(img)   # mean(x, y), x = column, y = row,std is similar to this 

   ### more code
    
    #print(mean[7,1:4])  
    #print(mean[3,0:4])
    #a, b, c, d = mean[5,0], mean[5,2], mean[5,4], mean[5,6]
    #new_array = (a+b+c+d)/4
    #print(new_array[2])
  
    ### save results, e.g. CSV