In [28]:
import rfpipe
# Updated for rfpipe version 1.3.1
from rfpipe import candidates
import numpy as np 
import pylab as plt
import matplotlib
import sys
import logging
from matplotlib import gridspec
from skimage.transform import resize
import scipy.signal as s

import h5py
import glob 
logger = logging.getLogger()
logger = logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(threadName)s - %(levelname)s -'
                                                        ' %(message)s')
%matplotlib inline

In [12]:
params = {
        'axes.labelsize' : 14,
        'font.size' : 9,
        'legend.fontsize': 12,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'text.usetex': False,
        'figure.figsize': [10, 5]
        }
matplotlib.rcParams.update(params)

In [13]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [4]:
# pklname = 'cands_'+cc.metadata.scanId +'.pkl'
pklname = '/lustre/aoc/observers/nm-10222/fetchrf/cands_16A-459_TEST_1hr_000.57633.66130137732.scan7.cut.7.1.pkl'

In [5]:
cds = list(rfpipe.candidates.iter_cands(pklname, select='canddata'))
cc = list(rfpipe.candidates.iter_cands(pklname, select='candcollection'))

In [14]:
def dedisperse(data, chan_freqs, tsamp, dms=None):
    nf, nt = data.shape
    assert nf == len(chan_freqs)
    delay_time = 4148808.0 * dms * (1 / (chan_freqs[0]) ** 2 - 1 / (chan_freqs) ** 2) / 1000
    delay_bins = np.round(delay_time / tsamp).astype('int64')
    ft = np.zeros(data.shape, dtype=np.float32)
    for ii in range(nf):
        ft[ii, :] = np.concatenate([data[ii,-delay_bins[ii]:], data[ii, :-delay_bins[ii]]])
    return ft

def make_dmt(ft, dmi, dmf, dmsteps, chan_freqs, tsamp):
    dm_list = np.linspace(dmi, dmf, dmsteps)
    dmt = np.zeros((dmsteps, ft.shape[1]), dtype=np.float32)
    for ii, dm in enumerate(dm_list):
        dmt[ii, :] = dedispersedts(ft, chan_freqs, tsamp, dms=dm)
    return dmt

def dedispersedts(data, chan_freqs, tsamp, dms=None):
    nf, nt = data.shape
    assert nf == len(chan_freqs)
    delay_time = 4148808.0 * dms * (1 / (chan_freqs[0]) ** 2 - 1 / (chan_freqs) ** 2) / 1000
    delay_bins = np.round(delay_time / tsamp).astype('int64')
    ts = np.zeros(nt, dtype=np.float32)
    for ii in range(nf):
        ts += np.concatenate([data[ii,-delay_bins[ii]:], data[ii, :-delay_bins[ii]]])
    return ts

In [15]:
def pkl_to_h5(pklfile, save_png = True, outdir = None, show = False):
    cds = list(rfpipe.candidates.iter_cands(pklfile, select='canddata'))
    cds_to_h5(cds, save_png, outdir, show)

In [16]:
def cds_to_h5(cds, save_png = True, outdir = None, show = False):
    for cd in cds:
        logging.info(f'Processing candidate at candloc {cd.loc}')
        if cd.data.any():
            cd_to_h5(cd, save_png, outdir, show)
        else:
            logging.warning('Canddata is empty. Skipping Candidate')

In [17]:
def cd_to_h5(cd, save_png = True, outdir = None, show = False):
    dtarr_ind = cd.loc[3]
    width_m = cd.state.dtarr[dtarr_ind]
    timewindow = cd.state.prefs.timewindow
    tsamp = cd.state.inttime*width_m
    dm = cd.state.dmarr[cd.loc[2]]
    ft_dedisp = np.flip(np.abs(cd.data[:,:,0].T) + np.abs(cd.data[:,:,1].T), axis=0)
    chan_freqs = np.flip(cd.state.freq*1000) #from high to low, MHz
    nf, nt = np.shape(ft_dedisp)
    
    logging.info(f'Size of the FT array is {(nf, nt)}')
    
    # If timewindow is not set during search, set it equal to the number of time bins of candidate
    if nt != timewindow:
        logging.info(f'Setting timewindow equal to nt = {nt}')
        timewindow = nt

    try:
        assert nf == len(chan_freqs) 
    except AssertionError as err:
        logging.exception("Number of frequency channel in data should match the frequency list")
        raise err

    dispersed = dedisperse(ft_dedisp, chan_freqs, tsamp, -1*dm)

    if dm is not 0:
        dm_start = 0
        dm_end = 2*dm
    else:
        dm_start = -10
        dm_end = 10

    logging.info(f'Generating DM-time for DM range {dm_start:.2f} pc/cc to {dm_end:.2f} pc/cc')
    dmt = make_dmt(dispersed, dm_start, dm_end, 256, chan_freqs, tsamp)

    reshaped_ft = resize(ft_dedisp, (256, 256), anti_aliasing=True)
    reshaped_dmt = resize(dmt, (256, 256), anti_aliasing=True)
    
    segment, candint, dmind, dtind, beamnum = cd.loc
    if outdir is not None:
        fnout = outdir+'cands_{0}_seg{1}-i{2}-dm{3}-dt{4}'.format(cd.state.fileroot, segment, candint,dmind, dtind)
    else:
        fnout = 'cands_{0}_seg{1}-i{2}-dm{3}-dt{4}'.format(cd.state.fileroot, segment, candint,dmind, dtind)

    with h5py.File(fnout+'.h5', 'w') as f:
        freq_time_dset = f.create_dataset('data_freq_time', data=reshaped_ft)
        freq_time_dset.dims[0].label = b"time"
        freq_time_dset.dims[1].label = b"frequency"

        dm_time_dset = f.create_dataset('data_dm_time', data=reshaped_dmt)
        dm_time_dset.dims[0].label = b"dm"
        dm_time_dset.dims[1].label = b"time"

    logging.info(f'Saved h5 as {fnout}.h5')

    if save_png:
        ts = np.arange(timewindow)*tsamp
        fig, ax = plt.subplots(nrows=2, ncols=1, figsize = (8,10), sharex=True)
        ax[0].imshow(reshaped_ft, aspect='auto', extent=[ts[0], ts[-1], np.min(chan_freqs), np.max(chan_freqs)])
        ax[0].set_ylabel('Freq')
        ax[0].title.set_text('Dedispersed FT')
        ax[1].imshow(reshaped_dmt, aspect='auto', extent=[ts[0], ts[-1], dm+1*dm, dm-dm])
        ax[1].set_ylabel('DM')
        ax[1].title.set_text('DM-Time')
        ax[1].set_xlabel('Time (s)')
        plt.tight_layout()
        plt.savefig(fnout+'.png')
        logging.info(f'Saved png as {fnout}.png')
        if show:
            plt.show()
        else:
            plt.close()

