# Author: Trevor Dorn-Wallenstein
# 11/15/17
# Let's design survey fields of $h+\chi$ Persei that cover a maximum number of a given set of stars.

In [1]:
import numpy as np, astropy.io.fits as fits, matplotlib.pyplot as plt
from astropy.table import Table
from scipy.optimize import minimize
import emcee as mc
from survey_tools import *
%matplotlib inline

In [2]:
member_hdu = fits.open('cluster_members.fits')
member_table = Table(member_hdu[1].data)
OB_table = member_table[member_table['SpT'] <= 20]

  return getattr(self.data, oper)(other)


In [3]:
OB_list = [Star(ra,dec) for ra,dec in zip(OB_table['RAJ2000'],OB_table['DEJ2000'])]

In [4]:
def make_survey_and_score(theta,args):
    """
    Wrapper for score_survey. Makes a new field_list, then scores the stars in star_list.
    
    Parameter
    ---------
    theta : list
        Should be a list of length N_fieldx2, with first N_field entries being RAs, next 
        N_field entries being Decs.
    args : tuple
        First entry should be a list of star objects in your survey. Second entry
        should be the size of the field of view of the camera in arcminutes. Third entry 
        should be the overlap_bonus
        
    Returns
    -------
    survey_score : float
        Score of this particular survey
    """
    
    try:
        assert len(theta) % 2 == 0
    except AssertionError as e:
        raise AssertionError("Length of theta should be N_fields x 2!")
    
    n_fields = len(theta) // 2
    ras = theta[:n_fields]
    decs = theta[n_fields:]
    
    star_list = args[0]
    field_size = args[1]
    overlap_bonus = args[2]
    
    field_list = [Field(ra,dec,field_size) for ra,dec in zip(ras,decs)]
    
    return score_survey(star_list,field_list,overlap_bonus)

In [5]:
def f_min(theta,args):
    
    return -1*make_survey_and_score(theta,args)

In [6]:
def design_survey(star_list, N_fields, N_start = None, overlap_bonus = 0.1):
    """
    Designs a survey by progressively eliminating the field that contributes to the survey
    value the least
    
    Parameters
    ----------
    star_list : list
        list of Star objects
    N_fields : int
        The target number of fields in the survey
    N_start : int
        If given, starts the survey with N_start random fields chosen from star_list
    overlap_bonus : float
        bonus you want to give to stars that appear in multiple fields
        
    Returns
    -------
    field_list : list
        list of Field objects that remain.
    """
    
    #Initialize fields from star list
    if N_start is not None:
        field_list = np.array([Field.from_star_list(star_list=star_list,size=3.0) for j in range(N_start)])
    else:
        field_list = np.array([Field.from_star_list(star_list=star_list,size=3.0,random=False,i=j) for j in range(len(star_list))])
    
    while len(field_list) != N_fields:
        
        scores = []
        
        for i,field in enumerate(field_list):
            
            #Get rid of one field
            trial_list = field_list[field_list != field]
            #Score the survey
            this_score = score_survey(star_list,trial_list,overlap_bonus)
            scores.append(this_score)
            
        field_list = np.delete(field_list,np.argmax(scores))
        print(len(field_list))
        
    return field_list

In [7]:
initial_list = initialize_fields(30,3.0,OB_list)

In [18]:
def lnprob(theta):
    
    ras = theta[:30]
    decs = theta[30:]

    field_list = [Field(ra,dec,3.0) for ra,dec in zip(ras,decs)]

    score = score_survey(OB_list,field_list,0.1)

    return np.log(score)

In [19]:
ndim, nwalkers = 80, 800

p0 = []

for field in initial_list:
    p0.append(field.ra)
for field in initial_list:
    p0.append(field.dec)
    
p0 = np.array(p0)

print(lnprob(p0))

6.67152648812


In [None]:
start_string = '# Region file format: DS9 version 4.1 \nglobal color=green dashlist=8 3 width=1 font="helvetica 10 normal roman" select=1 highlite=1 dash=0 fixed=0 edit=1 move=1 delete=1 include=1 source=1 \nicrs'

for field in the_survey[0]:
    
    reg_str = field.to_region_string()
    
    start_string += ' \n'
    
    start_string += reg_str

with open('field_reg.reg','w') as f:
    f.write(start_string)

start_string

In [None]:
the_survey