# Mask R-CNN - Train on Shapes Dataset


This notebook shows how to train Mask RCNN sequentially so that one can train starting with synthetic data, then retrain models on manual segmentations for better performance.

In [None]:
import os

# Set which GPU to use dynamically (e.g., "4" for GPU 4)
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

%load_ext tensorboard
import tensorflow as tf


import sys
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm

# Root directory of the project
ROOT_DIR = os.path.abspath("../")
DATA_DIR = os.path.abspath("../../../Data")
print("ROOT DIR:", ROOT_DIR)
print("Data Dir:",DATA_DIR)
# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library
from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log


print("sys.prefix:",sys.prefix)
print("sys.executable:",sys.executable)
print("sys.path:", sys.path)
%matplotlib inline 

mrcnn_dir = os.path.dirname(modellib.__file__)
model_file_path = os.path.join(mrcnn_dir,'model.py')
print("mrcnn directory:",mrcnn_dir)
print("Path to model.py:", model_file_path)

# Directory to save logs and trained model
MODEL_DIR = os.path.join(DATA_DIR, "logs")
print("MODEL DIRECTORY:", MODEL_DIR)

# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
# Download COCO trained weights from Releases if needed
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)
    
print("COCO MODEL PATH:", COCO_MODEL_PATH)


#import tensorboard stuff
#%load_ext tensorboard
#import tensorflow as tf
import datetime



## Configurations

In [None]:
class SAGEConfig(Config):
    """Configuration for training on the toy shapes dataset.
    Derives from the base Config class and overrides values specific
    to the toy shapes dataset.
    """
    # Give the configuration a recognizable name
    NAME = "SAGE"

    # Train on 1 GPU and 8 images per GPU. We can put multiple images on each
    # GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
    GPU_COUNT = 1
    IMAGES_PER_GPU = 2

    # Number of classes (including background)
    NUM_CLASSES = 1 + 1 #+ 1  # background + particle + cluster 

  # Input image resizing
    # Generally, use the "square" resizing mode for training and predicting
    # and it should work well in most cases. In this mode, images are scaled
    # up such that the small side is = IMAGE_MIN_DIM, but ensuring that the
    # scaling doesn't make the long side > IMAGE_MAX_DIM. Then the image is
    # padded with zeros to make it a square so multiple images can be put
    # in one batch.
    # Available resizing modes:
    # none:   No resizing or padding. Return the image unchanged.
    # square: Resize and pad with zeros to get a square image
    #         of size [max_dim, max_dim].
    # pad64:  Pads width and height with zeros to make them multiples of 64.
    #         If IMAGE_MIN_DIM or IMAGE_MIN_SCALE are not None, then it scales
    #         up before padding. IMAGE_MAX_DIM is ignored in this mode.
    #         The multiple of 64 is needed to ensure smooth scaling of feature
    #         maps up and down the 6 levels of the FPN pyramid (2**6=64).
    # crop:   Picks random crops from the image. First, scales the image based
    #         on IMAGE_MIN_DIM and IMAGE_MIN_SCALE, then picks a random crop of
    #         size IMAGE_MIN_DIM x IMAGE_MIN_DIM. Can be used in training only.
    #         IMAGE_MAX_DIM is not used in this mode.
    IMAGE_RESIZE_MODE = "square"
    IMAGE_MIN_DIM = 1024
    IMAGE_MAX_DIM = 1024
    # Minimum scaling ratio. Checked after MIN_IMAGE_DIM and can force further
    # up scaling. For example, if set to 2 then images are scaled up to double
    # the width and height, or more, even if MIN_IMAGE_DIM doesn't require it.
    # However, in 'square' mode, it can be overruled by IMAGE_MAX_DIM.
    IMAGE_MIN_SCALE = 0
    # Number of color channels per image. RGB = 3, grayscale = 1, RGB-D = 4
    # Changing this requires other changes in the code. See the WIKI for more
    # details: https://github.com/matterport/Mask_RCNN/wiki
    IMAGE_CHANNEL_COUNT = 3 #images are grayscale(8bit) so may need to change to 1
    
    MEAN_PIXEL = np.array([123.7, 116.8, 103.9]) #may need to change to one value for grayscale

    # Default 
    RPN_ANCHOR_SCALES = (32, 64, 128, 256,512) #(16,32,64,128,256)#  # anchor side in pixels
                #opt to change to smaller values to recognize smaller particles as well

    # Reduce training ROIs per image because the images are small and have
    # few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
    TRAIN_ROIS_PER_IMAGE = 256
        #increase from 128 to 256 to allow attempt more 
    DETECTION_MAX_INSTANCES = 200 #increase from 100 to 200
    # Use a small epoch since the data is simple
    STEPS_PER_EPOCH =7 # 76 for 200 #20 for 60 #188 for 750
                    
    #non-maximum suppression threshold for detection
    DETECTION_MIN_CONFIDENCE = 0.7
    # use small validation steps since the epoch is small
    VALIDATION_STEPS = 2#num of validation images/batch size
    
    #EARLY STOPPING
    EARLY_STOPPING_MONITOR = 'val_loss'
    EARLY_STOPPING_PATIENCE = 10 #number of epochs with no improvement required to stop
    ES_RESTORE_BEST_WEIGHTS =True
    ES_MODE= "min"
    ES_VERBOSE = 0
    
