In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import cv2
import pytorch_lightning as pl
import time

from torchsummary import summary
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data.sampler import SequentialSampler
from PIL import Image, ImageDraw
from sklearn.model_selection import train_test_split

import transforms as T
import utils

In [2]:
# TORCH_DEVICE = 'mps' # there is currently a bug: https://github.com/pytorch/pytorch/issues/78915
TORCH_DEVICE = 'cpu'
CKPT_PATH = './pre_trained_models/Zoobot_Clumps_Resnet/'
CKPT_NAME = 'Zoobot_Clump_Classifier_36.pth'

DATA_PATH = '../RPN_Backbone_GZ2/Data/'
IMAGE_PATH = DATA_PATH + 'real_pngs/'

MODEL_DIR = './models/Pytorch_Resnet_Zoobot_Clumps/'
LOG_DIR = MODEL_DIR + 'train'
MODEL_NAME = 'FRCNN_Resnet_Zoobot_Clumps'

BATCH_SIZE = 4
CUTOUT = (100, 100, 300, 300)
CUTOUT_ARRAY = np.array([100, 300, 100, 300])

In [4]:
# initialise Tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=LOG_DIR)

In [3]:
# read train and validation dfs
imageGroups_train = pd.read_pickle('./imageGroups_train.pkl')
imageGroups_valid = pd.read_pickle('./imageGroups_valid.pkl')

In [4]:
def get_transform(train):
    augs = []

    augs.append(T.PILToTensor())
    augs.append(T.ConvertImageDtype(torch.float))
    
    if train:
        augs.append(T.RandomHorizontalFlip(0.5))
    
    return T.Compose(augs)

In [5]:
# Dataset class and defined transformations
import SDSSGalaxyDataset

dataset_train = SDSSGalaxyDataset.SDSSGalaxyDataset(
    dataframe=imageGroups_train,
    image_dir=IMAGE_PATH,
    cutout=CUTOUT,
    colour=True,
    transforms=get_transform(train=True)
)
dataset_validation = SDSSGalaxyDataset.SDSSGalaxyDataset(
    dataframe=imageGroups_valid,
    image_dir=IMAGE_PATH,
    cutout=CUTOUT,
    colour=True,
    transforms=get_transform(train=False)
)

train_data_loader = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    collate_fn=utils.collate_fn
)
valid_data_loader = torch.utils.data.DataLoader(
    dataset_validation, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=4,
    collate_fn=utils.collate_fn
)

print("Count: {} are training and {} validation".format(len(dataset_train), len(dataset_validation)))

Count: 35542 are training and 8886 validation


In [6]:
def get_model(num_classes=2, trainable_layers=0):
    import copy_zoobot_weights

    # load an object detection model pre-trained for Zoobot
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        weights='COCO_V1',
        trainable_backbone_layers=3
    )

    model = copy_zoobot_weights.copy_Zoobot_clumps_weights_to_Resnet(
        model=model, 
        ckpt_path=CKPT_PATH + CKPT_NAME,
        device=TORCH_DEVICE,
        trainable_layers=trainable_layers
    )
    
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new on
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)
   
    return model

In [7]:
# Training
NUM_EPOCHS = 120

# get the model, all pretrained layers from the backbone CNN are freezed
frcnn_model = get_model(num_classes=5, trainable_layers=0)

# move model to the right device
frcnn_model = frcnn_model.to(TORCH_DEVICE)

# construct an optimizer
params = [p for p in frcnn_model.parameters() if p.requires_grad]
# optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by # 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)



In [8]:
for name, parameter in frcnn_model.named_parameters():
    if parameter.requires_grad:
        print(name)

backbone.fpn.inner_blocks.0.0.weight
backbone.fpn.inner_blocks.0.0.bias
backbone.fpn.inner_blocks.1.0.weight
backbone.fpn.inner_blocks.1.0.bias
backbone.fpn.inner_blocks.2.0.weight
backbone.fpn.inner_blocks.2.0.bias
backbone.fpn.inner_blocks.3.0.weight
backbone.fpn.inner_blocks.3.0.bias
backbone.fpn.layer_blocks.0.0.weight
backbone.fpn.layer_blocks.0.0.bias
backbone.fpn.layer_blocks.1.0.weight
backbone.fpn.layer_blocks.1.0.bias
backbone.fpn.layer_blocks.2.0.weight
backbone.fpn.layer_blocks.2.0.bias
backbone.fpn.layer_blocks.3.0.weight
backbone.fpn.layer_blocks.3.0.bias
rpn.head.conv.0.0.weight
rpn.head.conv.0.0.bias
rpn.head.cls_logits.weight
rpn.head.cls_logits.bias
rpn.head.bbox_pred.weight
rpn.head.bbox_pred.bias
roi_heads.box_head.fc6.weight
roi_heads.box_head.fc6.bias
roi_heads.box_head.fc7.weight
roi_heads.box_head.fc7.bias
roi_heads.box_predictor.cls_score.weight
roi_heads.box_predictor.cls_score.bias
roi_heads.box_predictor.bbox_pred.weight
roi_heads.box_predictor.bbox_pred.bia

