In [None]:
import os
import skimage.io
import matplotlib.pyplot as plt
import numpy as np

%matplotlib inline

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Root directory of the dataset
DATA_DIR = os.path.join(ROOT_DIR, "dataset/wad")

# Directory to save logs and trained model
LOGS_DIR = os.path.join(ROOT_DIR, "logs")

# Import Mask RCNN
from mrcnn.config import Config
import mrcnn.model as modellib
from mrcnn import visualize

## Configuration && Dataset

In [None]:
from project import wad_data

# CoCo Configuration
class CocoInferenceConfig(Config):
    """Configuration for training on MS COCO.
    Derives from the base Config class and overrides values specific
    to the COCO dataset.
    """
    # Give the configuration a recognizable name
    NAME = "coco"
    
    # Uncomment to train on 8 GPUs (default is 1)
    # GPU_COUNT = 8

    # Number of classes (including background)
    NUM_CLASSES = 1 + 80  # COCO has 80 classes
    
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    DETECTION_MIN_CONFIDENCE = 0
    
coco_config = CocoInferenceConfig()

# WAD Configuration
class WADInferenceConfig(wad_data.WADConfig):
    BATCH_SIZE = 1
    IMAGES_PER_GPU = 1
    
wad_config = WADInferenceConfig()

# Dataset
dataset = wad_data.WADDataset()
dataset.load_data(DATA_DIR, "train")
dataset.prepare()

print('Number of Images in Dataset: {}'.format(dataset.num_images))

## Load Model

In [None]:
model_weights_path = os.path.join(LOGS_DIR, 'wad20180620T1548/mask_rcnn_wad_0001.h5')

# Create model in inference mode
model = modellib.MaskRCNN(mode="inference", config=wad_config,
                          model_dir=LOGS_DIR)

model.load_weights(model_weights_path, by_name=True)

## Inference

In [None]:
# Load a random image from the dataset
image_id = np.random.randint(0, dataset.num_images)
image = dataset.load_image(image_id)
gt_masks = dataset.load_mask(image_id)

results = model.detect([image], verbose=1)[0]

visualize.display_instances(image, results['rois'], results['masks'], results['class_ids'], 
                            wad_data.index_to_classes, results['scores'])