config = SAGEConfig()
config.display()

## Notebook Preferences

In [None]:
#def get_ax(rows=1, cols=1, size=8):
#    """Return a Matplotlib Axes array to be used in
#    all visualizations in the notebook. Provide a
#    central point to control graph sizes.   
#    Change the default size attribute to control the size
#    of rendered images
#    """
#    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
#    return ax
    

## Dataset

Create a synthetic dataset

Extend the Dataset class and add a method to load the shapes dataset, `load_shapes()`, and override the following methods:

* load_image()
* load_mask()
* image_reference()

In [None]:

class SAGEDataset(utils.Dataset):
    """Load Dataset
    """
    def __init__(self,images_dir, particle_masks_dir, cluster_masks_dir, load_particle=True, load_cluster=True):
        super().__init__()
        self.images_dir = images_dir
        self.particle_masks_dir = particle_masks_dir
        self.cluster_masks_dir = cluster_masks_dir
        self.load_particle = load_particle  # Correctly initialize the attribute
        self.load_cluster = load_cluster  # Correctly initialize the attribute
   
        if load_particle and load_cluster:
            self.class_names = ["particle", "cluster"]
            self.add_class("SAGE",1,"particle") #add particle class
            self.add_class("SAGE",2,"cluster") #add cluster class 
        elif load_particle:
            self.class_names=["particle"]
            self.add_class("SAGE",1,"particle") #add particle class
        elif load_cluster:
            self.class_names=["cluster"]
            self.add_class("SAGE",1,"cluster") #add cluster class 
            
        #print(self.class_names)
    
    def load_dataset(self):
        """Load images and masks from specified directories."""
        #load images
        image_filenames = [f for f in os.listdir(self.images_dir) if f.endswith('.png')]
        #print(f"unsorted: {image_filenames}")
        #sort them by number
        image_filenames.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))
        #print(f"Sorted: {image_filenames}")
        
        for image_id, filename in enumerate(tqdm(image_filenames, "Adding images", dynamic_ncols=True)):
            image_path = os.path.join(self.images_dir, filename)
            image_no = filename.split('_')[1] #get image number
            basename = os.path.splitext(filename)[0] 
            #print(basename)
            self.add_image("SAGE", image_id=image_id, path=image_path,basename=basename,width=None, height=None)
            
        #load masks for each image
        
        for image_id in tqdm(range(len(self.image_info)), desc="Loading masks for images", dynamic_ncols=True):
            self.load_mask(image_id)

        

    def load_image(self, image_id):
        """ Load an image from the dataset."""
        info = self.image_info[image_id]
        image = cv2.imread(info['path'])
        return image

    def image_reference(self, image_id):
        """Return the particle data of the image."""
        info = self.image_info[image_id]
        if info["source"] == "SAGE":
            return info["path"]
        else:
            return super(self.__class__).image_reference(self, image_id)

    def load_mask(self, image_id):
        """load instance masks for the particle of the given image ID."""
        
        info = self.image_info[image_id]
        masks = []
        class_ids = []
        
        if self.load_particle and self.load_cluster:
            masks_particle, class_ids_particle = self._load_class_masks(info, self.particle_masks_dir, class_id=1,pattern='particle')
            #if masks_particle:
                #print(f"Loaded {len(masks_particle)} particle masks for Image ID {image_id}.")
            masks.extend(masks_particle)
            class_ids.extend(class_ids_particle)
            
            masks_cluster, class_ids_cluster = self._load_class_masks(info, self.cluster_masks_dir, class_id=2,pattern='cluster')
            #if masks_cluster:
               # print(f"Loaded {len(masks_cluster)} cluster masks for Image ID {image_id}.")
            masks.extend(masks_cluster)
            class_ids.extend(class_ids_cluster)
            #print("Both particle and cluster masks")
            
        #particle masks
        elif self.load_particle:
            masks_particle, class_ids_particle = self._load_class_masks(info, self.particle_masks_dir, class_id=1,pattern='particle')
            #if masks_particle:
                #print(f"Loaded {len(masks_particle)} particle masks for Image ID {image_id}.")
            masks.extend(masks_particle)
            class_ids.extend(class_ids_particle)
            #rint("only particle masks")
            
        #cluster masks
        elif self.load_cluster:
            masks_cluster, class_ids_cluster = self._load_class_masks(info, self.cluster_masks_dir, class_id='1',pattern='cluster')
            #if masks_cluster:
               # print(f"Loaded {len(masks_cluster)} cluster masks for Image ID {image_id}.")
            masks.extend(masks_cluster)
            class_ids.extend(class_ids_cluster)
           #print(("only cluster masks"))
            
        #combine masks into 3d array
        if masks:
            combined_mask = np.stack(masks, axis =-1)
            return combined_mask, np.array(class_ids, dtype=np.int32)
                      
        #print(f" No masks found for image ID {image_id}.")
        return np.zeros((0,0), dtype=np.bool_),np.zeros((0,),dtype=np.int32)

    
    def _load_class_masks(self,info, masks_dir, class_id, pattern):
        """Load msks for a specific class based on a pattern"""
        
        masks = []
        class_ids = []
        
        #construct mask filename based on image filename 
        _, image_filename = os.path.split(info['path']) 
        image_no = image_filename.split('_')[1].replace('.png','') #extract the base name without the extension to form mask filename
        #print(image_no)
        #print(f"Loading masks for image number:{image_no}")
        
        #load all masks for the current image
        i = 0 
        
        if pattern =='particle':
            
            i=0 
            first_mask_found=False
            
            
            while True: 
                mask_filename = f"mask_{image_no}_{i:06d}.png"
                mask_path = os.path.join(masks_dir, mask_filename)
                
                if os.path.exists(mask_path):
                    #print(f"Found mask file: {mask_path}")
                    mask = cv2.imread(mask_path,cv2.IMREAD_GRAYSCALE) #load mask
                    first_mask_found = True
                    if mask is not None:
                        masks.append(mask.astype(np.bool_))
                        class_ids.append(class_id)
                    i += 1
                   
                elif not first_mask_found:
                    i=1
                    continue
                else:
                    #print(f"Mask file not found: {mask_path}")
                    break
        elif pattern == 'cluster':
            # For clusters, load only one mask
            mask_filename = f"mask_{image_no}.png"
            mask_path = os.path.join(masks_dir, mask_filename)
            #print(f"Checking path for mask: {mask_path}")
        
            if os.path.exists(mask_path):
                #print(f"Found mask file: {mask_path}")
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) # Load mask
                if mask is not None:
                    masks.append(mask.astype(np.bool_))
                    class_ids.append(class_id)
            else:
                print(f"Mask file not found: {mask_path}")
        
        return masks, class_ids
                                 
        
                                
                                 
        

