In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import h5py
import os

sns.set_theme('poster')

def GetTemplateFiles(config):
    template_path = '/storage/home/adz6/group/project/datasets/data/dense_template_grid'
    
    radii, angles = np.meshgrid(config['r'], config['pa'])
    
    template_list = []
    for i, param_grp in enumerate(zip(radii.flatten(), angles.flatten())):
        for j, template_file in enumerate(os.listdir(template_path)):
            #scores_name = score_path.split('/')[-1]

            angle = template_file.split('grid_')[-1].split('_')[0]
            rad = template_file.split('cm.h5')[0].split('_')[-1]
            
            if (param_grp[0] == int(rad)) and (param_grp[1] == float(angle)):
                #print(template_file)
                template_list.append(os.path.join(template_path, template_file))

    return template_list

config = {
    'r': [1,],
    'pa': [86.0,87.0, 88.0],
}

def SliceData(data, nslice, slicesize):
    
    if data.shape[-1] < nslice * slicesize:
        print('Not enough samples for requested slicing parameters')
        return data.reshape((1, data.shape))
    else:
        
        sliced_data = np.zeros((nslice, data.shape[0], slicesize), dtype=np.complex128) # shape = (slice, channel, slice_samples)
        
        for n in range(nslice):
            sliced_data[n, :, :] = data[:, n * slicesize:(n+1) * slicesize]
            
        return sliced_data
    
def FFTSignal(signal):
    
    return np.fft.fft(signal, axis=-1) / signal.shape[-1]
        
    
def GenerateTemplate(template_signal, length, fft=True, norm=True):

    template_signal = template_signal[:, 0:length]
    var = 1.38e-23 * 50 * 10 * 200e6 / length
    if fft:
        template_signal = np.fft.fft(template_signal, axis=-1) / length
        normalization = 1 / np.sqrt(var * np.vdot(template_signal.flatten(), template_signal.flatten()))
    else:
        normalization = 1 / np.sqrt(length * var * np.vdot(template_signal.flatten(), template_signal.flatten()))
    
    if norm:
        return normalization * template_signal
    else:
        return template_signal

def CalculateSliceScores(signal, template, ):
    scores = np.zeros(signal.shape[0])
    
    for i in range(scores.size):
        scores[i] = abs(signal[i, :].conjugate() * template).sum()
        
    return scores

def CalculateSliceMatches(signal, template, template_signal):
    scores = np.zeros(signal.shape[0])
    
    for i in range(scores.size):
        scores[i] = abs(signal[i, :].conjugate() * template).sum()
        
    template_score = abs(template_signal.conjugate() * template).sum()
        
    return scores / template_score

def ModifySignalLength(signal, sig_length, sig_start_offset,):
    
    new_signal = np.copy(signal)
    
    if sig_length <= (signal.shape[0] - sig_start_offset): # measure length in units of fft windows
        
        sig_end = sig_length + sig_start_offset
        end_slice = np.int32(np.floor(sig_end))
        if end_slice == 0:
            end_sample = np.int32(signal.shape[-1] * (sig_end % (1)))
        else:
            end_sample = np.int32(signal.shape[-1] * (sig_end % (end_slice)))
        
        for i in range(signal.shape[0] - end_slice):
            islice = i + end_slice
            
            if i == 0:
                new_signal[islice, :, end_sample:] = 0
            else:
                new_signal[islice, :, :] = 0
        
    return new_signal


def ModifySignalStartTime(signal, diff):
    shape = signal.shape
    
    new_signal = np.zeros(shape, dtype=np.complex64)
    
    for i in range(signal.shape[0]):
        if i == 0:
            new_signal[i, :, diff:] = signal[i, :, 0:shape[-1] - diff]
        else:
            new_signal[i, :, :] = np.concatenate(
                (signal[i - 1, :, shape[-1]-diff:], signal[i, :, 0:shape[-1]-diff]),
                axis=-1)
        
    return new_signal


