In [226]:
import csv
from collections import defaultdict
import re
import itertools
import math
import numpy as np
from sklearn import model_selection, ensemble, preprocessing, tree, metrics, externals

In [7]:
def iterate_rows():
    with open('simbadresult.csv') as infile:
        reader = csv.DictReader(infile)
        for row in reader:
            yield row

In [146]:
unique_objects = set()
all_filters = set()
unique_sptypes = set()

# sptype_regex = re.compile(r'[OBAFGKM][0-9](I|II|III|IV|V)')
# For now only handle main sequence stars
sptype_regex = re.compile(r'(?P<typ>[OBAFGKM])(?P<cls>[0-9](\.[0-9]+)?)V')

sp_type_mapping = {value: index for (index, value) in enumerate('OBAFGKM')}

def parse_sptype(sptype):
    match = sptype_regex.match(sptype)
    if match:
        return match.group(0)

def sptype_float(sptype):
    match = sptype_regex.match(sptype)
    if match:
        int_component = sp_type_mapping[match.group('typ')]
        decimal_component = float(match.group('cls')) / 9.
        return int_component + decimal_component
        
for i, row in enumerate(iterate_rows()):
    unique_objects.add(row['main_id'])
    all_filters.add(row['filter'])
    sptype = parse_sptype(row['sp_type'])
    if sptype is None:
        continue
    unique_sptypes.add(sptype)

In [147]:
all_filters = [item for item in sorted(all_filters) if item.isupper()]
print("{} unique objects".format(len(unique_objects)))
print('Including {} filters: {}'.format(len(all_filters), all_filters ))

534451 unique objects
Including 8 filters: ['B', 'H', 'I', 'J', 'K', 'R', 'U', 'V']


In [148]:
filter_ordering = ['U', 'u', 'B', 'V', 'g', 'R', 'r', 'I', 'i', 'z', 'J', 'H', 'K']
all_filters.sort(key=lambda f: filter_ordering.index(f))
print(all_filters)

['U', 'B', 'V', 'R', 'I', 'J', 'H', 'K']


In [149]:
valid_colours = []
for start_band, end_band in itertools.product(all_filters, all_filters):
    if start_band.lower() == end_band.lower():
        continue
        
    if filter_ordering.index(start_band) >= filter_ordering.index(end_band):
        continue

    valid_colours.append((start_band, end_band))

In [150]:
rows = {}
for i, row in enumerate(iterate_rows()):
    sp_type = parse_sptype(row['sp_type'])
    if not sp_type:
        continue
        
    mag_label = row['filter']
    mag_value = float(row['flux'])
    obj_id = row['main_id']
        
    if obj_id in rows:
        if 'sp_type' not in rows[obj_id]:
            rows[obj_id]['sp_type'] = sp_type
        rows[obj_id][mag_label] = mag_value
    else:
        rows[obj_id] = {'sp_type': sp_type}
        for filt in all_filters:
            rows[obj_id][filt] = float('nan')
        rows[obj_id][mag_label] = mag_value

rows = list(rows.values())
print(rows[0])

{'R': nan, 'r': 18.997, 'g': 20.265, 'i': 18.508, 'U': nan, 'I': nan, 'K': nan, 'J': nan, 'u': 22.652, 'sp_type': 'M0V', 'V': nan, 'B': nan, 'H': nan, 'z': 18.124}


In [151]:
X, y = [], []
for row in rows:
    entry = []
    for start_band, end_band in valid_colours:
        colour_value = row[start_band] - row[end_band]
        entry.append(colour_value)
    if all(math.isnan(value) for value in entry):
        continue
    X.append(entry)
    y.append(row['sp_type'])
X, y = [np.array(data) for data in [X, y]]

Get rid of the nan values

In [200]:
imp = preprocessing.Imputer()
X_valid = imp.fit_transform(X)

In [198]:
def sptype_error(a, b):
    a = np.array([sptype_float(val) for val in a])
    b = np.array([sptype_float(val) for val in b])
    
    return np.sum((a - b) ** 2)

sptype_score = metrics.make_scorer(sptype_error, greater_is_better=False)

In [221]:
params = {
    'n_estimators': [3, 10, 50],
}
clf = model_selection.GridSearchCV(ensemble.RandomForestClassifier(), params,
                                   scoring=sptype_score, n_jobs=-1)
clf.fit(X_valid, y)



GridSearchCV(cv=None, error_score='raise',
       estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=10, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False),
       fit_params={}, iid=True, n_jobs=-1,
       param_grid={'n_estimators': [3, 10, 50]}, pre_dispatch='2*n_jobs',
       refit=True, return_train_score=True,
       scoring=make_scorer(sptype_error, greater_is_better=False),
       verbose=0)

In [222]:
clf_best = clf.best_estimator_
X_train, X_test, y_train, y_test = model_selection.train_test_split(X_valid, y)
clf_best.fit(X_train, y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=50, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False)

In [223]:
prediction = clf_best.predict(X_test).ravel()
total_error = sptype_error(prediction, y_test)
print(total_error)

4578.67237654


In [224]:
list(zip(prediction, y_test))[:10]

[('F3V', 'F5V'),
 ('A0V', 'A0V'),
 ('F5V', 'K0V'),
 ('F6V', 'F6V'),
 ('A8V', 'B9V'),
 ('F0V', 'F2V'),
 ('A1V', 'B9V'),
 ('A2V', 'A2V'),
 ('M3V', 'M6V'),
 ('G0V', 'G0V')]