In [None]:
import time
start = time.time()
#import psutil
import h5py
import numpy as np
import os
import damselfly as df
import matplotlib.pyplot as plt
import scipy.signal as signal
import argparse

#parser = argparse.ArgumentParser()
#parser.add_argument('pc_comp')
#parser.add_argument('signal_batch')
#args = parser.parse_args()

def shift_signal(signal, x_new, y_new):
    
    nch = signal.shape[0]
    angles = np.radians(np.arange(0, nch, 1) * 360 / nch)
    r_array = 0.10
    wavelength_lo = 3e8 / 25.86e9
    
    x_antenna = r_array * np.cos(angles)
    y_antenna = r_array * np.sin(angles)
    x_antenna = x_antenna.reshape((x_antenna.size, 1)).repeat(signal.shape[-1], axis = -1)
    y_antenna = y_antenna.reshape((y_antenna.size, 1)).repeat(signal.shape[-1], axis = -1)
    
    d_old = np.sqrt(x_antenna ** 2 + y_antenna ** 2)
    d_new = np.sqrt((x_new - x_antenna) ** 2 + (y_new - y_antenna) ** 2)
    
    
    phase_shift = 2 * np.pi * (d_new - d_old) / wavelength_lo
    #print(phase_shift)
    
    shifted_signal = np.exp(-1j * phase_shift) * signal
    
    return shifted_signal

def track_length_weights(rng, nslice, mean=3, invert=False):
    est_track_length = rng.exponential(mean)
    
    #print(est_track_length, np.ceil(est_track_length), int(nslice - np.ceil(est_track_length)),1 - np.ceil(est_track_length) + est_track_length, nslice)
    
    slice_weights = np.ones(nslice)
    
    if est_track_length >= nslice:
        return slice_weights
    else:
        N_empty_slice = int(nslice - np.ceil(est_track_length))
        
        partial_slice_weight = 1 - np.ceil(est_track_length) + est_track_length
        
        for n in range(N_empty_slice):
            slice_weights[n] = 0.0
            
        slice_weights[N_empty_slice] = partial_slice_weight
        
        if rng.integers(0, 2) == 1 and invert:
            slice_weights = np.flip(slice_weights)
        
        return slice_weights

def add_noise(signal, var):
    
    shape = signal.shape
    rng = np.random.default_rng()
    
    noise = rng.multivariate_normal([0, 0], np.eye(2) * var / 2, shape[0] * shape[1])
    noise = (noise[:, 0] + 1j * noise[:, 1]).reshape(shape)
    #noise = np.fft.fft((noise[:, 0] + 1j * noise[:, 1]).reshape(shape), axis=-1) / 8192
    
    return signal + noise
    
var = 1.38e-23 * 10 * 50 * 200e6 / 8192

IPCA = 0
IBATCH = 0

PATH = '/storage/home/adz6/group/project'
DAMSELPATH = os.path.join(PATH, 'damselfly')
SIMDATAPATH = os.path.join(DAMSELPATH, 'data/sim_data')
DATAPATH = os.path.join(DAMSELPATH, 'data/datasets')

## load pc ##
pc_mat_path = os.path.join(PATH, 'damselfly/data', '211019_spatial_shift_pc_energy_range.h5')
h5pca = h5py.File(pc_mat_path, 'r')
pc_mat = h5pca['pc_matrix']

select_pc = pc_mat[:, IPCA, :]
## ##

## load data ##

#data = os.path.join(os.path.join(PATH, 'damselfly/data/sim_data'), '211019_84_1d2sl_nosum_fft.h5')
#h5val_signals = h5py.File(data, 'r')

data = os.path.join(os.path.join(PATH, 'damselfly/data/sim_data'), '211019_84_100_1d2sl_nosum_fft.h5')
h5train_signals = h5py.File(data, 'r')

#print(h5val_signals['signal']['0'].shape)
#Nval_signal = len(list(h5val_signals['signal'].keys()))
Ntrain_signal = len(list(h5train_signals['signal'].keys()))

Nslice = 2
Nsample = 8192
Nch = 60
## ##

## define parameters ##

signal_shape = (60, 8192)
Ncopies_train = 20
noise_frac = 0.2
Nnoise = int(Ncopies_train * Ntrain_signal / (1 / noise_frac - 1))

rng = np.random.default_rng()

train_signal_indices = np.arange(0, Ncopies_train * Ntrain_signal, 1)
noise_signal_indices = np.arange(Ncopies_train * Ntrain_signal, Ncopies_train * Ntrain_signal + Nnoise, 1)

