In [None]:
!pip install ../input/detectron-05/whls/pycocotools-2.0.2/dist/pycocotools-2.0.2.tar --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/fvcore-0.1.5.post20211019/fvcore-0.1.5.post20211019 --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/antlr4-python3-runtime-4.8/antlr4-python3-runtime-4.8 --no-index --find-links ../input/detectron-05/whls 
!pip install ../input/detectron-05/whls/detectron2-0.5/detectron2 --no-index --find-links ../input/detectron-05/whls 

In [None]:
import detectron2
import torch
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.config import CfgNode as CN
from detectron2.modeling import build_model, DatasetMapperTTA, detector_postprocess
from detectron2.modeling import GeneralizedRCNNWithTTA as _GeneralizedRCNNWithTTA
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import (
    RandomFlip,
    ResizeShortestEdge,
    ResizeTransform,
    apply_augmentations,
)
from detectron2.structures import ImageList, Instances, Boxes
from detectron2.checkpoint import DetectionCheckpointer
from fvcore.transforms import HFlipTransform, NoOpTransform
from PIL import Image
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from fastcore.all import *
from torch import nn
import copy
from itertools import count
import warnings
warnings.simplefilter('ignore')
import gc

In [None]:
import sys
sys.path.append('../input/ensemble-boxes/')
sys.path.append('../input/sartorius-utils/')
sys.path.append('../input/yolov5-v6/yolov5-master/')
from ensemble_boxes import weighted_boxes_fusion, weighted_masks_fusion
from postprocess import detector_postprocess
from models.common import DetectMultiBackend
from utils.general import non_max_suppression

In [None]:
dataDir=Path('../input/sartorius-cell-instance-segmentation')

In [None]:
# From https://www.kaggle.com/stainsby/fast-tested-rle
def rle_decode(mask_rle, shape=(520, 704)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (height,width) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)  # Needed to align to RLE direction

def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

#predict one image.
def one_predictor(im, predictor, yolo_model, pred_class):
    yolo_boxes = yolo_get_boxes(im, yolo_model, pred_class)
    _img = torch.from_numpy(np.ascontiguousarray(im.transpose(2, 0, 1)))
    info_dict = [{'image': _img, 'height': 520, 'width': 704}]
    predictor.eval()
    with torch.no_grad():
        pred = predictor(info_dict, yolo_boxes)[0]
    scores = pred['instances'].scores.cpu().numpy()
    pred_masks = pred['instances'].pred_masks.cpu().numpy()
    pred_boxes = pred['instances'].pred_boxes.tensor.cpu().numpy()
    
    pred_masks = pred_masks[scores > WMF_THRESHOLDS[pred_class]]
    pred_boxes = pred_boxes[scores > WMF_THRESHOLDS[pred_class]]
    scores = scores[scores > WMF_THRESHOLDS[pred_class]]
    if pred_class == 1:
        pred_masks = pred_masks >= 0.5
    return pred_masks, scores, pred_boxes

#merge predictions of each models with WMF
def merge_predict(im, predictors, yolo_models, pred_class):
    
    for i, predictor in enumerate(predictors):
        if i == 0:
            pred_masks, scores, pred_boxes = one_predictor(im, predictor, yolo_models[i%5], pred_class)
            models = [0] * len(scores)
        else:
            pm, sc, pb = one_predictor(im, predictor, yolo_models[i%5], pred_class)
            models += [i] * len(sc)
            pred_masks = np.vstack([pred_masks, pm])
            scores = np.hstack([scores, sc])
            pred_boxes = np.vstack([pred_boxes, pb])

    sort_idx = np.argsort(-scores)
    pred_masks = pred_masks[sort_idx]
    pred_boxes = pred_boxes[sort_idx]
    models = np.array(models)[sort_idx]
    scores = scores[sort_idx]
    
    # ensemble models using WMF. WMF is WBF applied to mask ensemble.
    pred_masks, scores, pred_boxes = weighted_masks_fusion(pred_masks, pred_boxes, scores, models,
                                               skip_mask_thr=0,
                                               conf_type='model_weight',
                                               soft_weight=np.sum(MODEL_WEIGHTS[pred_class]),
                                               num_models=len(predictors),
                                               model_weights = MODEL_WEIGHTS[pred_class])
    #rint(np.max(pred_masks))
    pred_masks = np.array(pred_masks)
    pred_boxes = np.array(pred_boxes)
    scores = np.array(scores)
    
    sort_idx = np.argsort(-scores)
    pred_masks = pred_masks[sort_idx]
    pred_boxes = pred_boxes[sort_idx]
    scores = scores[sort_idx]
    
    #pred_masks = pred_masks >= MASK_THRESHOLDS[pred_class]
    
    return pred_masks, scores, pred_boxes

