# Run prediction pipeline for radio-detectable AGN

Jupyter Notebook to run the prediction pipeline presented in Carvajal et al., 2023.

Example code uses data from IR-detected sources in either HETDEX Spring field or Stripe 82.

In [1]:
%matplotlib inline
# Static plots
import numpy as np
from pycaret import classification as pyc
from pycaret import regression as pyr
from joblib import load
import pandas as pd
import global_variables as gv
import global_functions as gf
import os
import subprocess

### Reading data

Select the field to predict sources.

In [2]:
used_field = 'HETDEX'  # 'HETDEX' or 'S82'

In [3]:
file_name_dict      = {'S82': gv.file_S82, 'HETDEX': gv.file_HETDEX}
file_name           = gv.model_path + file_name_dict[used_field]

Check that data files have been downloaded.

In [4]:
gf.download_from_zenodo(file_name_dict[used_field], gv.model_path)

File predicted_rAGN_HETDEX.parquet has been already downloaded


For this notebook, not all columns will be needed. Select those that will be used.

In [5]:
used_cols = ['Z', 'band_num', 'class', 'W1mproPM', 'W2mproPM', 'gmag', 'rmag', 'imag', 'zmag', 'ymag', 'W3mag', 'W4mag', 'Jmag', 'Hmag', 'Kmag', 'LOFAR_detect', 'radio_AGN']

Load file

In [6]:
data_df = pd.read_parquet(file_name, 
                          engine='fastparquet', 
                          columns=used_cols)

Create new columns with colours. Only create colours used by models in pipeline. They are listed in the article.

In [7]:
colours_AGN   = ['g_r', 'r_i', 'r_J', 'i_z', 'i_y', 'z_y', 'z_W2', 'y_J', 'y_W1', 'y_W2', 'J_H', 'H_K', 'H_W3', 'W1_W2', 'W1_W3', 'W3_W4']
colours_radio = ['g_r', 'g_i', 'r_i', 'r_z', 'i_z', 'z_y', 'z_W1', 'y_J', 'y_W1', 'J_H', 'H_K', 'K_W3', 'K_W4', 'W1_W2', 'W2_W3']
colours_z     = ['g_r', 'g_W3', 'r_i', 'r_z', 'i_z', 'i_y', 'z_y', 'y_J', 'y_W1', 'J_H', 'H_K', 'K_W3', 'K_W4', 'W1_W2', 'W2_W3']
    
new_colours   = list(np.unique(colours_AGN + colours_radio + colours_z))

In [8]:
mag_names = {'g': 'gmag', 'r': 'rmag', 'i':'imag', 'z': 'zmag', 'y': 'ymag', 'J': 'Jmag', 'H': 'Hmag', 'K': 'Kmag', 'W1': 'W1mproPM', 'W2': 'W2mproPM', 'W3': 'W3mag', 'W4': 'W4mag'}

In [9]:
for colour in new_colours:
    mag_a_str, mag_b_str = colour.split('_')
    new_col = data_df.loc[:, mag_names[mag_a_str]] - data_df.loc[:, mag_names[mag_b_str]]
    data_df[colour] = new_col

Load models

Check that model files have been downloaded.

In [10]:
gf.download_from_zenodo(gv.AGN_gal_model + '.pkl', gv.model_path)
gf.download_from_zenodo(gv.cal_AGN_gal_model, gv.model_path)
gf.download_from_zenodo(gv.radio_model + '.pkl', gv.model_path)
gf.download_from_zenodo(gv.cal_radio_model, gv.model_path)
gf.download_from_zenodo(gv.full_z_model + '.pkl', gv.model_path)

File classification_AGN_galaxy.pkl has been already downloaded
File cal_classification_AGN_galaxy.joblib has been already downloaded
File classification_radio_detection.pkl has been already downloaded
File cal_classification_radio_detection.joblib has been already downloaded
File regression_redshift.pkl has been already downloaded


In [11]:
AGN_gal_clf           = pyc.load_model(gv.model_path + gv.AGN_gal_model)  #
cal_AGN_gal_clf       = load(gv.model_path + gv.cal_AGN_gal_model)  # calibrated model
radio_det_AGN_clf     = pyc.load_model(gv.model_path + gv.radio_model)
cal_radio_det_AGN_clf = load(gv.model_path + gv.cal_radio_model)  # calibrated model
redshift_reg_rAGN     = pyr.load_model(gv.model_path + gv.full_z_model)

Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded
Transformation Pipeline and Model Successfully Loaded


#### Run prediction models

Run models over all sources in dataset. The user can select, afterwards, sources that would be predicted as radio-detectable AGN (or any other combination of predictions).
The following steps take some time to run.

Classify between AGN and galaxies.

In [12]:
data_df = gf.predict_AGN_gal(data_df, 
                             AGN_gal_clf,
                             cal_AGN_gal_clf,
                             gv.AGN_thresh,
                             gv.cal_AGN_thresh)

Classify between radio-detectable and non radio-detectable sources.

In [13]:
data_df = gf.predict_radio_det(data_df,
                               radio_det_AGN_clf,
                               cal_radio_det_AGN_clf,
                               gv.radio_thresh,
                               gv.cal_radio_thresh)

Predict photometric redshifts.

In [14]:
data_df = gf.predict_z(data_df, 
                       redshift_reg_rAGN)

Display ten first radio-detected AGN correctly predicted as such in data frame.

In [18]:
filter_data = np.array(data_df.loc[:, 'class'] == 1) & np.array(data_df.loc[:, 'pred_class_cal'] == 1) & np.array(data_df.loc[:, 'LOFAR_detect'] == 1) & np.array(data_df.loc[:, 'pred_radio_cal_AGN'] == 1)

