# Mask R-CNN Training Script

This notebook demonstrates how to train a Mask R-CNN model using custom training code. It includes:
1. Model creation with optional pretrained weights.
2. A main function to handle training logic, such as:
   - Data loading (COCO or VOC format)
   - Distributed training
   - Mixed-precision training
   - Learning rate scheduling
   - Evaluation after each epoch
   - Saving checkpoints

## Contents
1. **Imports**: Necessary libraries and modules.
2. **Create Model Function**: Builds Mask R-CNN with a ResNet50-FPN backbone.
3. **Main Function**: The core training loop.
4. **Argument Parsing**: Command-line arguments for customizing training behavior.

### Notes
1. Make sure you have the required environment with `torch`, `torchvision`, and other dependencies installed.
2. Ensure the supporting scripts/files (`transforms.py`, `my_dataset_coco.py`, `my_dataset_voc.py`, `backbone.py`, `network_files.py`, `train_utils` directory, etc.) are in your Python path or in the same directory.
3. If running directly in Jupyter, you might need to manually supply or adjust the argument settings.


In [ ]:
import time
import os
import datetime

import torch
from torchvision.ops.misc import FrozenBatchNorm2d

import transforms
from my_dataset_coco import CocoDetection
from my_dataset_voc import VOCInstances
from backbone import resnet50_fpn_backbone
from network_files import MaskRCNN
import train_utils.train_eval_utils as utils
from train_utils import GroupedBatchSampler, create_aspect_ratio_groups,
                       init_distributed_mode, save_on_master, mkdir


## Create Model Function
This function builds a Mask R-CNN model with a ResNet50-FPN backbone. By default, it can load COCO-pretrained weights (unless `load_pretrain_weights` is set to `False`).

In [ ]:
def create_model(num_classes, load_pretrain_weights=True):
    """
    Create a Mask R-CNN model with a ResNet50-FPN backbone.
    Args:
        num_classes (int): Number of classes (excluding background) + 1 = total classes.
        load_pretrain_weights (bool): Whether to load COCO-pretrained weights.
    Returns:
        model (torch.nn.Module): The Mask R-CNN model.
    Note:
        If GPU memory is limited, consider using FrozenBatchNorm2d instead of nn.BatchNorm2d.
    """
    # Example: freeze layers or set trainable layers as needed.
    # Here, we specify trainable_layers=3, meaning the last 3 layers (layer2, layer3, layer4) are trainable.
    backbone = resnet50_fpn_backbone(pretrain_path="resnet50.pth", trainable_layers=3)
    model = MaskRCNN(backbone, num_classes=num_classes)

    if load_pretrain_weights:
        # Load the COCO-pretrained weights.
        # These weights come from the model trained on COCO dataset.
        weights_dict = torch.load("./maskrcnn_resnet50_fpn_coco.pth", map_location="cpu")
        # Remove the final prediction layers' parameters (box_predictor, mask_fcn_logits),
        # so they won't overwrite our desired layer sizes.
        for k in list(weights_dict.keys()):
            if ("box_predictor" in k) or ("mask_fcn_logits" in k):
                del weights_dict[k]

        # Load partial weights
        print(model.load_state_dict(weights_dict, strict=False))

    return model


## Main Training Function
This function:
1. Initializes distributed training (if applicable).
2. Creates data transforms for training and validation.
3. Loads COCO or VOC datasets.
4. Sets up data loaders with possible aspect ratio grouping.
5. Creates the Mask R-CNN model and loads pretrained weights if requested.
6. Handles optimizer, learning rate scheduler, and mixed-precision training (AMP).
7. Optionally resumes training from a checkpoint.
8. Trains for the specified number of epochs, evaluating after each epoch.
9. Logs and saves model checkpoints.


