# Use the model to get the cleaned frame

In [None]:
import json
import scipy

import numpy as np
import matplotlib.pyplot as plt

from gwpy.timeseries import TimeSeries
from gwpy.frequencyseries import FrequencySeries

import tensorflow as tf

import gwpy

## Load and prep the data

In [None]:
# load dataset params
with open(f'output/model/dataset_params.json') as dset_file:
    dset_params = json.load(dset_file)                                        

fs = dset_params['fs']
rec_size = dset_params['rec_size']
rec_future = dset_params['rec_future']
norm_factors_orig = dset_params['norm_values']

In [None]:
fs=512

In [None]:
def load_data(channels):
    
    dset = []
    norm_factors = []
    
    for channel in channels:
        
        # load data
        fname = f'data/{channel}' 
        data = np.load(fname)
        
        # remove sides due to whitening artifacts
        gps_times = data[fs*4:-fs*4, 0]
        data_tseries = data[fs*4:-fs*4, 1].reshape(-1, 1)
        
        # normalize and append tseries
        norm_factor = np.max(np.abs(data_tseries))
        norm_factors.append(norm_factor)
        dset.append(data_tseries / norm_factor)
    
    dset = np.squeeze(dset)
    dset = np.float32(dset)
    gps_times = np.array(gps_times)
    
    return dset, gps_times, norm_factors


channels = ['DCS-CALIB_STRAIN_CLEAN_C01_512Hz_event_4096s_whitened.npy',
            'LSC-POP_A_RF45_I_ERR_DQ_512Hz_event_4096s_whitened.npy',
            'LSC-POP_A_RF45_Q_ERR_DQ_512Hz_event_4096s_whitened.npy',
            'LSC-POP_A_RF9_I_ERR_DQ_512Hz_event_4096s_whitened.npy']

dset, gps_times, norm_factors = load_data(channels)

In [None]:
# re-normalize the data w.r.t. the dataset used in the training
dset_norm = dset.copy()
for i in range(4):
    dset_norm[i,:] = dset_norm[i,:] * norm_factors[i] / norm_factors_orig[i]

## Get the predicted noise in steps due to limited memory
Do this step only if there are no original_tseries.npy and noise_tseries.npy files in the /output folder

In [None]:
# get start/end idx for each step
break_points = np.linspace(0, len(dset[1,:]), num=32, endpoint=True)
idx_start = break_points.copy()
idx_start = idx_start[0:-1]
idx_start[1:] = idx_start[1:] - rec_size
idx_end = break_points[1:]

In [None]:
def get_arrays(dset, input_values, output_values, box_start, box_end, rec_size, rec_future):
    rec_past = rec_size - rec_future
    for i in range(box_start+rec_past, box_end-rec_future):
        array = np.array([dset[1, i-rec_past:i+rec_future],
                          dset[2, i-rec_past:i+rec_future],
                          dset[3, i-rec_past:i+rec_future]])
        input_values.append(array)
    output_values.append(dset[0, box_start+rec_past:box_end-rec_future])
    
    return input_values, output_values

In [None]:
model = tf.keras.models.load_model('output/model')
output_gps_times = gps_times[rec_size-rec_future:-rec_future]

prediction = np.empty([1,1])
output = np.empty([1,1])

for i in range(len(idx_start)):

    input_start_idx = int(idx_start[i])
    input_end_idx = int(idx_end[i])
    
    # create input/output arrays
    input_values = []
    output_values = []
    input_values, output_values = get_arrays(dset_norm, input_values, output_values, input_start_idx, 
                                             input_end_idx, rec_size, rec_future)

    input_values = np.array(input_values)
    output_values = np.hstack(output_values).reshape(-1,1)

    step_prediction = model.predict(input_values)
    
    prediction = np.concatenate([prediction, step_prediction])
    output = np.concatenate([output, output_values])
    
prediction = prediction[1:]
output = output[1:]

In [None]:
# save original and noisy tseries
original_tseries = np.array([np.squeeze(output_gps_times), np.squeeze(output)])
noise_tseries = np.array([np.squeeze(output_gps_times), np.squeeze(prediction)])
np.save('output/original_tseries', original_tseries)
np.save('output/noise_tseries', noise_tseries)

## Color, upsample the data and get the cleaned frame

In [None]:
# load tseries
data = np.load('output/noise_tseries.npy')
noise_tseries = TimeSeries(data[1,:], times=data[0,:])
gps_times = np.load('output/noise_tseries.npy')[0,:]

