In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Use the common dataset loader
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor, make_dense_array

In [None]:
import h5py
from scipy.sparse import coo_matrix

## This function just pulls an image directly from the file, without going through a pytorch dataloder
## You would need to have a file open: f = h5py.File(input_file, 'r')
def show_image(i, f):
    group = f[str(i)]
    data = group['data'][:]
    row = group['row'][:]
    col = group['col'][:]

    ## Use the format that ME requires                                                                                                                                                                         
	## Note that we can't build the sparse tensor here because ME uses some sort of global indexing                                                                                                            
	## And this function is replicated * num_workers                                                                                                                                                           
    this_sparse = coo_matrix((data, (row, col)), dtype=np.float32, shape=(800, 256))    
    this_image = this_sparse.toarray()

    gr = plt.imshow(this_image, origin='lower')
    plt.colorbar(gr)
    plt.show()

In [None]:
import torchvision.transforms.v2 as transforms
from ME_dataset_libs import MaxRegionCrop, RandomGridDistortion2D, RandomShear2D, RandomRotation2D, RandomHorizontalFlip, \
    RandomBlockZeroImproved, RandomScaleCharge, RandomJitterCharge, ConstantCharge, SemiRandomCrop, DoNothing

x_max=256
y_max=512
aug_transform = transforms.Compose([
            RandomGridDistortion2D(50,3),
            RandomShear2D(0.1, 0.1),
            RandomRotation2D(-10,10),
            RandomHorizontalFlip(),
            RandomBlockZeroImproved([0,10], [5,10], [0,x_max], [0,y_max]),
            RandomScaleCharge(0.02),
            RandomJitterCharge(0.02),
    	    SemiRandomCrop(x_max, y_max)
            ])

## Load some images into a data loader
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATION"
data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"

sim_dataset = SingleModuleImage2D_MultiHDF5_ME(data_dir, nom_transform=DoNothing(), aug_transform=aug_transform)
print("Found", sim_dataset.__len__(), "simulated events")
data_dataset = SingleModuleImage2D_MultiHDF5_ME(data_dir, nom_transform=DoNothing(), aug_transform=aug_transform)
print("Found", data_dataset.__len__(), "data events")

In [None]:
## Visualise data
def make_aug_comp_plot(dataset, n=0):

    ## The dataset works with pairs, so this is just a bit hacky to get more examples
    aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = dataset[n]
    aug3_bcoords, aug3_bfeats, aug4_bcoords, aug4_bfeats, _, _ = dataset[n]

    # Visualize the image
    plt.figure(figsize=(15,5))
    ax = plt.subplot(1,5,1)
    gr1 = plt.imshow(make_dense_array(orig_bcoords, orig_bfeats.squeeze(), 800, 256), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,2)
    gr2 = plt.imshow(make_dense_array(aug1_bcoords, aug1_bfeats.squeeze(), 512, 256), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,3)
    gr3 = plt.imshow(make_dense_array(aug2_bcoords, aug2_bfeats.squeeze(), 512, 256), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,4)
    gr4 = plt.imshow(make_dense_array(aug3_bcoords, aug3_bfeats.squeeze(), 512, 256), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,5)
    gr5 = plt.imshow(make_dense_array(aug4_bcoords, aug4_bfeats.squeeze(), 512, 256), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 


In [None]:
## Dump a few events!
## This will
for n in range(5): 
    make_aug_comp_plot(data_dataset, n)
    make_aug_comp_plot(sim_dataset, n)
