In [None]:
import pathlib
import numpy as np
import more_itertools as mit
import re

from matplotlib import pyplot as plt
%matplotlib inline
#plt.style.use('dark_background')

from tqdm.notebook import trange, tqdm

from six.moves import cPickle

import os

from scipy import signal as spsig
from scipy.signal import argrelextrema
from scipy.signal import argrelmin

#### Utility Functions

In [None]:
def glob_files(path, pattern=r'**/*.dat'):
    yield from (file for file in data_dir.glob('**/*.dat'))
    

def _extract_digit(pattern, text):
    """ helper for extract_bw_and_sf """
    try:
        match = int(re.search(pattern, text)[1])
    except:
        print('ERROR! unable to extract bw or sf')
        return 
    else:
        return match 
    
    
def extract_bw_and_sf(filename, BW_val):
    """ extracts params from input filename """
    #_bw_pattern, _sf_pattern, _l_pattern = r'BW(\d)', r'SF(\d{1,})', r'L(\d{1,})'
    _bw_pattern, _sf_pattern, _l_pattern = r'BW(\d)', r'SF(\d{1,})', r'Att(\d{1,})'
    bw_match = _extract_digit(_bw_pattern, filename)
    
    if bw_match is not None:
        bw_match = BW_val[bw_match - 1]
    
    sf_match = _extract_digit(_sf_pattern, filename)
    
    l_match = _extract_digit(_l_pattern, filename)
    
    return bw_match, sf_match, l_match

def check_and_load_file(filepath):
    filepath = pathlib.Path(filepath)
    
    if not isinstance(filepath, pathlib.Path) or not filepath.exists():
        print(f'ERROR! unable to find input file at:\n{filepath}')
        return
        
    return filepath


def _average_packet_length():
    sampPerSym = np.round(((2**SF)/BW)*Fs)
    pcktLen = 30.25*sampPerSym
    return pcktLen



### get encoding parameters

def get_and_set_encoding_params(filepath):
    BW_val = np.array([1, 2, 0, 0, 0, 0, 7, 8, 9])
    BW_val2 = np.array([0,10.4e3, 15.6e3, 0, 0, 0, 0, 125e3, 250e3, 500e3])
    BW, SF, Att= extract_bw_and_sf(filepath.name, BW_val) 
    
    Fs = int(1e6)

    sampPerSym = np.round(((2**SF)/BW_val2[BW])*Fs)
    pcktLen = 30.25*sampPerSym

    print(f'BW: {BW} | SF: {SF} | Fs: {Fs} | Samples Per Symbol: {sampPerSym} | Packet Length: {pcktLen}')
    return BW, SF, sampPerSym, Att


## load data
def load_data(filepath):
    try:
        signal = np.fromfile(filepath, dtype=np.complex64)
    except Exception as exc:
        print(f'unable to load file:\n{exc}')
        return 
    else:
        #print(f'loaded signal with {signal.size} samples')
        return signal

    
## normalize signal
def normalize_signal(signal, th=0.001):
    real_s = np.abs(np.real(signal)) 
    
    norm_s = np.array([
        np.ceil(val) if val >= th else 0
        for val in real_s
    ])
    return real_s, norm_s


### locate zero indices
def _locate_zero_indices(norm_signal):
    indices = np.where(norm_signal == 0.0)[0]
    
    return indices

### find consecutive groups 
def _find_consecutive_groups(zero_indexes):
    groups = [
        list(j) for j in 
        mit.consecutive_groups(sorted(list(set(zero_indexes))))
    ]
    
    lengths = np.array(
        [len(item) for item in groups]
    )
    
    #print(f'found {len(groups)} groups')
    return groups, lengths

### locate endpoints
def _locate_endpoints(groups, threshold=15_000):
    all_endpoints = [
        endpoints for group in groups
            if len(group) > threshold
        for endpoints in (group[0], group[-1])
    ]
    
    endpoint_pairs = [
        (start, stop) 
        for start, stop in zip(all_endpoints[1::2], all_endpoints[2::2])
    ]
    
    #print(f'Extracted {len(endpoint_pairs)} packets from signal')
    return endpoint_pairs

## extract endpoints
def plot_groups(indexes, lengths, filename):
    fig, axs = plt.subplots(2)
    fig.suptitle(f'{filename}')
    axs[0].plot(indexes)
    axs[1].plot(lengths)
    
    #plt.savefig(f'plots/{filename}.png')
    return fig

