In [None]:
import bootstrap
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
## Use the common dataset loader
from core.data.datasets import paired_2d_dataset_ME, triple_ME_collate_fn
from core.analysis.image_utils import make_dense, make_dense_from_tensor, make_dense_array

In [None]:
import h5py
from scipy.sparse import coo_matrix

## This function just pulls an image directly from the file, without going through a pytorch dataloder
## You would need to have a file open: f = h5py.File(input_file, 'r')
def show_image(i, f):
    group = f[str(i)]
    data = group['data'][:]
    row = group['row'][:]
    col = group['col'][:]

    ## Use the format that ME requires                                                                                                                                                                         
	## Note that we can't build the sparse tensor here because ME uses some sort of global indexing                                                                                                            
	## And this function is replicated * num_workers                                                                                                                                                           
    this_sparse = coo_matrix((data, (row, col)), dtype=np.float32, shape=(800, 256))    
    this_image = this_sparse.toarray()

    gr = plt.imshow(this_image, origin='lower')
    plt.colorbar(gr)
    plt.show()

In [None]:
import torchvision.transforms.v2 as transforms

# Force reload so I can play with changes outside jupyter...
import importlib
import core.data.augmentations_2d
importlib.reload(core.data.augmentations_2d)
import core.data.augmentations_2d as aug
import datasets.nularbox.augmentations_2d
importlib.reload(datasets.nularbox.augmentations_2d)
import datasets.nularbox.augmentations_2d as nu_aug

x_max=256
y_max=256

x_orig=512
y_orig=512

aug_prob=1

mod_transform = transforms.Compose([
    	    aug.RandomBlockZeroImproved([5,20], [5,10], [0,x_orig], [0,y_orig], p=aug_prob),
            aug.RandomBlockZeroImproved([50,200], [1,3], [0,x_orig], [0,y_orig], p=aug_prob),
            aug.RandomVerticalFlip(y_max=y_orig, p=0.5),
            aug.GridJitter(),
            aug.JitterCoords(),
            nu_aug.RandomCentralRotation2D(30, img_size=[y_orig, x_orig], frac=0.2, p=aug_prob),
            nu_aug.RandomCentralShear2D(0.2, 0.2, img_size=[y_orig, x_orig], frac=0.4, p=aug_prob),
            nu_aug.RandomCentralStretch2D(0.1, 0.1, img_size=[y_orig, x_orig], frac=0.4, p=aug_prob),
    	    aug.RandomGridDistortion2D(50, 4, 2, 10, p=aug_prob),
    	    aug.RandomScaleCharge(0.05, p=aug_prob),
        	aug.RandomJitterCharge(0.05, p=aug_prob),
    	    aug.BilinearSplatMod(0.2, 0.3, p=aug_prob),            
            nu_aug.RandomCenterCrop([y_orig,x_orig], [y_max,x_max], 10)
            ])

aug_transform = mod_transform

## Load some images into a data loader
sim_dir = "/pscratch/sd/c/cwilk/NULARBOX/GENIE10a_512x512"

nom_transform = transforms.Compose([
            aug.DoNothing(),
            ])

sim_dataset = paired_2d_dataset_ME(sim_dir, nom_transform=nom_transform, aug_transform=aug_transform, max_events=100000)
print("Found", sim_dataset.__len__(), "events")

In [None]:
import numpy as np
import time

def profile_transforms_avg(dataset, transform_block, n_events=10):
    times_accum = {t.__class__.__name__: [] for t in transform_block.transforms}

    for idx in range(n_events):
        _, _, _, _, coords, feats = dataset[idx]
        for t in transform_block.transforms:
            start = time.perf_counter()
            coords, feats = t(coords, feats)
            end = time.perf_counter()
            times_accum[t.__class__.__name__].append(end - start)

    times_avg = {k: np.mean(v) for k, v in times_accum.items()}
    return times_avg

In [None]:
times = profile_transforms_avg(sim_dataset, mod_transform, n_events=1000)
times_sorted = sorted(times.items(), key=lambda x: x[1], reverse=True)

for name, t in times_sorted:
    print(f"{name:30s}: {t*1000:.3f} ms") 

In [None]:
## Visualise data
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable

def make_aug_comp_plot(dataset, ids=[0], save_name=None):

    cmap = cm.turbo.copy()
    cmap.set_under("#F0F0F0")

    # Visualize the image
    n_augs = 5
    n_ids = len(ids)
    indices = np.arange(n_augs*n_ids)

    ## Set up the canvas
    plt.figure(figsize=(n_augs*2.1, n_ids*6))

    ## To keep track of the subplot
    running_n = 0
    
    ## Loop over images
    for i in ids:
    
        ## The dataset works with pairs, so this is just a bit hacky to get more examples
        aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = dataset[i]
        aug3_bcoords, aug3_bfeats, aug4_bcoords, aug4_bfeats, _, _ = dataset[i]

        ## Keep the same scale for them all
        #orig_max = max(orig_bfeats)
        
        augs = [make_dense_array(orig_bcoords, orig_bfeats.squeeze(), y_orig, x_orig),
                make_dense_array(aug1_bcoords, aug1_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug2_bcoords, aug2_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug3_bcoords, aug3_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug4_bcoords, aug4_bfeats.squeeze(), y_max, x_max)]

        orig_max = augs[0].max()
        orig = augs[0]
        #print("orig_min =", np.min(orig[orig != 0]))
        
        ## Loop over augmentations
        for aug in augs:
            ax = plt.subplot(n_ids,n_augs,running_n+1)
            # mean_val = np.mean(aug[np.isfinite(aug)]) 
            mean_val = np.mean(aug[aug != 0])
            #print(running_n, mean_val)
            nonzero_vals = aug[aug > 0]
            vmax = np.percentile(nonzero_vals, 80)
            ax.imshow(aug, origin='lower', cmap=cmap, vmin=1e-6, vmax=vmax) #, vmax=orig_max)
            ax.axis('off')
            running_n += 1
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
for evt in range(50):
    make_aug_comp_plot(sim_dataset, ids=[evt])