In [ ]:
def main(args):
    # Initialize distributed training if needed
    init_distributed_mode(args)
    print(args)

    device = torch.device(args.device)

    # File names for saving COCO evaluation results
    now = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    det_results_file = f"det_results{now}.txt"
    seg_results_file = f"seg_results{now}.txt"

    # Create data transforms
    print("Loading data")
    data_transform = {
        "train": transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(0.5)
        ]),
        "val": transforms.Compose([
            transforms.ToTensor()
        ])
    }

    COCO_root = args.data_path  # Path to COCO dataset root
    # Example for VOC:
    # data_root = "/path/to/VOCdevkit"

    # Load training dataset (COCO or VOC)
    train_dataset = CocoDetection(COCO_root, "train", data_transform["train"])
    # Alternatively:
    # train_dataset = VOCInstances(data_root, year="2012", txt_name="train.txt")

    # Load validation dataset (COCO or VOC)
    val_dataset = CocoDetection(COCO_root, "val", data_transform["val"])
    # Alternatively:
    # val_dataset = VOCInstances(data_root, year="2012", txt_name="val.txt")

    print("Creating data loaders")
    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_dataset)
        test_sampler = torch.utils.data.SequentialSampler(val_dataset)

    # Aspect ratio grouping
    if args.aspect_ratio_group_factor >= 0:
        group_ids = create_aspect_ratio_groups(train_dataset, k=args.aspect_ratio_group_factor)
        train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
    else:
        train_batch_sampler = torch.utils.data.BatchSampler(
            train_sampler, args.batch_size, drop_last=True)

    data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_sampler=train_batch_sampler,
        num_workers=args.workers,
        collate_fn=train_dataset.collate_fn
    )

    data_loader_test = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        sampler=test_sampler,
        num_workers=args.workers,
        collate_fn=train_dataset.collate_fn
    )

    print("Creating model")
    # Create model; num_classes includes background, so we add 1
    model = create_model(num_classes=args.num_classes + 1, load_pretrain_weights=args.pretrain)
    model.to(device)

    # Convert BatchNorm to SyncBatchNorm if requested (for multi-GPU)
    if args.distributed and args.sync_bn:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.gpu]
        )
        model_without_ddp = model.module

    # Collect trainable parameters
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(
        params,
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    # AMP scaler for mixed precision
    scaler = torch.cuda.amp.GradScaler() if args.amp else None

    # Example: MultiStepLR for learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=args.lr_steps,
        gamma=args.lr_gamma
    )

    # Resume from a checkpoint if provided
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch'] + 1
        if args.amp and "scaler" in checkpoint:
            scaler.load_state_dict(checkpoint["scaler"])

    # If test_only is specified, run evaluation and exit
    if args.test_only:
        utils.evaluate(model, data_loader_test, device=device)
        return

    # Lists for tracking metrics
    train_loss = []
    learning_rate = []
    val_map = []

    print("Start training")
    start_time = time.time()

    for epoch in range(args.start_epoch, args.epochs):
        # For distributed training, set epoch for the sampler
        if args.distributed:
            train_sampler.set_epoch(epoch)

        # Train for one epoch
        mean_loss, lr = utils.train_one_epoch(
            model,
            optimizer,
            data_loader,
            device,
            epoch,
            args.print_freq,
            warmup=True,
            scaler=scaler
        )

        # Step the scheduler to update the learning rate
        lr_scheduler.step()

        # Evaluate model after each epoch
        det_info, seg_info = utils.evaluate(model, data_loader_test, device=device)

        # Only record and save on the main process
        if args.rank in [-1, 0]:
            train_loss.append(mean_loss.item())
            learning_rate.append(lr)
            # det_info[1] often corresponds to mAP (depending on how evaluate is implemented)
            val_map.append(det_info[1])

            # Write detection results to file
            with open(det_results_file, "a") as f:
                result_info = [f"{i:.4f}" for i in det_info + [mean_loss.item()]] + [f"{lr:.6f}"]
                txt = "epoch:{} {}".format(epoch, '  '.join(result_info))
                f.write(txt + "\n")

            # Write segmentation results to file
            with open(seg_results_file, "a") as f:
                result_info = [f"{i:.4f}" for i in seg_info + [mean_loss.item()]] + [f"{lr:.6f}"]
                txt = "epoch:{} {}".format(epoch, '  '.join(result_info))
                f.write(txt + "\n")

        # Save checkpoints if an output directory is specified
        if args.output_dir:
            save_files = {
                'model': model_without_ddp.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'args': args,
                'epoch': epoch
            }
            if args.amp:
                save_files["scaler"] = scaler.state_dict()

            save_on_master(
                save_files,
                os.path.join(args.output_dir, f'model_{epoch}.pth')
            )

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

    # Plot curves only on the main process
    if args.rank in [-1, 0]:
        if len(train_loss) != 0 and len(learning_rate) != 0:
            from plot_curve import plot_loss_and_lr
            plot_loss_and_lr(train_loss, learning_rate)

        if len(val_map) != 0:
            from plot_curve import plot_map
            plot_map(val_map)


