# Three band training optim sample

by Shooby on Sep 19th <br>
last edited: Sep 19th

In [8]:
import numpy as np
import astropy.io.fits as pyfits
import matplotlib.pyplot as plt
import astropy.wcs as wcs
from astropy import units as u
from astropy.coordinates import SkyCoord
from scipy import ndimage
from PIL import Image


%matplotlib inline

In [9]:
gs = pyfits.getdata('../../WFIRST_WPS/CANDELS_fits/gds.fits')
sel1 = (gs['zbest']>0.1)&(gs['zbest']<0.7)&(gs['CLASS_STAR']<0.8)&(gs['Hmag']<25.5)&(gs['ISOAREA_IMAGE_F160W']>=300)&(gs['B_IMAGE_1']/gs['A_IMAGE_1']>0.5)
print (len(gs[sel1]))


500


In [10]:
def radec2xy(ra,dec,wc):
    coords = SkyCoord(ra,dec, unit='deg')
    a=wcs.utils.skycoord_to_pixel(coords, wc, origin=0,mode=u'wcs')
    return a[0],a[1]
    
def cut(ra,dec,andaze,filename):
    '''gets coordinates of the galaxy and the filter to return a cutout
    (also called a postage stamp) of the galaxy with given size'''
    hdr = pyfits.getheader(filename)
    w = wcs.WCS(hdr)
    x,y=radec2xy(ra,dec,w)
    x,y=np.int(x),np.int(y)
    im=pyfits.getdata(filename)[y-andaze:y+andaze,x-andaze:x+andaze]
    return im

def segment(a1,above_mean=1.1,add_image_back=0.3,random_noise=0.2):
    '''Returns segmentation of an image by some filtering of lower level 
    noises, to restrict the fitting to the galaxy itself and mask the 
    surroundings'''
    im2=ndimage.gaussian_filter(a1, 2)
    mask = (im2 > above_mean*im2.mean()).astype(np.float)
    mask += add_image_back * im2
    img = mask + random_noise*np.random.randn(*mask.shape)
    hist, bin_edges = np.histogram(img, bins=60)
    bin_centers = 0.5*(bin_edges[:-1] + bin_edges[1:])
    binary_im = img > random_noise
    open_img = ndimage.binary_opening(binary_im)
    close_img = ndimage.binary_closing(open_img)
    return close_img

def brightest_center(im, r = 30):
    
    '''This function is to check whether the central object of the 
    image is the brightest compared to its neighbors in the given cutout.
    Central is defined with a 10x10 pixel square in the center'''
    
    a0,a1 = np.unravel_index(np.argmax(im, axis=None), im.shape)
    ans = False
    if ((a0>((im.shape[0]-r)/2)) & (a0<((im.shape[0]+r)/2)) & (a1>((im.shape[1]-r)/2)) & (a1<((im.shape[0]+r)/2))):
        ans = True
    
    return ans

In [11]:
ra1,dec1=gs['RA_1'][sel1],gs['DEC_1'][sel1]

for boz in range(3000):
    n = np.int(np.random.uniform(0,len(ra1)-1))
    data1 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f435w_060mas_v1.5_drz.fits')
    data2 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f850l_060mas_v1.5_drz.fits')
    data3 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_wfc3_ir_f160w_060mas_v1.0_drz.fits')

    if (brightest_center(data1) & brightest_center(data2)&(brightest_center(data3))) :
        angle = np.random.uniform(0,180)
        s1 = ndimage.rotate(data1,angle,mode='nearest',reshape=False)
        s2 = ndimage.rotate(data2,angle,mode='nearest',reshape=False)
        s3 = ndimage.rotate(data3,angle,mode='nearest',reshape=False)

        da1,da2,da3 = np.arcsinh(s1[10:-10,10:-10]),np.arcsinh(s2[10:-10,10:-10]),np.arcsinh(s3[10:-10,10:-10])
        rgbArray = np.zeros((64,64,3), 'uint8')
        rgbArray[..., 0] = (255.0 / (da1.max()+0.1) * (da1 - da1.min())).astype(np.uint8)
        rgbArray[..., 1] = (255.0 / (da2.max()+0.1) * (da2 - da2.min())).astype(np.uint8)
        rgbArray[..., 2] = (255.0 / (da3.max()+0.1) * (da3 - da3.min())).astype(np.uint8)
        im = Image.fromarray(rgbArray)
        im.save('images/training_images/1/'+str(boz)+'.jpg')

    n = np.int(np.random.uniform(0,len(ra1)-1))
    data1 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f435w_060mas_v1.5_drz.fits')
    data2 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f850l_060mas_v1.5_drz.fits')
    data3 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_wfc3_ir_f160w_060mas_v1.0_drz.fits')

    if (brightest_center(data1) & brightest_center(data2)&(brightest_center(data3))) :
        angle = np.random.uniform(0,180)
        s1 = ndimage.rotate(data1,angle,mode='nearest',reshape=False)
        s2 = ndimage.rotate(data2,angle,mode='nearest',reshape=False)
        s3 = ndimage.rotate(data3,angle,mode='nearest',reshape=False)

        da1,da2,da3 = np.arcsinh(s1[10:-10,10:-10]),np.arcsinh(s2[10:-10,10:-10]),np.arcsinh(s3[10:-10,10:-10])
        rgbArray = np.zeros((64,64,3), 'uint8')
        rgbArray[..., 0] = (255.0 / (da1.max()+0.1) * (da1 - da1.min())).astype(np.uint8)
        rgbArray[..., 1] = (255.0 / (da2.max()+0.1) * (da2 - da2.min())).astype(np.uint8)
        rgbArray[..., 2] = (255.0 / (da3.max()+0.1) * (da3 - da3.min())).astype(np.uint8)
        im = Image.fromarray(rgbArray)
        im.save('images/training_images/2/'+str(boz)+'.jpg')

    

    

