# Creating the blended sample
By Shooby, Sep 16th <br>
last edited: Aug 29, 2020

In [1]:
import numpy as np
import astropy.io.fits as pyfits
import matplotlib.pyplot as plt
import astropy.wcs as wcs
from astropy import units as u
from astropy.coordinates import SkyCoord
from scipy import ndimage
from PIL import Image


def radec2xy(ra,dec,wc):
    coords = SkyCoord(ra,dec, unit='deg')
    a=wcs.utils.skycoord_to_pixel(coords, wc, origin=0,mode=u'wcs')
    return a[0],a[1]
    
def cut(ra,dec,andaze,filename):
    '''gets coordinates of the galaxy and the filter to return a cutout
    (also called a postage stamp) of the galaxy with given size'''
    hdr = pyfits.getheader(filename,ext =0)
    w = wcs.WCS(hdr)
    x,y=radec2xy(ra,dec,w)
    x,y=np.int(x),np.int(y)
    im=pyfits.getdata(filename)[y-andaze:y+andaze,x-andaze:x+andaze]
    return im

%matplotlib inline

In [2]:
def brightest_center(im, r = 10):
    
    '''This function is to check whether the central object of the 
    image is the brightest compared to its neighbors in the given cutout.
    Central is defined with a 10x10 pixel square in the center'''
    
    a0,a1 = np.unravel_index(np.argmax(im, axis=None), im.shape)
    ans = False
    if ((a0>((im.shape[0]-r)/2)) & (a0<((im.shape[0]+r)/2)) & (a1>((im.shape[1]-r)/2)) & (a1<((im.shape[0]+r)/2))):
        ans = True
    
    return ans
    

In [24]:
def galblend(gals=1, lim_hmag=25, plot_it=True):
    
    '''This is to put together two candels GOODS_S galaxies into a single 64x64 cutout.
    I make sure in each cutout the central galaxy is brightest object in the cutout so 
    in rescaling and adding two components still show up. Also, one galaxy is put at 
    the center and the second in some random distance from it. both cutouts are rotated with
    a random angle. Cutouts are from HST H band for now.'''
    
    ## reading GOODS-S catalog and initial selection on objects
    gs = pyfits.getdata('/Users/shemmati/Dropbox/WFIRST_WPS/CANDELS_fits/gds.fits')
    sel1 = (gs['zbest']>0.05)&(gs['zbest']<1.8)&(gs['CLASS_STAR']<0.95)&(gs['Hmag']<lim_hmag)&(gs['FWHM_IMAGE']>10)&(gs['DECdeg']<-27.8)
      
    ra, dec,red = gs['RA_1'][sel1],gs['DEC_1'][sel1],gs['zbest'][sel1]
    im = np.zeros([64,64])
   
    data1 = np.zeros([80,80])
    da1 = np.zeros([64,64])
    while not(brightest_center(data1)):
        n = np.int(np.random.uniform(0,len(ra)-1))
        data1 = cut(ra[n],dec[n],40,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f775w_060mas_v1.5_drz.fits')
        z1 = red[n]
    
    angle = np.random.uniform(0,180)
    s = ndimage.rotate(data1,angle,mode='nearest',reshape=False)
    da1 = s[8:-8,8:-8]
    im += da1
        
    dada = np.arcsinh(im)
    rescaled = (255.0 / (dada.max()+0.05) * (dada - dada.min())).astype(np.uint8)
    final_im = Image.fromarray(rescaled)

    return final_im,rescaled

In [26]:
for boz in range(2000):
    im,g = galblend(gals=1,lim_hmag=26,plot_it=False)
    im.save('images/training_images/1/'+str(boz)+'.jpg')
    
    im,g = galblend(gals=1,lim_hmag=26,plot_it=False)
    im.save('images/training_images/2/'+str(boz)+'.jpg')

for boz in range(200):
    im,g = galblend(gals=1,lim_hmag=26,plot_it=False)
    im.save('images/test_images/1/'+str(boz)+'.jpg')
    
    im,g = galblend(gals=1,lim_hmag=26,plot_it=False)
    im.save('images/test_images/2/'+str(boz)+'.jpg')
    
 

In [28]:
import os
from PIL import Image
from array import *
from random import shuffle
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# Load from and save to
Names = [['images/training_images','train'], ['images/test_images','test']]

for name in Names:

    data_image = array('B')
    data_label = array('B')
    FileList = []
    for dirname in os.listdir(name[0])[1:]: # [1:] Excludes .DS_Store from Mac OS

        path = os.path.join(name[0],dirname)
        for filename in os.listdir(path):

            if filename.endswith(".jpg"):
                FileList.append(os.path.join(name[0],dirname,filename))

    shuffle(FileList) # Usefull for further segmenting the validation set

    for filename in FileList:
        label = int(filename.split("/")[2])        
        Im = Image.open(filename)
     
        
        #plt.imshow(Im)
        pixel = Im.load()
        width, height = Im.size

        
        for x in range(0,width):
            for y in range(0,height):
                data_image.append(pixel[y,x])

        data_label.append(label) # labels start (one unsigned byte each)

    hexval = "{0:#0{1}x}".format(len(FileList),6) # number of files in HEX

    # header for label array

    header = array('B')
    header.extend([0,0,8,1,0,0])
    header.append(int('0x'+hexval[2:][:2],16))
    header.append(int('0x'+hexval[2:][2:],16))
    
    data_label = header + data_label

    # additional header for images array

    if max([width,height]) <= 256:
        header.extend([0,0,0,width,0,0,0,height])
    else:
        raise ValueError('Image exceeds maximum size: 256x256 pixels');

    header[3] = 3 # Changing MSB for image data (0x00000803)
    
    data_image = header + data_image

    output_file = open(name[1]+'-images-idx3-ubyte', 'wb')
    data_image.tofile(output_file)
    output_file.close()

    output_file = open(name[1]+'-labels-idx1-ubyte', 'wb')
    data_label.tofile(output_file)
    output_file.close()

# gzip resulting files

for name in Names:
    os.system('gzip '+name[1]+'-images-idx3-ubyte')
    os.system('gzip '+name[1]+'-labels-idx1-ubyte')