def extract_endpoints(normalized_signal, filename, threshold=15_000):
    zero_indexes = _locate_zero_indices(normalized_signal)
    consec_groups, groups_lengths = _find_consecutive_groups(zero_indexes)

    # validation plot 
    plot_groups(zero_indexes, groups_lengths, filename)
    
    endpoint_pairs = _locate_endpoints(consec_groups, threshold)
    return endpoint_pairs

def extract_indices(signal, scale_factor=101, window_length=17_500):
    real_s, down_s = preprocess(signal, scale_factor, window_length)
    
    packet_len = _average_packet_length()
    minima = find_minima(down_s, window=packet_len//scale_factor)
    
    _plot_comparison(real_s, down_s, minima)
    return minima 


def extract_endpoints_od(signal, scale_factor=101, window_length=17_500):
    minima = extract_indices(signal, scale_factor, window_length)
    
    endpoints = _generate_endpoints(minima, scale_factor, sampPerSym)
    return endpoints


def find_minima(signal, window=8_000):
    minima = spsig.argrelmin(signal, order=int(window))[0]
    
    return minima


def _generate_minima_plot(minima, size):
    minima_plot = np.zeros(size)
    
    for val in minima:
        minima_plot[val] = 1 
        
    return minima_plot


def _moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n


def _plot_comparison(real, down, minima):
    fig, axs = plt.subplots(3)
    fig.set_figwidth(10), fig.set_figheight(12)
    axs[0].plot(real[100_000:])
    axs[1].plot(down[1000:])
    axs[2].plot(_generate_minima_plot(minima, down.size)[1000:])




def preprocess(signal, scale_factor, window_length=17_500):
    real_s = np.abs(np.real(signal))
    
    down_s = _moving_average(
        spsig.decimate(real_s, scale_factor, ftype='iir'), 
        n=window_length // scale_factor
    )
    print(f'decimated by a factor of {scale_factor} | size: {down_s.size}')
    
    return real_s, down_s
    


## packet slicing
def _slice_and_pad(signal, endpoints, length):
    start, stop = endpoints
    sliced = signal[start:stop] 
    
    if len(sliced) < length:
        sliced = np.concatenate((
            sliced, np.zeros(length - len(sliced))
        ))
        
    return sliced


def slice_all_packets(signal, endpoints):
    max_length = max(stop - start for start, stop in endpoints)
    print(f'got max packet length: {max_length}')
    
    packets = np.vstack(
        tuple(
            _slice_and_pad(signal, pair, max_length)
            for pair in endpoints
        )
    )
    
    #print(f'Extracted {len(packets)} packets from signal')
    return packets


def _generate_endpoints(minima, scale, sampPerSym):
    packet_length = 30.25*sampPerSym
    return [
        (int(val*scale), int(val*scale+packet_length))
        for val in minima
    ]

def reshape_symbol_sets(d, num_samples):
    num_symbols, sym_len  = d.shape
    
    if sym_len < num_samples:
        R = int(np.ceil(num_samples/sym_len))
        N = int(np.floor(num_symbols/R))
    else:
        return d
        
    
    final_data = np.empty((N, num_samples), dtype='complex')
    
    for idx, r in enumerate(range(R, N, R)):
        # r is the row selectors for the start of each set to be reshaped
        last_r = last_r if r != R else 0
        
        data = d[last_r:r, :]
        d1d = np.reshape(data, data.size)
        final_data[idx, :] = d1d

        last_r = r
        
    return final_data




#### Define Data Directory

In [None]:
data_dir = pathlib.Path('Raw_Data/Benchtop/Attenuated/New/')
all_data_files = list(glob_files(data_dir))

print("# of files: " + str(len(all_data_files)))

In [None]:
all_data_files

### Function to extract packets, symbols, & save data

In [None]:
def extract_packets(filepath, norm_thresh=0.001, packet_thresh=15_000):
    filepath = check_and_load_file(filepath)    
    #print(f'norm threshold: {norm_thresh}')
    
    get_and_set_encoding_params(filepath) 
    signal = load_data(filepath)
    
    #N = 3_000_000 
    N = len(signal)
    
    real_s, norm_s = normalize_signal(signal[0:N], norm_thresh)
    #print("Normalization Complete")
  
    endpoint_pairs = extract_endpoints(norm_s, filepath.name, packet_thresh)
    print("Endpoint Pairs Extracted. Total: " + str(np.shape(endpoint_pairs)))
    
    x = np.shape(np.asarray(endpoint_pairs))
    if x[0] != 0:
        all_packets = slice_all_packets(signal[0:N], endpoint_pairs)
        print("Total Packets: " + str(np.shape(all_packets)))
    
        return all_packets
    else: 
        return [0]

def extract_symbols(packets, numPackets, sampPerSym):
    
    if sampPerSym < np.asarray(packets).shape[1]:
        # Get Params
        numSymbols = 8

        #print("Total Packet:" + str(numPackets))
        #print("Symbols to Extract: " + str(numSymbols))

        
        symbols = np.empty([numPackets*numSymbols,sampPerSym], dtype='complex')
        N = symbols.shape[0]
        k,l = 0, 0

        for i in range(numPackets):
            for j in range(numSymbols):
                symbols[l,:] = packets[i,k:k+sampPerSym]
                k = k + sampPerSym
                l = l+1
            k = 0

        print("# of symbols extract: " + str(symbols.shape[0]))
        return symbols
    else:
        return [0]

def save_data(symbols, dir_, filename):
    if str(os.path.exists(dir_+filename)) == 'True':

        with open(dir_+filename, mode='rb') as file:        
            d = cPickle.load(file)  
            
        d = np.concatenate([np.asarray(d), np.asarray(symbols)])
        
        print('Appending to file...')
        print('Total Symbols:' + str(d.shape))
        cPickle.dump(d, open(dir_+filename,'wb'))
        
    else:
        cPickle.dump(symbols, open(dir_+filename,'wb'))
    print("Saved data to:" + str(filename))

In [None]:
file = all_data_files[5]
print(file.name)
BW, SF, sampPerSym, L = get_and_set_encoding_params(file)
BW, SF, sampPerSym = int(BW), int(SF), int(sampPerSym)
print(BW, SF, L)

In [None]:
signal = load_data(file)
plt.plot(np.real(np.abs((signal[1_000_000:2_000_000]))))

#done: 4, 5,1,2

### Process all data (find zero padding approach)

In [None]:
numfiles = len(all_data_files)

for i in trange(1):
    file = all_data_files[5]
    signal = load_data(file)
    print("Processing:" + str(file))
    packets = extract_packets(file, norm_thresh=0.00015, packet_thresh=17_000)
    
    if len(packets) > 1:
        BW, SF, sampPerSym, L = get_and_set_encoding_params(file)
        BW, SF, sampPerSym = int(BW), int(SF), int(sampPerSym)
        numPackets = np.asarray(packets).shape[0]

        symbols = extract_symbols(packets, numPackets, sampPerSym)

        dir_ = 'Processed_Data/Indoor/Location5/'
        filename = "lora_symbols_BW" + str(BW) + "_SF" + str(SF) + "_L" + str(L) + ".p"
        #if len(symbols) > 2:
        #    save_data(symbols, dir_,filename)
            
plt.subplot(2,1,1)
plt.plot(np.fft.fftshift(np.abs(np.fft.fft((symbols[0,:])))))
plt.subplot(2,1,2)
plt.plot(np.real(symbols[0,:]))           

In [None]:
print(file.name)
plt.plot(symbols[50,:])

In [None]:
cPickle.dump(symbols, open('lora_symbols_BW1_SF12_Att40.p','wb'))

### Check number of symbols per class

In [None]:
def glob_files2(path, pattern=r'**/*.p'):
    yield from (file for file in data_dir.glob('**/*.p'))
    

In [None]:
data_dir = pathlib.Path('Processed_Data/Indoors/Location0/')
all_data_files = list(glob_files2(data_dir))

print("# of files: " + str(len(all_data_files)))
all_data_files

In [None]:

with open(all_data_files[12], mode='rb') as file:  
    d1 = cPickle.load(file)  

    print(file.name, d.shape)

In [None]:
dt = np.concatenate([d1[5,:], d1[6,:]])
dt.shape

In [None]:
plt.plot(np.fft.fftshift(np.abs(np.fft.fft((d[0,:])))))
plt.plot(np.fft.fftshift(np.abs(np.fft.fft((dt)))))