# If the image is shsy5y, split the image and predict each images.
def second_predict(im, predictors_shsy5y, yolo_models):
    #print(im.shape)
    scaled_im = cv2.resize(im, (im.shape[1]*2, im.shape[0]*2))
    mask_scales = []
    box_scales = []
    scores_all = []
    
    for i in range(3):
        del_mask = np.zeros((520, 704), dtype=np.uint8)
        if i == 0 or i == 1:
            del_mask[-5:, :] = 1
        if i == 1 or i == 2:
            del_mask[:5, :] = 1
        for j in range(3):
            if j == 0 or j == 1:
                del_mask[:, -5:] = 1
            if j == 1 or j == 2:
                del_mask[:, :5] = 1
            img = scaled_im[260*i:260*(i+2), 352*j:352*(j+2)]
            pred_masks, scores, pred_boxes = merge_predict(img, predictors_shsy5y, yolo_models, 0)
            pred_masks = pred_masks[scores >= THRESHOLDS[0]]
            pred_boxes = pred_boxes[scores >= THRESHOLDS[0]]
            scores = scores[scores >= THRESHOLDS[0]]

            for mask, score, box in zip(pred_masks, scores, pred_boxes):
                #print(mask.shape)  
                if np.sum(mask * del_mask) > 0:
                    continue
                _mask_scale = np.zeros((img.shape[0] * 2, img.shape[1] * 2))
                _mask_scale[260*i:260*(i+2), 352*j:352*(j+2)] = mask
                _mask_scale = cv2.resize(_mask_scale, (img.shape[1], img.shape[0]))
                
                box[[0, 2]] = (box[[0, 2]] + 352 * j) // 2
                box[[1, 3]] = (box[[1, 3]] + 260 * i) // 2
                mask_scales.append(_mask_scale)
                box_scales.append(box)
                scores_all.append(score)
                
    #merge predicts with WMF
    pred_masks, scores, pred_boxes = weighted_masks_fusion(np.array(mask_scales), np.array(box_scales),
                                                           np.array(scores_all), np.zeros(len(scores_all)),
                                                           skip_mask_thr=0,
                                                           conf_type='max')
    pred_masks = np.array(pred_masks)
    scores = np.array(scores)
    
    sort_idx = np.argsort(-scores)
    pred_masks = pred_masks[sort_idx]
    scores = scores[sort_idx]
                    
    return pred_masks, scores

def get_masks(fn, predictors, yolo_models, pred_class):
    
    im = cv2.imread(str(fn))
    
    if pred_class == 0:
        pred_masks, score = second_predict(im, predictors, yolo_models)
        pred_class = 0
        #print(0)  
    elif pred_class == 1:
        pred_class = 1
        pred_masks, score, _ = merge_predict(im, predictors, yolo_models, pred_class)
        #print(1)
    else:
        pred_class = 2
        pred_masks, score, _ = merge_predict(im, predictors, yolo_models, pred_class)
        #print(2)
    pred_masks = pred_masks >= MASK_THRESHOLDS[pred_class]
    pred_masks = pred_masks[score >= THRESHOLDS[pred_class]]
    score = score[score >= THRESHOLDS[pred_class]]
    res = []
    scores = []
    used = np.zeros(im.shape[:2], dtype=int)
    
    #remove duplication
    for _mask, s in zip(pred_masks, score):
        #print(mask.shape)
        mask = _mask * (1-used)
        if mask.sum() >= MIN_PIXELS[pred_class] and np.sum(mask) / np.sum(_mask) >= DUPL_THRESHOLDS[pred_class]: # skip predictions with small area
            used += mask
            res.append(mask)
            scores.append(s)
        
    res_last = []
    for r in res:
        res_last.append(rle_encode(r))
        
    return res_last

