In [6]:
import torchvision
import torch
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (cloud) + background

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

In [2]:
"""
Returns an instance of a MaskRCNN model where the head has been replaced with an untrained version
"""
def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

In [3]:
from torchvision.transforms import v2 as T

"""
Returns a list of transformations to apply to the data. Most of these are to get them to the right shape for the model input, but if the train
boolean is true then it will also add augmentations.
"""
def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [9]:
from engine import train_one_epoch, evaluate
from dataset import CloudsDataset
import utils

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - cloud and person
num_classes = 2

# use our dataset and defined transformations
dataset = CloudsDataset('dataset', get_transform(train=True))
dataset_test = CloudsDataset('dataset', get_transform(train=False))

# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()

# ! ! ! ! ! ! ! ! !
#TODO: Change these numbers to reflect the actual split
dataset = torch.utils.data.Subset(dataset, indices[:-150])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-150:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=utils.collate_fn
)

# get the model using our helper function
model = get_model_instance_segmentation(num_classes)

# move model to the right device (CPU)
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it just for 2 epochs
num_epochs = 2

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [ 0/33]  eta: 0:09:10  lr: 0.000161  loss: 5.5140 (5.5140)  loss_classifier: 0.5161 (0.5161)  loss_box_reg: 0.0041 (0.0041)  loss_mask: 4.7781 (4.7781)  loss_objectness: 0.1811 (0.1811)  loss_rpn_box_reg: 0.0346 (0.0346)  time: 16.6818  data: 3.8273
Epoch: [0]  [10/33]  eta: 0:05:59  lr: 0.001722  loss: 1.9034 (2.5347)  loss_classifier: 0.1730 (0.2318)  loss_box_reg: 0.0253 (0.0253)  loss_mask: 1.2423 (1.9968)  loss_objectness: 0.2014 (0.2505)  loss_rpn_box_reg: 0.0140 (0.0304)  time: 15.6266  data: 0.3497
Epoch: [0]  [20/33]  eta: 0:03:28  lr: 0.003283  loss: 0.9116 (1.7191)  loss_classifier: 0.0883 (0.1573)  loss_box_reg: 0.0253 (0.0258)  loss_mask: 0.6374 (1.3444)  loss_objectness: 0.1037 (0.1717)  loss_rpn_box_reg: 0.0085 (0.0200)  time: 15.9700  data: 0.0024
Epoch: [0]  [30/33]  eta: 0:00:47  lr: 0.004844  loss: 0.7660 (1.4043)  loss_classifier: 0.0681 (0.1303)  loss_box_reg: 0.0310 (0.0365)  loss_mask: 0.5631 (1.0735)  loss_objectness: 0.0621 (0.1428)  loss_rpn_box_re

KeyboardInterrupt: 

In [None]:
#Save the model
PATH = "model1.pt"
torch.save(model.state_dict(), PATH)

In [None]:
#Load the model for inference
model = get_model_instance_segmentation(num_classes)
model.load_state_dict(torch.load(PATH))
model.eval()