In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import h5py
import os
import scipy.stats as stats
from scipy.interpolate import griddata
from scipy.integrate import simpson

sns.set_theme('poster')

def GetTemplateFiles(config):
    template_path = '/storage/home/adz6/group/project/datasets/data/dense_template_grid'
    
    radii, angles = np.meshgrid(config['r'], config['pa'])
    
    template_list = []
    for i, param_grp in enumerate(zip(radii.flatten(), angles.flatten())):
        for j, template_file in enumerate(os.listdir(template_path)):
            #scores_name = score_path.split('/')[-1]

            angle = template_file.split('grid_')[-1].split('_')[0]
            rad = template_file.split('cm.h5')[0].split('_')[-1]
            
            if (param_grp[0] == int(rad)) and (param_grp[1] == float(angle)):
                #print(template_file)
                template_list.append(os.path.join(template_path, template_file))

    return template_list

config = {
    'r': [1,],
    'pa': [86.0,87.0, 88.0],
}


def ModifySignalLength(signal, sig_length, length):
    
    signal_array = np.copy(signal)[:, 0:length] # np.zeros((signal.shape[0], length), dtype=np.complex64)
    
    if sig_length <= length:
    
        signal_array[:, sig_length:] = 0
        
    return signal_array


def ModifySignalStartTime(signal, diff, length):
    
    signal_array = np.zeros((signal.shape[0], length), dtype=np.complex64)
    
    if diff >= 0:
        signal_array[:, diff:] = signal[:, 0:length-diff]
    elif diff < 0:
        signal_array[:, :] = signal[:, -diff:length-diff]
        
    return signal_array


def ModifySignal(signal, sig_length, diff, length):
    
    # if the signal is long enough such that it will fill the template
    # then we only need to shift start time
    if (sig_length >= (length - diff)):
        
        mod_signal = ModifySignalStartTime(signal, diff, length)
        return mod_signal
    # if the signal is shorter than the negative start time offset 
    # combined with template length it will end before the template
    # then we shift the signal and change the length
    elif (sig_length < (length - diff)):
        
        mod_signal = ModifySignalStartTime(signal, diff, length)
        # after the shift the length of the signal in the template
        # window will be (sig_length + diff) since diff is negative
        mod_signal = ModifySignalLength(mod_signal, sig_length + diff, length)
        
        return mod_signal

    