#classify a image
def classify_image(fn, predictor):
    im = cv2.imread(str(fn))
    pred = predictor(im)
    pred_class = int(torch.mode(pred['instances'].pred_classes)[0].cpu())
    return pred_class

#get boxes predicted by yolov5x
def yolo_get_boxes(im, model, pred_class, img_sz=640, augment=[True, False, True], 
                   weight=[0.95, 0.9, 0.85]):
    
    with torch.no_grad():
        model.warmup(imgsz=(1, 3, (img_sz, img_sz)), half=False)
        img = cv2.resize(im, (img_sz, img_sz))
        img = img.transpose((2, 0, 1))[np.newaxis]
        img = torch.from_numpy(img).to(device)
        #print(im)
        img = img / 255
        pred = model(img, augment=augment[pred_class])
        dets = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, max_det=2000)
    bboxes = dets[0].cpu().numpy()
    bboxes[:, :4] = bboxes[:, :4] / img_sz
    bboxes[:, 4] = bboxes[:, 4] * weight[pred_class]

    return bboxes

class GeneralizedRCNNWithTTA(_GeneralizedRCNNWithTTA):
    def __call__(self, batched_inputs, yolo_boxes):
        """
        Same input/output format as :meth:`GeneralizedRCNN.forward`
        """

        def _maybe_read_image(dataset_dict):
            ret = copy.copy(dataset_dict)
            if "image" not in ret:
                image = read_image(ret.pop("file_name"), self.model.input_format)
                image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1)))  # CHW
                ret["image"] = image
            if "height" not in ret and "width" not in ret:
                ret["height"] = image.shape[1]
                ret["width"] = image.shape[2]
            return ret

        return [self._inference_one_image(_maybe_read_image(x), yolo_boxes) for x in batched_inputs]
    
    def _inference_one_image(self, input, yolo_boxes):
        """
        Args:
            input (dict): one dataset dict with "image" field being a CHW tensor
        Returns:
            dict: one output dict
        """
        orig_shape = (input["height"], input["width"])
        augmented_inputs, tfms = self._get_augmented_inputs(input)
        self.device = self.model.device
        # Detect boxes from all augmented versions
        with self._turn_off_roi_heads(["mask_on", "keypoint_on"]):
            # temporarily disable roi heads
            all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms)
        # merge all detected boxes to obtain final predictions for boxes
        all_boxes.append(yolo_boxes[:, :4])
        all_scores.append(yolo_boxes[:, 4].tolist())
        all_classes.append(yolo_boxes[:, 5].tolist())
        merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape)

        if self.cfg.MODEL.MASK_ON:
            # Use the detected boxes to obtain masks
            augmented_instances = self._rescale_detected_boxes(
                augmented_inputs, merged_instances, tfms
            )
            # run forward on the detected boxes
            outputs = self._batch_inference(augmented_inputs, augmented_instances)
            # Delete now useless variables to avoid being out of memory
            del augmented_inputs, augmented_instances
            # average the predictions
            merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms)
            merged_instances = detector_postprocess(merged_instances, *orig_shape, mask_threshold=-1)
            return {"instances": merged_instances}
        else:
            return {"instances": merged_instances}
        
    def _get_augmented_boxes(self, augmented_inputs, tfms):
        # 1: forward with all augmented images
        outputs = self._batch_inference(augmented_inputs)
        # 2: union the results
        all_boxes = []
        all_scores = []
        all_classes = []
        for output, tfm in zip(outputs, tfms):
            # Need to inverse the transforms on boxes, to obtain results on original image
            pred_boxes = output.pred_boxes.tensor
            pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy())
            pred_boxes[:, [0, 2]] = pred_boxes[:, [0, 2]] / 704
            pred_boxes[:, [1, 3]] = pred_boxes[:, [1, 3]] / 520
            all_boxes.append(pred_boxes)

            all_scores.append(output.scores.tolist())
            all_classes.append(output.pred_classes.tolist())
        return all_boxes, all_scores, all_classes    
    
    # merge detections with WBF
    def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw):
        #print(all_boxes)
        boxes, scores, labels = weighted_boxes_fusion(all_boxes,
                                                     all_scores,
                                                     all_classes,
                                                     iou_thr=self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST,
                                                     skip_box_thr=1e-8,
                                                     conf_type='max')
        #print(scores)
        keepk = self.cfg.TEST.DETECTIONS_PER_IMAGE
        boxes = boxes[:keepk, :]
        boxes[:, [0, 2]] = boxes[:, [0, 2]] * 704
        boxes[:, [1, 3]] = boxes[:, [1, 3]] * 520
        scores = scores[:keepk]
        labels = labels[:keepk]
        result = Instances(shape_hw)
        result.pred_boxes = Boxes(torch.from_numpy(boxes).to(self.device))
        result.scores = torch.from_numpy(scores).to(self.device)
        result.pred_classes = torch.from_numpy(labels).to(self.device)
        #print(result)
        return result