In [19]:
display(data_df.loc[filter_data][:10])

Unnamed: 0_level_0,Z,band_num,class,W1mproPM,W2mproPM,gmag,rmag,imag,zmag,ymag,W3mag,W4mag,Jmag,Hmag,Kmag,LOFAR_detect,radio_AGN,H_K,H_W3,J_H,K_W3,K_W4,W1_W2,W1_W3,W2_W3,W3_W4,g_W3,g_i,g_r,i_y,i_z,r_J,r_i,r_z,y_J,y_W1,y_W2,z_W1,z_W2,z_y,pred_class,Score_AGN,Prob_AGN,pred_class_cal,pred_radio_AGN,Score_radio_AGN,Prob_radio_AGN,pred_radio_cal_AGN,pred_Z_rAGN
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
16456,0.846,8,1.0,18.337,18.363001,23.299999,21.8901,21.287201,20.9596,20.5193,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,-0.026001,1.667,1.693001,2.05,6.629999,2.012798,1.409899,0.7679,0.3276,4.4401,0.6029,0.9305,3.0693,2.182301,2.1563,2.622601,2.5966,0.4403,1,0.500016,0.657091,1,0,0.438512,0.410227,1,0.7586
27470,0.736,9,1.0,17.676001,17.408001,20.004499,19.639299,19.607401,19.268499,19.1992,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,0.268,1.006001,0.738001,2.05,3.334499,0.397099,0.3652,0.408201,0.338902,2.189299,0.031898,0.3708,1.749199,1.523199,1.791199,1.592499,1.860498,0.0693,1,0.500121,0.988672,1,0,0.310507,0.312719,1,0.7748
40804,1.078,9,1.0,17.311001,16.582001,18.525999,18.507601,18.4573,18.3627,18.2085,15.373,14.18,17.450001,17.24,16.59,1,1,0.65,1.867,0.210001,1.217,2.41,0.729,1.938001,1.209001,1.193,3.152999,0.068699,0.018398,0.2488,0.094601,1.0576,0.050301,0.144901,0.758499,0.897499,1.626499,1.051699,1.780699,0.1542,1,0.500122,0.989167,1,0,0.400925,0.38264,1,1.0918
44825,0.877,9,1.0,18.306,18.175001,19.906601,19.838699,19.6185,19.524599,19.334,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,0.130999,1.636,1.505001,2.05,3.236601,0.288101,0.067902,0.2845,0.093901,2.388699,0.2202,0.3141,1.883999,1.028,1.158998,1.218599,1.349598,0.190599,1,0.500121,0.988471,1,0,0.501761,0.455192,1,0.9805
45506,0.643,12,1.0,15.579,15.04,17.124001,17.214199,17.213301,17.3904,17.359301,14.250999,13.386,17.23,17.105,16.527,1,1,0.577999,2.854,0.125,2.276001,3.141001,0.539001,1.328001,0.789001,0.865,2.873001,-0.0893,-0.090199,-0.146,-0.177099,-0.0158,0.000898,-0.176201,0.129301,1.7803,2.319301,1.811399,2.3504,0.031099,1,0.500122,0.988938,1,0,0.875211,0.72647,1,0.6672
53799,0.712,9,1.0,18.17,18.677999,21.766199,21.3116,20.481701,19.868799,19.8018,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,-0.507999,1.5,2.007999,2.05,5.096199,1.284498,0.454599,0.679901,0.612902,3.861599,0.829899,1.442801,2.351799,1.6318,1.1238,1.698799,1.1908,0.066999,0,0.499981,0.349309,1,0,0.508161,0.459659,1,0.7047
58134,1.639,9,1.0,19.097,18.287001,21.1991,21.056299,20.6693,20.6933,20.603001,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,0.809999,2.427,1.617001,2.05,4.5291,0.5298,0.142801,0.066299,-0.024,3.606298,0.386999,0.362999,3.153,1.506001,2.316,1.5963,2.4063,0.0903,1,0.50012,0.9881,1,0,0.453618,0.421117,1,1.6791
62617,3.193,9,1.0,19.868,19.752001,21.003599,20.839701,20.9,20.7612,20.644501,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,0.115999,3.198,3.082001,2.05,4.333599,0.1036,0.163898,0.255499,0.1388,3.3897,-0.060299,0.078501,3.1945,0.776501,0.8925,0.8932,1.009199,0.116699,1,0.500119,0.987949,1,0,0.306487,0.309471,1,2.8276
63673,0.906,9,1.0,18.917,18.932001,22.117001,22.021601,21.5578,20.8374,20.2924,16.67,14.62,17.450001,17.24,16.59,1,1,0.65,0.57,0.210001,-0.08,1.97,-0.015001,2.247,2.262001,2.05,5.447001,0.5592,0.0954,1.2654,0.7204,4.5716,0.4638,1.1842,2.8424,1.375401,1.360399,1.920401,1.905399,0.545,1,0.500105,0.979736,1,0,0.6061,0.526861,1,0.932
69989,0.15,12,1.0,16.379999,16.229,18.2211,18.0998,17.5788,17.939301,17.516899,14.873,13.917,17.027,16.59,16.445,1,1,0.145,1.717,0.437,1.572,2.528,0.150999,1.506999,1.356,0.956,3.3481,0.6423,0.1213,0.061901,-0.3605,1.0728,0.521,0.1605,0.489899,1.1369,1.287899,1.559301,1.7103,0.422401,1,0.500117,0.986725,1,0,0.484179,0.442849,1,0.1523
