In [None]:
import sys
sys.path.insert(0, '..')

from paus_utils import w_central, z_NB

from jpasLAEs.utils import flux_to_mag

import pickle

import numpy as np

from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn import model_selection
from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 12})

In [None]:
field_name = 'W3'
savedir = '/home/alberto/almacen/PAUS_data/LF_corrections'

nb_min, nb_max = 0, 18

with open(f'{savedir}/mock_dict_{field_name}_nb{nb_min}-{nb_max}.pkl', 'rb') as f:
    mock_dict = pickle.load(f)

del mock_dict['SFG']
# del mock_dict['GAL_artifact']

# print(mock_dict['GAL_artifact'].keys())
print(mock_dict.keys())

In [None]:
# Get the minimum number of candidates to set the set length
N_candidates_list = []
for mock_name, mock in mock_dict.items():
    z_phot = z_NB(mock['lya_NB'])
    nice_z = np.abs(mock['zspec'] - z_phot) < 0.12

    if mock_name in ['QSO_LAEs_loL', 'QSO_LAEs_hiL']:
        N_candidates_list.append(sum(mock['nice_lya_0'][nice_z]))
    else:
        N_candidates_list.append(sum(mock['nice_lya_0']))

set_len = np.min(N_candidates_list)
print(N_candidates_list)
print(f'{set_len=}')

In [None]:
# Make the set for each class
tt_set = None
labels = None
rmag = None
zspec = None
zphot = None
L_Arr = None

nice_z_list = []

for mock_name, mock in mock_dict.items():
    mock_len = len(mock['zspec'])
    nice_lya = mock['nice_lya_0']

    z_phot = z_NB(mock['lya_NB'])
    nice_z = np.abs(np.array(mock['zspec']) - z_phot) < 0.12

    np.random.seed(299792458)
    selection = np.random.choice(np.arange(mock_len)[nice_lya], set_len,
                                 replace=False)
    this_set = np.hstack([
        mock['flx'][:40, selection].T * 1e17, # NBs
        mock['lya_NB'][selection].reshape(-1, 1),
        mock['r_mag'][selection].reshape(-1, 1),
        mock['flx'][40:45, selection].T * 1e17, # BBs
    ])

    if tt_set is None:
        tt_set = this_set
        this_rmag = flux_to_mag(mock['flx'][-4, selection], w_central[-4])
        rmag = this_rmag
        zspec = mock['zspec'][selection]
        L_Arr = mock['L_lya'][selection]
        zphot = z_NB(mock['lya_NB'])[selection]
    else:
        tt_set = np.vstack([tt_set, this_set])

        this_rmag = flux_to_mag(mock['flx'][-4, selection], w_central[-4])
        rmag = np.concatenate([rmag, this_rmag])
        zspec = np.concatenate([zspec, mock['zspec'][selection]])
        L_Arr = np.concatenate([L_Arr, mock['L_lya'][selection]])
        zphot = np.concatenate([zphot, z_NB(mock['lya_NB'])[selection]])

    nice_z_list.append(nice_z[selection])


label_names = []
labels_list = [1, 2, 2, 4]
for j in range(len(mock_dict)):
    i = labels_list[j]

    this_labels = np.ones(set_len).astype(int) * i

    mock_name = list(mock_dict.keys())[j]
    print(f'{i} for {mock_name}')

    if mock_name in ['QSO_LAEs_loL', 'QSO_LAEs_hiL']:
        this_labels[~nice_z_list[j]] = 1

    if labels is None:
        labels = this_labels
    else:
        labels = np.concatenate([labels, this_labels])
    label_names.append(mock_name)
label_names.append('?')


print()
for lb in [1, 2, 4, 5]:
    print(f'{lb}: {sum(labels == lb)}')

In [None]:
# Train/Test split
split_seed = 299792458
x_train, x_test, y_train, y_test =\
    model_selection.train_test_split(tt_set, labels, test_size=0.2,
                                     random_state=split_seed)

## Pre-processing ##
x_train[:, :40] /= np.sum(x_train[:, :40], axis=1).reshape(-1, 1)
x_train[:, 42:47] /= np.sum(x_train[:, 42:47], axis=1).reshape(-1, 1)
x_train[40] /= 100.

x_test[:, :40] /= np.sum(x_test[:, :40], axis=1).reshape(-1, 1)
x_test[:, 42:47] /= np.sum(x_test[:, 42:47], axis=1).reshape(-1, 1)
x_test[40] /= 100.


