# Mask R-CNN
The original Mask R-CNN implenetation is too old to run on one system with modern CUDA drivers etc. (even some forks whcih purport to have updated it for Tensorflow 2.x are not working for me), so I opted to use the `torchvision` version. The goal here is to make it as easy as possible for someone to reproduce this work, and having to deal with the nightmare that is out-of-date tensorflow and CUDA is not tenable. 

This script adapts a couple key functions (like the data loader and remove_overlaps) from the original train_maskrcnn script from Cellpose. 

In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
plt.style.use('dark_background')
import matplotlib as mpl
%matplotlib inline
mpl.rcParams['figure.dpi'] = 300
import time, os, sys
from tifffile import imread

# os.environ['MKL_DISABLE_FAST_MM'] = '0'
os.environ['LRU_CACHE_CAPACITY'] = '1'

### Define data loader

In [None]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from cellpose import io, transforms
from omnipose.utils import format_labels
import skimage.io
from tifffile import imread
import omnipose

from pathlib import Path
def getname(path,suffix='_masks'):
    return os.path.splitext(Path(path).name)[0].replace(suffix,'')

class BacteriaDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        mask_filter = '_masks'
        img_filter = ''
        img_names = io.get_image_files(root,mask_filter,img_filter=img_filter,look_one_level_down=True)
        mask_names = io.get_label_files(img_names, mask_filter, img_filter=img_filter)
#         self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages"))))
#         self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks"))))
        self.imgs = sorted(img_names,key=getname)
        self.masks = sorted(mask_names,key=getname)

    def __getitem__(self, idx):
#         print(idx)
        # load images and masks
        img_path = os.path.join(self.root, self.imgs[idx])
        mask_path = os.path.join(self.root, self.masks[idx])
        img = skimage.io.imread(img_path)#.convert('RGB')
        img = np.stack([omnipose.utils.normalize99(img)*(2**8-1)]*3,axis=-1).astype(np.uint8)
        img = Image.fromarray(img) # must convert to uint8 for pil  
        mask = imread(mask_path)
        mask = (np.array(mask))
        
        # instances are encoded as different numbers
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the integer-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            if xmax!=xmin and ymax!=ymin:
                boxes.append([xmin, ymin, xmax, ymax])
            else:
                print('uh oh',idx,obj_ids[i])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.bool)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        

        if self.transforms is not None:
            img, target = self.transforms(img, target)
#         torch.cuda.empty_cache()
        return img, target

    def __len__(self):
        return len(self.imgs)

### Define model

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_instance_segmentation_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False,pretrained_backbone=True)
    
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)
    model.roi_heads.detections_per_img = 1000
    model.roi_heads.nms_thresh = 0.7
    return model

### Define training augmentations
Notably, only random flipping is implemented here, as I beleive was the case for the original Mask R-CNN tensorflow implementation. 

In [None]:
# The vision/references/detection/ folder needs to be available in the same directory as this notebook

from engine import train_one_epoch, evaluate
import utils
import transforms as T

def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

### Initialize training and test datasets

In [None]:
#error without lowering this
num_workers = 0

# use our dataset and defined transformations
traindir = '/home/kcutler/DataDrive/omnipose_all/phase/train_sorted'
testdir = '/home/kcutler/DataDrive/omnipose_all/phase/test_sorted'
# traindir = '/home/kcutler/DataDrive/omnipose_maskrcnn/train'
# testdir = '/home/kcutler/DataDrive/omnipose_maskrcnn/test'
dataset = BacteriaDataset(traindir, get_transform(train=True))
dataset_test = BacteriaDataset(testdir, get_transform(train=False))
# split the dataset in train and test set
# torch.manual_seed(1)
# indices = torch.randperm(len(dataset)).tolist()
# dataset = torch.utils.data.Subset(dataset, indices[:-50])
# dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=num_workers,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=num_workers,
    collate_fn=utils.collate_fn)

In [None]:
# check to be sure it is correct
im,t = dataset[52]
im = dataset.__getitem__(2)[0]
plt.imshow(im[0])
plt.axis('off')

### Define model and training parameters

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2

# get the model using our helper function
model = get_instance_segmentation_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

### Train
Using recommended/default parameters. 

In [None]:
clean = 0
if clean:
    num_epochs = 200
    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=20)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
    #     evaluate(model, data_loader_test, device=device)

### Save model

In [None]:
modeldir = '/home/kcutler/DataDrive/maskrcnn/bacterialtrain200epochs_1000detections_per_img_v2'
if clean:
    torch.save(model, modeldir)

### Evaluate model
Test it our on a single image first

In [None]:
# device =  torch.device('cpu')
model = torch.load(modeldir).to(device)

# pick one image from the test set
img, _ = dataset_test[0]
# put the model in evaluation mode
model.eval()
# model.roi_heads.detections_per_img = 1000 #not needed 
model.roi_heads.nms_thresh = 1
# model.roi_heads.score_thresh
with torch.no_grad():
    prediction = model([img.to(device)])
    

### Batch process and reconstruct cell masks
Mask R-CNN predicts bounding boxes and cell probability within those bounding boxes, along with a confidence score form 0-1. The output is sorted from high to low scores. My first approach was basic approach to loop over all masks and append to a label matrix with incrementing labels. I optimized this a little bit by appending the highest scores last, thereby overwriting low-confidence labels with the higher-confidence ones, but there are a lot of issues with this. After much fiddling, I came up with the following mask reconstruction algorithm: 

