## Demo with DECam data

This walkthrough uses the Burke et al. 2019 dataset, but superceeds the old Matterport Mask R-CNN implementation.

#### A few notes:

The data can be obtained following the links in the [old repository](https://github.com/burke86/astro_rcnn). The dataset directories should be re-named "test", "train", and "val".

In [None]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger

setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random

# from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data import build_detection_train_loader
from detectron2.engine import DefaultTrainer
from detectron2.engine import SimpleTrainer
from detectron2.engine import HookBase
from detectron2.structures import BoxMode
import detectron2.solver as solver
import detectron2.modeling as modeler
import detectron2.data as data
import detectron2.data.transforms as T
import detectron2.checkpoint as checkpointer
from detectron2.data import detection_utils as utils

import weakref
import copy
import torch
import time
from typing import Dict, List, Optional

import imgaug.augmenters as iaa


from astropy.io import fits
import glob

import deepdisc.astrodet.astrodet as toolkit
from deepdisc.astrodet import detectron as detectron_addons

from deepdisc.data_format.file_io import DDLoader
from deepdisc.data_format.annotation_functions.annotate_decam import annotate_decam

In [None]:
# Print some versions so we know what works
print(torch.__version__)
print(detectron2.__version__)

In [None]:
# Prettify the plotting
from deepdisc.astrodet.astrodet import set_mpl_style

set_mpl_style()

In [None]:
# Todo: change the dirpath here
dirpath = "/home/shared/hsc/decam/decam_data/"  # Path to dataset
output_dir = "/home/shared/hsc/decam/models/"

dataset_names = ["train", "test", "val"]

### Register Astro R-CNN dataset

For detectron2 to read the data, it must be in a dictionary format. The flexible `DDLoader` class can be used to load data from a generic directory 
structure into a user-defined output structure of metadata. Below, we iterate over each dataset and initialize a `DDLoader` instance. The `DDLoader.generate_filedict` function is used to read the directory structure and return a dictionary of file names. We can then use the `DDLoader.generate_dataset_dict` function, which
passes a user-defined annotation function along to the files, using the 
generated dictionary of filenames.

In this case, we have a pre-made annotation function for DECAM data,
`annotate_decam`, which is passed along.

However, this step can take a few minutes, and so we recommend only running it once and saving the dictionary data as a json file that can be 
read in at the beginning of your code.


In [None]:
for i, d in enumerate(dataset_names):
    filenames_dir = os.path.join(dirpath, d)
    
    # Generate the dictionary of filenames
    decam_loader = DDLoader().generate_filedict(filenames_dir, 
                                                ['g', 'r', 'z'], 
                                                'img*.fits', 
                                                'masks.fits', 
                                                subdirs=True, 
                                                filt_loc=-6, 
                                                n_samples=20)
    
    # Register the dataset generator functions
    DatasetCatalog.register("astro_" + d, lambda: decam_loader.generate_dataset_dict(annotate_decam, filters=False).get_dataset())
    MetadataCatalog.get("astro_" + d).set(thing_classes=["star", "galaxy"], things_colors=["blue", "gray"])
astro_metadata = MetadataCatalog.get("astro_train")
dataset_dicts = {}

# for i, d in enumerate(dataset_names):
for i, d in enumerate(dataset_names):
    print(f"Loading {d}")
    dataset_dicts[d] = decam_loader.generate_dataset_dict(annotate_decam, filters=False).get_dataset()

In [None]:
# code snippet for unregistering if you want to change something

"""
if "astro_train" in DatasetCatalog.list():
    print('removing astro_train')
    DatasetCatalog.remove("astro_train")
    
    
if "astro_test" in DatasetCatalog.list():
    print('removing astro_test')
    DatasetCatalog.remove("astro_test")
    
if "astro_val" in DatasetCatalog.list():
    print('removing astro_val')
    DatasetCatalog.remove("astro_val")

"""

Run the following hidden cells if your data is already saved in dictionary format. You will need to change file paths. If you already registered the data, you will need to run the cell above

In [None]:
# Initialize a DDLoader class, which will just be used to load existing files
json_loader = DDLoader()

trainfile = os.path.join(dirpath, "train.json")
testfile = os.path.join(dirpath, "test.json")
valfile = os.path.join(dirpath, "val.json")

DatasetCatalog.register("astro_train", lambda: json_loader.load_coco_json_file(trainfile).get_dataset())
MetadataCatalog.get("astro_train").set(thing_classes=["star", "galaxy"])
astrotrain_metadata = MetadataCatalog.get("astro_train")  # astro_test dataset needs to exist

DatasetCatalog.register("astro_test", lambda: json_loader.load_coco_json_file(testfile).get_dataset())
MetadataCatalog.get("astro_test").set(thing_classes=["star", "galaxy"])
astrotest_metadata = MetadataCatalog.get("astro_test")
# astro_test dataset needs to exist


DatasetCatalog.register("astro_val", lambda: json_loader.load_coco_json_file(valfile).get_dataset())
MetadataCatalog.get("astro_val").set(thing_classes=["star", "galaxy"])
# astroval_metadata = MetadataCatalog.get("astro_val") # astro_test dataset needs to exist

In [None]:
dataset_dicts = {}
json_loader = DDLoader()

for i, d in enumerate(dataset_names):
    print(f"Loading {d}")
    filenames_dir = os.path.join(dirpath, d)
    dataset_dicts[d] = json_loader.load_coco_json_file(filenames_dir + ".json").get_dataset()

### Visualize ground truth examples

In [None]:
nsample = 3  # Number of example images to plot
fig, axs = plt.subplots(1, nsample, figsize=(5 * nsample, 5))

for i, d in enumerate(random.sample(dataset_dicts["test"], nsample)):
    # Use the Lupton scaling for better visualization
    img = toolkit.read_image_decam(d["file_name"], normalize="astrolupton", stretch=100, Q=10)

    visualizer = Visualizer(img, metadata=astro_metadata)
    # Plot the figures
    out = visualizer.draw_dataset_dict(d)
    axs[i].imshow(out.get_image())
    axs[i].axis("off")
    fig.tight_layout()
    fig.show()

### Data Augmentations

Below, we create the function train_mapper, which takes one of the metadata dictionaries, reads in the corresponding image, and applies custom
augmentations.  It will output a new dictionary that will be fed into the model. You can see an example of the augmentations working below.

In [None]:
import imgaug.augmenters.flip as flip
import imgaug.augmenters.blur as blur


def hflip(image):
    return flip.fliplr(image)


def gaussblur(image):
    aug = iaa.GaussianBlur(sigma=(0.0, np.random.random_sample() * 4 + 2))
    return aug.augment_image(image)


def addelementwise(image):
    aug = iaa.AddElementwise((-40, 40))
    return aug.augment_image(image)

The KRandomAugmentationList class will take a list of augmentations and and randomly apply k of them

In [None]:
def train_mapper(dataset_dict):

    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below

    image = toolkit.read_image_decam(dataset_dict["file_name"], normalize="astrolupton", stretch=100, Q=15)

    augs = detectron_addons.KRandomAugmentationList(
        [
            # my custom augs
            T.RandomRotation([-90, 90, 180], sample_style="choice"),
            T.RandomFlip(prob=0.5),
            T.RandomFlip(prob=0.5, horizontal=False, vertical=True),
            detectron_addons.CustomAug(gaussblur, prob=1.0),
            #detectron_addons.CustomAug(addelementwise, prob=1.0),
        ],
        k=-1,
    )

    # Data Augmentation
    auginput = T.AugInput(image)
    # Transformations to model shapes
    transform = augs(auginput)
    image = torch.from_numpy(auginput.image.copy().transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]
    return {
        # create the format that the model expects
        "image": image,
        "image_shaped": auginput.image,
        "height": 512,
        "width": 512,
        "image_id": dataset_dict["image_id"],
        "instances": utils.annotations_to_instances(annos, image.shape[1:]),
    }

Plot the original and augmented image

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10 * 2, 10))