## Scaler
scaler = MinMaxScaler()
# Apply scaling only to fluxes
scaler.fit(x_train[:, :47])
x_train[:, :47] = scaler.transform(x_train[:, :47])
x_test[:, :47] = scaler.transform(x_test[:, :47])

# # PCA
# pca = PCA(n_components=0.99, svd_solver='full')

# pca.fit(x_train)
# x_train = pca.transform(x_train)
# x_test = pca.transform(x_test)

print(x_train.shape)

In [None]:
def do_grid_search(algorithm, search_mode='random'):
    # Create the parameter grid based on the results of random search
    if algorithm == 'nn':
        param_grid = {
            'hidden_layer_sizes': [(60, 60), (60, 60, 60),
                                   (50, 50, 20), (40, 40, 20),
                                   (40, 30, 15)],
            'solver': ['adam'],
            'alpha': [1e-2, 1e-3, 1e-4],
            'batch_size': [50, 100, 250, 300],
            'learning_rate': ['adaptive', 'constant'],
            'max_iter': [10000]
        }
        # Create a based model
        model = MLPClassifier()
    elif algorithm == 'rf':
        param_grid = {
            'random_state': [22],
            'n_estimators': [1000],
            'bootstrap': [True, False],
            'max_depth': [20, 50, 70, 100],
            'min_samples_split': [2, 5, 10, 20],
            'min_samples_leaf': [1, 2, 4]
        }
        model = RandomForestClassifier()
    else:
        raise Exception('Model not known')

    # Instantiate the grid search model
    if search_mode == 'grid':
        grid_search = GridSearchCV(
            estimator=model, param_grid=param_grid,
            cv=3, n_jobs=-1, pre_dispatch='2*n_jobs',
            verbose=3,
        )
    elif search_mode == 'random':
        grid_search = RandomizedSearchCV(
            estimator=model, param_distributions=param_grid,
            cv=3, n_jobs=-1, pre_dispatch='2*n_jobs',
            verbose=3,
        )
    else:
        raise Exception('What?')

    grid_search.fit(x_train, y_train)

    return grid_search.best_params_

model = 'nn'
search_mode = 'grid'

best_params = do_grid_search(model, search_mode=search_mode)
# if model == 'nn':
#     best_params = {'activation': 'relu',
#     'alpha': 0.001,
#     'batch_size': 50,
#     'beta_1': 0.9,
#     'beta_2': 0.999,
#     'early_stopping': False,
#     'epsilon': 1e-08,
#     'hidden_layer_sizes': (60, 60, 60),
#     'learning_rate': 'constant',
#     'learning_rate_init': 0.001,
#     'max_fun': 15000,
#     'max_iter': 10000,
#     'momentum': 0.9,
#     'n_iter_no_change': 10,
#     'nesterovs_momentum': True,
#     'power_t': 0.5,
#     'random_state': None,
#     'shuffle': True,
#     'solver': 'adam',
#     'tol': 0.0001,
#     'validation_fraction': 0.1,
#     'verbose': False,
#     'warm_start': False} 
# elif model == 'rf':
#     best_params = {'random_state': 22, 'n_estimators': 1000, 'min_samples_split': 5, 'min_samples_leaf': 2, 'max_depth': 20, 'bootstrap': False, 'n_jobs': -1}
print(best_params)

In [None]:
if model == 'nn':
    cl_best = MLPClassifier(**best_params)
elif model == 'rf':
    cl_best = RandomForestClassifier(**best_params, n_jobs=-1)

cl_best.fit(x_train, y_train)
test_score = cl_best.score(x_test, y_test)
train_score = cl_best.score(x_train, y_train)
print(f'Score\n\nTrain: {train_score:0.3f}\nTest: {test_score:0.3f}')

In [None]:
# Predict test
pred_test = cl_best.predict(x_test)
log_p = cl_best.predict_log_proba(x_test)

for src in range(len(pred_test)):
    if pred_test[src] == 4:
        pred_i = 2
    elif pred_test[src] == 5:
        pred_i = 3
    else:
        pred_i = pred_test[src] - 1
    class_log_p = log_p[src, pred_i]
    if class_log_p < np.log(0.00000001):
        pred_test[src] = 6

In [None]:
# Save the classifier
save_dir = '/home/alberto/almacen/PAUS_data/ML_classifier'
with open(f'{save_dir}/source_classifier.sav', 'wb') as file:
    pickle.dump(cl_best, file)
# with open(f'{save_dir}/source_pca.sav', 'wb') as file:
#     pickle.dump(pca, file)
with open(f'{save_dir}/source_scaler.sav', 'wb') as file:
    pickle.dump(scaler, file)

