In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dask.distributed import Client, LocalCluster
client = Client(n_workers=1,
                threads_per_worker=7,
                memory_limit='18GB')
client

In [None]:
import copy
import sys
import xarray as xr
import numpy as np
import dask.array as da
import time
import os

import dask

import matplotlib.pyplot as plt
import hvplot.xarray
import holoviews as hv
import scipy.constants
import scipy

sys.path.append("..")
import processing_dask as pr
import plot_dask

sys.path.append("../../preprocessing/")
from generate_chirp import generate_chirp

In [None]:
#prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20230711_115449"

# For these: 50 m loopback cable, 30 dB attenuation before the cable, lab SDR
#prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20231206_170356" # 100k pulses, seemed like this one was maybe leveling out?
#prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20231206_173558" # 1 M pulses, 10 MHz BW, 10 us chirp duration
#prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20231206_174958" # 1 M pulses, 40 MHz BW, 10 us chirp duration
#prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20231209_150916" # 1 M pulses, 40 MHz BW, 10 us chirp duration (same setup, different day)
# 150 m of loopback cable, 0 dB attenuation before the cable, lab SDR
prefix = "/home/thomas/Documents/StanfordGrad/RadioGlaciology/sdr/data/20231209_151613" # 1 M pulses, 40 MHz BW, 10 us chirp duration (150 m of cable)

zero_sample_idx = 159
sig_speed = scipy.constants.speed_of_light * (2/3)

zarr_path = pr.save_radar_data_to_zarr(prefix, zarr_base_location="/home/thomas/Documents/StanfordGrad/RadioGlaciology/test_tmp_zarr_cache/", skip_if_cached=True)

zarr_path

In [None]:
raw = xr.open_zarr(zarr_path)

In [None]:
config = copy.deepcopy(raw.config)
#config['GENERATE']['window'] = 'blackman'

chirp_ts, chirp = generate_chirp(config)

# # Filter chirp
# chirp_freq_sweep_mhz = np.linspace(-28, 28, len(chirp))
# keep_mask = (chirp_freq_sweep_mhz > -10) & (chirp_freq_sweep_mhz < -1)
# chirp_filtered = np.zeros_like(chirp)
# chirp_filtered[keep_mask] = chirp[keep_mask]
# chirp = chirp_filtered
# # / Filter chirp

compressed = pr.pulse_compress(raw, chirp,
                               fs=raw.config['GENERATE']['sample_rate'],
                               zero_sample_idx=zero_sample_idx,
                               signal_speed=scipy.constants.c * (2/3)).persist()

In [None]:
ts = np.logspace(np.log10(2e-2), np.log10(300), 10)
ts = np.logspace(np.log10(2e-2), np.log10(10), 10)
#ts = np.logspace(np.log10(2e-2), np.log10(300), 20)
#ts = np.logspace(np.log10(2e-2), np.log10(1000), 20)

## Noise Floor Variance

In [None]:
actual_stack_t = np.nan * np.zeros_like(ts)
actual_stack_n = np.zeros_like(ts, dtype=int)
stack_noise_var = np.nan * np.zeros_like(ts)
stack_noise_mean = np.nan * np.zeros_like(ts)
stack_signal_mean = np.nan * np.zeros_like(ts)
stack_signal_var = np.nan * np.zeros_like(ts)

noise_start_m = 2000
noise_end_m = 4000
signal_start_m = 70
signal_end_m = 80

In [None]:
for t_idx, t in enumerate(ts):
    if not np.isnan(stack_noise_mean[t_idx]):
        continue
    
    timestamp = time.time()
    actual_stack_n[t_idx] = max(1, int(t / raw.attrs['config']['CHIRP']['pulse_rep_int']))
    actual_stack_t[t_idx] = actual_stack_n[t_idx] * raw.attrs['config']['CHIRP']['pulse_rep_int']
    print(f"[{t_idx+1}/{len(ts)}] \tt={actual_stack_t[t_idx]} \tn_stack={actual_stack_n[t_idx]}")
    
    with dask.config.set(**{'array.slicing.split_large_chunks': False}):
        compressed_subset = compressed[{'pulse_idx': slice(0, actual_stack_n[t_idx]*100)}]
        stacked = pr.stack(compressed_subset, actual_stack_n[t_idx])
        compressed_mag = xr.apply_ufunc(np.abs, stacked, dask='parallelized').chunk("auto")
        
        # Noise floor
        vs = compressed_mag["radar_data"].where((compressed_mag.reflection_distance > noise_start_m) & (compressed_mag.reflection_distance < noise_end_m)).dropna('travel_time').chunk("auto")
        if len(vs) > 20:
            vs = vs[:20]
        stack_noise_var[t_idx] = vs.var(dim="travel_time").mean().compute().item()
        stack_noise_mean[t_idx] = vs.mean().compute().item()

        # Signal peak
        ss = compressed_mag["radar_data"].where((compressed_mag.reflection_distance > signal_start_m) & (compressed_mag.reflection_distance < signal_end_m)).dropna('travel_time').chunk("auto")
        if len(ss) > 100:
            ss = ss[:100]
        stack_signal_mean[t_idx] = ss.max(dim="travel_time").mean().compute().item()
        stack_signal_var[t_idx] = ss.max(dim="travel_time").var().compute().item()
        
    print(f"Completed in {time.time() - timestamp} seconds from {len(vs)} computed variances and {len(ss)} computed signal peaks")