d = random.sample(dataset_dicts["train"], 1)[0]


img = toolkit.read_image_decam(d["file_name"], normalize="astrolupton", stretch=100, Q=15)
visualizer = Visualizer(img, metadata=astro_metadata, scale=1)
# Get the ground truth boxes
gt_boxes = np.array([a["bbox"] for a in d["annotations"]])
# Convert to the mode visualizer expects
gt_boxes = BoxMode.convert(gt_boxes, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS)
out = visualizer.overlay_instances(boxes=gt_boxes)
axs[0].imshow(out.get_image())
axs[0].axis("off")

aug_d = train_mapper(d)
img_aug = aug_d["image_shaped"]
visualizer = Visualizer(img_aug, metadata=astro_metadata, scale=1)
# Convert to the mode visualizer expects
out = visualizer.overlay_instances(boxes=aug_d["instances"].gt_boxes)
axs[1].imshow(out.get_image())
axs[1].axis("off")
fig.tight_layout()
fig.show()

### Training

We prepare for training by intializing a config object.  The we can take the intial weights from the pre-trained models in the model zoo.
This setup is for demo purposes, so it does not follow a full training schedule.

### Prepare For Training

We prepare for training by intializing a config object and setting hyperparameters.  The we can take the intial weights from the pre-trained models in the model zoo.  For a full list of available config options, check https://detectron2.readthedocs.io/en/latest/modules/config.html