In [23]:
from engine import train_one_epoch, evaluate

for epoch in range(NUM_EPOCHS):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(
        frcnn_model, 
        optimizer, 
        train_data_loader, 
        TORCH_DEVICE, 
        epoch, 
        print_freq=10,
        scaler=None,
        tb_writer=writer
        # tb_writer=None
    )
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    coco_evaluator = evaluate(
        frcnn_model, 
        valid_data_loader, 
        device=TORCH_DEVICE
    )
    for iou_type, coco_eval in coco_evaluator.coco_eval.items():
        writer.add_scalar("AP/IoU/0.50-0.95/all/100", coco_eval.stats[0], epoch)
        writer.add_scalar("AP/IoU/0.50/all/100", coco_eval.stats[1], epoch)
        writer.add_scalar("AP/IoU/0.75/all/100", coco_eval.stats[2], epoch)
        writer.add_scalar("AP/IoU/0.50-0.95/small/100", coco_eval.stats[3], epoch)
        writer.add_scalar("AP/IoU/0.50-0.95/medium/100", coco_eval.stats[4], epoch)
        writer.add_scalar("AP/IoU/0.50-0.95/large/100", coco_eval.stats[5], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/all/1", coco_eval.stats[6], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/all/10", coco_eval.stats[7], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/all/100", coco_eval.stats[8], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/small/100", coco_eval.stats[9], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/medium/100", coco_eval.stats[10], epoch)
        writer.add_scalar("AR/IoU/0.50-0.95/large/100", coco_eval.stats[11], epoch)

    model_save_path = MODEL_DIR + MODEL_NAME + '_' + str(epoch+1) + '.pth'
    torch.save(frcnn_model.state_dict(), model_save_path)

Epoch: [0]  [ 0/20]  eta: 0:02:35  lr: 0.000054  loss: 72.0398 (72.0398)  loss_classifier: 18.8892 (18.8892)  loss_box_reg: 0.0179 (0.0179)  loss_objectness: 53.0913 (53.0913)  loss_rpn_box_reg: 0.0414 (0.0414)  time: 7.7972  data: 1.4673
Epoch: [0]  [10/20]  eta: 0:01:00  lr: 0.000579  loss: 5.0756 (18.4802)  loss_classifier: 0.2344 (3.4928)  loss_box_reg: 0.0331 (0.0449)  loss_objectness: 4.7253 (14.8405)  loss_rpn_box_reg: 0.1058 (0.1022)  time: 6.0178  data: 0.1348
Epoch: [0]  [19/20]  eta: 0:00:05  lr: 0.001000  loss: 5.0756 (15.2605)  loss_classifier: 0.2127 (2.5701)  loss_box_reg: 0.0309 (0.1189)  loss_objectness: 4.7253 (12.3517)  loss_rpn_box_reg: 0.0876 (0.2198)  time: 5.7130  data: 0.0749
Epoch: [0] Total time: 0:02:14 (6.7135 s / it)
creating index...
index created!




Test:  [0/5]  eta: 0:00:21  model_time: 3.1385 (3.1385)  evaluator_time: 0.0011 (0.0011)  time: 4.3370  data: 1.1971
Test:  [4/5]  eta: 0:00:03  model_time: 2.9361 (2.8143)  evaluator_time: 0.0006 (0.0007)  time: 3.0552  data: 0.2398
Test: Total time: 0:00:35 (7.0575 s / it)
Averaged stats: model_time: 2.9361 (2.8143)  evaluator_time: 0.0006 (0.0007)
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall  



Test:  [0/5]  eta: 0:00:22  model_time: 3.1973 (3.1973)  evaluator_time: 0.0012 (0.0012)  time: 4.4230  data: 1.2243
Test:  [4/5]  eta: 0:00:03  model_time: 3.1226 (2.9662)  evaluator_time: 0.0006 (0.0007)  time: 3.2124  data: 0.2452
Test: Total time: 0:00:36 (7.2143 s / it)
Averaged stats: model_time: 3.1226 (2.9662)  evaluator_time: 0.0006 (0.0007)
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.000
 Average Recall  