In [12]:
ra1,dec1=gs['RA_1'][sel1],gs['DEC_1'][sel1]

for boz in range(1000):
    n = np.int(np.random.uniform(0,len(ra1)-1))
    data1 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f435w_060mas_v1.5_drz.fits')
    data2 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f850l_060mas_v1.5_drz.fits')
    data3 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_wfc3_ir_f160w_060mas_v1.0_drz.fits')

    if (brightest_center(data1) & brightest_center(data2)&(brightest_center(data3))) :
        angle = np.random.uniform(0,180)
        s1 = ndimage.rotate(data1,angle,mode='nearest',reshape=False)
        s2 = ndimage.rotate(data2,angle,mode='nearest',reshape=False)
        s3 = ndimage.rotate(data3,angle,mode='nearest',reshape=False)

        da1,da2,da3 = np.arcsinh(s1[10:-10,10:-10]),np.arcsinh(s2[10:-10,10:-10]),np.arcsinh(s3[10:-10,10:-10])
        rgbArray = np.zeros((64,64,3), 'uint8')
        rgbArray[..., 0] = (255.0 / (da1.max()+0.1) * (da1 - da1.min())).astype(np.uint8)
        rgbArray[..., 1] = (255.0 / (da2.max()+0.1) * (da2 - da2.min())).astype(np.uint8)
        rgbArray[..., 2] = (255.0 / (da3.max()+0.1) * (da3 - da3.min())).astype(np.uint8)
        im = Image.fromarray(rgbArray)
        im.save('images/test_images/1/'+str(boz)+'.jpg')

    n = np.int(np.random.uniform(0,len(ra1)-1))
    data1 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f435w_060mas_v1.5_drz.fits')
    data2 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_acs_wfc_f850l_060mas_v1.5_drz.fits')
    data3 = cut(ra1[n],dec1[n],42,'/Users/shemmati/Desktop/GOODS/goodss_all_wfc3_ir_f160w_060mas_v1.0_drz.fits')

    if (brightest_center(data1) & brightest_center(data2)&(brightest_center(data3))) :
        angle = np.random.uniform(0,180)
        s1 = ndimage.rotate(data1,angle,mode='nearest',reshape=False)
        s2 = ndimage.rotate(data2,angle,mode='nearest',reshape=False)
        s3 = ndimage.rotate(data3,angle,mode='nearest',reshape=False)

        da1,da2,da3 = np.arcsinh(s1[10:-10,10:-10]),np.arcsinh(s2[10:-10,10:-10]),np.arcsinh(s3[10:-10,10:-10])
        rgbArray = np.zeros((64,64,3), 'uint8')
        rgbArray[..., 0] = (255.0 / (da1.max()+0.1) * (da1 - da1.min())).astype(np.uint8)
        rgbArray[..., 1] = (255.0 / (da2.max()+0.1) * (da2 - da2.min())).astype(np.uint8)
        rgbArray[..., 2] = (255.0 / (da3.max()+0.1) * (da3 - da3.min())).astype(np.uint8)
        im = Image.fromarray(rgbArray)
        im.save('images/test_images/2/'+str(boz)+'.jpg')

    

    