This setup is for demo purposes, so it does not follow the full training schedule we use for the paper.  You can check the train_decam.py script for the final training configurations 

The model used here is not as good at transfer learning to astronomical images, so the results may not appear very good for the relatively short amount of iterations used here

In [None]:
cfgfile = '../configs/solo/demo_r50_hsc.py'          # The config file which contains information about the model 
cfg = LazyConfig.load(cfgfile)                       # Load in the config
model = return_lazy_model(cfg,freeze=False)          # Build the model from the config specifications
cfg.optimizer.params.model = model                   # Set up the training optimizer
optimizer = return_optimizer(cfg)



loader = loaders.return_train_loader(cfg, train_mapper)      # Set up the loader, which formats the data to be fed into the model

schedulerHook = return_schedulerhook(optimizer)      # Create a "hook" which will set up the scheduler to control learning rates
saveHook = return_savehook("model_temp")             # Create a "hook" which will save the model
hookList = [saveHook, schedulerHook]                 

cfg.train.init_checkpoint = "detectron2://ImageNetPretrained/MSRA/R-50.pkl"   #Initialize the model weights from a pre-trained model

cfg.OUTPUT_DIR ='./'                                 #Set the output directory





Now we can train the model!  We set up a trainer and tell it how often to output and when to stop

In [None]:
#trainer = toolkit.NewAstroTrainer(model, loader, optimizer, cfg)
#trainer.register_hooks(hookList)
trainer = return_lazy_trainer(model, loader, optimizer, cfg, hookList)
trainer.set_period(50)  # print loss every 10 iterations
trainer.train(0, 2000)

In [None]:
# Hack if you get SSL certificate error
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

import warnings

try:
    # ignore ShapelyDeprecationWarning from fvcore
    # This comes from the cropping
    from shapely.errors import ShapelyDeprecationWarning

    warnings.filterwarnings("ignore", category=ShapelyDeprecationWarning)

except:
    pass

In [None]:
from deepdisc.training.trainers import (
    return_evallosshook,
    return_lazy_trainer,
    return_optimizer,
    return_savehook,
    return_schedulerhook,
)
from detectron2.config import LazyConfig


from deepdisc.model.models import return_lazy_model

import deepdisc.model.loaders as loaders
from deepdisc.data_format.image_readers import HSCImageReader


### Plot The Loss

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))
ax.plot(trainer.lossList, label=r"$L_{\rm{tot}}$")
# ax.plot(losses, label=r'$L_{\rm{tot}}$')

ax.legend(loc="upper right")
ax.set_xlabel("training epoch", fontsize=20)
ax.set_ylabel("loss", fontsize=20)
ax.set_ylim(0,10)
fig.tight_layout()

### Inference

Inference should use the config with parameters that are used in training
cfg now already contains everything we've set previously. We changed it a little bit for inference:

In [None]:
cfgfile = '../configs/solo/demo_r50_hsc.py'
cfg = LazyConfig.load(cfgfile)
cfg.OUTPUT_DIR = './'
cfg.train.init_checkpoint = os.path.join(cfg.OUTPUT_DIR, "model_temp.pth")

#change these to play with the detection sensitivity
model.roi_heads.box_predictor.test_score_thresh = 0.3
#model.roi_heads.box_predictor.test_nms_thresh = 0.5

predictor = toolkit.AstroPredictor(cfg)


In [None]:
from detectron2.utils.visualizer import ColorMode

nsample = 3
fig, axs = plt.subplots(1, nsample, figsize=(5 * nsample, 5))