## Load Datasets:

Here, load the image sets that will be used for training. Multiple can be loaded for easy access and transfer learning


In [None]:
def load_and_register_dataset(dataset_name, results_dir , load_particle=True, load_cluster=False, create_dirs=True):
    
    
    images_anlyz_dir =  os.path.join(ROOT_DIR, 'SAGE', 'data', dataset_name)    
    particle_masks_anlyz_dir = os.path.join(images_anlyz_dir, 'particle')
    cluster_masks_anlyz_dir = os.path.join(images_anlyz_dir, 'cluster')
    
    if create_dirs and results_dir:                      
        dataset_results_dir = create_dataset_results_dirs(dataset_name, results_dir)
    #initalize and load dataset
    
    dataset_analyze = SAGEDataset(images_anlyz_dir,particle_masks_anlyz_dir, cluster_masks_anlyz_dir,
                                load_particle=load_particle, load_cluster=load_cluster)
    
    dataset_analyze.load_dataset()
    dataset_analyze.prepare()
    
    #add loaded dataset to dictionary
    datasets[dataset_name] = dataset_analyze
    
    return dataset_analyze
    
def print_loaded_datasets(loaded_datasets):
    """
    Prints the names of all the loaded datasets.
    
    Args:
        loaded_datasets (dict): Dictionary of loaded datasets where keys are dataset names.
        
    Usage:
        print_loaded_datasets(loaded_datasets)
    """
    if not loaded_datasets:
        print("No datasets loaded.")
        return
    
    print("Loaded Datasets:")
    for dataset_name in loaded_datasets.keys():
        print(f"- {dataset_name}")

