In [None]:
# Import from local package
import sys
sys.path.append('../model')

import numpy as np
import torch

In [None]:
### import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.style.use('classic')

In [None]:
from models import SimpleCNN2Layer as Model
from collectdata import collect_data
from training import select_gpu
from plots import plot_ruiplot
from efficiency import pv_locations, efficiency

In [None]:
device = select_gpu(0)

In [None]:
valfile = Path('/share/lazy/schreihf/PvFinder/Oct03_20K_val.npz')
validation = collect_data(valfile,
                          batch_size=1,
                          slice=slice(100),
                          masking=True,
                          device=device)

In [None]:
print(*np.sum(np.isnan(validation.dataset.tensors[1].cpu().numpy()), 1))

In [None]:
name = '../notebooks/Sep_18_mask_120000_2layer/Sep_18_mask_120000_2layer_5.pyt'
model = Model().to(device)
model.load_state_dict(torch.load(name))
model.eval()

In [None]:
%%time
with torch.no_grad():
    outputs = model(validation.dataset.tensors[0]).cpu().numpy()
    labels = validation.dataset.tensors[1].cpu().numpy()

In [None]:
inputs = validation.dataset.tensors[0].cpu().numpy().squeeze()
zvals = np.linspace(-100, 300, 4000, endpoint=False) + 0.05
finalmsg = ''
internal_count = 0

for n in range(10):
    input = inputs[n]
    label = labels[n]
    output = outputs[n]
    
    parameters = {
        "threshold": 1e-2,
        "integral_threshold": .2,
        "min_width": 3
    }
    
    ftruth = pv_locations(label, **parameters)
    fcomputed = pv_locations(output, **parameters)
    results = efficiency(label, output, difference=5.0, **parameters)
    
    finalmsg += str(results) + '\n'
    
    truth = np.around(ftruth).astype(np.int32)
    computed = np.around(fcomputed).astype(np.int32)
    
    # Join arrays and remove any points closer than 5 bins
    poi = np.sort(np.concatenate([truth, computed]))
    poi = poi[np.concatenate([[True], np.fabs(np.diff(poi)) > 5])]
    
    for i in poi:
        b_truth = np.fabs(ftruth - i) <= 5
        b_comp = np.fabs(fcomputed - i) <= 5
        in_truth = np.any(b_truth)
        in_comp = np.any(b_comp)
        
        if in_truth and in_comp:
            msg = 'PV found'
        elif in_truth:
            msg = 'PV not found'
        else:
            msg = 'False positive'
            
        with plt.style.context({
            'font.size':18,
            'font.weight':'bold'}):
        
            ax1, ax2 = plot_ruiplot(zvals, i, input, label, output)
            ax1.set_title(f"Event {n}: {msg}", fontdict={'size':18, 'weight':'bold'})

            v = .8

            truth_centroid = (ftruth[b_truth] / 10) - 100
            for value in truth_centroid:
                ax1.text(.02, v, f"True: {value:.3f} mm",
                         transform=ax1.transAxes)
                v -= .07

            comp_centroid = (fcomputed[b_comp] / 10) - 100
            for value in comp_centroid:
                ax1.text(.02, v, f"Pred: {value:.3f} mm",
                         transform=ax1.transAxes)
                v -= .07
                
            if len(truth_centroid) == 1 and len(comp_centroid) == 1:
                diff = np.fabs(truth_centroid[0] - comp_centroid[0]) * 1_000
                ax1.text(.02, v, f"∆: {diff:.0f} µm",
                         transform=ax1.transAxes)
                v -= .07

            plt.savefig(f'120000_3layer_{internal_count:02}.pdf')
            plt.show()
            internal_count += 1
            
print(finalmsg)