def ComputeScoreMaps(templates, length = 8192, mapsize=51, interleave_ind = 2):
    
    score_grid = np.zeros((mapsize, mapsize))
    test_score_grid = np.zeros((mapsize, mapsize))


    start_time_diff_list = np.int32(np.linspace(0, length/interleave_ind, mapsize))


    var = 1.38e-23 * 10 * 50 * 200e6  / length
    signal_lengths = np.int32(np.linspace(1, 2 * length, mapsize))

    rng = np.random.default_rng()
    rand_int = rng.integers(0, templates['data'].shape[0], 1)

    template = templates['data'][rand_int, :]

    template = template.reshape(60, template.shape[-1] // 60)

    signal = np.copy(template)

    template = np.fft.fft(template[:, 0:length], axis = -1) / (length)
    norm = 1 / np.sqrt(var * np.vdot(template.flatten(), template.flatten()))
    template = norm * template

    for n, start_time_diff in enumerate(start_time_diff_list):

        for i, sig_length in enumerate(signal_lengths):

            signal_in_window = ModifySignal(signal, sig_length, start_time_diff, length)

            signal_in_window = np.fft.fft(signal_in_window, axis=-1) / length

            #score_array[i] = abs(signal_in_window.conjugate() * template).sum()
            score_grid[n, i] = abs(signal_in_window.conjugate() * template).sum()
            norm = abs( 1 / np.sqrt( var * np.vdot(signal_in_window.flatten(), signal_in_window.flatten())))
            test_score_grid[n, i] = abs(norm * signal_in_window.conjugate() * signal_in_window).sum()
            
    return (score_grid, test_score_grid)


def AnalyticMatchMap(template_length, shape, interleave_ind = 2):
    max_sig_length = 30
    signal_lengths = np.linspace(1 / template_length, max_sig_length, shape[0]) # units of fft lengths
    signal_start_offsets = np.linspace(0, template_length / interleave_ind, shape[1]) / template_length
    
    #signal_lengths, signal_start_offsets = np.meshgrid(signal_lengths, signal_start_offsets)
    match_map = np.zeros(shape)
    
    #print(np.argwhere(signal_lengths <= 1).squeeze()[-1])
    match_map[:, 0:np.argwhere(signal_lengths <= 1).squeeze()[-1]] = np.sqrt(signal_lengths[ 0:np.argwhere(signal_lengths <= 1).squeeze()[-1]])
    print(match_map[:, 0:16])
    



def EstimateMeanTemporalMatch(templates, mean_track_length, template_length=8192, temporal_mapsize=51, interleave_ind = 2):
    
    n_pdf = 601
    max_track_len = 30
    x_pdf = np.linspace(2/template_length, max_track_len , n_pdf)
    extended_match_grid = np.zeros((temporal_mapsize, n_pdf))
    
    signal_lengths = np.int32(np.linspace(1, 2 * template_length, temporal_mapsize)) / template_length
    signal_start_offsets = np.int32(np.linspace(0, template_length / interleave_ind, temporal_mapsize))
    
    signal_length_extension = np.linspace(2+signal_lengths[2]-signal_lengths[1], max_track_len, n_pdf - temporal_mapsize)
    signal_length_interpolation = np.zeros(n_pdf)
    signal_length_interpolation[0:temporal_mapsize] = signal_lengths
    signal_length_interpolation[temporal_mapsize:] = signal_length_extension
    
    score_grids = ComputeScoreMaps(templates, length = template_length, mapsize = temporal_mapsize, interleave_ind=interleave_ind)
    match_grid = score_grids[0] / score_grids[1]
    #print(np.argwhere(np.isnan(match_grid)))
    #print(match_grid)
    
    extended_match_grid[:, 0:temporal_mapsize] = match_grid
    
    for extension_col in range(n_pdf - temporal_mapsize):
        extended_match_grid[:, temporal_mapsize + extension_col] = match_grid[:, -1]
    
    #plt.figure(figsize=(13, 8))
    #plt.imshow(extended_match_grid, aspect='auto')
    #plt.show()
    signal_lengths_pdf, signal_start_offsets_pdf = np.meshgrid(x_pdf, signal_start_offsets)
    signal_length_interpolation, signal_start_offset_interpolation = np.meshgrid(signal_length_interpolation, signal_start_offsets)
    
    match_interpolation = griddata(
        (signal_length_interpolation.flatten(), signal_start_offset_interpolation.flatten()),
        extended_match_grid.flatten(),
        (signal_lengths_pdf, signal_start_offsets_pdf), 
        method='nearest'
        
    )
    
    #print(np.argwhere(np.isnan(match_interpolation)))
    
    expon_probabilities = stats.expon.pdf(x_pdf, scale=mean_track_length)
    
    
    
    print(np.mean(simpson(match_interpolation * expon_probabilities.reshape((1, n_pdf)), signal_lengths_pdf, axis=-1)))
    
    
    
    

    
    #extended_match_grid[:, 0:temporal_mapsize] = match_grid
    
    #for extension_col in range(n_pdf - temporal_mapsize):
    #    extended_match_grid[:, extension_col] = match_grid[:, -1]
    
    #
    #print(expon_probabilities[0:20])
    
    #print(match_grid[0, 0:20])
    
    #weighted_match_grid = extended_match_grid * expon_probabilities.reshape((1, n_pdf))
    
    #mean_match_grid = np.mean(weighted_match_grid.sum(axis=-1) / n_pdf)
    
    #print(mean_match_grid)
    
    
    
    
    


In [None]:
template_file_paths = GetTemplateFiles(config)
templates = h5py.File(template_file_paths[0], 'r')

In [None]:
rng = np.random.default_rng()
rand_int = rng.integers(0, templates['data'].shape[0], 1)

template = templates['data'][rand_int, :]

template = template.reshape(60, template.shape[-1] // 60)

signal = np.copy(template)

# Compute score grids

In [None]:
score_maps = ComputeScoreMaps(templates, mapsize=51)

# estimate mean temporal match

In [None]:
EstimateMeanTemporalMatch(templates, 5, template_length = 8192 + 4069, temporal_mapsize=21)

In [None]:
EstimateMeanTemporalMatch(templates, 5, template_length = 8192 + 4096, temporal_mapsize=21, interleave_ind=1)

In [None]:
AnalyticMatchMap(8192, (500, 500))