In [None]:
#create dict to store datasets
datasets = {}

Load first dataset training and validation dataset 
here we use NS40_train and NS40_val to train first model on synthetically generated images.

Since we do not need to create a results directory for these, set results_dir = None and create_dirs = False


In [None]:
NS40_train = load_and_register_dataset('NS40_train', results_dir=None, create_dirs=False)
NS40_val = load_and_register_dataset('NS40_val', results_dir=None, create_dirs=False)

We will also load datasets for any subsequent training, such as if we want to train using manual segmentations after synthetic.

After synthetic model (SAGE$_0$) is trained, we will train that model again using a set of manual segmentations, (D1e1_train and D1e1_val). This will give us SAGE$_1$.

We can train again on another set of the same images, where each image was annotated by a different analyst than in D1e1. These are D2e1_train and D2e1_val. Training SAGE$_1$ on this set gives SAGE$_2$

In [None]:
#First set of manual segmentations
D1e1_train = load_and_register_dataset('D1e1_train', results_dir=None, create_dirs=False)
D1e1_val = load_and_register_dataset('D1e1_val', results_dir=None, create_dirs=False)

#Second Set

D2e1_train = load_and_register_dataset('D2e1_train', results_dir=None, create_dirs=False)
D2e1_val = load_and_register_dataset('D2e1_val', results_dir=None, create_dirs=False)


In [None]:
for name, dataset in datasets.items():
    print(f"\n--- Dataset: {name} ---")
    if len(dataset.image_ids) ==0: 
        print("No images loaded")
        continue
    image_id = dataset.image_ids[0]
    print(f"Image ID:{image_id}")
    
    image = dataset.load_image(image_id)
    mask, class_ids = dataset.load_mask(image_id)
    
    print(f"Mask shape for Image ID {image_id}: {mask.shape}")
    visualize.display_top_masks(image, mask, class_ids, dataset.class_names)


## Create Model

Now we will create and train our first model, trained on the synthetic images. 



In [None]:
print(MODEL_DIR)

config.STEPS_PER_EPOCH =2
config.VALIDATION_STEPS = 1

config.display()

In [None]:
# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)



Here we will tell it which weights to start with. 

Since we are starting from 'scratch', we will choose COCO weights as a base

In [None]:
# Which weights to start with?


#if you want to give it a specific path to train from:
manual_path = os.path.join(DATA_DIR, "logs/coco_M1e1_part/mask_rcnn_spectra_0034.h5")

init_with = "coco"  # imagenet, coco, or last


if init_with == "imagenet":
    model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
    # Load weights trained on MS COCO, but skip layers that
    # are different due to the different number of classes
    # See README for instructions to download the COCO weights
    model.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                                "mrcnn_bbox", "mrcnn_mask"])
elif init_with =="manual":
    #load the weights from specified path
    model.load_weights(manual_path, by_name=True)

elif init_with == "last":
    # Load the last model you trained and continue training
    model.load_weights(model.find_last(), by_name=True)

In [None]:
config.display()

## Training - Synthic Model (SAGE$_0$)



Train in two stages:
1. Only the heads. Here we're freezing all the backbone layers and training only the randomly initialized layers (i.e. the ones that we didn't use pre-trained weights from MS COCO). To train only the head layers, pass `layers='heads'` to the `train()` function.

2. Fine-tune all layers. For this simple example it's not necessary, but we're including it to show the process. Simply pass `layers="all` to train all layers.



In [None]:
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
#model.train(cluster_dataset_train, cluster_dataset_val, 
#            learning_rate=config.LEARNING_RATE, 
#            epochs=100, 
#            layers='heads')


#Define what datasets to train/validate with

dataset_train = datasets.get('NS40_train', None)
dataset_val = datasets.get('NS40_val', None)

#get epoch that it stopped at
model.train(dataset_train,dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=100, 
            layers='heads')

Now that the heads are trained, train full layers

In [None]:
# Fine tune all layers
# Passing layers="all" trains all layers. You can also 
# pass a regular expression to select which layers to
# train by name pattern.

model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE / 10, #changed to /5
            epochs=200, 
            layers="all")

