In [None]:
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import urllib.request

use_warning = False

np.random.seed(20)

def gauss(x, mu, sigma):
    return 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2/sigma/sigma/2)

def ensure_response_file(npz_path="response_44_v2a_full.npz"):
    npz_path = Path(npz_path)
    if npz_path.exists():
        return npz_path

    url = "https://raw.githubusercontent.com/DUNE/larnd-sim/develop/larndsim/bin/response_44_v2a_full.npz"
    print(f"Downloading {url} -> {npz_path} ...")
    urllib.request.urlretrieve(url, npz_path.as_posix())
    return npz_path

def deconv(wf, k):
    deconvq = np.fft.ifft(np.fft.fft(wf, n=len(wf)) / np.fft.fft(k, n=len(wf))).real
    return deconvq[:len(wf)+1-len(k)]

def deconv_filter(wf, kernel, lam0, lam_hf, lam_exp):
    wf_fft = np.fft.fft(wf, n=len(wf))
    kernel_fft = np.fft.fft(kernel, n=len(wf))
    kt = np.arange(len(wf), dtype=np.float64)
    fnyq = max(1, len(wf) // 2)
    k_fold = np.minimum(kt, len(wf) - kt)
    if lam_hf > 0.0:
        lam_vec = lam0 + lam_hf * (k_fold / float(fnyq)) ** float(lam_exp)
    else:
        lam_vec = np.full((len(wf),), lam0, dtype=np.float64)
    winer_like = np.conjugate(kernel_fft)/(np.absolute(kernel_fft)**2 + lam_vec)

    return np.fft.ifft(wf_fft * winer_like).real

def field_response(npz_path="response_44_v2a_full.npz"):
    path = ensure_response_file(npz_path)
    return np.load(path)

def trigger(wf, thres, noise=0):
    return np.argmax((np.cumsum(wf)+noise * np.random.normal(0, 1, wf.shape))>thres)

def integrate_k(wf, k):
    if wf.shape[0] % k:
        # raise ValueError('wf.shape[0] % k != 0')
        wf = np.pad(wf, (0, k - wf.shape[0] % k))
        if use_warning:
            print('warning: length of wf is now', len(wf))
    return wf.reshape(-1, k).sum(axis=-1)

def add_noise(wf, amp):
    noise = amp*np.random.normal(0, 1, (len(wf) +1 ))
    noise = np.diff(noise)
    return wf + noise, np.sum(noise)

# integrate every k ticks
def trigger_integrate_k(wf, k, start_idx):
    if wf[start_idx:].shape[0] % k:
        # raise ValueError('wf.shape[0] % k != 0')
        wfrest = np.pad(wf[start_idx:], (0, k - wf[start_idx:].shape[0] % k))
        if use_warning:
            print('warning: length of wf is now', len(wf))
    else:
        wfrest = wf[start_idx:]
    wfrest = wfrest.reshape(-1, k).sum(axis=-1)
    wf_trunc = np.zeros(len(wfrest)+1)
    wf_trunc[0] += np.sum(wf[:start_idx])
    wf_trunc[1:] += wfrest
    return wf_trunc, start_idx

def lost_waveform(wf_k_squeeze, kticks, fr_fg, thres, start_idx_fg):
    ''' assume fr is full length without downsampling
    wf_k_squeeze is 0 to trigger, trigger to end, and burst
    '''
    qmax = np.max(np.cumsum(wf_k_squeeze))
    # print('qmax', qmax)
    idx = np.argmin(np.abs(thres/qmax - np.cumsum(fr_fg)/np.sum(fr_fg)))
    # print(idx, np.cumsum(fr_fg)[idx], thres/qmax)
    lost = np.zeros(start_idx_fg) # at threshold
    if idx < start_idx_fg:
        # print('ok, good')
        lost[-idx:] = fr_fg[:idx] * thres / np.sum(fr_fg[:idx])
        # print(np.sum(lost))
    else:
        print('not good')
        lost = fr_fg[idx-start_idx_fg:idx] * thres / fr_fg[idx]
    ncycles = len(lost) // kticks
    # print(ncycles, len(wf_trunc))
    wf_full = np.zeros(ncycles + len(wf_k_squeeze)-1)
    wf_full[:ncycles] = lost[-ncycles * kticks:].reshape(ncycles, kticks).sum(axis=-1)
    # wf_full[ncycles] = wf_k_squeeze[0]
    # print(wf_full[ncycles], wf_trunc[0], thres)
    wf_full[ncycles:] = wf_k_squeeze[1:]
    return wf_full

def scale_q_test(qscale, shift=0, kticks=30, thres=5, noise=1, use_filter: None|dict = None, show: bool =True):
    q = gauss(np.arange(-10, 10, 0.1), 0, 0.5)
    q /= np.sum(q)
    qunit = gauss(np.arange(-10, 10, 0.1), 0, 0.5)
    qunit /= np.sum(qunit)

    q *= qscale
    q = np.roll(q, shift)
    if show:
        plt.figure()
        plt.title('Input Charge distribution')
        plt.plot(q)
    fr0 = field_response()['response'][0,0]
    fr0 = fr0.reshape(-1, 2).sum(axis=-1)[-1801:]
    fr0 *= 0.05

    wf = np.convolve(q, fr0, mode='full')

    deconvq = deconv(wf, fr0)

    wf_k_squeeze, start_idx_fg = trigger_integrate_k(wf, kticks, trigger(wf, thres, noise=noise)+1)
    wf_k_squeeze[1:], total_noise = add_noise(wf_k_squeeze[1:], noise)
    fr0_k = integrate_k(fr0, kticks)

    qfr0_k = integrate_k(np.convolve(qunit, fr0), kticks)
    wf_full_k = lost_waveform(wf_k_squeeze, kticks, fr0, thres, start_idx_fg)
    if use_filter:
        qdeconv2 = deconv_filter(wf_full_k, fr0_k, lam0=use_filter['lam0'],
                                 lam_hf=use_filter['lam_hf'], lam_exp=use_filter['lam_exp'])
        qdeconv2 = qdeconv2[:len(wf_full_k)-len(fr0_k)+1]
        # print(len(qdeconv2))
        # qsmear = gauss_filter(q, width=use_filter['width'], dt=1)
    else:
        qdeconv2 = deconv(wf_full_k, fr0_k)

    if show:
        plt.figure()
        # plt.plot(np.arange(len(wf_full))*kticks, wf_full, label='guess')
        plt.plot(np.arange(len(wf_full_k))*kticks, wf_full_k, label='guess')
        plt.plot(wf*kticks, label='wf true')
        plt.xlabel('time tick[50ns]')
        plt.legend()
        plt.show()

    # print(len(qdeconv2), 'length')
    if show:
        plt.figure()
        plt.plot(np.arange(len(qdeconv2))*kticks, qdeconv2/kticks, label='average')
        # plt.plot(deconvq, label='qhat')
        plt.plot(q, label='true')
        plt.xlabel('time tick[50ns]')
        plt.ylabel('charge')
        plt.title('Charge')
        plt.legend()
        plt.show()

    return np.sum(qdeconv2) - np.sum(q), np.argmax(qdeconv2)*kticks - np.argmax(q)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def fit_line_band(x, y, k=2.0, use_vertical_residual=True):
    """
    Estimate the band in sigma, +/-k*sigma
    Args:
    - x, y: 1D array-like
    - k: 1~68%，2~95%
    - use_vertical_residual: True y - y_hat；False distance from point to charge.
    Return
    - a, b: fit y = a*x + b
    - sigma: the standard deviation of the residuals
    - band_half: the half width of the band
    """
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()
    if x.size != y.size:
        raise ValueError("len(x) must be equal to len(y)")

    # 1) y = a*x + b
    a, b = np.polyfit(x, y, deg=1)

    # 2) residue
    y_hat = a * x + b
    r = y - y_hat

    if use_vertical_residual:
        # vertical residue
        sigma = np.sqrt(np.mean(r**2))
    else:
        # from point to line：|ax - y + b| / sqrt(a^2 + 1)
        d = np.abs(a * x - y + b) / np.sqrt(a**2 + 1.0)
        sigma = np.sqrt(np.mean(d**2))

    band_half = k * sigma
    return a, b, sigma, band_half

def plot_line_band(x, y, a, b, band_half, title=None):
    x = np.asarray(x).ravel()
    y = np.asarray(y).ravel()

    # sorted
    xs = np.linspace(x.min(), x.max(), 300)
    ys = a * xs + b

    plt.figure(figsize=(7, 5))
    plt.scatter(x, y, s=25, alpha=0.8, label="data")
    plt.plot(xs, ys, linewidth=2, label=f"fit: y={a:.3f}x+{b:.3f}")
    plt.fill_between(xs, ys - band_half, ys + band_half, alpha=0.2,
                     label=f"band: ±{band_half:.3f}")
    plt.xlabel("x")
    plt.ylabel("y")
    if title:
        plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
delta_q, delta_t = scale_q_test(24, show=True, use_filter={'lam0' : 0.1, 'lam_hf' : 0.1, 'lam_exp' : 4.5})
print('delta_q -----------------------:', delta_q)

In [None]:
delta_qs2 = []
delta_ts2 = []
for qscale in np.arange(6, 100, 0.5):
    delta_q, delta_t = scale_q_test(qscale, show=False, use_filter={'lam0' : 0.2, 'lam_hf' : 0.1, 'lam_exp' : 2.5})
    delta_qs2.append(delta_q)
    delta_ts2.append(delta_t)

In [None]:
a2, b2, sigma, band_half = fit_line_band(np.arange(6, 100, 0.5), delta_qs2, k=1.0, use_vertical_residual=True)
plot_line_band(np.arange(6, 100, 0.5), delta_qs2, a2, b2, band_half=band_half)

In [None]:
delta_qs = []
delta_ts = []
for qscale in np.arange(6, 100, 0.5):
    delta_q, delta_t = scale_q_test(qscale, show=False)
    delta_qs.append(delta_q)
    delta_ts.append(delta_t)

In [None]:
a, b, sigma, band_half = fit_line_band(np.arange(6, 100, 0.5), delta_qs, k=1.0, use_vertical_residual=True)
plot_line_band(np.arange(6, 100, 0.5), delta_qs, a, b, band_half=band_half)

plt.scatter(np.arange(6, 100, 0.5), delta_qs)
plt.ylim(-15, 15)

In [None]:
# plt.hist(delta_qs, bins=20, alpha=0.5)
# plt.hist(delta_qs2, bins=20, alpha=0.5)

In [None]:
plt.hist(np.array(delta_ts), bins=100)

In [None]:
plt.scatter(np.arange(6, 100, 0.5), delta_ts)

In [None]:
delta_q, delta_t = scale_q_test(24, show=True)
print('delta_q -----------------------:', delta_q)

In [None]:
plt.scatter(np.arange(6, 100, 0.5), delta_ts2)