nsplit = 25000
train_signal_indicies_batch_list = np.array_split(train_signal_indices, nsplit)
noise_signal_indicies_batch_list = np.array_split(noise_signal_indices, nsplit)

train_batch = train_signal_indicies_batch_list[IBATCH]
noise_batch = noise_signal_indicies_batch_list[IBATCH]

signal_ctr = 0
train_keys = np.zeros(Ncopies_train * Ntrain_signal, np.int32)
for i in range(Ntrain_signal):
    for j in range(Ncopies_train):
        train_keys[signal_ctr] = i
        signal_ctr += 1

## ##

## loop ##

## load signal batch ##
batch_ctr = 0

signal_batch = np.zeros((train_batch.size, Nslice, Nch * Nsample), np.complex64)
for isignal in train_batch:
    signal = h5train_signals['signal'][f'{train_keys[isignal]}'][:]
    
    r_shift = rng.random(1,) * 5e-2 # random distance between 0 and 5 cm

    slice_weights = np.ones(2)#track_length_weights(rng, Nslice, mean=2000, invert=True)

    for nslice in range(Nslice):
        
        #print(signal[nslice, :].shape)

        shifted_signal = shift_signal(signal[nslice, :], 0.01, 0.00)

        noise = add_noise(np.zeros(shifted_signal.shape), var)

        signal_batch[batch_ctr, nslice, :] = noise.flatten() + shifted_signal.flatten() * slice_weights[nslice]

    batch_ctr += 1

## ##

## do signal_projection ##

signal_proj = np.zeros((train_batch.size, 2, 51), np.complex64)

for idist in range(select_pc.shape[0]):
    
        temp_pc = select_pc[idist, :]
        
        #fig = plt.figure(figsize=(13, 8))
        #ax = fig.add_subplot(1,1,1)

        #ax.plot(temp_pc)

        #plt.show()
        #input()
        #plt.close()
        
        for islice in range(Nslice):
            
            #print((signal_batch[:, 0, :] * temp_pc).shape)
            #input()

            circ_conv = np.fft.ifft(np.fft.fft(signal_batch[:, islice, :], axis=-1) * np.fft.fft(np.flip(temp_pc)), axis=-1)
            
            imax = np.argmax(abs(circ_conv), axis=-1)
            irows = np.arange(0, imax.size, 1)
           #print(circ_conv[irows, imax].shape)

            signal_proj[:, islice, idist] = circ_conv[irows, imax]

            
## ##

## load noise batch ##
batch_ctr = 0

noise_only_batch = np.zeros((noise_batch.size, Nslice, Nch * Nsample), np.complex64)
#print(noise_batch.shape, noise_only_batch.shape)
for isignal in noise_batch:
    #signal = h5train_signals['signal'][f'{train_keys[isignal]}'][:]
    
    #r_shift = rng.random(1,) * 5e-2 # random distance between 0 and 5 cm

    #slice_weights = track_length_weights(rng, Nslice, mean=4, invert=True)

    for nslice in range(Nslice):

        #shifted_signal = shift_signal(signal[nslice, :], r_shift, 0.00)

        noise = add_noise(np.zeros(signal_shape), var)

        noise_only_batch[batch_ctr, nslice, :] = noise.flatten()

    batch_ctr += 1

## ##

## do noise projection ##

noise_proj = np.zeros((noise_batch.size, 2, 51), np.complex64)

for idist in range(select_pc.shape[0]):
        temp_pc = select_pc[idist, :]
        
        for islice in range(Nslice):

            circ_conv = np.fft.ifft(np.fft.fft(noise_only_batch[:, islice, :], axis=-1) * np.fft.fft(temp_pc), axis=-1)

            imax = np.argmax(abs(circ_conv), axis=-1)
            irows = np.arange(0, imax.size, 1)

            #print(circ_conv.shape, imax.shape, irows.shape, noise_proj.shape, circ_conv[irows, imax].shape)

            noise_proj[:, islice, idist] = circ_conv[irows, imax]
            
## ##

## ##


np.savez(f'batch{IBATCH}_comp{IPCA}', noise_proj = noise_proj, signal_proj = signal_proj)
end = time.time()
print(f'Total time = {end-start}')


In [None]:
plt.imshow(abs(noise_proj[:, 0, :]), aspect='auto')
plt.colorbar()

print(np.max(abs(noise_proj[:, 1, :])))

In [None]:
plt.imshow(abs(signal_proj[:, 0, :]), aspect='auto')
plt.colorbar()

print(np.max(abs(signal_proj[:, 0, :])))