In [None]:
# output_base_stack = os.path.join("20230628-outputs/", raw.attrs["basename"]+"-stack")

# d = xr.Dataset({"noise_var": ("t", stack_noise_var)}, coords={"t": actual_stack_t, "n_stack": ("t", actual_stack_n)})
# d.to_netcdf(output_base_stack + ".nc")

In [None]:
actual_stack_n[0]

In [None]:
# Save actual_stack_t, actual_stack_n, stack_noise_var, stack_noise_mean, stack_signal_mean, stack_signal_var to a pickle file
import pickle
with open(os.path.join("20231208-outputs/", os.path.basename(prefix) + "-stats.pickle"), "wb") as f:
    pickle.dump({"actual_stack_t": actual_stack_t,
                 "actual_stack_n": actual_stack_n,
                 "stack_noise_var": stack_noise_var,
                 "stack_noise_mean": stack_noise_mean,
                 "stack_signal_mean": stack_signal_mean,
                 "stack_signal_var": stack_signal_var,
                 "config": raw.attrs["config"]}, f)

In [None]:
fig, ax = plt.subplots()
ax.semilogx()
#ax.scatter(actual_stack_t, stack_signal_var, label="Variance")
ax.scatter(actual_stack_t, 20*np.log10(stack_signal_mean), label="Mean")
ax.fill_between(actual_stack_t, 20*np.log10(stack_signal_mean - np.sqrt(stack_signal_var)), 20*np.log10(stack_signal_mean + np.sqrt(stack_signal_var)), alpha=0.2)
ax.set_xlabel('Total coherent integration time [s]')
ax.set_ylabel('Signal peak')
ax.set_title(f"pulse_rep_int = {raw.attrs['config']['CHIRP']['pulse_rep_int']} s\n{os.path.basename(prefix)}")
ax.legend()
ax.grid()

ax_n = ax.twiny()
ax_n.semilogx()
xmin, xmax = ax.get_xlim()
ax_n.set_xlim(xmin / raw.attrs['config']['CHIRP']['pulse_rep_int'], xmax / raw.attrs['config']['CHIRP']['pulse_rep_int'])
ax_n.set_xlabel('n_stack')
ax_n_ticks = (ax.get_xticks()/raw.attrs['config']['CHIRP']['pulse_rep_int']).round(1).astype(int)
ax_n.set_xticks([500, 5000, 50000])
ax_n.set_xticklabels([500, 5000, 50000])

plt.show()

In [None]:
ax_n_ticks

In [None]:
1/raw.attrs['config']['CHIRP']['pulse_rep_int']

In [None]:
fig, ax = plt.subplots()
ax.loglog()
ax.scatter(actual_stack_t, stack_noise_var, label="Variance")
#ax.scatter(actual_stack_t, stack_noise_mean, label="Mean")
ax.set_xlabel('Time [s]')
ax.set_ylabel('Noise floor (2-4km)')
ax.set_title(f"pulse_rep_int = {raw.attrs['config']['CHIRP']['pulse_rep_int']} s")
ax.legend()
plt.grid()
#fig.savefig(output_base_stack + ".png")

In [None]:
fig, ax = plt.subplots()
ax.loglog()
ax.scatter(actual_stack_n, stack_noise_var)
ax.set_xlabel('n_stack')
ax.set_ylabel('Variance of noise floor (2-4km)')
ax.set_title(f"pulse_rep_int = {raw.attrs['config']['CHIRP']['pulse_rep_int']} s")
plt.grid()
#fig.savefig(output_base_stack + ".png")

## Plotting

In [None]:
pickles = [
    "20231208-outputs/20231206_173558-stats.pickle",
    "20231208-outputs/20231206_174958-stats.pickle",
    "20231208-outputs/20231209_150916-stats.pickle",
    "20231208-outputs/20231209_151613-stats.pickle",
]

fig, ax = plt.subplots()
ax.semilogx()

for pickle_path in pickles:
    with open(pickle_path, "rb") as f:
        d = pickle.load(f)
    sig_mean_db = 20*np.log10(d["stack_signal_mean"])
    ax.scatter(d["actual_stack_t"], sig_mean_db - sig_mean_db[0], label=os.path.basename(pickle_path).split("-")[0])

#ax.scatter(actual_stack_t, 20*np.log10(stack_signal_mean), label="Mean")
#ax.fill_between(actual_stack_t, 20*np.log10(stack_signal_mean - np.sqrt(stack_signal_var)), 20*np.log10(stack_signal_mean + np.sqrt(stack_signal_var)), alpha=0.2)
ax.set_xlabel('Total coherent integration time [s]')
ax.set_ylabel('Signal peak')
ax.set_title(f"pulse_rep_int = {raw.attrs['config']['CHIRP']['pulse_rep_int']} s\n{os.path.basename(prefix)}")
ax.legend()
ax.grid()

