In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
!curl -O https://raw.githubusercontent.com/DUNE/larnd-sim/refs/heads/develop/larndsim/bin/response_44_v2a_full.npz

In [None]:
def gauss(x, mu, sigma):
    return 1/(sigma * np.sqrt(2*np.pi)) * np.exp(-(x-mu)**2/sigma/sigma/2)

In [None]:
def field_response():
    response = np.load('response_44_v2a_full.npz')
    return response

In [None]:
fr0 = field_response()['response'][0,0]
fr0 = fr0.reshape(-1, 2).sum(axis=-1)[-1801:]
print(np.argmax(fr0[::-1]>0))

In [None]:
plt.plot(fr0)
plt.xlabel('time tick [50ns]')
plt.ylabel('dq/dt')
plt.title('Charge')
print(len(fr0))

In [None]:
q = gauss(np.arange(-10, 10, 0.1), 0, 0.3)

In [None]:
plt.plot(q)
plt.xlabel('tick [50ns]')
plt.ylabel('q per tick')
plt.title('Charge')

In [None]:
wf = np.convolve(q, fr0)

In [None]:
def deconv(wf, k):
    deconvq = np.fft.ifft(np.fft.fft(wf, n=len(wf)) / np.fft.fft(k, n=len(wf))).real
    return deconvq

In [None]:
deconvq = deconv(wf, fr0)

In [None]:
plt.plot(deconvq[:200], '--', label='qhat')
plt.plot(q, '.-', label='true')
plt.legend()

In [None]:
# integrate every k ticks
def integrate_k(wf, k, offset=0):
    wf = wf[offset:]
    if wf.shape[0] % k:
        # raise ValueError('wf.shape[0] % k != 0')
        wf = np.pad(wf, (0, k - wf.shape[0] % k))
        print('warning: length of wf is now', len(wf))
    return wf.reshape(-1, k).sum(axis=-1)

In [None]:
kticks = 30
wf_k = integrate_k(wf, kticks)

In [None]:
for i in range(0, kticks, 4):
    fr0_k = integrate_k(fr0[:-1], kticks, offset=i)
    qdeconv2 = deconv(wf_k, fr0_k)
    print(np.sum(qdeconv2), np.sum(q), np.sum(deconvq))

    plt.plot(np.arange(0, 200+1, kticks), qdeconv2[:200//kticks+1]/kticks, label=f'deconv, fr offset={i}, dq={np.sum(qdeconv2)-np.sum(q):.2f}')
plt.plot(q, label='true')
plt.xlim(0, 250)
plt.legend(loc='upper right')
plt.xlabel('Time tick [50ns]')
plt.ylabel('q')
plt.title('Charge')

In [None]:
# linear interpolation
def interpolate(wf_k, k):
    return np.interp(np.arange(0, len(wf_k)* k), [0,] + [r for r in range(k, (len(wf_k)+1)*k, k)], np.cumsum([0,]+[w for w in wf_k]))

In [None]:
wf_interpo = np.diff(interpolate(wf_k, kticks))

In [None]:
plt.plot(wf_interpo, label='interpolated')
plt.plot(wf, label='true')
plt.legend()

In [None]:
qdeconv3 = deconv(wf_interpo, fr0)
print(np.sum(qdeconv2), np.sum(q), np.sum(deconvq))
print(np.argmax(qdeconv3), np.argmax(q), np.argmax(deconvq))
print(np.std(qdeconv3[:200]), np.std(q), np.std(deconvq[:200]))


plt.plot(qdeconv3[:200], label='interpolation')
plt.plot(deconvq[:200], label='qhat')
plt.plot(q, label='true')
plt.legend()

In [None]:
def deconv_filter(wf, kernel, lam0, lam_hf, lam_exp):
    wf_fft = np.fft.fft(wf, n=len(wf))
    kernel_fft = np.fft.fft(kernel, n=len(wf))
    kt = np.arange(len(wf), dtype=np.float64)
    fnyq = max(1, len(wf) // 2)
    k_fold = np.minimum(kt, len(wf) - kt)
    if lam_hf > 0.0:
        lamvec = lam0 + lam_hf * (k_fold / float(fnyq)) ** float(lam_exp)
    else:
        lam_vec = np.full((len(wf),), lam0, dtype=np.float64)
    winer_like = np.conj(kernel_fft)/(np.absolute(kernel_fft)**2 + lamvec)

    return np.fft.ifft(wf_fft * winer_like).real

In [None]:
qdeconv4 = deconv_filter(wf_interpo, fr0, 50, 10., 1.)
print(np.sum(qdeconv4), np.sum(q), np.sum(deconvq))
print(np.argmax(qdeconv4), np.argmax(q), np.argmax(deconvq))
print(np.std(qdeconv4[:200]), np.std(q), np.std(deconvq[:200]))


plt.plot(qdeconv4[:200], label='interpolation')
plt.plot(deconvq[:200], label='qhat')
plt.plot(q, label='true')
plt.legend()