1. Use hysteresis thresholding on cell probability to get a candidate mask
2. Check to see if it overlaps with any cells 50% or more and is above a minimum area; if so, just add to that existing mask.
3. Otherwise, set its pixels (any that don't overlap with existing masks, which are at higher confidence and should not be overwritten) to a new label value.

Comparing the raw summed mask output to my generated masks, I think this gives us pretty much all we can from the really poor Mask R-CNN output. Note that for some of the really large 2kx2k images, the RAM usage is obscene - above 40GB. In a previous iteration, I never noticed this because I was running on a machine with 128GB of RAM. This resource usage comes a little from the neural network prediction on CPU (not enough VRAM to do it on GPU for those images either) but mostly from the intermediate post-processing steps. 

In [None]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device =  torch.device('cpu')
model = torch.load(modeldir).to(device)

In [None]:
import fastremap 

savedir = '/home/kcutler/DataDrive/omnipose_all/bact_phase_comparison/maskrcnn'
io.check_dir(savedir)

nimg = len(dataset_test)
final_masks = [[] for i in range(nimg)]


In [None]:
from skimage import filters
from cellpose.metrics import _label_overlap

min_area = 100

# for i in range(nimg):
for i in range(56,nimg):

# for i in [0]:
    entry, path = dataset_test[i], dataset_test.masks[i]
    print(getname(path))
    img = entry[0]
    # put the model in evaluation mode
    model.eval()
    model.roi_heads.nms_thresh = 1
    # model.roi_heads.score_thresh
    with torch.no_grad():
        prediction = model([img.to(device)])
    
    labels = np.zeros(img.shape[-2:],dtype=np.uint32)

    scores = prediction[0]['scores'].detach().cpu().numpy() # outputs in descending order

    cutoff = np.percentile(scores,25)
    inds = np.where(scores>cutoff)[0]
    l = 1
    for j in inds:
        pred = np.array((prediction[0]['masks'][j, 0].cpu()))
        mask_threshold = .8
        m =  filters.apply_hysteresis_threshold(pred, mask_threshold-.25, mask_threshold)
        if not np.any(labels):
            labels[m] = l
        elif np.any(m):
            overlap = _label_overlap(labels,np.uint(m))
            match = np.argmax(overlap[:,1])
            area = np.sum(m)
            pix = np.logical_and(m,labels==0) #only add to areas where there are no cell pixels yet
            if match==0 or overlap[match,1]/area < 0.5 and area>min_area: 
                l+=1 # only increment if there is not significant overlap
                labels[pix] = l
            else:
                labels[pix] = match
    
    # del prediction 
    labels = fastremap.refit(labels)
    
    final_masks[i] = labels
    io.imsave(os.path.join(savedir,getname(path)+'_masks.tif'),labels)
    print(i,'{}% done'.format((i+1)/nimg*100))

In [None]:
import ncolor
plt.imshow(ncolor.label(labels,max_depth=20),interpolation='None')
# plt.imshow(labels)
np.max(labels)

In [None]:
overlapping_masks_test = np.stack([np.array((prediction[0]['masks'][j, 0].cpu())>0.9) for j in inds],axis=0)
print(len(overlapping_masks_test),overlapping_masks_test.shape)

medians_test = []
for mask in overlapping_masks_test:
   
    ypix, xpix = np.nonzero(mask)
    medians_test.append((np.array([ypix.mean(), xpix.mean()])))

labels_test = np.int32(remove_overlaps(overlapping_masks_test, overlapping_masks_test.sum(axis=0), np.array(medians_test)))


In [None]:
plt.imshow(np.stack([np.array((prediction[0]['masks'][j, 0].cpu())) for j in inds],axis=0).sum(axis=0))

In [None]:
# new algo: loop through all, generate hysteresis mask, compare overlap. If the overlap is high, add to existing label. 
# If overlap is low, make new label. 

from skimage import filters
from cellpose.metrics import _label_overlap
labels = np.zeros(img.shape[-2:],dtype=np.uint32)


# for j in range(0,10):
# for j in [28]:

cutoff = np.percentile(scores,50)
inds = np.where(scores>cutoff)[0]
min_area = 100
l = 1
for j in inds:
    pred = np.array((prediction[0]['masks'][j, 0].cpu()))
    mask_threshold = .8
    m =  filters.apply_hysteresis_threshold(pred, mask_threshold-.25, mask_threshold)
    if not np.any(labels):
        labels[m] = l
        print('adding first label')
    elif np.any(m):
        overlap = _label_overlap(labels,np.uint(m))
        match = np.argmax(overlap[:,1])
        area = np.sum(m)
        pix = np.logical_and(m,labels==0)
        if match==0 or overlap[match,1]/area < 0.5 and area>min_area: 
            l+=1 # only increment if there is not significant overlap
            labels[pix] = l
        else:
            labels[pix] = match
    # fig = plt.figure(figsize=(2,2))
    # plt.imshow(np.hstack((labels,m,pred)),interpolation='none')
    # plt.axis('off')
    # plt.show()

In [None]:
# plt.imshow(np.array((prediction[0]['masks'][1, 0].cpu())))
# plt.imshow(img.numpy().transpose(1,2,0))
# plt.imshow(ncolor.label(labels,max_depth=20))
plt.imshow(labels)

In [None]:
plt.imshow(img.numpy().transpose(1,2,0))


In [None]:
for mask in [final_masks[i] for i in range(59,100)]:
    if mask.ndim==2:
        plt.imshow(mask)
        plt.axis('off')
        plt.show()

In [None]:
final_masks[60]

In [None]:
prediction[0]