In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Use the common dataset loader
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor, make_dense_array

In [None]:
import h5py
from scipy.sparse import coo_matrix

## This function just pulls an image directly from the file, without going through a pytorch dataloder
## You would need to have a file open: f = h5py.File(input_file, 'r')
def show_image(i, f):
    group = f[str(i)]
    data = group['data'][:]
    row = group['row'][:]
    col = group['col'][:]

    ## Use the format that ME requires                                                                                                                                                                         
	## Note that we can't build the sparse tensor here because ME uses some sort of global indexing                                                                                                            
	## And this function is replicated * num_workers                                                                                                                                                           
    this_sparse = coo_matrix((data, (row, col)), dtype=np.float32, shape=(800, 256))    
    this_image = this_sparse.toarray()

    gr = plt.imshow(this_image, origin='lower')
    plt.colorbar(gr)
    plt.show()

In [None]:
import torchvision.transforms.v2 as transforms

## Need to test a modified splat function in which the threshold is varied randomly in some range
## + the probability that a threshold is applied at all should be variable.
class BilinearSplatModSpeedy:
    def __init__(self, threshold_min=0.04, threshold_max=0.04, p=0.5):
        self.threshold_min=threshold_min
        self.threshold_max=threshold_max
        self.p = p
        
    def __call__(self, coords, feats):
        
        feats = np.squeeze(feats)  # Remove single-dimensional entries from shape
        
        # Floor and ceil coordinates for each point
        x0, y0 = np.floor(coords[:, 1]).astype(int), np.floor(coords[:, 0]).astype(int)
        x1, y1 = x0 + 1, y0 + 1
    
        # Calculate the weights for bilinear interpolation
        wx1 = coords[:, 1] - x0
        wx0 = 1 - wx1
        wy1 = coords[:, 0] - y0
        wy0 = 1 - wy1

        #N = coords.shape[0]
        #coords_combined = np.empty((4*N, 2), dtype=int)
        #coords_combined[0*N:1*N] = np.stack([y0, x0], axis=-1)
        #coords_combined[1*N:2*N] = np.stack([y0, x1], axis=-1)
        #coords_combined[2*N:3*N] = np.stack([y1, x0], axis=-1)
        #coords_combined[3*N:4*N] = np.stack([y1, x1], axis=-1)

        # Coordinates for the four corners
        coords00 = np.stack([y0, x0], axis=-1)
        coords10 = np.stack([y1, x0], axis=-1)
        coords01 = np.stack([y0, x1], axis=-1)
        coords11 = np.stack([y1, x1], axis=-1)
        
        # Calculate interpolated feature values for each of the four corners
        f00 = feats * (wx0 * wy0)
        f10 = feats * (wx0 * wy1)
        f01 = feats * (wx1 * wy0)
        f11 = feats * (wx1 * wy1)

        #weights = np.stack([wx0*wy0, wx1*wy0, wx0*wy1, wx1*wy1], axis=1)  # (N,4)
        #features_combined = (feats[:, None] * weights).reshape(-1)
        
        # Combine coordinates and features
        coords_combined = np.vstack([coords00,coords01,coords10,coords11])
        features_combined = np.concatenate([f00, f01, f10, f11])
    
        # Consolidate features at unique coordinates
        #unique_coords, indices = np.unique(coords_combined, axis=0, return_inverse=True)
        #summed_feats = np.zeros(len(unique_coords))    
        #np.add.at(summed_feats, indices, features_combined)

        W = 10000
        hash_vals = coords_combined[:,0] * W + coords_combined[:,1]
        unique_hashes, inverse = np.unique(hash_vals, return_inverse=True)
        summed_feats = np.zeros(len(unique_hashes), dtype=features_combined.dtype)
        np.add.at(summed_feats, inverse, features_combined)
        unique_coords = np.stack([unique_hashes // W, unique_hashes % W], axis=-1)
        
        ## Get the threshold
        threshold = np.random.uniform(self.threshold_min, self.threshold_max)
        
        if np.random.rand() < self.p:
            # Create a mask for values above the threshold
            mask = summed_feats >= threshold
    
            # Apply the mask to filter features and coordinates
            unique_coords = unique_coords[mask]
            summed_feats = summed_feats[mask]        
        
        # Reshape summed_feats to (N, 1)
        summed_feats = summed_feats.reshape(-1, 1)
        
        return unique_coords, summed_feats

In [None]:
import torchvision.transforms.v2 as transforms
from ME_dataset_libs import get_transform, DoNothing, FirstRegionCrop

# Force reload so I can play with changes outside jupyter...
import importlib
import ME_dataset_libs
importlib.reload(ME_dataset_libs)
from ME_dataset_libs import MaxRegionCrop, RandomGridDistortion2D, RandomShear2D, RandomRotation2D, RandomHorizontalFlip, \
    RandomBlockZeroImproved, RandomScaleCharge, RandomJitterCharge, ConstantCharge, DoNothing, SemiRandomCrop, ConstantCharge, \
    RandomPixelNoise2D, BilinearSplat, RandomStretch2D, RandomVerticalFlip, RandomInPlaceHorizontalFlip, RandomInPlaceVerticalFlip, \
    SimpleCrop, RandomDropout, JitterCoords, UnlogCharge, RelogCharge, GridJitter, SplitJitterCoords, BilinearSplatMod

## 768 is chosen to fit with the current encoder architecture
x_max=256
y_max=768

x_orig=256
y_orig=800

x_max = x_orig
y_max = y_orig

mod_transform = transforms.Compose([
    	    RandomBlockZeroImproved([50,100], [5,10], [0,x_orig], [0,y_orig]), ## Does very little
            RandomBlockZeroImproved([500,2000], [1,3], [0,x_orig], [0,y_orig]), ## Looks good
        	RandomInPlaceHorizontalFlip(), ## Good
            RandomInPlaceVerticalFlip(), ## Good
    	    RandomHorizontalFlip(x_max=x_orig), ## Good
            RandomVerticalFlip(y_max=y_orig), ## Good
            RandomPixelNoise2D(10),
            UnlogCharge(),
            GridJitter(),
            SplitJitterCoords(10),
            RandomShear2D(0.1, 0.1), ## Good
            RandomRotation2D(6), ## Good
            RandomStretch2D(0.1, 0.1),
    	    RandomGridDistortion2D(100, 5, 2, 25), ## Good
    	    RandomScaleCharge(0.05),
        	RandomJitterCharge(0.05),
    	    BilinearSplatMod(0.3, 0.5, 0.5),
            RelogCharge(),
       	    RandomScaleCharge(0.02),
        	RandomJitterCharge(0.02),
            SemiRandomCrop(x_max, y_max, 20),
            ])

aug_transform = mod_transform #
# aug_transform = get_transform('fsd', "vbigaugbilinfixnostretch")
## Load some images into a data loader
sim_dir = "/pscratch/sd/c/cwilk/FSD/SIMULATIONv2"
data_dir = "/pscratch/sd/c/cwilk/FSD/DATA"
nom_transform = transforms.Compose([
            FirstRegionCrop((800, 256), (768, 256)),
            # ConstantCharge(),
            ])

sim_dataset = SingleModuleImage2D_MultiHDF5_ME(sim_dir, nom_transform=nom_transform, aug_transform=aug_transform, max_events=100000)
data_dataset = SingleModuleImage2D_MultiHDF5_ME(data_dir, nom_transform=nom_transform, aug_transform=aug_transform, max_events=100000)
print("Found", data_dataset.__len__(), "data events")

In [None]:
def set_ops(A, B):
    both = np.array([x for x in A if any((B == x).all(1))])
    only_A = np.array([x for x in A if not any((B == x).all(1))])
    only_B = np.array([x for x in B if not any((A == x).all(1))])

    return both, only_A, only_B

In [None]:
import numpy as np
import time

def profile_transforms_avg(dataset, transform_block, n_events=10):
    times_accum = {t.__class__.__name__: [] for t in transform_block.transforms}

    for idx in range(n_events):
        _, _, _, _, coords, feats = dataset[idx]
        for t in transform_block.transforms:
            start = time.perf_counter()
            coords, feats = t(coords, feats)
            end = time.perf_counter()
            times_accum[t.__class__.__name__].append(end - start)

    times_avg = {k: np.mean(v) for k, v in times_accum.items()}
    return times_avg

In [None]:
## BilinearSplat starts at 8.911 ms averaged over 1000 events
times = profile_transforms_avg(data_dataset, mod_transform, n_events=1000)
times_sorted = sorted(times.items(), key=lambda x: x[1], reverse=True)

for name, t in times_sorted:
    print(f"{name:30s}: {t*1000:.3f} ms") 

In [None]:
def simple_test(dataset, n=0):

    aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = dataset[n]

    print("AUG1 #1:", aug1_bcoords[0], aug1_bfeats[0])
    print("AUG2 #1:", aug2_bcoords[0], aug2_bfeats[0])
    print("ORIG #1:", orig_bcoords[0], orig_bfeats[0])

    print("AUG1 SHAPE:", aug1_bcoords.shape, aug1_bfeats.shape)
    print("AUG2 SHAPE:", aug2_bcoords.shape, aug2_bfeats.shape)    
    print("ORIG SHAPE:", orig_bcoords.shape, orig_bfeats.shape)

    ## Bilinear splat should really have a threshod of ~0.5 to be close to the data distribution...
    # print(orig_bfeats.min(), orig_bfeats.max())
    both, only_A, only_B = set_ops(orig_bcoords, aug1_bcoords)
    print(len(both), len(only_A), len(only_B))
    

In [None]:
simple_test(sim_dataset, 69861)

In [None]:
## Visualise data
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable

def make_aug_comp_plot(dataset, ids=[0], save_name=None):

    cmap = cm.turbo.copy()
    cmap.set_under("#F0F0F0")

    # Visualize the image
    n_augs = 5
    n_ids = len(ids)
    indices = np.arange(n_augs*n_ids)

    ## Set up the canvas
    plt.figure(figsize=(n_augs*2.1, n_ids*6))

    ## To keep track of the subplot
    running_n = 0
    
    ## Loop over images
    for i in ids:
    
        ## The dataset works with pairs, so this is just a bit hacky to get more examples
        aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = dataset[i]
        aug3_bcoords, aug3_bfeats, aug4_bcoords, aug4_bfeats, _, _ = dataset[i]

        ## Keep the same scale for them all
        #orig_max = max(orig_bfeats)
        
        augs = [make_dense_array(orig_bcoords, orig_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug1_bcoords, aug1_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug2_bcoords, aug2_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug3_bcoords, aug3_bfeats.squeeze(), y_max, x_max),
                make_dense_array(aug4_bcoords, aug4_bfeats.squeeze(), y_max, x_max)]

        orig_max = augs[0].max()
        orig = augs[0]
        print("orig_min =", np.min(orig[orig != 0]))

        
        ## Loop over augmentations
        for aug in augs:
            ax = plt.subplot(n_ids,n_augs,running_n+1)
            # mean_val = np.mean(aug[np.isfinite(aug)]) 
            mean_val = np.mean(aug[aug != 0])
            print(running_n, mean_val)
            nonzero_vals = aug[aug > 0]
            vmax = np.percentile(nonzero_vals, 80)
            ax.imshow(aug, origin='lower', cmap=cmap, vmin=1e-6, vmax=vmax) #, vmax=orig_max)
            ax.axis('off')
            running_n += 1
    plt.tight_layout()
    if save_name: plt.savefig(save_name, dpi=300, bbox_inches='tight')
    plt.show()
    # plt.close()

In [None]:
import numpy as np
def find_biggest_events(dataset, ret_num=1):

    size_list = []
    
    max_event = len(dataset)
    for n in range(max_event):
        if n%1000==0:
            print("Processed", n, "/", max_event)
        _, _, _, _, _, orig_bfeats = dataset[n]
        size_list.append(orig_bfeats.shape[0])

    indices = np.argsort(size_list)[::-1]  # reverse for descending order
    return indices[:ret_num]

In [None]:
big_list = find_biggest_events(data_dataset, 50)

In [None]:
print(big_list)

In [None]:
## Visualise data
def make_aug_diff_plot(dataset, images=[0]):

    ## The dataset works with pairs, so this is just a bit hacky to get more examples
    aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = dataset[n]
    aug3_bcoords, aug3_bfeats, aug4_bcoords, aug4_bfeats, _, _ = dataset[n]

    nom_dense  = make_dense_array(orig_bcoords, orig_bfeats.squeeze(), 800, 256)
    aug1_dense = make_dense_array(aug1_bcoords, aug1_bfeats.squeeze(), 800, 256)
    aug2_dense = make_dense_array(aug2_bcoords, aug2_bfeats.squeeze(), 800, 256)
    aug3_dense = make_dense_array(aug3_bcoords, aug3_bfeats.squeeze(), 800, 256)
    aug4_dense = make_dense_array(aug4_bcoords, aug4_bfeats.squeeze(), 800, 256)

    diff1 = aug1_dense - nom_dense
    diff2 = aug2_dense - nom_dense
    diff3 = aug3_dense - nom_dense
    diff4 = aug4_dense - nom_dense

    augs = [nom_dense, diff1, diff2, diff3, diff4]
    
    vmax = max(np.max(np.abs(img)) for img in images)
    vmin = -vmax  # symmetric around zero
    vmin = -1
    vmax = 1

    ntotal = len(augs)
    nimages = len(images)
    plt.figure(figsize=(ntotal*1.8,6*nimages))
    for i, aug in enumerate(augs, 1):
        ax = plt.subplot(nimages, naugs, i)
        im = ax.imshow(aug, origin='lower', cmap='seismic', vmin=vmin, vmax=vmax)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        #ax.axis('off')
        #plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)  # adjust fraction/pad as needed

    plt.tight_layout()
    plt.show()

In [None]:
## Dump a few events!
## Interesting numbers: 3, 28, 50, 75, 77, 95, 179, 237, 239, 272, 303, 347, 69861, 73664, 16498
#for n in range(400, 450): 
test_list = [16498, 179, 69861]
# test_list = [3, 28, 50, 75, 77, 95, 179, 237, 239, 272, 303, 347, 69861, 73664, 16498]
make_aug_comp_plot(data_dataset, test_list, save_name="example_augmentations.jpg")

#for n in test_list:
    # print(n)
    # make_aug_comp_plot(data_dataset, n)
    #make_aug_diff_plot(data_dataset, n)