ax_n = ax.twiny()
ax_n.semilogx()
xmin, xmax = ax.get_xlim()
ax_n.set_xlim(xmin / raw.attrs['config']['CHIRP']['pulse_rep_int'], xmax / raw.attrs['config']['CHIRP']['pulse_rep_int'])
ax_n.set_xlabel('n_stack')
ax_n_ticks = (ax.get_xticks()/raw.attrs['config']['CHIRP']['pulse_rep_int']).round(1).astype(int)
ax_n.set_xticks([500, 5000, 50000])
ax_n.set_xticklabels([500, 5000, 50000])

plt.show()

## Signal peak phase

In [None]:
# Signal
reflector_distance_expected = 25
expected_peak_idx = (np.abs(compressed.reflection_distance - reflector_distance_expected)).argmin().item()

peak_idxs = compressed["radar_data"].reduce(
    lambda x, axis: (np.abs((x[:, expected_peak_idx-5:expected_peak_idx+5]))).argmax(axis=axis) + expected_peak_idx-5,
    dim='travel_time')
peak_idxs.persist()
true_peak_idx = peak_idxs[0].compute().item()
if not (peak_idxs == true_peak_idx).all().compute().item():
    print("WARNING: Peak indices are not all the same!")

In [None]:
expected_internal_path_idx = (np.abs(compressed.reflection_distance)).argmin().item()
expected_internal_path_idx

In [None]:
peak_phases = xr.apply_ufunc(
        lambda x, idx: np.angle(x[idx]),
        compressed["radar_data"], peak_idxs,
        input_core_dims=[['travel_time'],[]], # The dimension operated over -- aka "don't vectorize over this"
        output_core_dims=[[]], # The output dimensions of the lambda function itself
        exclude_dims=set(("travel_time",)), # Dimensions to not vectorize over
        vectorize=True, # Vectorize other dimensions using a call to np.vectorize
        dask="parallelized", # Allow dask to chunk and parallelize the computation
        output_dtypes=[np.float32], # Needed for dask: explicitly provide the output dtype
        #dask_gufunc_kwargs={"output_sizes": {'travel_time': 1}} # Also needed for dask:
        # explicitly provide the output size of the lambda function. See
        # https://docs.dask.org/en/stable/generated/dask.array.gufunc.apply_gufunc.html
    ).persist()

In [None]:
fs = raw.attrs['config']['GENERATE']['sample_rate']

actual_dt = np.zeros_like(ts)
var = np.zeros_like(ts)

for t_idx, t in enumerate(ts):
    print(f"[{t_idx}/{len(ts)}] \tt={t}")
    pulses = max(1, int(t / raw.attrs['config']['CHIRP']['pulse_rep_int']))
    actual_dt[t_idx] = pulses * raw.attrs['config']['CHIRP']['pulse_rep_int']
    ph_group_mean = peak_phases.rolling(pulse_idx=pulses).mean()
    var[t_idx] = ((ph_group_mean[:-pulses].drop_indexes("pulse_idx") - ph_group_mean[pulses:].drop_indexes("pulse_idx"))**2).mean().compute().item()

In [None]:
output_base_2svar = os.path.join("20230628-outputs/", raw.attrs["basename"]+"-2svar")

d = xr.Dataset({"var_2s": ("dt", var)}, coords={"dt": actual_dt})
d.to_netcdf(output_base_2svar + ".nc")

In [None]:
fig, ax = plt.subplots()
ax.loglog()
ax.scatter(actual_dt, var)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Two sample phase variance')
ax.set_title(f"pulse_rep_int = {raw.attrs['config']['CHIRP']['pulse_rep_int']} s")
plt.grid()
fig.savefig(output_base_2svar + ".png")

In [None]:
output_base_phase = os.path.join("20230628-outputs/", raw.attrs["basename"]+"-phase")

peak_idx_plot = peak_idxs.hvplot.scatter(x='pulse_idx')
peak_phase_plot = peak_phases.hvplot.scatter(x='pulse_idx', datashade=True)
peak_phase_rolling_plot = peak_phases.rolling(pulse_idx=100).mean().hvplot.scatter(x='pulse_idx', datashade=True)

In [None]:
hv.save(peak_idx_plot, output_base_phase+"-peak-idx.png", fmt='png')
hv.save(peak_phase_plot, output_base_phase+"-peak-phase.png", fmt='png')
hv.save(peak_phase_rolling_plot, output_base_phase+"-peak-phase-rolling.png", fmt='png')

hv.save(peak_idx_plot, output_base_phase+"-peak-idx.html", fmt='widgets')
hv.save(peak_phase_plot, output_base_phase+"-peak-phase.html", fmt='widgets')
hv.save(peak_phase_rolling_plot, output_base_phase+"-peak-phase-rolling.html", fmt='widgets')

peak_idx_plot, peak_phase_plot, peak_phase_rolling_plot