In [None]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
import wget
import re
import csv
import urllib.request

## Use this to get data sets from Legacy Surveys
### Specifies files from Tractor to download, pulls info on objects, gets fits cutouts from skyviewer using coordinates, writes csv file to have the same format as the training classifications.csv, writes csvfile containing object info from Tractor, returns object dictionaries
### NOTE: The directories 'tractor_folders', 'tractor_fits' and 'fits_cutouts' must exist in the same folder as this notebook to run.

In [None]:
def fluxToMag(f):
    return (-2.5 * (np.log(f)/np.log(10.) - 9))

In [None]:
# !!! NOT USING THIS VERSION ANYMORE !!!
# - this one does cutouts last, but it's better to get a cutout after each object is pulled and then write csv
#   files afterwards

# t_folder = '000'-'359'
# n_files = int or 'all'
# n_objects = int or 'all'
# min_passes = int

# input above values. downloads files from Tractor, gets ra/dec of objects, downloads fits cutouts of those objects
# using legacysurvey skyviewer. also generates csv file to match "classifications.csv" (from lens challlenge), and a
# csv file containing all the information in the generated object dictionaries
def download_Tractor(csvfile, objectcsvfile, t_folder='000', n_objects='all', min_passes=2, counter_init=0):
    # download folder within tractor catalog from nersc portal, get html of folder webpage
    print('downloading Tractor files...')
    t_url = 'http://portal.nersc.gov/project/cosmo/data/legacysurvey/dr5/tractor/{}/'.format(t_folder)
    foldername = '/Users/mac/Desktop/LBNL/DR5/tractor_folders/tractor_' + t_folder
    t_folder_file = wget.download(t_url, foldername)
    with open(t_folder_file) as wgetfile:
        wgetdata = wgetfile.read()

    # search html code for filename 
    prog = re.compile('(?<=\.fits">).{21}')
    result = prog.findall(wgetdata)
    
    # create dictionaries for each object containing image, ra/dec, flux (mag), nobs (exposures)
    objects = [] # list of object dictionaries
    for f in result:
        # open tractor file, get object info
        print('opening file {}...'.format(f))
        filelink = t_url + f
        t_file = wget.download(filelink, '/Users/mac/Desktop/LBNL/DR5/tractor_fits/')
        with fits.open(t_file) as fit:
            filedata = fit[1].data
        # go through objects
        goodobj = 0
        for i in range(filedata.shape[0]):
            
            # check if got enough
            if len(objects) == n_objects:
                print('\nOKAY GOT ENOUGH BREAK NOWWWWW\n')
                break
                
            # determine coverage to see if object should be kept
            nobs_g = filedata[i][65]
            nobs_r = filedata[i][66]
            nobs_z = filedata[i][68]
            if nobs_g >= min_passes and nobs_r >= min_passes and nobs_z >= min_passes:
                ra = filedata[i][6]
                dec = filedata[i][7]
                print('\tgetting object at {} {}'.format(ra,dec))

                image = 'not cut yet'
                filename = 'not cut yet'
                brickname = filedata[i][2]
                objid = filedata[i][3]

                flux_g = filedata[i][17]
                mag_g = 0
                if flux_g != 0:
                    mag_g = fluxToMag(flux_g)

                flux_r = filedata[i][18]
                mag_r = 0
                if flux_r != 0:
                    mag_r = fluxToMag(flux_r)

                flux_z = filedata[i][20]
                mag_z = 0
                if flux_z != 0:
                    mag_z = fluxToMag(flux_z)

                object_dict = {'image':image,'brickname':brickname,'objid':objid,'filename':'not cut yet','ra':ra,'dec':dec,'flux_g':flux_g,'flux_r':flux_r,'flux_z':flux_z,'mag_g':mag_g,'mag_r':mag_r,'mag_z':mag_z,'nobs_g':nobs_g,'nobs_r':nobs_r,'nobs_z':nobs_z}
                objects.append(object_dict)    
                goodobj += 1
        print('got {} good objects from file {}'.format(goodobj, f))
        if len(objects) == n_objects:
            print('enough objects gathered.')
            break
            
            
            # this was to figure out the ACTUAL index, leaving here in case I need to search through catalog again