In [None]:
import io
def save_displayed_config(config, log_dir, train_path, val_path):
    captured_output = io.StringIO()
    sys.stdout = captured_output
    
    config.display()
    
    print(f"Training Data Path: {train_path}")
    print(f"Validation Data Path: {val_path}")
    sys.stdout=sys.__stdout__
    
    config_path_txt = os.path.join(log_dir, "config.txt")
    with open(config_path_txt, "w") as f:
        f.write(captured_output.getvalue())

name = model.find_last()
print(name)
model_name=name.partition('/logs/')[2]
print(model_name)
folder_name = model_name.split(os.sep)[0]
print(folder_name)
                
path = os.path.join(MODEL_DIR, folder_name)
print(path)

save_displayed_config(config,log_dir = path, train_path = dataset_train.particle_masks_dir, val_path =dataset_val.particle_masks_dir )

Now, we can use that last model that was trained, and retrain it with our first set of manual segmentations.

We can tell its to start with the weights from that model using either find_last or specifying the path. 

In [None]:
#if you want to give it a specific path to train from:
manual_path = os.path.join(DATA_DIR, "logs/coco_M1e1_part/mask_rcnn_spectra_0034.h5")

init_with = "last"  # imagenet, coco, or last


if init_with =="manual":
    #load the weights from specified path
    model.load_weights(manual_path, by_name=True)

elif init_with == "last":
    # Load the last model you trained and continue training
    model.load_weights(model.find_last(), by_name=True)

In [None]:
print_loaded_datasets(datasets)

In [None]:
#assign new training and validaton datasets

dataset_train = datasets.get('D1e1_train', None)
dataset_val = datasets.get('D1e1_val', None)

In [None]:
# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)



In [None]:
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
#model.train(cluster_dataset_train, cluster_dataset_val, 
#            learning_rate=config.LEARNING_RATE, 
#            epochs=100, 
#            layers='heads')

#get epoch that it stopped at
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=100, 
            layers='heads',
            )

In [None]:
# Fine tune all layers
# Passing layers="all" trains all layers. You can also 
# pass a regular expression to select which layers to
# train by name pattern.

#model.train(cluster_dataset_train, cluster_dataset_val, 
  #          learning_rate=config.LEARNING_RATE / 5, #changed to /5
   #         epochs=200, 
    #        layers="all")
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE / 10, #changed to /5
            epochs=200, 
            layers="all")

In [None]:
name = model.find_last()
print(name)
model_name=name.partition('/logs/')[2]
print(model_name)
folder_name = model_name.split(os.sep)[0]
print(folder_name)

save_displayed_config(config,log_dir = path, train_path = dataset_train.particle_masks_dir, val_path =dataset_val.particle_masks_dir )

Now, we repeat the same process with the 2nd set of manual segmentations, D2e1_train, and D2e1_val

In [None]:
#if you want to give it a specific path to train from:
manual_path = os.path.join(DATA_DIR, "logs/coco_M1e1_part/mask_rcnn_spectra_0034.h5")

init_with = "last"  # imagenet, coco, or last


if init_with =="manual":
    #load the weights from specified path
    model.load_weights(manual_path, by_name=True)

elif init_with == "last":
    # Load the last model you trained and continue training
    model.load_weights(model.find_last(), by_name=True)

In [None]:
print_loaded_datasets(datasets)

In [None]:
# Create model in training mode
model = modellib.MaskRCNN(mode="training", config=config,
                          model_dir=MODEL_DIR)

In [None]:
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
#model.train(cluster_dataset_train, cluster_dataset_val, 
#            learning_rate=config.LEARNING_RATE, 
#            epochs=100, 
#            layers='heads')

#get epoch that it stopped at
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE, 
            epochs=100, 
            layers='heads',
            )



In [None]:
# Fine tune all layers
# Passing layers="all" trains all layers. You can also 
# pass a regular expression to select which layers to
# train by name pattern.

#model.train(cluster_dataset_train, cluster_dataset_val, 
  #          learning_rate=config.LEARNING_RATE / 5, #changed to /5
   #         epochs=200, 
    #        layers="all")
model.train(dataset_train, dataset_val, 
            learning_rate=config.LEARNING_RATE / 10, #changed to /5
            epochs=200, 
            layers="all")

In [None]:
name = model.find_last()
print(name)
model_name=name.partition('/logs/')[2]
print(model_name)
folder_name = model_name.split(os.sep)[0]
print(folder_name)

save_displayed_config(config,log_dir = path, train_path = dataset_train.particle_masks_dir, val_path =dataset_val.particle_masks_dir )