In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import h5py
import os
import sys
import scipy
import damselfly as df
import mayfly as mf
import scipy.signal
import scipy.stats
import scipy.interpolate

PATH = '/storage/home/adz6/group/project'
RESULTPATH = os.path.join(PATH, 'results/damselfly')
PLOTPATH = os.path.join(PATH, 'plots/damselfly')
DATAPATH = os.path.join(PATH, 'datasets/data')
#SIMDATAPATH = os.path.join(PATH, 'damselfly/data/sim_data')

"""
Date: 6/25/2021
Description: template
"""

def GetFreqLabel(signal_metadata, index, kass_data, kass_metadata):
    
    #print(signal_metadata.iloc[index]['theta_min'])
    
    # need to check Kass index since the beamforming process mixes up the indices
    kass_index = kass_metadata[(kass_metadata['energy'] == signal_metadata.iloc[index]['energy']) & 
                        (kass_metadata['theta_min'] == signal_metadata.iloc[index]['theta_min']) & 
                       (kass_metadata['x_min'] == signal_metadata.iloc[index]['x_min'])
                       ].index[0]
    
    # takes the starting cyclotron frequncy as the label
    frequency = kass_data['fc'][kass_index, 0]
    
    return frequency

def CreateGroups(h5file, config):
    for i, grp in enumerate(['train', 'val', 'test']):
        h5file.create_group(grp)
        h5file[grp].create_group('meta')
        if i == 0:
            h5file[grp].create_dataset('data', config['train_shape'], dtype='f4')
            h5file[grp].create_dataset('label', (config['train_shape'][0],), dtype='f4')
            
            for j, key in enumerate(['energy', 'x_min', 'theta_min']):
                h5file[grp]['meta'].create_dataset(key, (config['train_shape'][0],), dtype='f4')
        else:
            h5file[grp].create_dataset('data', config['test_shape'], dtype='f4')
            h5file[grp].create_dataset('label', (config['test_shape'][0],), dtype='f4')
            
            for j, key in enumerate(['energy', 'x_min', 'theta_min']):
                h5file[grp]['meta'].create_dataset(key, (config['test_shape'][0],), dtype='f4')
        

def GetEnergyLabel(signal_metadata, index):
    
    # takes the energy from the metadata
    energy = signal_metadata.iloc[index]['energy']
    
    return energy

def DataSlicer(data, islice, inds, nch, slicesize):
    
    print(data[inds, :].reshape((inds.size, nch, data.shape[-1] // nch))[:, :, islice * slicesize:(islice + 1) * slicesize].shape)
    
    return data[inds, :].reshape((inds.size, nch, data.shape[-1] // nch))[:, :, islice * slicesize:(islice + 1) * slicesize]

#def AddNoise(data, var):
#    rng = np.random.default_rng()
    
#    noise = rng.multivariate_normal([0,0], np.eye(2) * var / 2, data.size)
    
#    noise = noise[:, 0] + 1j * noise[:, 1]
    
#    return data + noise.reshape(data.shape)

# same signals in each dataset. Noise to be added during training to save storage space.

def CreateDLDataset(config, data, metadata):
    
    h5file = h5py.File(name, 'w')
    
    CreateGroups(h5file, config)
    
    labels = np.zeros(data.data.shape[0])
    
    chunk_inds = np.array_split(np.arange(0, data.data.shape[0], 1), config['nchunk'])
    
    #if config['label'] == 'freq':
    #    for i in range(data.data.shape[0]):
            #labels[i] = GetFreqLabel(metadata, i, kass_data, kass_metadata)
    if config['label'] == 'energy':
        for i in range(data.data.shape[0]):
            labels[i] = GetEnergyLabel(metadata, i)        
    if config['label'] == 'class':
        labels = np.ones(data.data.shape[0])
        
    for i, grp in enumerate(['train', 'val', 'test']):
        print(f'Starting {grp}')
        for islice in range(config['nslice']):
            
            for ichunk in range(config['nchunk']):

                data_slice = DataSlicer(data.data, islice, chunk_inds[ichunk], config['nch'], config['slicesize'])

                data_fft = np.fft.fftshift(np.fft.fft(data_slice, axis=-1), axes=(-1)) / config['slicesize']

                h5file[grp]['data'][chunk_inds[ichunk], 2 * islice, :, :] = data_fft.real
                h5file[grp]['data'][chunk_inds[ichunk], 2 * islice + 1, :, :] = data_fft.imag
                #h5file[grp]['data'][:, 3 * islice + 2, :] = abs(data_noise) ** 2
            
        h5file[grp]['label'][:] = labels

        h5file[grp]['meta']['energy'][:] = np.array(metadata['energy'].array)
        h5file[grp]['meta']['x_min'][:] = np.array(metadata['x_min'].array)
        h5file[grp]['meta']['theta_min'][:] = np.array(metadata['theta_min'].array)
        print(f'Done with {grp}')
    h5file.close()
    
    


In [None]:
os.listdir(os.path.join(DATAPATH,))

In [None]:
os.listdir(os.path.join(PATH, 'datasets', 'kass'))

# load data

In [None]:
# signal data
data = mf.data.MFDataset(os.path.join(DATAPATH, '211027_84_25_2cm.h5'))
metadata = pd.DataFrame(data.metadata)

# kass data
# h5kass_data = h5py.File(os.path.join(PATH, 'datasets', 'kass', '211129_sens_est_dense_grid_84.5_0cm_kass.h5'), 'r')

#kass_data = h5kass_data['kass']
#kass_metadata = {}
#for i, key in enumerate(h5kass_data['meta'].keys()):
#    kass_metadata[key] = h5kass_data['meta'][key][:]
    
#kass_metadata = pd.DataFrame.from_dict(kass_metadata)

In [None]:
data.data.shape

In [None]:
1474560 // 60


# define output dataset parameters

In [None]:
nsignal = data.shape[0]
nsample = data.shape[1]

nchunk = 4

# same signals in train, test, val sets. Different noise samples added to signals at run time
#ncopies_train = 10
#ncopies_test = 4

nslice = 1
ninput_ch = 2 # real, imag
slicesize = 2 * 8192
nantenna = 60

train_shape = (nsignal, nslice * ninput_ch, nantenna, slicesize)
test_shape = (nsignal, nslice * ninput_ch, nantenna, slicesize)

noise_temp = 10
fsample = 200e6
system_z = 50
nch = 60
kB = 1.38e-23

noise_var = kB * nch * noise_temp * system_z * fsample
noise_var_per_bin = noise_var / slicesize

name = os.path.join(DATAPATH, 'dl', '211203_dl_classification_84_25_2cm_slice1_sample2x8192_no_sum.h5')
label = 'class'

config = {
    'train_shape': train_shape,
    'test_shape': test_shape,
    'nsignal': nsignal,
    'nsample': nsample,
    'nslice': nslice,
    'ninput_ch': ninput_ch,
    'slicesize': slicesize,
    'noise_temp': noise_temp,
    'fsample': fsample,
    'system_z': system_z,
    'nch': nch,
    'noise_var': noise_var,
    'noise_var_per_bin': noise_var_per_bin,
    'name': name,
    'label': label,
    'nchunk': nchunk,
}


CreateDLDataset(config, data, metadata)