In [None]:
rmag_train, rmag_test =\
    model_selection.train_test_split(rmag, test_size=0.2, random_state=split_seed)
zspec_train, zspec_test =\
    model_selection.train_test_split(zspec, test_size=0.2, random_state=split_seed)
L_Arr_train, L_Arr_test =\
    model_selection.train_test_split(L_Arr, test_size=0.2, random_state=split_seed)
zphot_train, zphot_test =\
    model_selection.train_test_split(zphot, test_size=0.2, random_state=split_seed)

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Compute confusion matrix
r_mask = (L_Arr_test >= 43.5) & (rmag_test < 23)
cm = confusion_matrix(y_test[r_mask], pred_test[r_mask],
                      labels=[1, 2, 4])

# Plot confusion matrix
cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :]
label_names_cm = ['QSO cont.', 'QSO LAE', r'Low-$z$ galaxy']
sns.heatmap(cm, annot=True, cmap='Blues', fmt='.2f',
            xticklabels=label_names_cm, yticklabels=label_names_cm,
            cbar=False, annot_kws={"fontsize":17})
plt.xlabel('Predicted Labels', fontsize=17)
plt.ylabel('True Labels', fontsize=17)
plt.savefig('../figures/NN_class_confusion_matrix.pdf', bbox_inches='tight', pad_inches=0.1,
            facecolor='w')
plt.show(block=False)

In [None]:
# import seaborn as sns
# from sklearn.metrics import confusion_matrix

# # Compute confusion matrix
# r_mask = (rmag_test >= 23)
# cm = confusion_matrix(y_test[r_mask], pred_test[r_mask])

# # Plot confusion matrix
# cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :]
# label_names_cm = ['QSO_cont', 'QSO_LAE', 'GAL', '?']
# sns.heatmap(cm, annot=True, cmap="Blues", fmt='.2f',
#             xticklabels=label_names_cm, yticklabels=label_names_cm,
#             cbar=False)
# plt.xlabel("Predicted Labels")
# plt.ylabel("True Labels")
# plt.title('r $\geq$ 22')
# plt.show(block=False)

In [None]:
# # Compute confusion matrix
# r_mask = (rmag_test > 22)
# cm = confusion_matrix(y_test[r_mask], pred_test[r_mask])

# # Plot confusion matrix
# cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :]
# sns.heatmap(cm, annot=True, cmap="Blues", fmt='.2f',
#             xticklabels=label_names, yticklabels=label_names,
#             cbar=False)
# plt.xlabel("Predicted Labels")
# plt.ylabel("True Labels")
# plt.title('$r > 22$')
# plt.show(block=False)

In [None]:
# from jpasLAEs.utils import bin_centers

# extra_mask = (rmag_test < 24)
# laes_as_laes = ((pred_test == 2) | (pred_test == 3)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_cont = ((pred_test == 1) | (pred_test == 4)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_gal = ((pred_test == 4)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_badqso = (pred_test == 1) & ((y_test == 2) | (y_test == 3)) & extra_mask
# badqso_as_badqso = (pred_test == 1) & (y_test == 1)
# gal_as_cont = (y_test == 4) & ((pred_test == 4) | (pred_test == 1))
# gal_as_lae = (y_test == 4) & ((pred_test == 2) | (pred_test == 3))

# z_bins = np.linspace(2.7, 4, 15)
# z_bins_c = bin_centers(z_bins)
# h_good_laes, _ = np.histogram(zphot_test[laes_as_laes], z_bins)
# h_bad_laes, _ = np.histogram(zphot_test[laes_as_cont], z_bins)
# h_laes_as_gal, _ = np.histogram(zphot_test[laes_as_gal], z_bins)
# h_laes_as_badqso, _ = np.histogram(zphot_test[laes_as_badqso], z_bins)
# h_gal_as_cont, _ = np.histogram(zphot_test[gal_as_cont], z_bins)
# h_gal_as_lae, _ = np.histogram(zphot_test[gal_as_lae], z_bins)
# h_all_gal, _ = np.histogram(zphot_test[pred_test == 4], z_bins)
# h_badqso_as_badqso, _ = np.histogram(zphot_test[badqso_as_badqso], z_bins)
# h_all_badqso, _ = np.histogram(zphot_test[pred_test == 1], z_bins)


# fig, ax = plt.subplots(figsize=(6, 4))

