In [None]:
#@title Imports
!pip install -q patchify

from patchify import patchify
import numpy as np
from astropy.io import fits
from astropy.visualization import ZScaleInterval
import matplotlib.pyplot as plt
from tqdm import tqdm
import shutil
import torch
from scipy import ndimage
import random
import json
import os
import random
from PIL import Image
import matplotlib.image
from skimage.measure import label, regionprops
from sklearn.model_selection import train_test_split
import cv2

norm = ZScaleInterval()

IMG_PATH = 'path/to/images'
MASK_PATH = 'path/to/masks'
OUT_PATH = 'path/to/output/folder'


In [None]:
#@title Dataset Generating Function
def generate_dataset(width = 384): 

  annot = {}

  # load FITS images and segmentation mask
  img = fits.getdata(IMG_PATH)
  mask = fits.getdata(MASK_PATH)
  img_patches = patchify(img, (384, width), step=(384, width)) # patch image in 384x384
  mask_patches = patchify(mask, (384, width), step=(384, width))

  dim1, dim2, _, _ = img_patches.shape
  output_folder = OUT_PATH
  if os.path.isdir(output_folder): 
     shutil.rmtree(output_folder)
  os.mkdir(output_folder)
  os.mkdir(output_folder+'/images')
  os.mkdir(output_folder+'/density_maps')

  count = 0
  for i in tqdm(range(dim1)):
    for j in range(dim2):
      
      rgb = cv2.cvtColor(norm(img_patches[i][j]), cv2.COLOR_GRAY2RGB) # convert image to RGB
      imgToSave = Image.fromarray((rgb*255).astype(np.uint8))
      imgToSave.save(output_folder + '/images/'+ str(count).zfill(6) +'.png') 

      elements = np.delete(np.unique(mask_patches[i][j]), 0)
      gt_mask = np.array([True if i in elements else False for i in mask_patches[i][j].flatten()]).reshape((384, width)) # create binary masks 
      labels = label(gt_mask.astype(int)) 
      props = regionprops(labels) # extract bounding regions from segmentation mask
      
      examples = [props[i] for i in random.sample(range(len(props)), 3)] if len(props)>=3 else props # save 3 object examples for each patch at random
      coords=[]
      points=[]

      # extract coordinates
      for prop in props:
          min_x = prop.bbox[1]
          min_y = prop.bbox[0]
          max_x = prop.bbox[3]
          max_y = prop.bbox[2]
          coords.append([[min_x, min_y], [min_x, max_y], [max_x, max_y], [max_x, min_y]])

          # extract central point of bbox for density maps
          points.append([(min_x + max_x)/2, (min_y+max_y)/2])

      examples = [coords[i] for i in random.sample(range(len(coords)), 3)] if len(coords)>=3 else coords
      annot[str(count).zfill(6) +'.png'] = {'box_examples_coordinates': examples,
                                            'points': points,
                                            'density_path': output_folder + '/density_maps/' + str(count).zfill(6) + '.npy',
                                            'img_path': output_folder + '/images/' + str(count).zfill(6) + '.png',
                                            }

      #create density map
      gt_map = np.zeros((384, width), dtype='float32')
      for p in points:
          gt_map[min(383, int(p[1]))][min(width-1, int(p[0]))] = 1

      # apply gaussian filter to density maps
      gt_map = ndimage.gaussian_filter(gt_map, sigma=(1, 1), order=0)

      # increase pixel values by a 60 multiplier (avg. # of pixels of a dot) 
      gt_map = gt_map*60
      gt_map = torch.from_numpy(gt_map)
      np.save(output_folder + '/density_maps/' + str(count).zfill(6), gt_map)
      count += 1
  
  with open("./annotation.json", "w") as outfile:
      json.dump(annot, outfile, indent=4)

In [None]:
#@title Generate Dataset
width = 384 #@param {type:"integer"}
generate_dataset(width)

100%|██████████| 39/39 [20:20<00:00, 31.30s/it]


In [None]:
#@title Train, Val, Test Split
images = os.listdir('./images')
print('Number of images:', len(images))

train_split = 0.8 #@param {type: "number"}
val_test_split = 0.5 #@param {type: "number"}

train, val = train_test_split(images, test_size = 1-train_split, shuffle = True)
val, test = train_test_split(val, test_size = val_test_split, shuffle = True)

# store split in a json dictionary
train_test_val = {'train': train, 
                  'test': test,
                  'val': val}

with open("./train_test_val.json", "w") as outfile:
    json.dump(train_test_val, outfile, indent=4)

Number of images: 1521
