In [None]:
import moth
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import csv
import scipy
import scipy.signal
import h5py

PATH = '/storage/home/adz6/group/project'
VNADATAPATH = os.path.join(PATH, 'datasets/data/lab/vna')
DAQDATAPATH = os.path.join(PATH, 'datasets/data/lab/dig')

LABDATAPATH = os.path.join(PATH, 'labdata')
#MOTHPATH = os.path.join(PATH, 'moth')
#MOTHDATA = os.path.join(MOTHPATH, 'data/datasets')
#PLOTPATH = os.path.join(PATH, 'plots/moth')
#MOTHRESULTS = os.path.join(PATH, 'results/moth')


def CombineMeasurements(name, path, lo_f=25.8e9, start_sample = 8192, nch=4, ):
    
    
    h5file = h5py.File(name, 'w')
    
    params = ExtractSweepParameters(path)
    
    for i, meas_file in enumerate(os.listdir(path)):
        print(meas_file)
        h5file.create_group(f'{meas_file}',)
        h5file[meas_file].__setitem__('step', params['step'][i])
        h5file[meas_file].__setitem__('z', params['z'][i])
        h5file[meas_file].__setitem__('lo_f', lo_f)
        for j, rad_sweep_file in enumerate(os.listdir(os.path.join(path, meas_file))):
            traces_by_angle = []
            angles = []
            for k, ang_sweep_file in enumerate(os.listdir(os.path.join(path, meas_file, rad_sweep_file))):
                
                angles.append(int(ang_sweep_file.split('.csv')[0]))
                
                traces = ParseDigTrace(os.path.join(path, meas_file, rad_sweep_file, ang_sweep_file))
                
                # append zeros to the trace if the digitizer skips channels (known issue)
                if traces.shape[0] < nch:
                    traces = np.concatenate(
                        (traces, np.zeros((nch - traces.shape[0], traces.shape[-1])))
                    )
                traces = ConvertToVolt(traces)
                traces = RemoveDC(traces)
                traces_by_angle.append(traces)
                
            sort_inds = np.argsort(angles) # make sure angles are in the right order
                
            traces = CorrectLO(np.asarray(traces_by_angle)[sort_inds, :, start_sample:])
            h5file[meas_file].create_dataset(f'{rad_sweep_file}', data=traces)
            
            
            if j % 5 == 4:
                print(f'{j + 1} / {len(os.listdir(os.path.join(path, meas_file)))}')

    h5file.close()
    
    print(params)
    
def ExtractSweepParameters(path):
    
    rad_step_list = []
    z_pos_list = []
    
    params = {'step': [], 'z': [], 'nrad': [], 'nangle': [], 'nsample': [], 'nch':[]}
    for i, file in enumerate(os.listdir(path)):
        params['step'].append(int(file.split('step')[-1].split('mm')[0])) #mm
        params['z'].append(int(file.split('z')[-1].split('mm')[0]))
        
        params['nrad'].append(len(os.listdir(os.path.join(path, file))))
        
        #print(os.listdir(os.path.join(path, file, '0')))
        
        params['nangle'].append(len(os.listdir(os.path.join(path, file, '0'))))
        
        example_file = os.listdir(os.path.join(path, file, '0'))[0]
        #example_trace = np.zeros(1)
        example_trace = ParseDigTrace(os.path.join(path, file, '0', example_file))
        
        params['nsample'].append(example_trace.shape[-1])
        params['nch'].append(example_trace.shape[0])
    
    return params

def ParseDigTrace(path):
    
    with open(path, 'r') as infile:
        traces = []
        for i, item in enumerate(csv.reader(infile)):
            item = np.float32(item)
            if item.sum() > 1:
                traces.append(item)
                
        
    return np.array(traces, np.float32)
                
    
def ConvertToVolt(raw_trace, dig_range = [2, 0.5, 0.5, 0.5]):
    
    nbit = 14
    conversion_factor = np.asarray(dig_range) / (2 ** 14)
    
    volt_trace = conversion_factor.reshape((conversion_factor.size, 1)) * np.asarray(raw_trace)
    
    return volt_trace

def RemoveDC(traces):
    
    mean_traces = np.mean(traces, axis=-1)
    
    mean_remove_traces = traces - mean_traces.reshape((mean_traces.size, 1),)
    
    return mean_remove_traces

def CorrectLO(traces):
    
    # traces.shape = (Nangle, Nch, Nsample)
    lo_fft = np.fft.fft(scipy.signal.hilbert(traces[:, 0, :]), axis=-1) # LO channel
    
    phase_at_max = np.diag(np.angle(lo_fft[:, np.argmax(abs(lo_fft), axis=-1)])) # lo phase of channel 0 at all angle acquisitions
    
    phase_diff = phase_at_max - phase_at_max[0] # relative lo phases to 0th acquisition

    correction = np.exp(-1j * phase_diff).reshape((phase_diff.size, 1, 1))
    
    corrected_traces = scipy.signal.hilbert(traces[:, :, :], axis=-1) * correction
    
    return corrected_traces


    

In [None]:
os.listdir(os.path.join(LABDATAPATH, 'dig'))

In [None]:
os.listdir(os.path.join(PATH, 'datasets', 'data', 'lab', 'dig', ))

In [None]:
#print(os.listdir(os.path.join(LABDATAPATH, 'vna', '211119_cres2_vna_trace', 'range3cm_step1mm_z0mm')))

path2data = os.path.join(LABDATAPATH, 'dig', '211213_cres2_meas_3ch_10slot')

name = '211213_cres2_meas_3ch_10slot.h5'

save_file = os.path.join(PATH, 'datasets', 'data', 'lab', 'dig', name)

#CombineMeasurements(os.path.join(DAQDATAPATH, name), path2data)

In [None]:

CombineMeasurements(save_file, path2data)


# check that data is being parsed correctly

In [None]:
debug_data = os.path.join(PATH, 'datasets', 'data', 'lab', 'dig', '211213_cres2_meas_3ch_debug.h5')

In [None]:
data = h5py.File(debug_data, 'r')

In [None]:
array1 = data['range3cm_step1mm_z0mm']['0'][:]

In [None]:
data.close()

In [None]:
#sns.set_theme(context='paper', style='whitegrid')

fig = plt.figure(figsize=(13, 8))
ax = fig.add_subplot(1,1,1)


for i in range(3):
    ax.plot(array1[i, 1, :].real)
    #ax.plot(array1[30, 0, :].real)
    #ax.plot(array1[60, 0, :].real)
    #ax.plot(array1[90, 0, :].real)
    #ax.plot(array1[6, 0, :].real)

ax.set_xlim(0, 200)