In [None]:
import numpy as np
import mayfly as mf
import h5py
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os 
import sys
import json
import scipy.signal
import scipy.stats
import scipy.interpolate
import pickle as pkl
import torch
import scipy.optimize

PATH = '/storage/home/adz6/group/project/'
RESULTPATH = os.path.join(PATH, 'results/mayfly')
PLOTPATH = os.path.join(PATH, 'plots/mayfly')
DATAPATH = os.path.join(PATH, 'datasets/data')

def GetScoreFiles(path2scores, params, random=True):
    
    radii = params['r']
    angles = params['pa']
    nsample = params['N']
    
    score_file_list = []
    for i, file in enumerate(os.listdir(path2scores)):
        
        if not os.path.isdir(os.path.join(path2scores, file)):
            continue
        
        try:
            file_angle = float(file.split('grid_')[-1].split('_')[0])
            file_rad = int(file.split('cm')[0].split('_')[-1])
            file_has_random = file.find('random')
        except BaseException as err:
            print(err, file)
            continue

        for j, item in enumerate(zip(radii, angles, nsample)):
            #print(item, file_rad, file_angle, file_has_random, file)
            
            #print(file_rad == item[0], )
            if (file_rad == item[0]) and (file_angle == item[1]) and (random and (file_has_random > 0)) :
                print(file)
                
                for k, subfile in enumerate(os.listdir(os.path.join(path2scores, file))):
                    
                    if int(subfile.split('nsample')[-1].split('.npy')[0]) == item[2]:
                        
                        score_file_list.append(os.path.join(path2scores, file, subfile))
                        
    return score_file_list

def GetTemplateFiles(score_list):
    template_path = '/storage/home/adz6/group/project/datasets/data/dense_template_grid'
    
    template_list = []
    for i, score_path in enumerate(score_list):
        scores_name = score_path.split('/')[-1]
        
        angle = scores_name.split('grid_')[-1].split('_')[0]
        rad = scores_name.split('_template')[0].split('_')[-1]
        #print(angle, rad, )
        for j, template_file in enumerate(os.listdir(template_path)):
            #print(angle, rad, template_file)
            if (template_file.find(angle) > 0) and (template_file.find(rad) > 0):
                print(template_file)
                template_list.append(os.path.join(template_path, template_file))
                
    return template_list
                
def GetSignalFiles(score_list):
    signal_path = '/storage/home/adz6/group/project/datasets/data/dense_template_random'
    
    signal_list = []
    for i, score_path in enumerate(score_list):
        scores_name = score_path.split('/')[-1]
        
        angle = scores_name.split('grid_')[-1].split('_')[0]
        rad = scores_name.split('_template')[0].split('_')[-1]
        #print(angle, rad, )
        for j, signal_file in enumerate(os.listdir(signal_path)):
            #print(angle, rad, template_file)
            if (signal_file.find(angle) > 0) and (signal_file.find(rad) > 0):
                print(signal_file)
                signal_list.append(os.path.join(signal_path, signal_file))
                
    return signal_list


def SortScores(scores, energy_grid, angle_grid, random_index, norm_array, signal_metadata, template_metadata):
    
    signal_energy = signal_metadata.iloc[random_index]['energy']
    signal_angle = signal_metadata.iloc[random_index]['theta_min']

    sorted_scores = np.zeros(energy_grid.size)

    for j, pair in enumerate(zip(energy_grid.flatten(), angle_grid.flatten())):

        try:
            #print(j)
            _index = template_metadata[(template_metadata['energy'] == pair[0]) & (template_metadata['theta_min'] == pair[1])].index[0]
            #print(scores[_index, random_index])
            sorted_scores[j] = scores[_index, random_index] * norm_array[_index]
        except:
            #print('Oh no!')
            pass # some of the simulations failed so there will be zero pixels
        
    return sorted_scores

