In [None]:
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow_addons as tfa
from matplotlib import pyplot as plt

from freedom.neural_nets.transformations import hitnet_trafo, chargenet_trafo
from freedom.utils.i3cols_dataloader import load_events

In [None]:
def imshow_zero_center(image, **kwargs):
    lim = tf.reduce_max(abs(image))
    plt.figure(figsize=(12,9))
    plt.imshow(image, vmin=-lim, vmax=lim, cmap='seismic', **kwargs)
    plt.colorbar()

In [None]:
model_path = "../../freedom/resources/models/DeepCore/"
model = keras.models.load_model(model_path+'HitNet_ranger_08_Dec_2021-10h53/epoch_50_model.hdf5',
                                custom_objects={'hitnet_trafo':hitnet_trafo})
model.layers[-1].activation = tf.keras.activations.linear
model.compile()

In [None]:
x = tf.Variable([[1, 1, -350, 11000, 1, 0, 0, 0, np.pi, 0]], name='x_var', dtype=tf.float32)
t = tf.Variable([[1, 50, -350, 10770, 4, 1, 10, 10]], name='t_var', dtype=tf.float32)

Gradients

In [None]:
with tf.GradientTape(watch_accessed_variables=False) as tape:
    tape.watch(t)
    model_vals = model([x, t])

model_grad = tape.gradient(model_vals, t)
model_grad.numpy()

Check gradients

In [None]:
X = np.tile([1, 1, -350, 11000, 1, 0, 0, 0, np.pi, 0], 100).reshape((100,10))
o = np.ones(100)
r = np.linspace(10770, 10771, 100)
T = np.stack([o, 50*o, -350*o, r, 4*o, o, 10*o, 10*o]).T

y = model.predict([X, T])
pred = y

plt.figure(figsize=(12,9))
plt.plot(r, pred, label='LLH')
plt.xlabel('time')

(pred[1] - pred[0])[0] * 100, model_grad.numpy()[0][3]

Hessian

In [None]:
with tf.GradientTape(watch_accessed_variables=False) as tape2:
    tape2.watch(t)
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(t)
        model_vals = model([x, t])

    model_grad = tape.gradient(model_vals, t)

model_hess = tape2.jacobian(model_grad, t)
np.diag(tf.reshape(model_hess, [8, 8])) #model_hess.numpy()

In [None]:
hess_mat = tf.reshape(model_hess, [8, 8])
imshow_zero_center(hess_mat)

Fisher/Covariance

In [None]:
def TikhonovCorrection(FisherMatrix, threshold=0.001):
    if (np.linalg.eigvals(FisherMatrix) < 0).any():
        #print('negative eigenvalues detected...')
        if (np.linalg.eigvals(FisherMatrix) > -threshold).all():
            #print('neg. EV seem quite small. Trying Tikhonov correction...')
            correction = 0
            # find appropriate correction value first:
            for eigval in np.linalg.eigvals(FisherMatrix):
                if eigval < 0 and abs(eigval) > correction:
                    correction = abs(eigval)*2
            for k in range(0,7):
                FisherMatrix[k,k] += correction
            if (np.linalg.eigvals(FisherMatrix) >= 0).all():
                print('...worked!')
    return FisherMatrix

In [None]:
fisher_mat = TikhonovCorrection(-hess_mat.numpy())

In [None]:
cov_mat = np.linalg.inv(fisher_mat)

In [None]:
np.sqrt(np.diag(cov_mat))

Real event

In [None]:
model_c = keras.models.load_model(model_path+'ChargeNet_ranger_23_Nov_2021-13h51/epoch_2000_model.hdf5',
                                  custom_objects={'chargenet_trafo':chargenet_trafo})
model_c.layers[-1].activation = tf.keras.activations.linear
model_c.compile()

