In [2]:
import logging
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import cv2
import pytorch_lightning as pl
import time

from torchsummary import summary
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data.sampler import SequentialSampler
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

import transforms as T
import utils

In [3]:
# TORCH_DEVICE = 'mps' # there is currently a bug: https://github.com/pytorch/pytorch/issues/78915
TORCH_DEVICE = 'cpu'
# TORCH_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DATA_PATH = './Data/'
IMAGE_PATH = '../RPN_Backbone_GZ2/Data/real_pngs/'
PRE_TRAINED_MODELS_PATH = '../Faster_R-CNN_GZ2/pre_trained_models/'

NUM_CLASSES = 3 # 2 classes (clump, odd clump) + background
NUM_EPOCHS = 40

BATCH_SIZE = 8
CUTOUT = (100, 100, 300, 300)
CUTOUT_ARRAY = np.array([100, 300, 100, 300])

## Model definitions

In [4]:
def get_model(model_name, num_classes=3, trainable_layers=0):
    """
    Creates the model object for Faster R-CNN

    Args:
      model_name (str): 'Zoobot_pre_trained', 'Zoobot_fine_tuned', 'Resnet_Imagenet'
      num_classes (int): number of classes the detector should outpub, 
        must include a class for the background
      trainable_layers (int): number of blocks of the classification backbone,
        counted from top, that should be made trainable
        e.g. 0 - all blocks fixed, 1 - 'backbone.body.conv1' trainable

    Returns:
      FasterRCNN model

    """
    import copy_zoobot_weights

    # load an object detection model pre-trained on COCO, all layers fixed
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        weights_backbone='IMAGENET1K_V1',
        trainable_backbone_layers=0
    )

    try:
        if model_name == 'Zoobot_pre_trained':
            zoobot_ckpt_path = PRE_TRAINED_MODELS_PATH + 'Zoobot_Resnet_Torchvision/epoch=20-step=6552.ckpt'
            model = copy_zoobot_weights.copy_Zoobot_weights_to_Resnet(
                model=model, 
                ckpt_path=zoobot_ckpt_path,
                device=TORCH_DEVICE,
                trainable_layers=trainable_layers
            )
            print('Zoobot pre-trained loaded.')
    
        elif model_name == 'Zoobot_fine_tuned':
            zoobot_ckpt_path = PRE_TRAINED_MODELS_PATH + 'Zoobot_Clumps_Resnet/Zoobot_Clump_Classifier_36.pth'
            model = copy_zoobot_weights.copy_Zoobot_clumps_weights_to_Resnet(
                model=model, 
                ckpt_path=zoobot_ckpt_path,
                device=TORCH_DEVICE,
                trainable_layers=trainable_layers
            )
            print('Zoobot fine-tuned for clumps loaded.')
        
        elif model_name == 'Resnet_Imagenet':
            print('ResNet initialised with Imagenet weights loaded.')
    
        else:
            print('None of the valid models chosen.')
    
    except Exception as e:
        print(str(e))

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new on
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
   
    return model

In [5]:
# create model dict
model_dict = {
    'resnet' : {
        'model_name' : 'Resnet_Imagenet',
        'description' : 'ResNet50 initialised with default weights IMAGENET1K_V1',
        'log_dir' : './models/FRCNN_Resnet_Imagenet/',
    },
    'zoobot_clumps' : {
        'model_name' : 'Zoobot_fine_tuned',
        'description' : 'ResNet50 initialised with weights from a Zoobot classifier fine-tuned for clumps',
        'log_dir' : './models/FRCNN_Resnet_Zoobot_Clumps/',
    },
    'zoobot' : {
        'model_name' : 'Zoobot_pre_trained',
        'description' : 'ResNet50 initialised with weights from Zoobot, all layers kept fix for training',
        'log_dir' : './models/FRCNN_Resnet_Zoobot/',
    },
}

## Datasets

In [6]:
def get_transform(train):
    augs = []

    augs.append(T.PILToTensor())
    augs.append(T.ConvertImageDtype(torch.float))
    
    if train:
        augs.append(T.RandomHorizontalFlip(0.5))
    
    return T.Compose(augs)


def get_dataloader_dict(train_df, val_df, image_dir, cutout, is_colour):
    import SDSSGalaxyDataset
    image_datasets = {}

    image_datasets['train'] = SDSSGalaxyDataset.SDSSGalaxyDataset(
        dataframe=train_df,
        image_dir=image_dir,
        cutout=cutout,
        colour=is_colour,
        transforms=get_transform(train=True)
    )
    image_datasets['val'] = SDSSGalaxyDataset.SDSSGalaxyDataset(
        dataframe=val_df,
        image_dir=image_dir,
        cutout=cutout,
        colour=is_colour,
        transforms=get_transform(train=False)
    )
    
    return {x: torch.utils.data.DataLoader(
        image_datasets[x], 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=4,
        collate_fn=utils.collate_fn
    ) for x in ['train', 'val']}