#             print(object_dict)
#             print(filedata[i])
#             print('\n')
#             for j in range(len(filedata[i])):
#                 try:
#                     print('{}: flux={}, mag={}'.format(j, filedata[i][j], fluxToMag(filedata[i][j])))
#                 except TypeError:
#                     print('incorrect input type for fluxToMag function')
#             print('\n\n\n\n')


    # now create csv file to match fits file images, imitating the lensfinder challenge format
    print('writing classification csv file...')
    with open(csvfile+'_'+t_folder+'.csv', 'w') as myFile:  
        # NOTE: can't get these to have double quotes (e.g. "ID") in the csv file. can only get none and """ID"""
        # not sure if this will affect how it works (lensfinder challenge csv headers have double quotes)
        myFields = ["ID","is_lens","Einstein_area","numb_pix_lensed_image","flux_lensed_image_in_sigma"]
        writer = csv.DictWriter(myFile, fieldnames=myFields)    
        writer.writeheader()
        counter=counter_init
        for i in range(len(objects)):
            # might want to change the ID for the training process if it makes it easier to iterate
            ID = '{:06d}'.format(counter)
            counter+=1
            writer.writerow({"ID":ID,"is_lens":0,"Einstein_area":'nan',"numb_pix_lensed_image":'nan',"flux_lensed_image_in_sigma":'nan'})

    # write csv file to contain information from object dictionaries
    print('writing object info csv file...')
    with open(objectcsvfile+'_'+t_folder+'.csv', 'w') as oFile:  
        # omitting 'image' bc it's a lot to put in the csv file and we already have it
        oFields = ['brickname','objid','filename','ra','dec','flux_g','flux_r','flux_z','mag_g','mag_r','mag_z','nobs_g','nobs_r','nobs_z']
        writer = csv.DictWriter(oFile, fieldnames=oFields)    
        writer.writeheader()
        for o in objects:
            writer.writerow({'brickname':o['brickname'],'objid':o['objid'],'filename':o['filename'],'ra':o['ra'],'dec':o['dec'],'flux_g':o['flux_g'],'flux_r':o['flux_r'],'flux_z':o['flux_z'],'mag_g':o['mag_g'],'mag_r':o['mag_r'],'mag_z':o['mag_z'],'nobs_g':o['nobs_g'],'nobs_r':o['nobs_r'],'nobs_z':o['nobs_z']})
    
    # get fits cutouts from ra/dec and add images to object dictionaries
    print('getting fits cutouts from skyviewer...')
    # try to break downloads into segments so they actually work
    counter = counter_init
    fails = 0
    for o in objects:
        # url specifies ra/dec as well as size (101), pixscale (0.262 is native) and layer (decals-dr5)
        url = 'http://legacysurvey.org/viewer/fits-cutout?ra={:5f}&dec={:5f}&size=101&layer=decals-dr5&pixscale=0.262&bands=grz'.format(o['ra'], o['dec'])
        # try to download the fits cutout from viewer
        try:
#             filename = urllib.request.urlretrieve(url, '/Users/mac/Desktop/LBNL/DR5/fits_cutouts/cutout_{:06d}.fits'.format(counter))
            filename = wget.download(url, '/Users/mac/Desktop/LBNL/DR5/fits_cutouts/cutout_{:06d}.fits'.format(counter))
            with fits.open(filename) as fit:
                o['image'] = fit[0].data
                if fails > 0:
                    print('a fits file was just succesfully cut from the viewer after an error on a previous file. nice!')
#         except:
#             print('some error happened well gosh darn it what is this from? let\'s just try and ignore it :)')
#             # try it again if fails
#             try:
#                 filename = wget.download(url, '/Users/mac/Desktop/LBNL/DR5/fits_cutouts/cutout_{:06d}.fits'.format(counter))
#                 with fits.open(filename) as fit:
#                     o['image'] = fit[0].data
#             except:
#                 print('yeah twice in a row we gonna raise it')
#                 raise
        except:
            fails += 1
            print('some issue happened getting cutout_{:06d}\n- object info -\nra, dec = {}, {}\nbrickname, objid = {}, {}'.format(counter, o['ra'], o['dec'], o['brickname'], o['objid']))
            o['image'] = 'ERROR: failed to load image from viewer during fits cutout download'
            if fails == 20:
                print('20 fails. Raising error')
                raise

        o['filename'] = 'cutout_{:06d}.fits'.format(counter)
        counter+=1 
    
    print('done.')
    return objects

In [None]:
test1 = download_Tractor(csvfile='class_test', objectcsvfile='objectinfo', t_folder='000', n_objects=100, min_passes=4, counter_init=0)
test2 = download_Tractor(csvfile='class_test', objectcsvfile='objectinfo', t_folder='001', n_objects=100, min_passes=4, counter_init=100)
test3 = download_Tractor(csvfile='class_test', objectcsvfile='objectinfo', t_folder='002', n_objects=100, min_passes=4, counter_init=200)

In [None]:
# !!! STILL NOT USING THIS VERSION - SWITCHED TO TERMINAL AND .PY FILE