In [None]:
def calculate_hess(event, h_model, c_model, theta, return_grad=False):
    t = tf.Variable(theta)
    x_h, x_c = tf.Variable(event['hits']), tf.Variable([event['total_charge']])
    
    with tf.GradientTape(watch_accessed_variables=False) as tape2:
        tape2.watch(t)
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch(t)
            
            t_h = tf.reshape(tf.tile(t, tf.constant([len(event['hits'])])), (len(event['hits']), 8))
            t_c = tf.reshape(t, (1,8))
            model_vals = tf.add(tf.math.reduce_sum(tf.multiply(x_h[:, 4], h_model([x_h, t_h])[:, 0])), 
                                c_model([x_c, t_c])
                               )
        
        model_grad = tape.gradient(model_vals, t)
    model_hess = tape2.jacobian(model_grad, t).numpy()

    if return_grad:
        return model_hess, model_grad
    return model_hess

In [None]:
with open('../../freedom/resources/test_data/test_events.pkl', 'rb') as f:
    events = pickle.load(f)[:1000]
    
df = pd.read_pickle('/tf/localscratch/weldert/freeDOM/recos/OscNext/numu_only_noSmall_50.pkl')[:len(events)]

In [None]:
N = 0
bf = np.array([df.x[N], df.y[N], df.z[N], df.time[N], df.azimuth[N], 
               df.zenith[N], df['cascade energy'][N], df['track energy'][N]])

In [None]:
%%time
hess_mat, grads = calculate_hess(events[N], model, model_c, bf, True)
np.diag(hess_mat)

In [None]:
fisher_mat = TikhonovCorrection(-hess_mat)

In [None]:
cov_mat = np.linalg.inv(fisher_mat)

In [None]:
imshow_zero_center(cov_mat)

In [None]:
np.sqrt(np.diag(cov_mat))

In [None]:
events[N]['params'] - bf

minimize

In [None]:
p = np.array(bf)
for i in range(10):
    hess_mat, grads = calculate_hess(events[N], model, model_c, p, True)
    newton_step = -np.matmul(np.linalg.inv(hess_mat), grads)
    p += 0.1*newton_step
p

In [None]:
events[N]['params'] - p

Many events

In [None]:
'''
resi, std, curv = [], [], []
for i, e in enumerate(events):
    #if len(e['hits']) > 99:
    #    continue
    bf = np.array([df.x[i], df.y[i], df.z[i], df.time[i], df.azimuth[i], 
                   df.zenith[i], df['cascade energy'][i], df['track energy'][i]])
    
    hess_mat = calculate_hess(e, model, model_c, bf)
    fisher_mat = TikhonovCorrection(-hess_mat)
    cov_mat = np.linalg.inv(fisher_mat)
    
    std.append(np.sqrt(np.diag(cov_mat)))
    resi.append(e['params'] - bf)
    curv.append(-np.diag(hess_mat))
'''

In [None]:
#np.save('resi', resi)
#np.save('std', std)
#np.save('curv', curv)

resi, std, curv = np.load('resi.npy'), np.load('std.npy'), np.load('curv.npy')
np.sum(np.isnan(std)), std.shape

In [None]:
i = 5
r, s = np.abs(resi[:, i]), std[:, i]

plt.figure(figsize=(12,9))
plt.scatter(r[np.isfinite(s)], s[np.isfinite(s)])
#plt.scatter(r[10], s[10], color='red')
plt.plot([0, np.max(np.abs(resi[:, i]))], [0, np.max(np.abs(resi[:, i]))], c='black')
plt.xlabel('|residuum|')
plt.ylabel('std from covariance from fisher from hessian')
plt.ylim(0,3)
plt.xlim(0,3)
np.corrcoef(r[np.isfinite(s)], s[np.isfinite(s)])[0][1]

In [None]:
i = 7
r, s = np.abs(resi[:, i]), 1/curv[:, i]

plt.figure(figsize=(12,9))
plt.scatter(r[np.isfinite(s)], s[np.isfinite(s)])
#plt.plot([0, np.max(np.abs(resi[:, i]))], [0, np.max(np.abs(resi[:, i]))], c='black')
plt.xlabel('|residuum|')
plt.ylabel('-curvature')
#plt.ylim(-1,300)
#plt.xlim(0,3)
np.corrcoef(r[np.isfinite(s)], s[np.isfinite(s)])[0][1]