In [8]:
import os
import sys
import math
import random 
import numpy as np
import torch
import pickle
import utils
import cv2
import pandas as pd

In [2]:
sys.path.append(os.path.dirname(os.path.realpath('__file__')))
root_dir = os.path.dirname(os.path.realpath('__file__'))
print(root_dir)

exp_name = 'shapes'
image_height = image_width = 320
train_dir = os.path.join(root_dir, exp_name, 'train')
val_dir = os.path.join(root_dir, exp_name, 'val')
print(train_dir)
print(val_dir)

/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn
/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/train
/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/val


# Shapes Dataset from Matterport

In [3]:
class ShapesDataset(utils.Dataset):
    """Generates the shapes synthetic dataset. The dataset consists of simple
    shapes (triangles, squares, circles) placed randomly on a blank surface.
    The images are generated on the fly. No file access required.
    """
    def __init__(self, out_dir):
        super(ShapesDataset, self).__init__()
        self.out_dir = out_dir
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)
    
    def load_shapes(self, count, height, width):
        """Generate the requested number of synthetic images.
        count: number of images to generate.
        height, width: the size of the generated images.
        """
        # Add classes
        self.add_class("shapes", 1, "square")
        self.add_class("shapes", 2, "circle")
        self.add_class("shapes", 3, "triangle")

        # Add images
        # Generate random specifications of images (i.e. color and
        # list of shapes sizes and locations). This is more compact than
        # actual images. Images are generated on the fly in load_image().
        for i in range(count):
            bg_color, shapes = self.random_image(height, width)
            self.add_image("shapes", image_id=i, path=None,
                           width=width, height=height,
                           bg_color=bg_color, shapes=shapes)

    def image_reference(self, image_id):
        """Return the shapes data of the image."""
        info = self.image_info[image_id]
        if info["source"] == "shapes":
            return info["shapes"]
        else:
            super(self.__class__).image_reference(self, image_id)
            
    def load_image(self, image_id):
        """Generate an image from the specs of the given image ID.
        Typically this function loads the image from a file, but
        in this case it generates the image on the fly from the
        specs in image_info.
        """
        info = self.image_info[image_id]
        bg_color = np.array(info['bg_color']).reshape([1, 1, 3])
        image = np.ones([info['height'], info['width'], 3], dtype=np.uint8)
        image = image * bg_color.astype(np.uint8)
        for shape, color, dims in info['shapes']:
            image = self.draw_shape(image, shape, dims, color)
        return image

    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """
        info = self.image_info[image_id]
        shapes = info['shapes']
        count = len(shapes)
        mask = np.zeros([info['height'], info['width'], count], dtype=np.uint8)
        for i, (shape, _, dims) in enumerate(info['shapes']):
            mask[:, :, i:i+1] = self.draw_shape(mask[:, :, i:i+1].copy(),
                                                shape, dims, 1)
        # Handle occlusions
        occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
        for i in range(count-2, -1, -1):
            mask[:, :, i] = mask[:, :, i] * occlusion
            occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
        # Map class names to class IDs.
        class_ids = np.array([self.class_names.index(s[0]) for s in shapes])
        return mask.astype(np.bool), class_ids.astype(np.int32)

    def draw_shape(self, image, shape, dims, color):
        """Draws a shape from the given specs."""
        # Get the center x, y and the size s
        x, y, s = dims
        if shape == 'square':
            cv2.rectangle(image, (x-s, y-s), (x+s, y+s), color, -1)
        elif shape == "circle":
            cv2.circle(image, (x, y), s, color, -1)
        elif shape == "triangle":
            points = np.array([[(x, y-s),
                                (x-s/math.sin(math.radians(60)), y+s),
                                (x+s/math.sin(math.radians(60)), y+s),
                                ]], dtype=np.int32)
            cv2.fillPoly(image, points, color)
        return image

    def random_shape(self, height, width):
        """Generates specifications of a random shape that lies within
        the given height and width boundaries.
        Returns a tuple of three valus:
        * The shape name (square, circle, ...)
        * Shape color: a tuple of 3 values, RGB.
        * Shape dimensions: A tuple of values that define the shape size
                            and location. Differs per shape type.
        """
        # Shape
        shape = random.choice(["square", "circle", "triangle"])
        # Color
        color = tuple([random.randint(0, 255) for _ in range(3)])
        # Center x, y
        buffer = 20
        y = random.randint(buffer, height - buffer - 1)
        x = random.randint(buffer, width - buffer - 1)
        # Size
        s = random.randint(buffer, height//4)
        return shape, color, (x, y, s)

    def random_image(self, height, width):
        """Creates random specifications of an image with multiple shapes.
        Returns the background color of the image and a list of shape
        specifications that can be used to draw the image.
        """
        # Pick random background color
        bg_color = np.array([random.randint(0, 255) for _ in range(3)])
        # Generate a few random shapes and record their
        # bounding boxes
        shapes = []
        boxes = []
        N = random.randint(1, 4)
        for _ in range(N):
            shape, color, dims = self.random_shape(height, width)
            shapes.append((shape, color, dims))
            x, y, s = dims
            boxes.append([y-s, x-s, y+s, x+s])
        # Apply non-max suppression wit 0.3 threshold to avoid
        # shapes covering each other
        keep_ixs = utils.non_max_suppression(np.array(boxes), np.arange(N), 0.3)
        shapes = [s for i, s in enumerate(shapes) if i in keep_ixs]
        return bg_color, shapes
           
    def save_image_and_mask(self, image_id):
        img = self.load_image(image_id)
        seg, class_id = self.load_mask(image_id)
        out = np.concatenate((img, seg), axis=2)
        out_path = os.path.join(self.out_dir, '{}.npy'.format(image_id))
        self.image_info[image_id]['path'] = out_path
        np.save(out_path, out)
        
        with open(os.path.join(self.out_dir, 'meta_info_{}.pickle'.format(image_id)), 'wb') as handle:
            pickle.dump([out_path, class_id, str(image_id)], handle)
        
def aggregate_meta_info(exp_dir):
    files = [os.path.join(exp_dir, f) for f in os.listdir(exp_dir) if 'meta_info' in f]
    df = pd.DataFrame(columns=['path', 'class_id', 'pid'])
    
    for f in files:
        with open(f, 'rb') as handle:
            df.loc[len(df)] = pickle.load(handle)
    df.to_pickle(os.path.join(exp_dir, 'info_df.pickle'))
    print ("aggregated meta info to df with length", len(df))


In [4]:
# Training dataset
dataset_train = ShapesDataset(train_dir)
dataset_train.load_shapes(500, image_height, image_width)
dataset_train.prepare()

# Validation dataset
dataset_val = ShapesDataset(val_dir)
dataset_val.load_shapes(50, image_height, image_width)
dataset_val.prepare()

In [6]:
image_ids = np.random.choice(dataset_train.image_ids, 4)
for image_id in image_ids:
    image = dataset_train.load_image(image_id)
    mask, class_ids = dataset_train.load_mask(image_id)
    print(image.shape)
    print(mask.shape)
    print(dataset_train.image_info[image_id],'\n')


(320, 320, 3)
(320, 320, 3)
{'id': 300, 'source': 'shapes', 'path': None, 'width': 320, 'height': 320, 'bg_color': array([151, 230, 249]), 'shapes': [('triangle', (181, 84, 60), (97, 46, 50)), ('circle', (24, 46, 238), (115, 251, 47)), ('circle', (147, 95, 181), (263, 60, 56))]} 

(320, 320, 3)
(320, 320, 3)
{'id': 330, 'source': 'shapes', 'path': None, 'width': 320, 'height': 320, 'bg_color': array([185,  68, 132]), 'shapes': [('circle', (255, 68, 42), (263, 41, 78)), ('circle', (128, 120, 22), (291, 151, 61)), ('square', (25, 247, 219), (264, 253, 55))]} 

(320, 320, 3)
(320, 320, 1)
{'id': 124, 'source': 'shapes', 'path': None, 'width': 320, 'height': 320, 'bg_color': array([165, 150,  23]), 'shapes': [('triangle', (145, 114, 70), (276, 265, 73))]} 

(320, 320, 3)
(320, 320, 1)
{'id': 322, 'source': 'shapes', 'path': None, 'width': 320, 'height': 320, 'bg_color': array([164, 156, 140]), 'shapes': [('square', (98, 220, 215), (190, 57, 56))]} 



In [7]:
for image_id in image_ids:
    dataset_train.save_image_and_mask(image_id)
    print(dataset_train.image_info[image_id])

{'id': 300, 'source': 'shapes', 'path': '/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/train/300.npy', 'width': 320, 'height': 320, 'bg_color': array([151, 230, 249]), 'shapes': [('triangle', (181, 84, 60), (97, 46, 50)), ('circle', (24, 46, 238), (115, 251, 47)), ('circle', (147, 95, 181), (263, 60, 56))]}
{'id': 330, 'source': 'shapes', 'path': '/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/train/330.npy', 'width': 320, 'height': 320, 'bg_color': array([185,  68, 132]), 'shapes': [('circle', (255, 68, 42), (263, 41, 78)), ('circle', (128, 120, 22), (291, 151, 61)), ('square', (25, 247, 219), (264, 253, 55))]}
{'id': 124, 'source': 'shapes', 'path': '/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/train/124.npy', 'width': 320, 'height': 320, 'bg_color': array([165, 150,  23]), 'shapes': [('triangle', (145, 114, 70), (276, 265, 73))]}
{'id': 322, 'source': 'shapes', 'path': '/Users/jdeguzman/Documents/MRCNN-3D/Mask_RCNN-3D/mrcnn/shapes/trai

# Exec.py

In [None]:
"""execution script."""
import argparse
import os
import time
import torch

# import utils.exp_utils as utils
import utils
from evaluator import Evaluator
from predictor import Predictor
from plotting import plot_batch_prediction

def load_image_gt(dataset, config, image_id, augment=False, use_mini_mask=False):
    """Load and return ground truth data for an image (image, mask, bounding boxes).
    
    augment: If true, apply random image augmentation. Currently, only horizontal 
        flipping is offered.
    use_mini_mask: If False, returns full-size masks that are the same height
        and width as the original image. These can be big, for example
        1024x1024x100 (for 100 instances). Mini masks are smaller, typically,
        224x224 and are generated by extracting the bounding box of the
        object and resizing it to MINI_MASK_SHAPE.
    
    Returns:
    image: [height, width, 3]
    shape: the original shape of the image before resizing and cropping.
    class_ids: [instance_count] Integer class IDs
    bbox: [instance_count, (y1, x1, y2, x2)]
    mask: [height, width, instance_count]. The height and width are those
        of the image unless use_mini_mask is True, in which case they are
        defined in MINI_MASK_SHAPE.
    """
    # Load image and mask
    image = dataset.load_image(image_id)
    mask, class_ids = dataset.load_mask(image_id)
    original_shape = image.shape
    image, window, scale, padding, crop = utils.resize_image(
        image,
        min_dim=config.IMAGE_MIN_DIM,
        min_scale=config.IMAGE_MIN_SCALE,
        max_dim=config.IMAGE_MAX_DIM,
        mode=config.IMAGE_RESIZE_MODE)
    mask = utils.resize_mask(mask, scale, padding, crop)

    # Random horizontal flips.
    if augment:
        logging.warning("'augment' is deprecated. Use 'augmentation' instead.")
        if random.randint(0, 1):
            image = np.fliplr(image)
            mask = np.fliplr(mask)

    # Note that some boxes might be all zeros if the corresponding mask got cropped out.
    # and here is to filter them out
    _idx = np.sum(mask, axis=(0, 1)) > 0
    mask = mask[:, :, _idx]
    class_ids = class_ids[_idx]
    
    # Bounding boxes. Note that some boxes might be all zeros
    # if the corresponding mask got cropped out.
    # bbox: [num_instances, (y1, x1, y2, x2)]
    bbox = utils.extract_bboxes(mask)

    # Active classes
    # Different datasets have different classes, so track the
    # classes supported in the dataset of this image.
    active_class_ids = np.zeros([dataset.num_classes], dtype=np.int32)
    source_class_ids = dataset.source_class_ids[dataset.image_info[image_id]["source"]]
    active_class_ids[source_class_ids] = 1

    # Resize masks to smaller size to reduce memory usage
    if use_mini_mask:
        mask = utils.minimize_mask(bbox, mask, config.MINI_MASK_SHAPE)

    # Image meta data
#     image_meta = compose_image_meta(image_id, original_shape, image.shape,
#                                     window, scale, active_class_ids)

    return image, class_ids, bbox, mask

def train(logger, dataset_train):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
        cf.dim, cf.fold, cf.exp_dir, cf.model))

    net = model.net(cf, logger).cuda()
    optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)

    starting_epoch = 1

    # prepare monitoring
    monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf)

    if cf.resume_to_checkpoint:
        starting_epoch, monitor_metrics = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch))

    logger.info('loading dataset and initializing batch generators...')
#     batch_gen = data_loader.get_train_generators(cf, logger)
    
    
    for epoch in range(starting_epoch, cf.num_epochs + 1):
        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cf.learning_rate[epoch - 1]

        start_time = time.time()

        net.train() # function built into nn.module
        train_results_list = []
        num_train_batches = len(dataset_train.image_ids)
        for bix in range(num_train_batches):
#             batch = next(batch_gen['train'])
            batch = {}
            image, class_id, bbox, mask = load_image_gt(dataset_train, config, bix)
            
            batch['data'] = image
            batch['roi_labels'] = class_id
            batch['bb_target'] = bbox
            batch['roi_masks'] = mask
            batch['pid'] = dataset_train.image_ids[0]
            
            tic_fw = time.time()
            results_dict = net.train_forward(batch) ### TODO: ensure this works with mods
            tic_bw = time.time()
            optimizer.zero_grad()
            results_dict['torch_loss'].backward()
            optimizer.step()
            logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                        .format(bix + 1, num_train_batches, epoch, tic_bw - tic_fw,
                                time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'])
            train_results_list.append([results_dict['boxes'], batch['pid']])
            monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values'])

        _, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                for _ in range(batch_gen['n_val']):
                    batch = next(batch_gen[cf.val_mode])
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch, is_validation=True)
                    val_results_list.append([results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values'])

                _, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
                model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)

            # update monitoring and prediction plots
            TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format(
                epoch, epoch_time, train_time, epoch_time-train_time))
            batch = next(batch_gen['val_sampling'])
            results_dict = net.train_forward(batch, is_validation=True)
            logger.info('plotting predictions from validation sampling.')
            plot_batch_prediction(batch, results_dict, cf)


# def test(logger):
#     """
#     perform testing for a given fold (or hold out set). save stats in evaluator.
#     """
#     logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
#     net = model.net(cf, logger).cuda()
#     test_predictor = Predictor(cf, net, logger, mode='test')
#     test_evaluator = Evaluator(cf, logger, mode='test')
#     batch_gen = data_loader.get_test_generator(cf, logger)
#     test_results_list = test_predictor.predict_test_set(batch_gen, return_results=True)
#     test_evaluator.evaluate_predictions(test_results_list)
#     test_evaluator.score_test_df()


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str,  default='train_test',
                        help='one out of: train / test / train_test / analysis / create_exp')
    parser.add_argument('--folds', nargs='+', type=int, default=None,
                        help='None runs over all folds in CV. otherwise specify list of folds.')
    parser.add_argument('--exp_dir', type=str, default='/path/to/experiment/directory',
                        help='path to experiment dir. will be created if non existent.')
    parser.add_argument('--server_env', default=False, action='store_true',
                        help='change IO settings to deploy models on a cluster.')
    parser.add_argument('--slurm_job_id', type=str, default=None, help='job scheduler info')
    parser.add_argument('--use_stored_settings', default=False, action='store_true',
                        help='load configs from existing exp_dir instead of source dir. always done for testing, '
                             'but can be set to true to do the same for training. useful in job scheduler environment, '
                             'where source code might change before the job actually runs.')
    parser.add_argument('--resume_to_checkpoint', type=str, default=None,
                        help='if resuming to checkpoint, the desired fold still needs to be parsed via --folds.')
    parser.add_argument('--exp_source', type=str, default='experiments/toy_exp',
                        help='specifies, from which source experiment to load configs and data_loader.')

    args = parser.parse_args()
    folds = args.folds
    resume_to_checkpoint = args.resume_to_checkpoint

    if args.mode == 'train' or args.mode == 'train_test':

        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, args.use_stored_settings)
        cf.slurm_job_id = args.slurm_job_id
        model = utils.import_module('model', cf.model_path)
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
        if folds is None:
            folds = range(cf.n_cv_splits)

        for fold in folds:
            cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
            cf.fold = fold
            cf.resume_to_checkpoint = resume_to_checkpoint
            if not os.path.exists(cf.fold_dir):
                os.mkdir(cf.fold_dir)
            logger = utils.get_logger(cf.fold_dir)
            train(logger)
            cf.resume_to_checkpoint = None
            if args.mode == 'train_test':
                test(logger)

    elif args.mode == 'test':

        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
        cf.slurm_job_id = args.slurm_job_id
        model = utils.import_module('model', cf.model_path)
        data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
        if folds is None:
            folds = range(cf.n_cv_splits)

        for fold in folds:
            cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
            logger = utils.get_logger(cf.fold_dir)
            cf.fold = fold
            test(logger)

    # load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
    elif args.mode == 'analysis':
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
        logger = utils.get_logger(cf.exp_dir)

        if cf.hold_out_test_set:
            cf.folds = args.folds
            predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
            results_list = predictor.load_saved_predictions(apply_wbc=True)
            utils.create_csv_output(results_list, cf, logger)

        else:
            if folds is None:
                folds = range(cf.n_cv_splits)
            for fold in folds:
                cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
                cf.fold = fold
                predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
                results_list = predictor.load_saved_predictions(apply_wbc=True)
                logger.info('starting evaluation...')
                evaluator = Evaluator(cf, logger, mode='test')
                evaluator.evaluate_predictions(results_list)
                evaluator.score_test_df()

    # create experiment folder and copy scripts without starting job.
    # usefull for cloud deployment where configs might change before job actually runs.
    elif args.mode == 'create_exp':
        cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, use_stored_settings=True)
        logger = utils.get_logger(cf.exp_dir)
        logger.info('created experiment directory at {}'.format(args.exp_dir))

    else:
        raise RuntimeError('mode specified in args is not implemented...')