# ax.plot(z_bins_c, h_good_laes / (h_good_laes + h_bad_laes), label='LAEs as LAEs')
# ax.plot(z_bins_c, h_laes_as_gal / (h_good_laes + h_bad_laes), label='LAEs as GAL')
# ax.plot(z_bins_c, h_laes_as_badqso / (h_good_laes + h_bad_laes), label='LAEs as low-z QSO')
# ax.plot(z_bins_c, h_gal_as_cont / h_all_gal, label='GAL as GAL')
# ax.plot(z_bins_c, h_badqso_as_badqso / h_all_badqso, label='low-z QSO as low-z QSO')


# ax.legend(fontsize=11)
# ax.set_ylim(0, 1)

# ax.set_xlabel('zphot')

# plt.show(block=False)

In [None]:
# extra_mask = (rmag_test < 24)
# laes_as_laes = ((pred_test == 2) | (pred_test == 3)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_cont = ((pred_test == 1) | (pred_test == 4)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_gal = ((pred_test == 4)) & ((y_test == 2) | (y_test == 3)) & extra_mask
# laes_as_badqso = (pred_test == 1) & ((y_test == 2) | (y_test == 3)) & extra_mask
# badqso_as_badqso = (pred_test == 1) & (y_test == 1)
# gal_as_cont = (y_test == 4) & ((pred_test == 4) | (pred_test == 1))
# gal_as_lae = (y_test == 4) & ((pred_test == 2) | (pred_test == 3))

# L_lya_bins = np.linspace(42, 46, 15)
# L_lya_bins_c = bin_centers(L_lya_bins)
# h_good_laes, _ = np.histogram(L_Arr_test[laes_as_laes], L_lya_bins)
# h_bad_laes, _ = np.histogram(L_Arr_test[laes_as_cont], L_lya_bins)
# h_laes_as_gal, _ = np.histogram(L_Arr_test[laes_as_gal], L_lya_bins)
# h_laes_as_badqso, _ = np.histogram(L_Arr_test[laes_as_badqso], L_lya_bins)
# h_gal_as_cont, _ = np.histogram(L_Arr_test[gal_as_cont], L_lya_bins)
# h_gal_as_lae, _ = np.histogram(L_Arr_test[gal_as_lae], L_lya_bins)
# h_all_gal, _ = np.histogram(L_Arr_test[pred_test == 4], L_lya_bins)
# h_badqso_as_badqso, _ = np.histogram(L_Arr_test[badqso_as_badqso], L_lya_bins)
# h_all_badqso, _ = np.histogram(L_Arr_test[pred_test == 1], L_lya_bins)


# fig, ax = plt.subplots(figsize=(6, 4))

# ax.plot(L_lya_bins_c, h_good_laes / (h_good_laes + h_bad_laes), label='LAEs as LAEs')
# ax.plot(L_lya_bins_c, h_laes_as_gal / (h_good_laes + h_bad_laes), label='LAEs as GAL')
# ax.plot(L_lya_bins_c, h_laes_as_badqso / (h_good_laes + h_bad_laes), label='LAEs as low-z QSO')
# ax.plot(L_lya_bins_c, h_gal_as_cont / h_all_gal, label='GAL as GAL')
# ax.plot(L_lya_bins_c, h_badqso_as_badqso / h_all_badqso, label='low-z QSO as low-z QSO')


# ax.legend(fontsize=11)
# ax.set_ylim(0, 1)

# ax.set_xlabel('L_lya')

# plt.show(block=False)

In [None]:
# # ROC curves
# from sklearn.metrics import roc_curve

# # TRUE MEANS CONTAMINANT HERE

# y_binary = np.zeros_like(y_test).astype(bool)
# y_binary[y_test != 2] = True

# cont_p = 1 - np.exp(log_p[:, 1])

# fpr, tpr, thresholds = roc_curve(y_binary, cont_p)

# # Compute nice threshold
# fpr_thresh_Arr = np.array([0.1, 0.05, 0.01])
# thresh_Arr = np.empty_like(fpr_thresh_Arr)
# for i, this_fpr in enumerate(fpr_thresh_Arr):
#     thresh_Arr[i] = thresholds[fpr >= this_fpr][0]

# print(thresh_Arr)

# # Represent the ROC curve
# fig, ax = plt.subplots(figsize=(6, 4))

# ax.plot(fpr, tpr, lw=2)
# for thr in thresh_Arr:
#     # ax.axvline(fpr[thresholds == thr], ls='--', c='k')
#     ax.axvline(fpr[thresholds == thr], ls='--', c='k')

# ax.set_xlabel('FALSE POSITIVE RATE')
# ax.set_ylabel('TRUE POSITIVE RATE')
# ax.set_ylim(0, 1)
# ax.set_xlim(0, 1)

# plt.show(block=False)