<a href="https://colab.research.google.com/github/yeqinghuang516/UCSD-ECE285-Object-Detection-Using-Deep-Learning/blob/master/Mask%20RCNN/MaskRCNN_Train_Sageband.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## This notebook is suggested to run on Colab

## Download Github Repo

In [0]:
%%shell
# Download TorchVision repo to use some files from references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.3.0

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

# Download files from our own project repo
cd ..
git clone https://github.com/yeqinghuang516/UCSD-ECE285-Object-Detection-Using-Deep-Learning.git

In [0]:
import os
import sys
sys.path.append('/content/UCSD-ECE285-Object-Detection-Using-Deep-Learning/Mask RCNN/')
sys.path.append('/content/')
import torch
from PIL import Image
import torchvision as tv
from engine import train_one_epoch, evaluate
import utils
from dataset import *
from model import *

## Define some parameters for training

In [0]:
num_epochs = 200
evaluation_interval = 2
class_names = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## Download and initiate dataset & dataloader

In [0]:
os.makedirs('data', exist_ok = True)
root = 'data'
dataset = VOCDataset(root, image_set = 'train', transforms = get_transform(train = True))
testset = VOCDataset(root, image_set = 'val', transforms = get_transform(train = False))

1999642624it [02:50, 13235573.42it/s]                                

download =  False


In [0]:
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size= 8, shuffle=True, num_workers=8,
    collate_fn=utils.collate_fn)

test_loader = torch.utils.data.DataLoader(
    testset, batch_size= 2, shuffle=False, num_workers=8,
    collate_fn=utils.collate_fn)

## Initiate model and load previous checkpoint (if available)

In [0]:
os.makedirs("checkpoints", exist_ok=True)

# our dataset has 21 classes, 20 object class + 1 background
num_classes = 21

# get the model using our helper function
model = MaskRCNN(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=1e-3,momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by 0.5 every 10 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
cur_epoch = 0

# load previous checkpoint, will not load if 'pretrained_weight' is not valid
pretrained_weight = '/content/gdrive/My Drive/MaskRCNN/maskrcnn_99.pth'

if os.path.isfile(pretrained_weight):  
  checkpoint = torch.load(pretrained_weight)
  model.load_state_dict(checkpoint['net'])
  optimizer.load_state_dict(checkpoint['optimizer'])
  lr_scheduler.load_state_dict(checkpoint['scheduler'])
  cur_epoch = checkpoint['epoch'] + 1
  print('load state dict')

## Start Training

In [0]:
for epoch in range(cur_epoch, num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    if epoch % evaluation_interval == 0:
      evaluate(model, test_loader, device=device)
    state_dict = {'net': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': lr_scheduler.state_dict(), 'epoch': epoch}
    torch.save(state_dict, f"checkpoints/maskrcnn_%d.pth" % epoch)