In [None]:
%load_ext autoreload
%autoreload 2
import paltas
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from astropy.io import fits
from pathlib import Path
import matplotlib.colors as colors
import sys
sys.path.insert(0, '/Users/smericks/Desktop/CS236/sl_project/')
import visualization_utils


Use paltas package to import COSMOS catalog & prepare cutouts

In [None]:
# Average AB magnitude zeropoint for the COSMOS run.
output_ab_zeropoint = 25.95
source_params_dict = {
		'z_source':None,
		'cosmos_folder':'COSMOS_23.5_training_sample/',
		'max_z':None,'minimum_size_in_pixels':None,'faintest_apparent_mag':None,
		'smoothing_sigma':0.00,'random_rotation':False,
		'output_ab_zeropoint':output_ab_zeropoint,
		'min_flux_radius':None,
		'center_x':None,
		'center_y':None }
cosmos_source_galaxies = paltas.Sources.cosmos.COSMOSCatalog(
    cosmology_parameters='planck18',
    source_parameters=source_params_dict)

In [None]:
# from paltas.Sources.cosmos.CosmosCATALOG 
# rewritten to avoid dependence on self objects
def iter_image_and_metadata_bulk(folder,catalog, message=''):
    """Yields the image array and metadata for all of the images
    in the catalog.

    Args:
        message (str): If the iterator uses tqdm, this message
            will be displayed.

    Returns:
        (generator): A generator that can be iterated over to give
        lenstronomy kwargs.

    Notes:
        This will read the fits files.
    """
    folder = Path(folder)
    catalog_i = 0
    _pattern = f'real_galaxy_images_23.5_n*.fits'  # noqa: F541, F999
    files = list(sorted(folder.glob(_pattern),
        key=paltas.Sources.cosmos.COSMOSCatalog._file_number))

    # Iterate over all the matching files.
    for fn in tqdm(files, desc=message):
        with fits.open(fn) as hdul:
            for img in hdul:
                yield img.data, catalog[catalog_i]
                catalog_i += 1

In [None]:
def prepare_cutout(im,cutout_size):

    # arbitrary choice
    if np.shape(im)[0] %2 != 0:
        im = im[:-1,:]
    if np.shape(im)[1] %2 != 0:
        im = im[:,:-1]

    if np.shape(im)[0] < cutout_size or np.shape(im)[1] < cutout_size:
        # pad
        # try to estimate noise
        std_dev = np.mean([np.std(im[:10,:10]),np.std(im[-10:,:10]),np.std(im[:10,-10:]),np.std(im[-10:,-10:])])
        #std_dev = np.mean(im[:10,:10])
        new_im = np.random.normal(loc=0,scale=std_dev,size=(cutout_size,cutout_size))
        diffx = int((cutout_size - np.shape(im)[0])/2)
        diffy = int((cutout_size - np.shape(im)[1])/2)
        # edge cases!!
        if diffx < 1:
            new_im[:,diffy:-diffy] = im
            if diffy < 1:
                new_im = im
        elif diffy < 1:
            new_im[diffx:-diffx,:] = im
        else:
            new_im[diffx:-diffx,diffy:-diffy] = im

    elif np.shape(im)[0] > cutout_size or np.shape(im)[1] > cutout_size:
        # crop
        diffx = int((np.shape(im)[0] - cutout_size)/2)
        diffy = int((np.shape(im)[1] - cutout_size)/2)

        # edge cases!!
        if diffx < 1:
            new_im = im[:,diffy:-diffy]
            if diffy < 1:
                new_im = im
        elif diffy < 1:
            new_im = im[diffx:-diffx,:]
        else:
            new_im = im[diffx:-diffx,diffy:-diffy]

    else:
        new_im = im

    # make some cut on small pixel values so that noise floor is the same for 
    # synthetic and real noise
    new_im[new_im < 1e-2] = 0

    # TODO: Fix this: janky fix!! 
    new_im = new_im[:,:-1]
    new_im = new_im[:-1,:]

    return new_im

In [None]:
cosmos_folder = 'COSMOS_23.5_training_sample/'
HUBBLE_ACS_PIXEL_WIDTH = 0.03   # Arcsec
cutout_size = 102