In [None]:
test_names = (dataDir/'test').ls()

### Initiate a Predictor from our trained model

In [None]:
THRESHOLDS = [.65, .6, .8]
WMF_THRESHOLDS = [.65, .6, .55]
MIN_PIXELS = [75, 150, 75]
MASK_THRESHOLDS = [.5, .5, .5]
DUPL_THRESHOLDS = [.7, .7, .7]
MODEL_WEIGHTS = [[1] * 6, [2] * 5 + [1] * 5, [2] * 5 + [1] * 5]
#MODEL_WEIGHTS = [[1] * 5, [1] * 10, [1] * 5]
device = 'cuda'

In [None]:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"))
cfg.INPUT.MASK_FORMAT='bitmask'
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 3 
cfg.MODEL.WEIGHTS = os.path.join('../input/sartorius-pretrained-model', "best_detectron2_R101_FPN.pth")  
cfg.TEST.DETECTIONS_PER_IMAGE = 1000
predictor = DefaultPredictor(cfg)

In [None]:
def get_shsy5y_model():
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
    cfg.INPUT.MASK_FORMAT='bitmask'
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.7
    cfg.TEST.DETECTIONS_PER_IMAGE = 1000
    cfg.TEST.AUG = CN({"ENABLED": True})
    cfg.TEST.AUG.MIN_SIZES = (640, 750, 860)
    cfg.TEST.AUG.MAX_SIZE = 1440
    cfg.TEST.AUG.FLIP = True

    cfg.MODEL.WEIGHTS = '../input/sartorius-model-shsy5y-pseudo/best_detectron2_X152_scaled_shsy5y_pseudo_fold1.pth'
    model_shsy5y1 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y1).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y1 = GeneralizedRCNNWithTTA(cfg, model_shsy5y1)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-shsy5y-pseudo/best_detectron2_X152_scaled_shsy5y_pseudo_fold2.pth'
    model_shsy5y2 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y2).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y2 = GeneralizedRCNNWithTTA(cfg, model_shsy5y2)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-shsy5y-pseudo/best_detectron2_X152_scaled_shsy5y_pseudo_fold3.pth'
    model_shsy5y3 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y3).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y3 = GeneralizedRCNNWithTTA(cfg, model_shsy5y3)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-shsy5y-pseudo/best_detectron2_X152_scaled_shsy5y_pseudo_fold4.pth'
    model_shsy5y4 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y4).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y4 = GeneralizedRCNNWithTTA(cfg, model_shsy5y4)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-shsy5y-pseudo/best_detectron2_X152_scaled_shsy5y_pseudo.pth'
    model_shsy5y5 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y5).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y5 = GeneralizedRCNNWithTTA(cfg, model_shsy5y5)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-1class-model/best_detectron2_X152_1class.pth'
    cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.33, 0.5, 1.0, 2.0, 3.0]]
    model_shsy5y6 = build_model(cfg)
    DetectionCheckpointer(model_shsy5y6).load(cfg.MODEL.WEIGHTS)
    predictor_shsy5y6 = GeneralizedRCNNWithTTA(cfg, model_shsy5y6)

    predictors_shsy5y = [predictor_shsy5y1, predictor_shsy5y2, predictor_shsy5y3, predictor_shsy5y4, predictor_shsy5y5, predictor_shsy5y6]
    
    model_shsy5y0 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_scaled_shsy5y.pt', device=device)
    model_shsy5y1 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_scaled_shsy5y_fold1.pt', device=device)
    model_shsy5y2 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_scaled_shsy5y_fold2.pt', device=device)
    model_shsy5y3 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_scaled_shsy5y_fold3.pt', device=device)
    model_shsy5y4 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_scaled_shsy5y_fold4.pt', device=device)
    yolo_models_shsy5y = [model_shsy5y1, model_shsy5y2, model_shsy5y3, model_shsy5y4, model_shsy5y0]
    
    return predictors_shsy5y, yolo_models_shsy5y

