In [None]:
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Use the common dataset loader
from ME_dataset_libs import SingleModuleImage2D_solo_ME, solo_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor, make_dense_array

In [None]:
import h5py
from scipy.sparse import coo_matrix

## This function just pulls an image directly from the file, without going through a pytorch dataloder
## You would need to have a file open: f = h5py.File(input_file, 'r')
def show_image(i, f):
    group = f[str(i)]
    data = group['data'][:]
    row = group['row'][:]
    col = group['col'][:]

    ## Use the format that ME requires                                                                                                                                                                         
	## Note that we can't build the sparse tensor here because ME uses some sort of global indexing                                                                                                            
	## And this function is replicated * num_workers                                                                                                                                                           
    this_sparse = coo_matrix((data, (row, col)), dtype=np.float32, shape=(800, 256))    
    this_image = this_sparse.toarray()

    gr = plt.imshow(this_image, origin='lower')
    plt.colorbar(gr)
    plt.show()

In [None]:
## Setup the dataloader
import time
from ME_dataset_libs import FirstRegionCrop, DoNothing
start = time.process_time() 

## Modify the nominal transform
#nom_transform = transforms.Compose([
#            FirstRegionCrop((800, 256), (768, 256)),
#            ConstantCharge(),
#            ])
nom_transform=DoNothing()

data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATIONv2"
max_data_events=200000
max_sim_events=200000
single_sim_dataset = SingleModuleImage2D_solo_ME(sim_dir, transform=nom_transform, max_events=max_sim_events)
single_data_dataset = SingleModuleImage2D_solo_ME(data_dir, transform=nom_transform, max_events=max_data_events)

print("Time taken to load", single_data_dataset.__len__(),"data and", single_sim_dataset.__len__(), "images:", time.process_time() - start)

data_loader   = torch.utils.data.DataLoader(single_data_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=1024,
                                            shuffle=False,
                                            num_workers=4)

sim_loader    = torch.utils.data.DataLoader(single_sim_dataset,
                                            collate_fn=solo_ME_collate_fn,
                                            batch_size=1024,
                                            shuffle=False,
                                            num_workers=4)

In [None]:
def raw_image_loop(loader):
    nhits = []
    maxQ = []
    sumQ = []
    labels = []
    y_range = []
    x_range = []    
    
    ## Loop over the images (discard any extra info returned by loader)
    for batch_coords, batch_feats, batch_labels, *_ in loader:

        batch_size = len(batch_labels)
        batch_coords = batch_coords.to(device)
        batch_feats = batch_feats.to(device)
        orig_batch = ME.SparseTensor(batch_feats, batch_coords, device=device)   

        y_range += [torch.max(i[:,0]).item()-torch.min(i[:,0]).item() for i in orig_batch.decomposed_coordinates]
        x_range += [torch.max(i[:,1]).item()-torch.min(i[:,1]).item() for i in orig_batch.decomposed_coordinates]
        nhits += [i.shape[0] for i in orig_batch.decomposed_features]
        sumQ += [i.sum().item() for i in orig_batch.decomposed_features]
        maxQ += [i.max().item() for i in orig_batch.decomposed_features]
        labels += [i for i in batch_labels]

    ## Return a dictionary to make my life easier
    return {
        "nhits": np.array(nhits),
        "sumQ": np.array(sumQ),
        "maxQ": np.array(maxQ),
        "labels": np.array(labels),
        "yrange": np.array(y_range),
        "xrange":np.array(x_range),
    }


In [None]:
## Get the processed vectors of interest from the datasets                                                                                                                                                     
data_raw = raw_image_loop(data_loader)
sim_raw = raw_image_loop(sim_loader)

In [None]:
## Visualise data
from matplotlib import cm
from ME_dataset_libs import Label
def make_label_comp_plot(sim_dataset, sim_raw, min_hits=20, save_name=None):

    cmap = cm.turbo.copy()
    cmap.set_under("#F0F0F0")
    
    ## How many labels are there? (skip data)
    label_values = [m.value for m in Label]
    label_names  = [m.name for m in Label]

    ## OLD
    # new_names = [r"EM", r"Neutron", r"Proton", r"External $\mu$", r"Multi-$\mu$", r"$\mu$-capture", r"$\mu$-decay", "Clean MIP", "Messy MIP"]
    # labels_to_show = [1, 2, 3, 6, 5, 7, 8, 10, 11]

    ## NEW
    new_names = ["Clean MIP", "Messy MIP", r"$\mu$-capture", r"$\mu$-decay", r"Multi-$\mu$", r"EM", r"Neutron", r"Proton", r"External $\mu$"]
    labels_to_show = [10, 11, 7, 8, 5, 1, 2, 3, 6]
    
    nlabels = len(labels_to_show)
    
    ## Set up the figure so there's one subfigure per label
    plt.figure(figsize=(nlabels*1.8,6))

    ## Loop over labels
    for l in range(nlabels):
        lval = labels_to_show[l]

        ## Selected id
        idx = 0

        ## Calculate fractions
        this_frac = np.sum(sim_raw['labels'] == label_values[lval])/float(len(sim_dataset))*100
        
        ## Pick a reasonable choice at random for throughgoing and stopping
        if "THROUGH" in label_names[lval] or "STOPPING" in label_names[lval]:
            lindices = np.where((sim_raw['labels'] == label_values[lval]) & (sim_raw['nhits'] > min_hits))[0]
            idx = lindices[0]
        else:
            # Step 1: get indices matching the label condition
            label_mask = (sim_raw['labels'] == label_values[lval])
            filtered_indices = np.where(label_mask)[0]

            # Step 2: find the index of the max nhits among those indices
            max_idx_within_filtered = np.argmax(sim_raw['nhits'][filtered_indices])

            # Step 3: map back to original array index
            idx = filtered_indices[max_idx_within_filtered]
        
        coords, feats, *_ = sim_dataset[idx]
        ax = plt.subplot(1,nlabels,l+1)
        inputs = make_dense_array(coords, feats.squeeze(), 800, 256)
        nonzero_vals = inputs[inputs > 0]
        vmax = np.percentile(nonzero_vals, 80)
        gr = plt.imshow(inputs, origin='lower', cmap=cmap, vmin=1e-6, vmax=vmax)
        ax.axis('off')
        ax.set_xlabel(label_names[lval], fontsize=12)
        ax.text(0.5, -0.01, f"{new_names[l]}\n({this_frac:.2f}%)", ha='center', va='top', transform=ax.transAxes, fontsize=12)
    
    if save_name: plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    return



In [None]:
make_label_comp_plot(single_sim_dataset, sim_raw, 600, 'example_labels.jpg')
# print(len(single_sim_dataset))