In [7]:
# read the image-to-run relation
df_runs = pd.read_pickle(DATA_PATH + 'image_ids_for_runs.pkl')

# read the full set with bounding boxes
df = pd.read_pickle(DATA_PATH + 'clump_scout_full_set.pkl').rename(columns={'local_id': 'local_ids'})
df = df[['zoo_id', 'local_ids', 'label', 'label_text', 'x1', 'x2', 'y1', 'y2']]
df['local_ids'] = df['local_ids'].astype(int)
df['label'] = df['label'].astype(int)

In [8]:
runs = df_runs['run'].unique()
groups = ['Training', 'Validation']

In [9]:
df_data = (
    df_runs[df_runs['group'].isin(groups)]
    .merge(df, how='inner', on='zoo_id')
)

In [None]:
# Test dataloader
run = 1

dataloader_dict = get_dataloader_dict(
    train_df=df_data[(df_data['run']==run) & (df_data['group']=='Training')],
    val_df=df_data[(df_data['run']==run) & (df_data['group']=='Validation')],
    image_dir=IMAGE_PATH,
    cutout=CUTOUT,
    is_colour=True
)

images, targets = next(iter(dataloader_dict['train']))
images = list(image.to(TORCH_DEVICE) for image in images)
targets = [{k: v.to(TORCH_DEVICE) for k, v in t.items()} for t in targets]

for i in range(4):
    boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
    sample = images[i].permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(5, 5))
    sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
    
    for box in boxes:
        cv2.rectangle(sample,  # the image is in RGB, convert to BGR for cv2 annotations
                      (box[0], box[1]),
                      (box[2], box[3]),
                      (0, 0, 255), 1)
    plt.imshow(sample[:, :, ::-1])
    # plt.axis('off')

In [None]:
# Training
from engine import train_one_epoch, evaluate

for run in runs:  #for run in range(17,40,1):
    print('Executing run: {}'.format(run))
    
    # load data
    dataloader_dict = get_dataloader_dict(
        train_df=df_data[(df_data['run']==run) & (df_data['group']=='Training')],
        val_df=df_data[(df_data['run']==run) & (df_data['group']=='Validation')],
        image_dir=IMAGE_PATH,
        cutout=CUTOUT,
        is_colour=True
    )

    for model, model_data in model_dict.items():
        # initialise Tensorboard writer
        tb_log_dir = model_data['log_dir'] + 'run={}/'.format(run) + 'train'
        writer = SummaryWriter(log_dir=tb_log_dir)

        # get the model
        frcnn_model = get_model(
            model_name=model_data['model_name'],
            num_classes=NUM_CLASSES,
            trainable_layers=0
        )
        
        # move model to the right device
        frcnn_model = frcnn_model.to(TORCH_DEVICE)
        
        # construct an optimizer
        params = [p for p in frcnn_model.parameters() if p.requires_grad]
        # optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
        optimizer = torch.optim.Adam(params, lr=0.0001, weight_decay=0.00005)
        
        # and a learning rate scheduler which decreases the learning rate by # 10x every 3 epochs
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

        # Looping through epochs
        for epoch in range(NUM_EPOCHS):
            # train for one epoch, printing every 10 iterations
            train_one_epoch(
                frcnn_model, 
                optimizer, 
                dataloader_dict['train'], 
                TORCH_DEVICE, 
                epoch, 
                print_freq=10,
                scaler=None,
                tb_writer=writer
                # tb_writer=None
            )
            
            # update the learning rate
            lr_scheduler.step()
            
            # evaluate on the test dataset
            coco_evaluator = evaluate(
                frcnn_model, 
                dataloader_dict['val'], 
                device=TORCH_DEVICE
            )
            for iou_type, coco_eval in coco_evaluator.coco_eval.items():
                writer.add_scalar("AP/IoU/0.50-0.95/all/100", coco_eval.stats[0], epoch)
                writer.add_scalar("AP/IoU/0.50/all/100", coco_eval.stats[1], epoch)
                writer.add_scalar("AP/IoU/0.75/all/100", coco_eval.stats[2], epoch)
                writer.add_scalar("AP/IoU/0.50-0.95/small/100", coco_eval.stats[3], epoch)
                writer.add_scalar("AP/IoU/0.50-0.95/medium/100", coco_eval.stats[4], epoch)
                writer.add_scalar("AP/IoU/0.50-0.95/large/100", coco_eval.stats[5], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/all/1", coco_eval.stats[6], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/all/10", coco_eval.stats[7], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/all/100", coco_eval.stats[8], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/small/100", coco_eval.stats[9], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/medium/100", coco_eval.stats[10], epoch)
                writer.add_scalar("AR/IoU/0.50-0.95/large/100", coco_eval.stats[11], epoch)
        
            if (epoch+1) % 20 == 0:
                model_save_path = model_data['log_dir'] + 'run={}/'.format(run) + model_data['model_name'] + '_{}.pth'.format(epoch+1)
                torch.save(frcnn_model.state_dict(), model_save_path)