# NEW VERSION, TRYING TO FIX WHEN OBJECTS ARE SKIPPED
# this version will continuously append to files and not do each step all at once, meaning each time an object is found we
# will immediately try to get cutout and then disregard if we can't

def download_Tractor2(csvfile, objectcsvfile, t_folder='000', n_objects='all', min_passes=2, counter_init=0):
    
    # download folder within tractor catalog from nersc portal, get html of folder webpage
    print('downloading Tractor files...')
    t_url = 'http://portal.nersc.gov/project/cosmo/data/legacysurvey/dr5/tractor/{}/'.format(t_folder)
    foldername = '/Users/mac/Desktop/LBNL/DR5/tractor_folders/tractor_' + t_folder
    t_folder_file = wget.download(t_url, foldername)
    with open(t_folder_file) as wgetfile:
        wgetdata = wgetfile.read()

    # search html code for filenames 
    prog = re.compile('(?<=\.fits">).{21}')
    result = prog.findall(wgetdata)
    
    # create dictionaries for each object containing image, ra/dec, flux (mag), nobs (exposures)
    objects = [] # list of object dictionaries
    counter = counter_init
    total_fails = 0
    
    # iterate through tractor filenames found
    for f in result:
        
        # open the file, get object info
        print('opening file {}...'.format(f))
        filelink = t_url + f
        t_file = wget.download(filelink, '/Users/mac/Desktop/LBNL/DR5/tractor_fits/')
        with fits.open(t_file) as fit:
            filedata = fit[1].data

        # go through objects in file
        goodobj = 0
        for i in range(filedata.shape[0]):
            
            # check if enough objects have already been added to the list
            if len(objects) == n_objects:
                break
            # stop program if it gets stalled after a failure and can't continue to get fits
            if total_fails == 15:
                print('!!! failed to get cutout 15 times - quitting early with recovered objects !!!')
                break
#                 raise TimeoutError('actually, that was 15 total failures while getting fits cutouts - quitting program. Figure your shit out dude.')
            
            # determine coverage to see if the object should be kept
            nobs_g = filedata[i][65]
            nobs_r = filedata[i][66]
            nobs_z = filedata[i][68]
            if nobs_g >= min_passes and nobs_r >= min_passes and nobs_z >= min_passes:

                # now see if we can get it from the viewer cutout
                ra = filedata[i][6]
                dec = filedata[i][7]
                print('\t attempting to get cutout for object at {} {}'.format(ra,dec))
                url = 'http://legacysurvey.org/viewer/fits-cutout?ra={}&dec={}&size=101&layer=decals-dr5&pixscale=0.262&bands=grz'.format(ra, dec)
                failed_attempts = 0
                failed = False
                retrieved = False
                while failed == False and retrieved == False:
                    # checks how many attempts have been made, moves onto next object if too many
                    if failed_attempts == 50:
                        failed = True
                    try:
                        filename = wget.download(url, '/Users/mac/Desktop/LBNL/DR5/fits_cutouts_{}/cutout_{:06d}.fits'.format(t_folder, counter))
                        with fits.open(filename) as fit:
                            image = fit[0].data
                        retrieved = True
                    except:
                        failed_attempts += 1
                        if failed_attempts % 10 == 0:
                            print('\t\tfailed attempt {}'.format(failed_attempts))
                        pass
   
                if failed:
                    # don't update object list or counter
                    total_fails += 1
                    print('\t\tfailed to retrieve cutout for object - moving on to next object...')
                
                if retrieved:
                    # get all remaining values needed for dictionary, update object list and counter
                    filename = 'cutout_{:06d}'.format(counter)
                    brickname = filedata[i][2]
                    objid = filedata[i][3]
                    flux_g = filedata[i][17]
                    mag_g = 0
                    if flux_g != 0:
                        mag_g = fluxToMag(flux_g)
                    flux_r = filedata[i][18]
                    mag_r = 0
                    if flux_r != 0:
                        mag_r = fluxToMag(flux_r)
                    flux_z = filedata[i][20]
                    mag_z = 0
                    if flux_z != 0:
                        mag_z = fluxToMag(flux_z)

                    object_dict = {'image':image,'brickname':brickname,'objid':objid,'filename':filename,'ra':ra,'dec':dec,'flux_g':flux_g,'flux_r':flux_r,'flux_z':flux_z,'mag_g':mag_g,'mag_r':mag_r,'mag_z':mag_z,'nobs_g':nobs_g,'nobs_r':nobs_r,'nobs_z':nobs_z}
                    objects.append(object_dict)
                    counter += 1
                    goodobj += 1
                    
        print('got {} good objects from {}'.format(goodobj, f))
        if len(objects) == n_objects:
            print('enough objects gathered.')
            break
        if total_fails == 15:
            break
    
    # check to make sure there were enough objects found in the folder of files
    if len(objects) < n_objects:
        print('\tonly able to retrieve {} good objects from folder {}'.format(len(objects), t_folder))


    # now create csv file to match fits file images, imitating the lensfinder challenge format
    print('writing classification csv file...')
    with open(csvfile+'_'+t_folder+'.csv', 'w') as myFile:  
        # NOTE: can't get fields to have double quotes (e.g. "ID") in the csv file. can only get none and """ID"""
        # not sure if this will affect how it works (lensfinder challenge csv headers have double quotes)
        myFields = ["ID","is_lens","Einstein_area","numb_pix_lensed_image","flux_lensed_image_in_sigma"]
        writer = csv.DictWriter(myFile, fieldnames=myFields)    
        writer.writeheader()
        counter=counter_init
        for i in range(len(objects)):
            # might want to change the ID for the training process if it makes it easier to iterate
            ID = '{:06d}'.format(counter)
            counter+=1
            writer.writerow({"ID":ID,"is_lens":0,"Einstein_area":'nan',"numb_pix_lensed_image":'nan',"flux_lensed_image_in_sigma":'nan'})

            
    # write csv file to contain information from object dictionaries
    print('writing object info csv file...')
    with open(objectcsvfile+'_'+t_folder+'.csv', 'w') as oFile:  
        # omitting 'image' bc it's a lot to put in the csv file and we already have it
        oFields = ['brickname','objid','filename','ra','dec','flux_g','flux_r','flux_z','mag_g','mag_r','mag_z','nobs_g','nobs_r','nobs_z']
        writer = csv.DictWriter(oFile, fieldnames=oFields)    
        writer.writeheader()
        for o in objects:
            writer.writerow({'brickname':o['brickname'],'objid':o['objid'],'filename':o['filename'],'ra':o['ra'],'dec':o['dec'],'flux_g':o['flux_g'],'flux_r':o['flux_r'],'flux_z':o['flux_z'],'mag_g':o['mag_g'],'mag_r':o['mag_r'],'mag_z':o['mag_z'],'nobs_g':o['nobs_g'],'nobs_r':o['nobs_r'],'nobs_z':o['nobs_z']})
    
    print('done.')
    return objects