for i, d in enumerate(random.sample(dataset_dicts["test"], nsample)):
    img = toolkit.read_image_decam(d["file_name"], normalize="astrolupton", stretch=100, Q=15)
    outputs = predictor(
        img
    )  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format

    print("total instances:", len(d["annotations"]))
    print("detected instances:", len(outputs["instances"].pred_boxes))
    print("")

    v = Visualizer(
        img,
        metadata=astro_metadata,
        scale=1,
        instance_mode=ColorMode.SEGMENTATION,  # remove the colors of unsegmented pixels. This option is only available for segmentation models
    )
    out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    axs[i].imshow(out.get_image())
    axs[i].axis("off")
    fig.tight_layout()
    fig.show()

### Evaluate

In [None]:
def test_mapper(dataset_dict):

    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below

    image = toolkit.read_image_decam(dataset_dict["file_name"], normalize="astrolupton", stretch=100, Q=10)
    augs = T.AugmentationList([])
    # Data Augmentation
    auginput = T.AugInput(image)
    # Transformations to model shapes
    transform = augs(auginput)
    image = torch.from_numpy(auginput.image.copy().transpose(2, 0, 1))
    annos = [
        utils.transform_instance_annotations(annotation, [transform], image.shape[1:])
        for annotation in dataset_dict.pop("annotations")
    ]
    return {
        # create the format that the model expects
        "image": image,
        "image_shaped": auginput.image,
        "height": 512,
        "width": 512,
        "image_id": dataset_dict["image_id"],
        "instances": utils.annotations_to_instances(annos, image.shape[1:]),
        "annotations": annos,
    }

In [None]:
from detectron2.evaluation import inference_on_dataset
from detectron2.data import build_detection_test_loader

# NOTE: New version has max_dets_per_image argument in default COCOEvaluator
evaluator = toolkit.COCOEvaluatorRecall("astro_val", use_fast_impl=True, output_dir=cfg.OUTPUT_DIR)

test_loader = build_detection_test_loader(dataset_dicts["val"], mapper=test_mapper)

In [None]:
results = inference_on_dataset(predictor.model, test_loader, evaluator)

In [None]:
results["bbox"]["AP-star"]

In [None]:
ap_type = "bbox"  # Which type of precision/recall to use? 'segm', or 'bbox'
cls_names = ["star", "galaxy"]

results_per_category = results[ap_type]["results_per_category"]

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs = axs.flatten()

ious = np.linspace(0.50, 0.95, 10)
colors = plt.cm.viridis(np.linspace(0, 1, len(ious)))

# Plot precision recall
for j, precision_class in enumerate(results_per_category):
    precision_shape = np.shape(precision_class)
    for i in range(precision_shape[0]):
        # precision has dims (iou, recall, cls, area range, max dets)
        # area range index 0: all area ranges
        # max dets index -1: typically 100 per image
        p_dat = precision_class[i, :, j, 0, -1]
        # Hide vanishing precisions
        mask = p_dat > 0
        # Only keep first occurance of 0 value in array
        mask[np.cumsum(~mask) == 1] = True
        p = p_dat[mask]
        # Recall points
        r = np.linspace(0, 1, len(p))  # Recall is always defined from 0 to 1 for these plots, I think
        dr = np.diff(np.linspace(0, 1, len(p_dat)))[0]  # i think
        # Plot
        iou = np.around(ious[i], 2)
        AP = 100 * np.sum(p * dr)
        axs[j].plot(
            r, p, label=r"${\rm{AP}}_{%.2f} = %.1f$" % (iou, AP), color=colors[i], lw=2
        )  # use a viridis color scheme
        axs[j].set_xlabel("Recall", fontsize=20)
        axs[j].set_ylabel("Precision", fontsize=20)
        axs[j].set_xlim(0, 1.1)
        axs[j].set_ylim(0, 1.1)
        axs[j].legend(fontsize=10, title=f"{cls_names[j]}", bbox_to_anchor=(1.35, 1.0))

fig.tight_layout()

This demo is just to show how to set up the training.  We encourage you to add object classes, try different contrast scalings, and train for longer!  


    
  
  You can also look at the content of the output below  
  
  

In [None]:
outputs = predictor(img)

In [None]:
outputs['instances'].get_fields().keys()

In [None]:
print(outputs['instances'].scores)