## Argument Parsing and Entry Point
Here we define command-line arguments for training. In typical use, you'd run this script from a terminal, e.g.:
```
python train_script.py --data-path /data/coco2017 --epochs 30 --batch-size 8
```
In a notebook environment, you can simulate arguments by providing `args` manually or adjusting them as needed.

In [ ]:
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Mask R-CNN Training")

    # Path to the root of your dataset, e.g. COCO
    parser.add_argument('--data-path', default='/data/coco2017', help='Path to dataset root')
    # Device (e.g. 'cuda' or 'cpu')
    parser.add_argument('--device', default='cuda', help='Device (cuda or cpu)')
    # Number of classes (excluding background). Example: 80 for COCO.
    parser.add_argument('--num-classes', default=90, type=int, help='Number of object classes (excl. background)')
    # Batch size per GPU
    parser.add_argument('-b', '--batch-size', default=4, type=int,
                        help='Images per GPU; total batch size is num_GPU x batch_size')
    # Starting epoch if resuming training
    parser.add_argument('--start_epoch', default=0, type=int, help='Start epoch')
    # Total training epochs
    parser.add_argument('--epochs', default=26, type=int, metavar='N',
                        help='Number of total epochs to run')
    # Number of data loading workers
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='Number of data loading workers (default: 4)')
    # Initial learning rate
    parser.add_argument('--lr', default=0.005, type=float,
                        help='Initial learning rate (adjust according to GPU count and batch size)')
    # Momentum for SGD
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='Momentum')
    # Weight decay for SGD
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='Weight decay (default: 1e-4)',
                        dest='weight_decay')
    # LR scheduler step size (if using StepLR)
    parser.add_argument('--lr-step-size', default=8, type=int,
                        help='Decrease LR every step-size epochs (StepLR)')
    # Milestones for MultiStepLR
    parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int,
                        help='Decrease LR at these milestone epochs (MultiStepLR)')
    # LR gamma for the scheduler
    parser.add_argument('--lr-gamma', default=0.1, type=float,
                        help='LR multiply factor at each milestone (default=0.1)')
    # Print frequency
    parser.add_argument('--print-freq', default=50, type=int, help='Print frequency')
    # Output directory for saving checkpoints
    parser.add_argument('--output-dir', default='./multi_train', help='Directory to save outputs')
    # Resume from a checkpoint
    parser.add_argument('--resume', default='', help='Path to the checkpoint to resume from')
    # Aspect ratio group factor
    parser.add_argument('--aspect-ratio-group-factor', default=3, type=int,
                        help='Aspect ratio group factor; set <0 to disable')
    # If true, only evaluate and exit
    parser.add_argument('--test-only', action='store_true', help='Only run evaluation')

    # Distributed training parameters
    parser.add_argument('--world-size', default=4, type=int,
                        help='Number of distributed processes')
    parser.add_argument('--dist-url', default='env://',
                        help='URL used to set up distributed training')
    parser.add_argument('--sync-bn', dest='sync_bn', type=bool, default=False,
                        help='Use synchronized batch norm')
    parser.add_argument('--pretrain', type=bool, default=True,
                        help='Whether to load COCO pretrained weights')
    parser.add_argument('--amp', default=False, action='store_true',
                        help='Use torch.cuda.amp for mixed precision training')

    args = parser.parse_args()

    # Create output directory if it does not exist
    if args.output_dir:
        mkdir(args.output_dir)

    # Run the main function
    main(args)
