In [1]:
import json

import pickle
import pycocotools.mask as RLE
import pathlib

import numpy as np

ampis_root = pathlib.Path('../src/')
assert ampis_root.is_dir()

import explore_data_powder

import sys
sys.path.append(str(ampis_root))
from ampis import analyze
from ampis.structures import instance_set
from ampis.visualize import quick_visualize_iset

In [2]:
dataset_name_ = 'particle'

gt_root = pathlib.Path('../data/raw/via_2.0.8/via_powder_particle_masks.json')
pred_root = pathlib.Path('../data/interim/powder_results/particle_predictions_outputs_compressed.pickle')

assert gt_root.is_file()
assert pred_root.is_file()


In [3]:
# predicted data loaded first- organized by training followed by validation
with open(pred_root, 'rb') as f:
    pred_data = pickle.load(f)

gt_ddicts = explore_data_powder.get_data_dicts(gt_root)
    


In [4]:
pred_instances = [analyze.instance_set().read_from_model_out(
    x,True) for x in pred_data.values()]
    
    
gt_instances = [analyze.instance_set().read_from_ddict(x, True)
              for x in explore_data_powder.get_data_dicts(gt_root)]

metadata = explore_data_powder.get_metadata(gt_root)

pred_instances, gt_instances = analyze.align_instance_sets(pred_instances, gt_instances)

In [5]:
gt = gt_instances[-1]
pred = pred_instances[-1]

In [6]:
from ampis.structures import masks_to_bitmask_array

In [7]:
gt.instances.masks

PolygonMasks(num_instances=238)

In [8]:
gt_masks = masks_to_bitmask_array(gt.instances.masks, gt.instances.image_size)

In [9]:
pred_masks = masks_to_bitmask_array(pred.instances.masks)

In [26]:
def size_thresh (x) : return x.sum((1,2)) > 100

In [29]:
gt_masks_filt = gt_masks[size_thresh(gt_masks)]
pred_masks_filt = pred_masks[size_thresh(pred_masks)]

In [13]:
def brute_force_mask_match(gt, pred, thresh=0.5):
    n_pred = pred.shape[0]
    tp = []
    fn = []
    IOU = []
    pred_matched = np.zeros(n_pred, np.bool)
    
    for gt_idx, gtMask in enumerate(gt):
        
        IOU_max = 0
        IOU_Argmax = -1
        for pred_idx, predMask in enumerate(pred):
            iou_i = np.logical_and(gtMask, predMask).sum()/np.logical_or(gtMask,predMask).sum()
            if iou_i > IOU_max:
                IOU_max = iou_i
                IOU_Argmax = pred_idx
        
        if IOU_max > thresh:
            tp.append([gt_idx, IOU_Argmax])
            IOU.append(IOU_max)
            pred_matched[IOU_Argmax] = True
        else:
            fn.append(gt_idx)
    
    fp = np.asarray([x for x, y in enumerate(pred_matched) if not y], np.int)
    tp = np.asarray(tp, np.int)
    fn = np.asarray(fn, np.int)
    IOU = np.asarray(IOU)
    
    return {'tp': tp,
           'fp': fp,
           'fn':fn,
           'IOU': IOU}
    
    
            
            

In [30]:
results = brute_force_mask_match(gt_masks_filt, pred_masks_filt)

In [37]:
len(pred_masks_filt)

193

In [36]:
len(gt_masks_filt)

213

In [33]:
def match_stats(results):
    tp = len(results['tp'])
    fp = len(results['fp'])
    fn = len(results['fn'])
    print(tp, fp, fn)
    print('precision: {}'.format(tp/(tp+fp)))
    print('recall: {}'.format(tp/(tp+fn)))

In [34]:
match_stats(results)

178 15 35
precision: 0.9222797927461139
recall: 0.8356807511737089