def ModifySignal(signal, sig_length, diff, length):
    
    # if the signal is long enough such that it will fill the template
    # then we only need to shift start time
    if (sig_length >= (length - diff)):
        
        mod_signal = ModifySignalStartTime(signal, diff, length)
        return mod_signal
    # if the signal is shorter than the negative start time offset 
    # combined with template length it will end before the template
    # then we shift the signal and change the length
    elif (sig_length < (length - diff)):
        
        mod_signal = ModifySignalStartTime(signal, diff, length)
        # after the shift the length of the signal in the template
        # window will be (sig_length + diff) since diff is negative
        mod_signal = ModifySignalLength(mod_signal, sig_length + diff, length)
        
        return mod_signal
    # if the start shift is positive and the signal is shorter than the remaining template window
    # we shift then shorten the signal
    #elif (sig_length < (length - diff)) and (diff >= 0):
        
    #    mod_signal = ModifySignalStartTime(signal, diff, length)
    #    mod_signal = ModifySignalLength(mod_signal, sig_length + diff, length)
        
    #    return mod_signal
        

In [None]:
template_file_paths = GetTemplateFiles(config)
templates = h5py.File(template_file_paths[0], 'r')

In [None]:
print(template_file_paths)

In [None]:
rng = np.random.default_rng()
rand_int = rng.integers(0, templates['data'].shape[0], 1)

template = templates['data'][rand_int, :]

template = template.reshape(60, template.shape[-1] // 60)

signal = np.copy(template)

In [None]:
signal_start_offset = 1024
signal_length = 1.2
fft_length = 8192
nslice = 3


sliced_signal = SliceData(signal, nslice, fft_length)

mod_signal = ModifySignalStartTime(sliced_signal, signal_start_offset)
mod_signal = ModifySignalLength(mod_signal, signal_length, signal_start_offset / fft_length)

mod_signal = FFTSignal(mod_signal)

template_signal = GenerateTemplate(template, fft_length, norm=False)
template_norm = GenerateTemplate(template, fft_length)

print(CalculateSliceScores(mod_signal, template_norm))
print(CalculateSliceMatches(mod_signal, template_norm, template_signal))

# calculate score and match over a grid

In [None]:
#signal_start_offset = 1024
#signal_length = 1.2
fft_length = 8192
nslice = 3


In [None]:
grid_size = 201

score_matrix = np.zeros((grid_size, grid_size, nslice))
match_matrix = np.zeros((grid_size, grid_size, nslice))

signal_length_array = np.linspace(1, nslice * fft_length, grid_size ) / fft_length
signal_start_offset_array = np.int32(np.linspace(0, fft_length, grid_size))

for i, signal_offset in enumerate(signal_start_offset_array):
    for j, signal_length in enumerate(signal_length_array):
        
        sliced_signal = SliceData(signal, nslice, fft_length)

        mod_signal = ModifySignalStartTime(sliced_signal, signal_offset)
        mod_signal = ModifySignalLength(mod_signal, signal_length, signal_offset / fft_length)

        mod_signal = FFTSignal(mod_signal)

        template_signal = GenerateTemplate(template, fft_length, norm=False)
        template_norm = GenerateTemplate(template, fft_length)

        score_matrix[i, j, :] = CalculateSliceScores(mod_signal, template_norm)
        match_matrix[i, j, :] = CalculateSliceMatches(mod_signal, template_norm, template_signal)
        
        if (i % 10 == 9) and (j % 10 == 9):
            print(f'( {i + 1}, {j + 1} )')

        

In [None]:
sns.set_theme(style='ticks')
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1, 1)
img = ax.imshow(score_matrix[:, :, 0], aspect='auto')
cbar = fig.colorbar(img)

In [None]:
sns.set_theme(style='ticks')
fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1, 1)
img = ax.imshow(match_matrix[:, :, 0], aspect='auto')
cbar = fig.colorbar(img)