# X-class Day 3: Prototyping Source Classification with Random Forest

Using RandomForestClassifier from scikit learn to train a classifier on photon event metadata. 

Input files: 
 - event list 
 - region files for training sets (bkg and src)
 - region files for testing sets (bkg and src)

In [1]:
# %load class2.py
import pandas as pd
import numpy as np
import copy
#import seaborn as sns
import matplotlib as mpl
import matplotlib.cm as cm
from astropy.table import Table, vstack
from sklearn.ensemble import RandomForestClassifier

In [9]:
def parse_region_file(reg_file): 
    """parses a region file for coordinate positions and radii
    
    Region file syntax: 
        # Region file format: CIAO version 1.0
        circle(3448.493,4095.1661,20.325203)
        circle(3811.8263,4301.8327,20.325203)
        ...
        
    which is: generic_shape(x,y,radius)
    """
    f = open(reg_file,'r')
    x,y,r =[],[],[]
    for l in f.readlines(): 
        if l[0] not in ['#','\n',' ']:
            xt,yt,rt = l.split('(')[1].split(')')[0].split(',')
            x.append(float(xt));y.append(float(yt));r.append(float(rt))
    return np.asarray(x),np.asarray(y),np.asarray(r)

def get_events(evt,x,y,r=10.0):
    """get events from a position and calculate offset
    
    Event list and position x,y and radius are used to 
    trim a new event list into only the region of interest. 
    
    Non-offset x,y values are removed from new event list. 
    
    """
    revt = np.sqrt((evt['x']-x)**2+(evt['y']-y)**2)
    tevt = copy.copy(evt[revt<=r])
    tevt['xoff'] = tevt['x']-x
    tevt['yoff'] = tevt['y']-y
    del tevt['x']
    del tevt['y']
    return tevt

def merge_pos(evt,x,y,lab,r=10.0): 
    """Merges the x,y region of interests and associated labels. 
    
    Full event list is queryed for list of x,y and lab and new merged 
    products are returned. 
    """
    for i in range(len(x)): 
        if (i == 0): 
            mevt = get_events(evt,x[i],y[i],r)
            levt = np.zeros(len(mevt))
            levt[:] = lab[i]
        else:
            tmpevt = get_events(evt,x[i],y[i],r)
            tmplevt = np.zeros(len(tmpevt))
            tmplevt[:] = lab[i] 
            mevt = vstack([mevt,tmpevt])
            levt = np.hstack([levt,tmplevt])
    return mevt,levt

def build_rfc(evt,lab,rfc = None):
    """Build and fit random forest classifier. 
    
    Generates RFC if rfc is None, otherwise can use user-defined rfc. 
    
    """
    if (rfc is None): 
        rfc = RandomForestClassifier(n_estimators=200,oob_score=True)
    X = copy.copy(evt.to_pandas())
    Y = copy.copy(lab)
    rfc.fit(X.values,Y)
    return rfc,X.values,Y

def do_rfc(evt,rfc): 
    """Classify event list. 
    """
    X = copy.copy(evt.to_pandas())
    Y = rfc.predict(X.values)
    print "{0:0.1f} {1:0.1f} ({2})".format(100.*float(len(np.where(Y==0)[0]))/len(Y),
                                           100.*float(len(np.where(Y==1)[0]))/len(Y),len(Y))
    return Y  

In [10]:
# read all events 
e = Table.read('Data/evt_1229.fits',hdu=1)
# rid ec of bad columns
ec = copy.copy(e)
badcols = ['status','ccd_id','expno','node_id','chipx','chipy','tdetx','tdety','detx','dety','pi','pha']
for bc in badcols: 
    del ec[bc]

In [11]:
# read region positions:
b1x,b1y,b1r = parse_region_file('Data/b1_1229.reg')
b2x,b2y,b2r = parse_region_file('Data/b2_1229.reg')
s1x,s1y,s1r = parse_region_file('Data/src1_1229.reg')
s2x,s2y,s2r = parse_region_file('Data/src2_1229.reg')

In [13]:
# using b1,s1 as training, b2,s2 as
trnx,trny = np.hstack((s1x,b1x)),np.hstack((s1y,b1y))
trnl = np.hstack((np.ones(len(s1x)),np.zeros(len(b1x))))

In [14]:
trne,trnlab = merge_pos(ec,trnx,trny,trnl)

In [15]:
rfc,X,Y = build_rfc(trne,trnlab) 

print "OOB Score: {0}".format(rfc.oob_score_)
print sorted(zip(trne.colnames,rfc.feature_importances_),key=lambda q: q[1],reverse=True)

OOB Score: 0.681998864282
[('pha_ro', 0.2153833339454897), ('energy', 0.20705415666650728), ('yoff', 0.17866798299748449), ('xoff', 0.16223802147515634), ('time', 0.13131316254200606), ('fltgrade', 0.062291237039190282), ('grade', 0.043052105334165952)]


In [16]:
# sources
print "Source Tests"
print "BG%  S%   (N)"
for xi,yi in zip(s2x,s2y): 
    lp = do_rfc(get_events(ec,xi,yi),rfc)

Source Tests
BG%  S%   (N)
31.7 68.3 (120)
37.5 62.5 (128)
19.2 80.8 (151)
32.5 67.5 (154)
39.6 60.4 (101)
23.7 76.3 (198)


In [17]:
# bg
print "Background Tests"
print "BG%  S%   (N)"
for xi,yi in zip(b2x,b2y):
    lp = do_rfc(get_events(ec,xi,yi),rfc)

Background Tests
BG%  S%   (N)
63.3 36.7 (49)
70.0 30.0 (60)
69.6 30.4 (79)
67.2 32.8 (61)
60.5 39.5 (76)
73.0 27.0 (63)
57.6 42.4 (59)
71.2 28.8 (52)


In [18]:
lp = do_rfc(get_events(ec,4100.333,4130.125),rfc)

54.4 45.6 (90)


Good for high count sources. So-so for low count. 

Todo: 
 - determine dependency on radius of region considered 
 - mess with RFC parameters 
 - run on a list of star positions (2MASS/USNO) 