In [None]:
def get_astro_model():
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
    cfg.INPUT.MASK_FORMAT='bitmask'
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.7
    cfg.TEST.DETECTIONS_PER_IMAGE = 500
    cfg.TEST.AUG = CN({"ENABLED": True})
    cfg.TEST.AUG.MIN_SIZES = (640, 750, 860)
    cfg.TEST.AUG.MAX_SIZE = 1440
    cfg.TEST.AUG.FLIP = True

    #freeze1
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-pseudo/best_detectron2_X152_astro_step_freeze1_pseudo_fold1.pth'
    model_astro1 = build_model(cfg)
    DetectionCheckpointer(model_astro1).load(cfg.MODEL.WEIGHTS)
    predictor_astro1 = GeneralizedRCNNWithTTA(cfg, model_astro1)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-pseudo/best_detectron2_X152_astro_step_freeze1_pseudo_fold2.pth'
    model_astro2 = build_model(cfg)
    DetectionCheckpointer(model_astro2).load(cfg.MODEL.WEIGHTS)
    predictor_astro2 = GeneralizedRCNNWithTTA(cfg, model_astro2)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-pseudo/best_detectron2_X152_astro_step_freeze1_pseudo_fold3.pth'
    model_astro3 = build_model(cfg)
    DetectionCheckpointer(model_astro3).load(cfg.MODEL.WEIGHTS)
    predictor_astro3 = GeneralizedRCNNWithTTA(cfg, model_astro3)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-pseudo/best_detectron2_X152_astro_pseudo_fold4.pth'
    model_astro4 = build_model(cfg)
    DetectionCheckpointer(model_astro4).load(cfg.MODEL.WEIGHTS)
    predictor_astro4 = GeneralizedRCNNWithTTA(cfg, model_astro4)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-pseudo/best_detectron2_X152_astro_cereb_step_freeze1_pseudo.pth'
    model_astro5 = build_model(cfg)
    DetectionCheckpointer(model_astro5).load(cfg.MODEL.WEIGHTS)
    predictor_astro5 = GeneralizedRCNNWithTTA(cfg, model_astro5)
    
    #freeze2
    cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.33, 0.5, 1.0, 2.0, 3.0]]
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-freeze2-pseudo/best_detectron2_X152_astro_step_freeze2_pseudo_fold1.pth'
    model_astro6 = build_model(cfg)
    DetectionCheckpointer(model_astro6).load(cfg.MODEL.WEIGHTS)
    predictor_astro6 = GeneralizedRCNNWithTTA(cfg, model_astro6)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-freeze2-pseudo/best_detectron2_X152_astro_step_freeze2_pseudo_fold2.pth'
    model_astro7 = build_model(cfg)
    DetectionCheckpointer(model_astro7).load(cfg.MODEL.WEIGHTS)
    predictor_astro7 = GeneralizedRCNNWithTTA(cfg, model_astro7)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-freeze2-pseudo/best_detectron2_X152_astro_step_freeze2_pseudo_fold3.pth'
    model_astro8 = build_model(cfg)
    DetectionCheckpointer(model_astro8).load(cfg.MODEL.WEIGHTS)
    predictor_astro8 = GeneralizedRCNNWithTTA(cfg, model_astro8)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-freeze2-pseudo/best_detectron2_X152_astro_step_freeze2_pseudo_fold4.pth'
    model_astro9 = build_model(cfg)
    DetectionCheckpointer(model_astro9).load(cfg.MODEL.WEIGHTS)
    predictor_astro9 = GeneralizedRCNNWithTTA(cfg, model_astro9)
    
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-astro-freeze2-pseudo/best_detectron2_X152_astro_step_freeze2_pseudo.pth'
    model_astro10 = build_model(cfg)
    DetectionCheckpointer(model_astro10).load(cfg.MODEL.WEIGHTS)
    predictor_astro10 = GeneralizedRCNNWithTTA(cfg, model_astro10)
    
    predictors_astro = [predictor_astro1, predictor_astro2, predictor_astro3, predictor_astro4, predictor_astro5,
                       predictor_astro6, predictor_astro7, predictor_astro8, predictor_astro9, predictor_astro10]
    
    model_astro0 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_astro.pt', device=device)
    model_astro1 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_astro_fold1.pt', device=device)
    model_astro2 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_astro_fold2.pt', device=device)
    model_astro3 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_astro_fold3.pt', device=device)
    model_astro4 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_astro_fold4.pt', device=device)
    yolo_models_astro = [model_astro1, model_astro2, model_astro3, model_astro4, model_astro0]
    
    return predictors_astro, yolo_models_astro

