In [None]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import cv2
import pytorch_lightning as pl
import time

from torchsummary import summary
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data.sampler import SequentialSampler
from PIL import Image
from sklearn.model_selection import train_test_split

import transforms as T
import utils

In [None]:
# TORCH_DEVICE = 'mps' # there is currently a bug: https://github.com/pytorch/pytorch/issues/78915
TORCH_DEVICE = 'cpu'
CKPT_PATH = './pre_trained_models/Zoobot_Resnet_Torchvision/'
CKPT_NAME = 'epoch=20-step=6552.ckpt'

DATA_PATH = '../RPN_Backbone_GZ2/Data/'
IMAGE_PATH = DATA_PATH + 'real_pngs/'

LOG_DIR = './models/Zoobot_Resnet/train'

# using typical split of 80:10:10
SIZE_OF_VALIDATION_SET = 0.1
SIZE_OF_TEST_SET = 0.1

BATCH_SIZE = 32
CUTOUT = (100, 100, 300, 300)
CUTOUT_ARRAY = np.array([100, 300, 100, 300])

In [None]:
# initialise Tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=LOG_DIR)

In [None]:
# create dataframe for image annotations
# loading metadata
file_path1 = DATA_PATH + 'combined_cat.pkl'
file_path2 = DATA_PATH + 'zoo2LocalIdMap.pkl'

df_combined_cat = (pd
    .read_pickle(file_path1)
    #.reset_index() 
    #.explode('false_pos_prob_stats')
)

zooToLocal = pd.read_pickle(file_path2)
df_combined_cat['local_ids'] = zooToLocal.loc[df_combined_cat.index.get_level_values(0)].to_numpy()

df_combined_cat.reset_index(inplace=True)

# Filter out any bulge markings that snuck through
df_combined_cat['is_central'] = (
    np.abs(0.5*(df_combined_cat['x2_normed'] + df_combined_cat['x1_normed']) - 0.5) < 0.02
    ) & (
    np.abs(0.5*(df_combined_cat['y2_normed'] + df_combined_cat['y1_normed']) - 0.5) < 0.02
    )

# and reduce to only samples with objects
df_combined_cat = df_combined_cat[(~df_combined_cat['is_central'] | df_combined_cat['empty'])]

# stick to sizes used for Zoobot training
cutout = CUTOUT_ARRAY
cutout_normed = CUTOUT_ARRAY/400

# Convert x/y normed
pad = 0.05
df_combined_cat['pad_selector'] = (
    df_combined_cat['x1_normed'] > cutout_normed[0] + pad
    ) & (
    df_combined_cat['x2_normed'] < cutout_normed[1] - pad
    ) & (
    df_combined_cat['y1_normed'] > cutout_normed[2] + pad
    ) & (
    df_combined_cat['y2_normed'] < cutout_normed[3] - pad
    )

df_combined_cat = df_combined_cat[(df_combined_cat['pad_selector'] | df_combined_cat['empty'])]

def convert_x_normed(x_normed):
    x_normed = (x_normed - cutout_normed[0]) / (cutout_normed[1] - cutout_normed[0])
    return x_normed

def convert_y_normed(y_normed):
    y_normed = (y_normed - cutout_normed[2]) / (cutout_normed[3] - cutout_normed[2])
    return y_normed

df_combined_cat['x1_normed'] = df_combined_cat.apply(lambda x: convert_x_normed(x['x1_normed']), axis=1)
df_combined_cat['x2_normed'] = df_combined_cat.apply(lambda x: convert_x_normed(x['x2_normed']), axis=1)
df_combined_cat['y1_normed'] = df_combined_cat.apply(lambda y: convert_y_normed(y['y1_normed']), axis=1)
df_combined_cat['y2_normed'] = df_combined_cat.apply(lambda y: convert_y_normed(y['y2_normed']), axis=1)

df_combined_cat['x1'] = df_combined_cat['x1_normed'] * (cutout[1] - cutout[0])
df_combined_cat['x2'] = df_combined_cat['x2_normed'] * (cutout[1] - cutout[0])
df_combined_cat['y1'] = df_combined_cat['y1_normed'] * (cutout[3] - cutout[2])
df_combined_cat['y2'] = df_combined_cat['y2_normed'] * (cutout[3] - cutout[2])

