In [None]:
import os

import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"]  = "python"
import mrcnn.visualize
import mrcnn.config
import mrcnn.model
from datasets import load_dataset
import warnings
import cv2
import pathlib

In [None]:
WEIGHTS_PATH = pathlib.Path("weights")
WEIGHTS_PATH.mkdir(exist_ok=True)

WEIGHTS_CHECKPOINT_PATH = WEIGHTS_PATH.joinpath("checkpoints")
WEIGHTS_CHECKPOINT_PATH.mkdir(exist_ok=True)

COCO_WEIGHTS = WEIGHTS_PATH.joinpath("mask_rcnn_coco.h5")
#FROG_WEIGHTS = list(WEIGHTS_CHECKPOINT_PATH.glob("*.h5"))[-1]
#FROG_WEIGHTS = WEIGHTS_PATH.joinpath("mask_rcnn_frog_students.h5")


CLASS_NAME = "frog_stomach"
CLASS_NAMES = ["BG", CLASS_NAME]
FROG_IMAGES = "./frog_photos"
FROG_DATASET = "./pelophylax_lessonae" ""#"perara/pelophylax_lessonae"


In [None]:
from numpy import asarray
from PIL import ImageDraw, Image
import mrcnn.utils

class FrogStomachDataset(mrcnn.utils.Dataset):

    def load_dataset(self, images_dir: str, dataset_dir: str=None, is_train=True):
        # Adds information (image ID, image path, and annotation file path) about each image in a dictionary.
        self.add_class("dataset", 1, CLASS_NAME)

        images_dir_path = pathlib.Path(images_dir)
        dataset_split = "train" if is_train else "validation"
        ds = load_dataset(
            FROG_DATASET,
            name="default",
            splits=[dataset_split],
            image_dir=images_dir_path,
            dataset_dir=dataset_dir
        )

        for sample in ds[dataset_split]:
            image_id = sample["image_id"]
            image_path = sample["image_path"]
            annotation = sample["segmentation"]
            width = sample["width"]
            height = sample["height"]
            category_id = sample["category_id"]

            self.add_image('dataset', image_id=image_id, path=image_path, annotation=annotation, width=width, height=height, category_id=category_id)

    # Loads the binary masks for an image.
    def load_mask(self, image_id):

        image_info = self.image_info[image_id]
        annotations = image_info['annotation']
        width = image_info["width"]
        height = image_info["height"]
        category_id = image_info["category_id"]
        instance_masks = []
        class_ids = []


        mask = Image.new('1', (width, height))
        mask_draw = ImageDraw.ImageDraw(mask, '1')
        for segmentation in annotations:
            try:
                mask_draw.polygon(segmentation, fill=1)
            except:
                pass
            bool_array = np.array(mask) > 0
            instance_masks.append(bool_array)
            class_ids.append(category_id)

        mask = np.dstack(instance_masks)
        class_ids = np.array(class_ids, dtype=np.int32)

        return mask.astype("bool"), asarray(class_ids, dtype='int32')

In [None]:
class FrogStomachConfig(mrcnn.config.Config):
    NAME = CLASS_NAME
    GPU_COUNT = 1
    NUM_CLASSES =  1 + 1


    DETECTION_MIN_CONFIDENCE = 0.9

    IMAGES_PER_GPU = 16
    STEPS_PER_EPOCH = 100
    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128
EPOCHS = 100
class InferenceConfig(FrogStomachConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1
    USE_MINI_MASK = False


In [None]:
config = InferenceConfig()
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    model_inference = mrcnn.model.MaskRCNN(mode="inference",
                                           config=config,
                                           model_dir=WEIGHTS_PATH)



    model_inference.load_weights(filepath=model_inference.find_last(),
                                 by_name=True)


In [None]:


val_dataset = FrogStomachDataset()
val_dataset.load_dataset(
    images_dir=str(FROG_IMAGES),
    dataset_dir=str(FROG_DATASET),
    is_train=False
)
val_dataset.prepare()

In [None]:
from mrcnn import visualize
import random

dataset = val_dataset
config = config

image_id = random.choice(dataset.image_ids)

image, image_meta, gt_class_id, gt_bbox, gt_mask = mrcnn.model.load_image_gt(dataset, config, image_id)

print(image.shape)

info = dataset.image_info[image_id]
print("image ID: {}.{} ({}) {}".format(info["source"], info["id"], image_id,
                                       dataset.image_reference(image_id)))

# Run object detection
results = model_inference.detect([image], verbose=1)

# Display results
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], dataset.class_names, r['scores'], title="Predictions")