In [None]:
def get_cort_model():
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
    cfg.INPUT.MASK_FORMAT='bitmask'
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
    cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.7
    cfg.MODEL.ANCHOR_GENERATOR.SIZES = [[16, 32, 64, 128]]
    cfg.TEST.DETECTIONS_PER_IMAGE = 650
    cfg.TEST.AUG = CN({"ENABLED": True})
    cfg.TEST.AUG.MIN_SIZES = (640, 750, 860)
    cfg.TEST.AUG.MAX_SIZE = 1440
    cfg.TEST.AUG.FLIP = True

    #freeze2 models
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo-freeze2/best_detectron2_X152_cort_step_freeze2_pseudo_fold1.pth'
    model_cort1 = build_model(cfg)
    DetectionCheckpointer(model_cort1).load(cfg.MODEL.WEIGHTS)
    predictor_cort1 = GeneralizedRCNNWithTTA(cfg, model_cort1)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo-freeze2/best_detectron2_X152_cort_step_freeze2_pseudo_fold2.pth'
    model_cort2 = build_model(cfg)
    DetectionCheckpointer(model_cort2).load(cfg.MODEL.WEIGHTS)
    predictor_cort2 = GeneralizedRCNNWithTTA(cfg, model_cort2)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo-freeze2/best_detectron2_X152_cort_step_freeze2_pseudo_fold3.pth'  
    model_cort3 = build_model(cfg)
    DetectionCheckpointer(model_cort3).load(cfg.MODEL.WEIGHTS)
    predictor_cort3 = GeneralizedRCNNWithTTA(cfg, model_cort3)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo-freeze2/best_detectron2_X152_cort_step_freeze2_pseudo_fold4.pth'
    model_cort4 = build_model(cfg)
    DetectionCheckpointer(model_cort4).load(cfg.MODEL.WEIGHTS)
    predictor_cort4 = GeneralizedRCNNWithTTA(cfg, model_cort4)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo-freeze2/best_detectron2_X152_cort_step_freeze2_pseudo.pth'
    model_cort5 = build_model(cfg)
    DetectionCheckpointer(model_cort5).load(cfg.MODEL.WEIGHTS)
    predictor_cort5 = GeneralizedRCNNWithTTA(cfg, model_cort5)
    
    #freeze1 models
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo/best_detectron2_X152_cort_step_freeze1_pseudo_fold1_2.pth'
    model_cort6 = build_model(cfg)
    DetectionCheckpointer(model_cort6).load(cfg.MODEL.WEIGHTS)
    predictor_cort6 = GeneralizedRCNNWithTTA(cfg, model_cort6)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo/best_detectron2_X152_cort_step_freeze1_pseudo_fold2_2.pth'  
    model_cort7 = build_model(cfg)
    DetectionCheckpointer(model_cort7).load(cfg.MODEL.WEIGHTS)
    predictor_cort7 = GeneralizedRCNNWithTTA(cfg, model_cort7)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo/best_detectron2_X152_cort_step_freeze1_pseudo_fold3_2.pth'
    model_cort8 = build_model(cfg)
    DetectionCheckpointer(model_cort8).load(cfg.MODEL.WEIGHTS)
    predictor_cort8 = GeneralizedRCNNWithTTA(cfg, model_cort8)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo/best_detectron2_X152_cort_step_freeze1_pseudo_fold4_2.pth'
    model_cort9 = build_model(cfg)
    DetectionCheckpointer(model_cort9).load(cfg.MODEL.WEIGHTS)
    predictor_cort9 = GeneralizedRCNNWithTTA(cfg, model_cort9)
    cfg.MODEL.WEIGHTS = '../input/sartorius-model-cort-pseudo/best_detectron2_X152_cort_step_freeze1_pseudo2.pth'
    model_cort10 = build_model(cfg)
    DetectionCheckpointer(model_cort10).load(cfg.MODEL.WEIGHTS)
    predictor_cort10 = GeneralizedRCNNWithTTA(cfg, model_cort10)
    
    predictors_cort = [predictor_cort1, predictor_cort2, predictor_cort3, predictor_cort4, predictor_cort5,
                      predictor_cort6, predictor_cort7, predictor_cort8, predictor_cort9, predictor_cort10]
    
    model_cort0 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_cort2.pt', device=device)
    model_cort1 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_cort_fold1.pt', device=device)
    model_cort2 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_cort_fold2.pt', device=device)
    model_cort3 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_cort_fold3.pt', device=device)
    model_cort4 = DetectMultiBackend(weights='../input/sartorius-yolo-models/best_cort_fold4.pt', device=device)
    yolo_models_cort = [model_cort1, model_cort2, model_cort3, model_cort4, model_cort0]
    
    return predictors_cort, yolo_models_cort