In [None]:
# load ASD
asd = FrequencySeries.read(f'data/DCS-CALIB_STRAIN_CLEAN_C01_512Hz_ASD.txt')

In [None]:
fs = 512
f_low = 10
asd_win = 4 

# create a filter in time domain
firwin = scipy.signal.firwin(asd_win*fs+1, [f_low], pass_zero=False, window='hann', fs=fs)

# covert it to freq domain and remove phase
ffirwin = np.abs(scipy.fft.rfft(firwin))

# multiply the filter with the asd
asd_filtered = ffirwin * asd

# convert the asd to tseries
time_asd = scipy.fft.irfft(asd_filtered)

# roll it and smooth the edges out
time_asd = np.roll(time_asd, len(time_asd)//2)
hann = scipy.signal.windows.hann(len(time_asd))
time_asd = time_asd * hann

# pad with zeros
zeros = np.zeros(len(noise_tseries))
zeros[0:len(time_asd)] = time_asd
time_asd = zeros

# convert back the ASD to freq series
freq_asd = np.abs(scipy.fft.rfft(time_asd))

# tseries in freq domain
fseries = scipy.fft.rfft(noise_tseries)

# color tseries
colored = scipy.fft.irfft(fseries * freq_asd) * norm_factors_orig[0]
colored_tseries = TimeSeries(colored[2048:-2048], times=gps_times[2048:-2048])

# upsample
colored_tseries_upsampled = colored_tseries.resample(4096)

In [None]:
# load 4096Hz original tseries and crop it
data = np.load(f'data/DCS-CALIB_STRAIN_CLEAN_C01_4096Hz_event_4096s.npy')
orig_tseries = TimeSeries(data[:,1], times=data[:,0])
orig_tseries_cropped = orig_tseries.crop(colored_tseries_upsampled.times[0], colored_tseries_upsampled.times[-1])

# get the cleaned tseries and save it
cleaned_tseries = orig_tseries_cropped - colored_tseries_upsampled[:-1]
cleaned_tseries_array = np.array([cleaned_tseries.times, cleaned_tseries])
np.save('output/DCS-CALIB_STRAIN_CLEAN_C01_4096Hz_event_cleaned', cleaned_tseries_array)

## Specgrams of the cleaned data

In [None]:
data = np.load(f'data/DCS-CALIB_STRAIN_CLEAN_C01_4096Hz_event_4096s.npy')
orig_tseries = TimeSeries(data[:,1], times=data[:,0])

data = np.load(f'output/DCS-CALIB_STRAIN_CLEAN_C01_4096Hz_event_cleaned.npy')
cleaned_tseries = TimeSeries(data[1,:], times=data[0,:])

In [None]:
# plotting params
gps = 1264316116.5
crop_win = 20
start_crop = gps - crop_win
end_crop = gps + crop_win
plot_win = 2
start_plot = gps - plot_win
end_plot = gps + plot_win

# crop data for faster q transforms
orig_tseries_cropped = orig_tseries.crop(start_crop,end_crop)
cleaned_tseries_cropped = cleaned_tseries.crop(start_crop, end_crop)

dataset = ['orig','clean','diff']
q_trans = {}
q_trans['orig'] = orig_tseries_cropped.q_transform(outseg=(start_plot,end_plot),qrange=(10,20))
q_trans['clean'] = cleaned_tseries_cropped.q_transform(outseg=(start_plot,end_plot),qrange=(10,20))
q_trans['diff'] = q_trans['orig'] - q_trans['clean']

ylim = (10, 512) 
alim = (0, 25)

label = {}
label['orig'] = 'Original data'
label['clean'] = 'Cleaned data'
label['diff'] = 'Original - Cleaned'

plot, axes = plt.subplots(nrows=3, sharex=True, figsize=(3.375*2.0,3.375*3.0))

for i, ax in zip(dataset,axes):

    pcm = ax.imshow(q_trans[i],vmin=alim[0],vmax=alim[1])
    ax.set_ylim(ylim[0],ylim[1])
    ax.set_xlabel('')
    ax.set_yscale('log')
    ax.plot([gps],10, label=label[i], visible=False)
    ax.grid(alpha=0.6)
    ax.legend(loc='upper left', handlelength=0, handletextpad=0)

axes[1].set_ylabel(r"$\mathrm{Frequency \ (Hz)}$")
axes[-1].set_xlabel(r"$\mathrm{Time \ (seconds)}$")
cbar = axes[0].colorbar(clim=(alim[0], alim[1]),location='top')
cbar.set_label(r"$\mathrm{Normalized \ energy}$");