In [None]:
def prepare_to_classify(ft, dmt):
    data_ft = s.detrend(np.nan_to_num(ft))
    data_ft /= np.std(data_ft)
    data_ft -= np.median(data_ft)
    data_dt = np.nan_to_num(dmt)
    data_dt /= np.std(data_dt)
    data_dt -= np.median(data_dt)
    X = np.reshape(data_ft, (256,256, 1))
    Y = np.reshape(data_dt, (256,256, 1))

    X[X != X] = 0.0
    Y[Y != Y] = 0.0
    X = X.reshape(-1, 256,256, 1)
    Y = Y.reshape(-1, 256,256, 1)

    X = X.copy(order='C')
    Y = Y.copy(order='C')

    payload = {"data_freq_time": X.tolist(), "data_dm_time": Y.tolist()}
    return payload

In [49]:
def classify_cd(cd, KERAS_REST_API_URL):
    if not cd.data.any():
        logging.warning('Canddata is empty')
        frb_prob = -1
    else:
        dtarr_ind = cd.loc[3]
        width_m = cd.state.dtarr[dtarr_ind]
        timewindow = cd.state.prefs.timewindow
        tsamp = cd.state.inttime*width_m
        dm = cd.state.dmarr[cd.loc[2]]
        ft_dedisp = np.flip(np.abs(cd.data[:,:,0].T) + np.abs(cd.data[:,:,1].T), axis=0)
        chan_freqs = np.flip(cd.state.freq*1000) #from high to low, MHz
        nf, nt = np.shape(ft_dedisp)

        logging.info(f'Size of the FT array is {(nf, nt)}')

        # If timewindow is not set during search, set it equal to the number of time bins of candidate
        if nt != timewindow:
            logging.info(f'Setting timewindow equal to nt = {nt}')
            timewindow = nt

        try:
            assert nf == len(chan_freqs) 
        except AssertionError as err:
            logging.exception("Number of frequency channel in data should match the frequency list")
            raise err

        dispersed = dedisperse(ft_dedisp, chan_freqs, tsamp, -1*dm)

        if dm is not 0:
            dm_start = 0
            dm_end = 2*dm
        else:
            dm_start = -10
            dm_end = 10

        logging.info(f'Generating DM-time for DM range {dm_start:.2f} pc/cc to {dm_end:.2f} pc/cc')
        dmt = make_dmt(dispersed, dm_start, dm_end, 256, chan_freqs, tsamp)

        # TODO: Test if median padding in time is better than resize 
        reshaped_ft = resize(ft_dedisp, (256, 256), anti_aliasing=True)
        reshaped_dmt = resize(dmt, (256, 256), anti_aliasing=True)

        payload = prepare_to_classify(reshaped_ft, reshaped_dmt)

        response = requests.post(KERAS_REST_API_URL, json=payload).json()
        frb_prob = response["predictions"][0][1]
        logging.info(f'FRB probability of this candidate is: {frb_prob}')
    return frb_prob 

In [47]:
pkllist = glob.glob('/lustre/aoc/observers/nm-10222/rfgpu/Refinement/*pkl')
cds = list(rfpipe.candidates.iter_cands(pkllist[0], select='canddata'))
cd = cds[3]

In [55]:
import requests
KERAS_REST_API_URL = "http://localhost:5000/predict"
frb_prob = classify_cd(cd, KERAS_REST_API_URL)

2019-05-15 22:13:36,616 - root - INFO - Size of the FT array is (256, 66)
2019-05-15 22:13:36,618 - root - INFO - Generating DM-time for DM range 0.00 pc/cc to 1173.80 pc/cc
2019-05-15 22:13:37,265 - root - INFO - FRB probability of this candidate is: 0.9996829032897949


 

In [None]:
for pkl in pkllist:
    pkl_to_h5(pklfile=pkl, show=False, save_png=True, outdir='/lustre/aoc/observers/nm-10222/fetchrf/h5s/')