def ScoreMap(score_file, template_file, signal_file, number_of_maps):
    var = 1.38e-23 * 200e6 * 10 * 50
    
    score_name = score_file.split('/')[-1]
    scores = np.load(score_file)
    
    angle = score_name.split('grid_')[-1].split('_')[0]
    rad = score_name.split('_template')[0].split('_')[-1]
    nsample = int(score_name.split('_nsample')[-1].split('.npy')[0])
    nslice = int(score_name.split('_nslice')[-1].split('_')[0])
    
    energy_array = np.linspace(18575, 18580, 101)
    central_pitch_angle = float(angle)
    angle_array = np.linspace(central_pitch_angle - 0.005, central_pitch_angle + 0.005, 101)
    energy_grid, angle_grid = np.meshgrid(energy_array, angle_array)
    
    _template_data = mf.data.MFDataset(template_file)
    template_metadata = pd.DataFrame(_template_data.metadata)

    _signal_data = mf.data.MFDataset(signal_file)
    signal_metadata = pd.DataFrame(_signal_data.metadata)
    
    print('Loading Data')
    template_data = _template_data.data[:].reshape(_template_data.shape[0], 60, _template_data.shape[-1] // 60)[:, :, 0:nslice*nsample].reshape(_template_data.shape[0], nslice*nsample*60)
    print('Calculating Normalization')
    norm_array = 1 / np.sqrt(var * abs(template_data * template_data.conjugate()).sum(-1))
    
    score_map_list = []
    
    for i in range(number_of_maps):
        print(f'Creating Map {i + 1}')
        rng = np.random.default_rng()
        random_index = rng.integers(0, scores.shape[-1], 1)[0]

        sorted_scores = SortScores(scores, energy_grid, angle_grid, random_index, norm_array, signal_metadata, template_metadata)
    
        score_map_list.append(sorted_scores.reshape(energy_grid.shape))
        
    del template_data
    return np.array(score_map_list)
    
    


# Get a list of score files

In [None]:
#os.listdir(os.path.join(RESULTPATH, 'scores'))

In [None]:
#print(score_list)

In [None]:
path2scores = os.path.join(RESULTPATH, 'scores')

rads = [1, 2, 3, 4]
samples = [8192]
angles = [86.0, 87.0, 88.0]

rads, angles, samples = np.meshgrid(rads, angles, samples)

params = {'r': rads.flatten(), 'pa': angles.flatten(), 'N': samples.flatten()}

score_list = GetScoreFiles(path2scores, params)
template_list = GetTemplateFiles(score_list)
signal_list = GetSignalFiles(score_list)


# compute score maps

In [None]:
number_of_maps = 1

score_map_list = []
for i, group in enumerate(zip(score_list, template_list, signal_list)):
    
    print(group[0].split('/')[-1])
    score_map_list.append(ScoreMap(group[0], group[1], group[2], number_of_maps))
    
    

In [None]:
print(score_map)

In [None]:
plt.imshow(score_map[0])
plt.colorbar()

In [None]:
var = 1.38e-23 * 200e6 * 10 * 50
n_example = 5
sns.set_theme(context = 'poster', style='ticks')

cmap = sns.color_palette('mako_r', as_cmap=True)
for i, file_name in enumerate(score_list_1):
    
    scores = np.load(os.path.join(RESULTPATH, 'scores', file_name))
    
    # scores have symmetry for same signal/template
    #scores = np.tril(scores) + np.tril(scores).T - np.diag(np.diag(scores))
    
    energy_array = np.linspace(18575, 18580, 101)
    
    central_pitch_angle = float(file_name.split('grid_')[-1].split('_')[0])
    
    radial_position = int(file_name.split('cm')[0].split('_')[-1])
    
    angle_array = np.linspace(central_pitch_angle - 0.005, central_pitch_angle + 0.005, 101)
    
    energy_grid, angle_grid = np.meshgrid(energy_array, angle_array)

    _template_data = mf.data.MFDataset(os.path.join(PATH, 'datasets/data', 'dense_template_grid', template_list[i]))
    template_metadata = pd.DataFrame(_template_data.metadata)

    _signal_data = mf.data.MFDataset(os.path.join(PATH, 'datasets/data', 'dense_template_random', signal_list[i]))
    signal_metadata = pd.DataFrame(_signal_data.metadata)
    
    template_data = _template_data.data[:].reshape(_template_data.shape[0], 60, _template_data.shape[-1] // 60)[:, :, 0:2*8192].reshape(_template_data.shape[0], 2*8192*60)
    
    norm_array = 1 / np.sqrt(var * abs(template_data * template_data.conjugate()).sum(-1))

    # choose a random signal
    
    for k in range(n_example):
        rng = np.random.default_rng()
        random_index = rng.integers(0, scores.shape[-1], 1)[0]

        signal_energy = signal_metadata.iloc[random_index]['energy']
        signal_angle = signal_metadata.iloc[random_index]['theta_min']

        sorted_scores = np.zeros(energy_grid.size)


        for j, pair in enumerate(zip(energy_grid.flatten(), angle_grid.flatten())):

            try:
                #print(i)
                _index = template_metadata[(template_metadata['energy'] == pair[0]) & (template_metadata['theta_min'] == pair[1])].index[0]
                #print(scores[_index, random_index])
                sorted_scores[j] = scores[_index, random_index] * norm_array[_index]
            except:
                #print('Oh no!')
                pass # some of the simulations failed so there will be zero pixels

        fig = plt.figure(figsize=(13, 8))
        ax = fig.add_subplot(1,1,1)


        img = ax.imshow(1 * sorted_scores.reshape(energy_grid.shape), 
                   aspect='auto', 
                   interpolation='none',
                  extent = (18575, 18580, central_pitch_angle+0.005, central_pitch_angle-0.005),
                       cmap=cmap)
        cb = fig.colorbar(img, label='Score')
        ax.set_yticks(np.linspace(central_pitch_angle-0.005, central_pitch_angle+0.005, 5))
        ax.set_xlabel('Energy (eV)')
        ax.set_ylabel('Pitch Angle (deg)')
        ax.set_title('Matche Filter Scores\n T = 10K, Samples = 16384')
        
        name = f'220112_mf_score_map_random_{central_pitch_angle}_{radial_position}cm_example_{k}.png'
        plt.savefig(os.path.join(PATH, 'plots', 'mayfly', name))
        plt.show()
    
        plt.show()
    



# compute match using all templates, plot match histograms

In [None]:
score_list_1 = [
                '211129_sens_est_dense_grid_84.5_0cm_template_scores_nslice2_random.npy',
                '220107_sens_est_dense_grid_87.0_0cm_template_scores_nslice2_random.npy',
                '211129_sens_est_dense_grid_89.5_0cm_template_scores_nslice2_random.npy',
               ]

template_list = [
                '211129_sens_est_dense_grid_84.5_0cm.h5',
                '220107_sens_est_dense_grid_87.0_0cm.h5',
                '211129_sens_est_dense_grid_89.5_0cm.h5',
                ]

signal_list = [
                '211221_sens_est_dense_grid_84.5_0cm_random.h5',
               '211221_sens_est_dense_grid_87.0_0cm_random.h5',
               '211221_sens_est_dense_grid_89.5_0cm_random.h5'
              ]

In [None]:
sns.set_theme(context='poster')

for i, file_name in enumerate(score_list_1):
    
    scores = np.load(os.path.join(RESULTPATH, 'scores', file_name))
    
    print(scores.shape)
    # scores have symmetry for same signal/template
    #scores = np.tril(scores) + np.tril(scores).T - np.diag(np.diag(scores))

    energy_array = np.linspace(18575, 18580, 101)
    
    central_pitch_angle = float(file_name.split('grid_')[-1].split('_')[0])
    radial_position = int(file_name.split('cm')[0].split('_')[-1])
    
    angle_array = np.linspace(central_pitch_angle - 0.005, central_pitch_angle + 0.005, 101)
    
    energy_grid, angle_grid = np.meshgrid(energy_array, angle_array)

    _template_data = mf.data.MFDataset(os.path.join(PATH, 'datasets/data', 'dense_template_grid', template_list[i]))
    template_metadata = pd.DataFrame(_template_data.metadata)
    
    
    _signal_data = mf.data.MFDataset(os.path.join(PATH, 'datasets/data', 'dense_template_random', signal_list[i]))
    signal_metadata = pd.DataFrame(_signal_data.metadata)
    
    signal_data = _signal_data.data[:]
    
    signal_data = signal_data.reshape(signal_data.shape[0], 60, signal_data.shape[-1] // 60)[:, :, 0:2*8192].reshape(signal_data.shape[0], 60 * 2 * 8192)

    # print(signal_data.shape)
    ideal_scores = abs(signal_data * signal_data.conjugate()).sum(-1)
    
    fig = plt.figure(figsize=(13,8))
    ax = fig.add_subplot(1,1,1)
    
    match = scores.max(0) / ideal_scores
    hist = ax.hist(match, 50)
    
    ax.set_xlabel('Match (Best Score / Ideal Score)')
    ax.set_ylabel('N')
    plt.tight_layout()
    
    name = f'220112_match_histogram_all_templates_{central_pitch_angle}_{radial_position}cm.png'
    plt.savefig(os.path.join(PATH, 'plots', 'mayfly', name))
    plt.show()


# Down select template grids

In [None]:


energies_base = np.linspace(18575, 18580, 101)
angles_base = np.linspace(89.5 - 0.005, 89.5 + 0.005, 101)

n_point = 2

energy_grid, angle_grid = np.meshgrid(energies_base, angles_base)

#for i in range(9):

#    print(energy_grid.flatten()[np.arange(0, 101 * 101, i+1)].size)

fig = plt.figure(figsize = (13, 13))

ax = fig.add_subplot(1,1,1)
ax.plot(energy_grid.flatten()[np.arange(0, 101 * 101, n_point)], angle_grid.flatten()[np.arange(0, 101 * 101, n_point)], '.', markersize=10)

ax.set_xlabel('Energy (eV)')
ax.set_ylabel('Pitch Angle (deg)')

ax.set_title(f'Example Down-selected Grid, N = {energy_grid.flatten()[np.arange(0, 101 * 101, n_point)].size}')

name = '220111_example_down_select_grid_4'

#plt.savefig(os.path.join(PATH, 'plots', 'mayfly', name))

# animate grids

In [None]:
sns.set_theme(context='poster')

energies_base = np.linspace(18575, 18580, 101)
angles_base = np.linspace(89.5 - 0.005, 89.5 + 0.005, 101)

energy_grid, angle_grid = np.meshgrid(energies_base, angles_base)

def animate(k):
    line.set_xdata(energy_grid.flatten()[np.arange(0, 101 * 101, k+1)])
    line.set_ydata(angle_grid.flatten()[np.arange(0, 101 * 101, k+1)])
    
    ax.set_title(k)
    
    return line,


fig = plt.figure(figsize = (13, 13))

ax = fig.add_subplot(1,1,1)

line, = ax.plot(energy_grid.flatten(), angle_grid.flatten(), '.', markersize=10)
ax.set_title(1)

#ani = matplotlib.animation.FuncAnimation(fig, animate, frames = np.arange(1, 1000, 20), interval=700)

#ani.save(os.path.join(PATH, 'plots', 'mayfly', 'test.gif', ))

# Down-select template grid, plot several histograms
