# Mask R-CNN (Image Segmentation) Inference
We'll start by importing the required packages and defining some essential paths (change the paths to wherever those files are on your system).

In [None]:
# external packages
import skimage.io
import matplotlib.pyplot as plt
import numpy as np

# mask r-cnn components
from mrcnn.config import Config
import mrcnn.model as modellib
from mrcnn import utils
from mrcnn import visualize

# paths
DATASET_PATH = '../data/wad'
LOGS_PATH = './logs'

MRCNN_MODEL_WEIGHTS = './image_seg/mask_rcnn_wad.h5'

## Configuration & Dataset

In [None]:
from train.wad_dataset import WadConfig, WadDataset


# WAD Configuration
class WADInferenceConfig(WadConfig):
    BATCH_SIZE = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0


wad_config = WADInferenceConfig()

# Dataset
dataset = WadDataset()
dataset.load_data(DATASET_PATH, 'train', use_pickle=False)
dataset.prepare()

print('Number of Images in Dataset: {}'.format(dataset.num_images))

## Load Model

In [None]:
# Create model in inference mode
model = modellib.MaskRCNN(mode="inference", config=wad_config, model_dir=LOGS_PATH)

model.load_weights(MRCNN_MODEL_WEIGHTS, by_name=True)

## Inference

In [None]:
from train.wad_dataset import index_to_class_names

# Load a random image from the dataset
image_id = np.random.randint(0, dataset.num_images)
image = dataset.load_image(image_id)
gt_masks = dataset.load_mask(image_id)

print(f'Running detection on image {image_id} (filename: {dataset.image_info[image_id]["path"]})')

results = model.detect([image], verbose=1)[0]

visualize.display_instances(image, results['rois'], results['masks'], results['class_ids'], 
                            index_to_class_names, results['scores'])

## Statistics

In [None]:
# Calculate Average Precision (AP) for each image
APs = []
for image_id in range(dataset.num_images):
    # Load image and ground truth data
    image, image_meta, gt_class_id, gt_bbox, gt_mask =\
        modellib.load_image_gt(dataset, wad_config,
                               image_id, use_mini_mask=False)
    molded_images = np.expand_dims(modellib.mold_image(image, wad_config), 0)
    
    # Run object detection
    results = model.detect([image], verbose=0)
    r = results[0]
    
    # Compute AP
    AP, precisions, recalls, overlaps =\
        utils.compute_ap(gt_bbox, gt_class_id, gt_mask,
                         r["rois"], r["class_ids"], r["scores"], r['masks'])
    APs.append(AP)
    
    print('Image: {0:3d}\tAP: {1:1.4f}'.format(image_id, AP))
    
# Calculate mAP score for dataset
print("mAP: ", np.mean(APs))