In [None]:
def get_predictors(cls):
    if cls == 0:
        return get_shsy5y_model()
    elif cls == 1:
        return get_astro_model()
    else:
        return get_cort_model()

In [None]:
img_classes = {0:[], 1:[], 2:[]}
img_ids = {0:[], 1:[], 2:[]}
for i, fn in enumerate(test_names):
    cls = classify_image(fn, predictor)
    img_classes[cls].append(fn)
    #img_ids[cls].append(val_id[i])
del predictor
torch.cuda.empty_cache()
#gc.collect()

In [None]:
ids = []
masks = []
for cls in range(3):
    #cls = 2
    predictors, yolo_models = get_predictors(cls)
    filenames = img_classes[cls]
    #gc.collect()
    for i in range(len(filenames)):
        encoded_masks = get_masks(filenames[i], predictors, yolo_models, cls)
        for enc in encoded_masks:
            ids.append(filenames[i].stem)
            masks.append(enc)
        print(cls, i)
        torch.cuda.empty_cache()
        #gc.collect()
    del predictors, yolo_models
    torch.cuda.empty_cache()
    #gc.collect()

In [None]:
pd.DataFrame({'id':ids, 'predicted':masks}).to_csv('submission.csv', index=False)
pd.read_csv('submission.csv').head()