In [None]:
# based on code in paltas
# load catalog into numpy format 
catalog_path = cosmos_folder+'custom_cutouts_round3/paltas_catalog.npy'
npy_files_path = cosmos_folder+'custom_cutouts_round3/npy_files/'

#npy_files_path.mkdir(exist_ok=True)
			
# Combine all partial catalog files
catalogs = [paltas.Sources.cosmos.unfits(cosmos_folder + fn) for fn in [
    'real_galaxy_catalog_23.5.fits',
    'real_galaxy_catalog_23.5_fits.fits'
]]

# Duplicate IDENT field crashes numpy's silly merge function.
catalogs[1] = np.lib.recfunctions.drop_fields(catalogs[1],
    'IDENT')

# Custom fields
catalogs += [
    np.zeros(len(catalogs[0]),
        dtype=[('size_x', int),('size_y', int),('z', float),
        ('pixel_width', float)])]

catalog = np.lib.recfunctions.merge_arrays(catalogs, flatten=True)

catalog['pixel_width'] = HUBBLE_ACS_PIXEL_WIDTH
catalog['z'] = catalog['zphot']
catalog['size_x'] = cutout_size
catalog['size_y'] = cutout_size

# Loop over the images to find their sizes.
catalog_i = 0
sum_ims = []
counter = 0
for img, meta in iter_image_and_metadata_bulk(folder=cosmos_folder,catalog=catalog):
    # Grab the shape of each image.
    # IMPOSE SAME CUTOUTSIZE FOR EVERY IMAGE
    counter += 1
    if counter == 5 or counter == 12 or counter == 17:
        print(np.shape(img))
        plt.figure()
        plt.imshow(img)
        std = np.mean([np.std(img[:10,:10]),np.std(img[-10:,:10]),np.std(img[:10,-10:]),np.std(img[-10:,-10:])])
        print(std)
        if catalog_i == 17:
            break

    img = prepare_cutout(img,cutout_size)
    # REMOVE IMAGES WITHOUT ENOUGH FLUX
    sum_ims.append(np.sum(img))
    if np.sum(img) < 14:
        pass
    else:
        # NORMALIZE TO SUM TO ONE
        img = img / np.sum(img)
        # Save the image as its own image.
        img = img.astype(np.float64)
        np.save(str(npy_files_path+('image_%07d.npy'%(catalog_i))),img)
        catalog_i += 1

np.save(catalog_path,catalog)

In [None]:
plt.hist(sum_ims)
print(np.median(sum_ims))
print(np.sum(np.asarray(sum_ims) < 14))
np.min(sum_ims)

#4822/56000


In [None]:
visualization_utils.matrix_plot_from_folder('COSMOS_23.5_training_sample/custom_cutouts_round2/npy_files/','training_grid.pdf')

In [None]:
i = 38
im = np.load('COSMOS_23.5_training_sample/custom_cutouts/npy_files/image_%07d.npy'%(i))
print(np.sum(im))
plt.figure()
plt.imshow(im)
im2 = np.load('COSMOS_23.5_training_sample/custom_cutouts_round2/npy_files/image_%07d.npy'%(i))
print(np.sum(im2))
plt.figure()
plt.imshow(im2)

In [None]:
catalog = np.load('COSMOS_23.5_training_sample/custom_cutouts/paltas_catalog.npy')

In [None]:
print(type(catalog))
print(catalog['pixel_width'].shape)

### Make a tfrecord file ###

In [None]:
from paltas.Analysis.dataset_generation import generate_tf_record, generate_tf_dataset


generate_tf_record('COSMOS_23.5_training_sample/custom_cutouts_round2/npy_files/',[],'placeholder.csv',
	'COSMOS_23.5_training_sample/custom_cutouts_round2/training_data.tfrecord')

In [None]:
generate_tf_record('COSMOS_23.5_training_sample/custom_cutouts_round2/val_npy_files/',[],'placeholder.csv',
	'COSMOS_23.5_training_sample/custom_cutouts_round2/validation_data.tfrecord')