# Mask RCNN

In [17]:
############################################################
#  Mask R-CNN Class
############################################################

class net(nn.Module):
    def __init__(self, cf, logger):

        super(net, self).__init__()
        self.cf = cf
        self.logger = logger
        self.build()

        if self.cf.weight_init is not None:
            logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
            mutils.initialize_weights(self)
        else:
            logger.info("using default pytorch weight init")


    def build(self):
        """Build Mask R-CNN architecture."""

        # Image size must be dividable by 2 multiple times.
        h, w = self.cf.patch_size[:2]
        if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5):
            raise Exception("Image size must be dividable by 2 at least 5 times "
                            "to avoid fractions when downscaling and upscaling."
                            "For example, use 256, 320, 384, 448, 512, ... etc. ")
        if len(self.cf.patch_size) == 3:
            d = self.cf.patch_size[2]
            if d / 2**3 != int(d / 2**3):
                raise Exception("Image z dimension must be dividable by 2 at least 3 times "
                                "to avoid fractions when downscaling and upscaling.")



        # instanciate abstract multi dimensional conv class and backbone class.
        conv = mutils.NDConvGenerator(self.cf.dim)
        backbone = utils.import_module('bbone', self.cf.backbone_path)

        # build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head
        self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
        self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
        self.fpn = backbone.FPN(self.cf, conv)
        self.rpn = RPN(self.cf, conv)
        self.classifier = Classifier(self.cf, conv)
        self.mask = Mask(self.cf, conv)


    def train_forward(self, batch, is_validation=False):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                'monitor_values': dict of values to be monitored.
        """
        img = batch['data']
        gt_class_ids = batch['roi_labels']
        gt_boxes = batch['bb_target']
        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
        gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))]


        img = torch.from_numpy(img).float().cuda()
        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]

        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
        rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img)
        mrcnn_class_logits, mrcnn_pred_deltas, mrcnn_pred_mask, target_class_ids, mrcnn_target_deltas, target_mask,  \
        sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes, gt_masks)

        # loop over batch
        for b in range(img.shape[0]):
            if len(gt_boxes[b]) > 0:

                # add gt boxes to output list for monitoring.
                for ix in range(len(gt_boxes[b])):
                    box_results_list[b].append({'box_coords': batch['bb_target'][b][ix],
                                                'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'})

                # match gt boxes with anchors to generate targets for RPN losses.
                rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])

                # add positive anchors used for loss to output list for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})

            else:
                rpn_match = np.array([-1]*self.np_anchors.shape[0])
                rpn_target_deltas = np.array([0])

            rpn_match = torch.from_numpy(rpn_match).cuda()
            rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()

            # compute RPN losses.
            rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match, rpn_class_logits[b], self.cf.shem_poolsize)
            rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match)
            batch_rpn_class_loss += rpn_class_loss / img.shape[0]
            batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0]

            # add negative anchors used for loss to output list for monitoring.
            neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:])
            for n in neg_anchors:
                box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})

            # add highest scoring proposals to output list for monitoring.
            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})

        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
        if 0 not in sample_proposals.shape:
            rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
            for ix, r in enumerate(rois):
                box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                            'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})

        batch_rpn_class_loss = batch_rpn_class_loss
        batch_rpn_bbox_loss = batch_rpn_bbox_loss

        # compute mrcnn losses.
        mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits)
        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids)

        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
        # In this case, the mask_loss is taken out of training.
        if not self.cf.frcnn_mode:
            mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids)
        else:
            mrcnn_mask_loss = torch.FloatTensor([0]).cuda()

        loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + mrcnn_mask_loss

        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
        dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]



        # run unmolding of predictions for monitoring and merge all results to one dictionary.
        return_masks = self.cf.return_masks_in_val if is_validation else False
        results_dict = get_results(self.cf, img.shape, detections, detection_masks,
                                   box_results_list, return_masks=return_masks)

        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()}

        results_dict['logger_string'] =  \
            "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " \
            "mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(),
                                                     batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(),
                                                     mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount)

        return results_dict


    def test_forward(self, batch, return_masks=True):
        """
        test method. wrapper around forward pass of network without usage of any ground truth information.
        prepares input data for processing and stores outputs in a dictionary.
        :param batch: dictionary containing 'data'
        :param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off).
        :return: results_dict: dictionary with keys:
               'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                       [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
               'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
        """
        img = batch['data']
        img = torch.from_numpy(img).float().cuda()
        _, _, _, detections, detection_masks = self.forward(img)
        results_dict = get_results(self.cf, img.shape, detections, detection_masks, return_masks=return_masks)
        return results_dict


    def forward(self, img, is_training=True):
        """
        :param img: input images (b, c, y, x, (z)).
        :return: rpn_pred_logits: (b, n_anchors, 2)
        :return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
        :return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
        :return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
        :return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
        """
        # extract features.
        fpn_outs = self.fpn(img)
        rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels]
        self.mrcnn_feature_maps = rpn_feature_maps

        # loop through pyramid layers and apply RPN.
        layer_outputs = []  # list of lists
        for p in rpn_feature_maps:
            layer_outputs.append(self.rpn(p))

        # concatenate layer outputs.
        # convert from list of lists of level outputs to list of lists of outputs across levels.
        # e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
        outputs = list(zip(*layer_outputs))
        outputs = [torch.cat(list(o), dim=1) for o in outputs]
        rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs

        # generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier.
        proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference
        batch_rpn_rois, batch_proposal_boxes = proposal_layer(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf)

        # merge batch dimension of proposals while storing allocation info in coordinate dimension.
        batch_ixs = torch.from_numpy(np.repeat(np.arange(batch_rpn_rois.shape[0]), batch_rpn_rois.shape[1])).float().cuda()
        rpn_rois = batch_rpn_rois.view(-1, batch_rpn_rois.shape[2])
        self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1)

        # this is the first of two forward passes in the second stage, where no activations are stored for backprop.
        # here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.)
        # for inference/monitoring as well as sampling of rois for the loss functions.
        # processed in chunks of roi_chunk_size to re-adjust to gpu-memory.
        chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size)
        class_logits_list, bboxes_list = [], []
        with torch.no_grad():
            for chunk in chunked_rpn_rois:
                chunk_class_logits, chunk_bboxes = self.classifier(self.mrcnn_feature_maps, chunk)
                class_logits_list.append(chunk_class_logits)
                bboxes_list.append(chunk_bboxes)
        batch_mrcnn_class_logits = torch.cat(class_logits_list, 0)
        batch_mrcnn_bbox = torch.cat(bboxes_list, 0)
        self.batch_mrcnn_class_scores = F.softmax(batch_mrcnn_class_logits, dim=1)

        # refine classified proposals, filter and return final detections.
        detections = refine_detections(rpn_rois, self.batch_mrcnn_class_scores, batch_mrcnn_bbox, batch_ixs, self.cf, )

        # forward remaining detections through mask-head to generate corresponding masks.
        scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2
        scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda()


        detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale
        with torch.no_grad():
            detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes)

        return [rpn_pred_logits, rpn_pred_deltas, batch_proposal_boxes, detections, detection_masks]


    def loss_samples_forward(self, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks):
        """
        this is the second forward pass through the second stage (features from stage one are re-used).
        samples few rois in detection_target_layer and forwards only those for loss computation.
        :param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels.
        :param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates.
        :param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c)
        :return: sample_logits: (n_sampled_rois, n_classes) predicted class scores.
        :return: sample_boxes: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement.
        :return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal.
        :return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals.
        :return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement.
        :return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals.
        :return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting.
        """
        # sample rois for loss and get corresponding targets for all Mask R-CNN head network losses.
        sample_ix, sample_target_class_ids, sample_target_deltas, sample_target_mask = \
            detection_target_layer(self.rpn_rois_batch_info, self.batch_mrcnn_class_scores,
                                   batch_gt_class_ids, batch_gt_boxes, batch_gt_masks, self.cf)

        # re-use feature maps and RPN output from first forward pass.
        sample_proposals = self.rpn_rois_batch_info[sample_ix]
        if 0 not in sample_proposals.size():
            sample_logits, sample_boxes = self.classifier(self.mrcnn_feature_maps, sample_proposals)
            sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals)
        else:
            sample_logits = torch.FloatTensor().cuda()
            sample_boxes = torch.FloatTensor().cuda()
            sample_mask = torch.FloatTensor().cuda()

        return [sample_logits, sample_boxes, sample_mask, sample_target_class_ids, sample_target_deltas,
                sample_target_mask, sample_proposals]

NameError: name 'nn' is not defined