In [None]:
test1 = download_Tractor2(csvfile='classifications', objectcsvfile='objectinfo', t_folder='000', n_objects=150, min_passes=4, counter_init=0)

In [None]:
# this function is for Chris to use on CMU DeepLens

# opens csv file made in download_Tractor, also uses object dictionary from the above cell to create an hdf5 file with
# the csv info and image array
def make_hdf5(csvfile, objects, imgs):
    # imgs = number of images to write
    cat = Table.read(csvfile)
    if imgs < len(cat):
        cat= cat[0:imgs]
    ims = np.zeros((imgs, 3, 101, 101)) # 4 -> 3
   
    # Loads the images from object dictionaries (previous cell)
    for o in objects:
        ims[i] = o['image']

    # Concatenate images to catalog
    cat['image'] = ims

    # Export catalog as HDF5 (should probably include path before 'catalogs_')
    cat.write('catalogs_'+str(imgs)+'.hdf5', path='/ground', append=True)

    ###### THIS MAY HAVE TO GO SOMEWHERE ELSE
    from astropy.table import Table
    # Loads the table created in the previous section
    d = Table.read(export_path+'catalogs_'+str(imgs)+'.hdf5', path='/ground') # Path to be adjusted on your machine
    x = array(d['image']).reshape((-1,3,101,101)) # 4 -> 3
    print x.shape
    y = array(d['is_lens']).reshape((-1,1))
    print y.shape


#### Also, in order to run this code on other computers, tractor_folders/, tractor_fits/, and fits_cutouts/ must be created and their paths added to the notebook

In [None]:
# just plotting images
for i in range(len(test)):
    image = test[i]['image']
    plt.figure(figsize=(10,7))
    plt.subplot(131)
    plt.title('IMAGES FOR OBJECT AT RA={}, DEC={}'.format(test[i]['ra'], test[i]['dec']))
    plt.imshow(image[0,:,:])
    plt.subplot(132)
    plt.imshow(image[1,:,:])
    plt.subplot(133)
    plt.imshow(image[2,:,:])
    
    

In [None]:
image.shape

In [None]:
for i in range(len(test)):
    image = test[i]['image']
    plt.imshow(image.T)
    plt.show()