# Check, if image exists
df_combined_cat['filename'] = IMAGE_PATH + df_combined_cat['local_ids'].apply(str) + '.png'
df_combined_cat['file_exists'] = (df_combined_cat['filename']).apply(os.path.exists)

# labels
# 0 - background
# 1 - Clump 
# 2 - Odd Clump
# 3 - Improbable Clump
# 4 - Odd Improbable Clump
df_combined_cat['is_odd'] = np.where(df_combined_cat['mean_tool'] > 0.5, True, False)
df_combined_cat['is_improbable'] = np.where(df_combined_cat['false_pos_prob'] > 0.7, True, False)

df_combined_cat['label'] = np.select(
    [
        (~df_combined_cat['empty']) & (~df_combined_cat['is_odd']) & (~df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (df_combined_cat['is_odd']) & (~df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (~df_combined_cat['is_odd']) & (df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (df_combined_cat['is_odd']) & (df_combined_cat['is_improbable']),
    ], 
    [
        1,
        2,
        2, #3
        2, #4
    ],
    default = None
)

df_combined_cat['label_text'] = np.select(
    [
        (~df_combined_cat['empty']) & (~df_combined_cat['is_odd']) & (~df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (df_combined_cat['is_odd']) & (~df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (~df_combined_cat['is_odd']) & (df_combined_cat['is_improbable']),
        (~df_combined_cat['empty']) & (df_combined_cat['is_odd']) & (df_combined_cat['is_improbable']),
    ], 
    [
        b'clumpy',
        b'clumpy, odd',
        b'clumpy, odd', # b'clumpy, improbable',
        b'clumpy, odd', # b'clumpy, odd and improbable',
    ],
    default = ''
)

# get train and validation samples
unique_ids = df_combined_cat[df_combined_cat['file_exists']]['image_id'].unique()
# unique_ids = unique_ids[:500] # for prototyping

train_ids, val_ids = train_test_split(unique_ids, test_size=SIZE_OF_VALIDATION_SET + SIZE_OF_TEST_SET, random_state=42)
df_combined_cat = df_combined_cat[df_combined_cat['file_exists']]

df_combined_cat = df_combined_cat[
    ['image_id', 'local_ids', 'filename', 'label', 'label_text',
    # 'x1_normed', 'x2_normed', 'y1_normed', 'y2_normed']
    'x1', 'x2', 'y1', 'y2']
]

imageGroups_train = df_combined_cat[df_combined_cat['image_id'].isin(train_ids)]
imageGroups_valid = df_combined_cat[df_combined_cat['image_id'].isin(val_ids)]

imageGroups_train = imageGroups_train.set_index(['image_id', 'local_ids', 'filename', 'label'])
imageGroups_valid = imageGroups_valid.set_index(['image_id', 'local_ids', 'filename', 'label'])

imageGroups_train.reset_index(inplace=True)
imageGroups_valid.reset_index(inplace=True)

epochs = 80
print('Size of train-set: {}, Size of validation-set: {}'.format(len(imageGroups_train),len(imageGroups_valid)))
print('So, for {} epochs we need {} steps.'.format(epochs, (len(imageGroups_train)+len(imageGroups_valid)/BATCH_SIZE*epochs)))

In [None]:
# saving dfs
imageGroups_train.to_pickle('imageGroups_train.pkl')
imageGroups_valid.to_pickle('imageGroups_valid.pkl')
imageGroups_train.to_csv('imageGroups_train.csv')
imageGroups_valid.to_csv('imageGroups_valid.csv')

In [None]:
def get_transform(train):
    augs = []

    augs.append(T.PILToTensor())
    augs.append(T.ConvertImageDtype(torch.float))
    
    if train:
        augs.append(T.RandomHorizontalFlip(0.5))
    
    return T.Compose(augs)

In [None]:
# Dataset class and defined transformations
import SDSSGalaxyDataset

dataset_train = SDSSGalaxyDataset.SDSSGalaxyDataset(
    dataframe=imageGroups_train,
    image_dir=IMAGE_PATH,
    cutout=CUTOUT,
    colour=True,
    transforms=get_transform(train=True)
)
dataset_validation = SDSSGalaxyDataset.SDSSGalaxyDataset(
    dataframe=imageGroups_valid,
    image_dir=IMAGE_PATH,
    cutout=CUTOUT,
    colour=True,
    transforms=get_transform(train=False)
)

train_data_loader = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4,
    collate_fn=utils.collate_fn
)
valid_data_loader = torch.utils.data.DataLoader(
    dataset_validation, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=4,
    collate_fn=utils.collate_fn
)

print("Count: {} are training and {} validation".format(len(dataset_train), len(dataset_validation)))

In [None]:
images, targets = next(iter(train_data_loader))
images = list(image.to(TORCH_DEVICE) for image in images)
targets = [{k: v.to(TORCH_DEVICE) for k, v in t.items()} for t in targets]

In [None]:
for i in range(2):
    boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
    sample = images[i].permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(5, 5))
    sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
    
    for box in boxes:
        cv2.rectangle(sample,  # the image is in RGB, convert to BGR for cv2 annotations
                      (box[0], box[1]),
                      (box[2], box[3]),
                      (0, 0, 255), 1)
    plt.imshow(sample[:, :, ::-1])
    # plt.axis('off')

In [None]:
def get_model(num_classes=2, trainable_layers=0):
    import copy_zoobot_weights

    # load an object detection model pre-trained for Zoobot
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        weights='COCO_V1',
        trainable_backbone_layers=3
    )

    model = copy_zoobot_weights.copy_Zoobot_weights_to_Resnet(
        model=model, 
        ckpt_path=CKPT_PATH + CKPT_NAME,
        device=TORCH_DEVICE,
        trainable_layers=trainable_layers
    )
    
    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new on
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)
   
    return model

In [None]:
# Training
NUM_EPOCHS = 120

# get the model, all pretrained layers from the backbone CNN are freezed
frcnn_model = get_model(num_classes=3, trainable_layers=0)

# move model to the right device
frcnn_model = frcnn_model.to(TORCH_DEVICE)

# construct an optimizer
params = [p for p in frcnn_model.parameters() if p.requires_grad]
# optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by # 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
from engine import train_one_epoch, evaluate

for epoch in range(NUM_EPOCHS):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(
        frcnn_model, 
        optimizer, 
        train_data_loader, 
        TORCH_DEVICE, 
        epoch, 
        print_freq=10,
        scaler=None,
        tb_writer=writer
    )
    
    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(
        frcnn_model, 
        valid_data_loader, 
        device=TORCH_DEVICE
    )

In [None]:
# Interference
loaded_model = get_model(num_classes = 3)
loaded_model.load_state_dict(torch.load(
    './models/FasterRCNN_Zoobot_weights.pth',
    map_location=torch.device(TORCH_DEVICE)
))

#put the model in evaluation mode
loaded_model.eval()

In [None]:
images, targets = next(iter(valid_data_loader))
images = list(image.to(TORCH_DEVICE) for image in images)
targets = [{k: v.to(TORCH_DEVICE) for k, v in t.items()} for t in targets]

In [None]:
with torch.no_grad():
    prediction = loaded_model(images)

In [None]:
for i in range(BATCH_SIZE):
    boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
    pred_boxes = prediction[i]['boxes'].cpu().numpy().astype(np.int32)
    score = prediction[i]['scores'].cpu().numpy()
    sample = images[i].permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(5, 5))
    sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
    
    for box in boxes:
        cv2.rectangle(sample,  # the image is in RGB, convert to BGR for cv2 annotations
                      (box[0], box[1]),
                      (box[2], box[3]),
                      (0, 0, 255), 1)

    for idx, box in enumerate(pred_boxes):
        if score[idx] > 0.5:
            cv2.rectangle(sample,  # the image is in RGB, convert to BGR for cv2 annotations
                      (box[0], box[1]),
                      (box[2], box[3]),
                      (0, 255, 255), 1)

    plt.imshow(sample[:, :, ::-1])
    plt.axis('off')