In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Use the common dataset loader
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor, make_dense_array

In [None]:
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
import random
import time

## Import all of the pre-defined transforms
from ME_dataset_libs import CenterCrop, MaxNonZeroCrop, MaxRegionCrop, \
    RandomCrop, RandomHorizontalFlip, RandomRotation2D, RandomShear2D, \
    RandomBlockZero, RandomBlockZeroImproved, RandomJitterCharge, \
    RandomScaleCharge, RandomElasticDistortion2D, RandomGridDistortion2D, \
    BilinearInterpolation

aug_transform = transforms.Compose([
            RandomGridDistortion2D(5,5),
            RandomShear2D(0.1, 0.1),
            RandomHorizontalFlip(),
            RandomRotation2D(-10,10),
            RandomBlockZeroImproved([0,10], [5,10]),
            RandomScaleCharge(0.02),
            RandomJitterCharge(0.02),
    	    RandomCrop()
            ])


## Get a concrete dataset and data loader
inDir = "/pscratch/sd/c/cwilk/h5_inputs_v9/"
start = time.process_time()
train_dataset = SingleModuleImage2D_MultiHDF5_ME(inDir, nom_transform=MaxRegionCrop(), aug_transform=aug_transform)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

In [None]:
## Visualise data
def make_aug_comp_plot(n=0):

    ## The dataset works with pairs, so this is just a bit hacky to get more examples
    aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = train_dataset[n]
    aug3_bcoords, aug3_bfeats, aug4_bcoords, aug4_bfeats, _, _ = train_dataset[n]

    # Visualize the image
    plt.figure(figsize=(15,5))
    ax = plt.subplot(1,5,1)
    gr1 = plt.imshow(make_dense_array(orig_bcoords, orig_bfeats.squeeze()), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,2)
    gr2 = plt.imshow(make_dense_array(aug1_bcoords, aug1_bfeats.squeeze()), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,3)
    gr3 = plt.imshow(make_dense_array(aug2_bcoords, aug2_bfeats.squeeze()), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,4)
    gr4 = plt.imshow(make_dense_array(aug3_bcoords, aug3_bfeats.squeeze()), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 
    ax = plt.subplot(1,5,5)
    gr5 = plt.imshow(make_dense_array(aug4_bcoords, aug4_bfeats.squeeze()), origin='lower')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False) 


In [None]:
## Dump a few events!
for n in range(20): 
    make_aug_comp_plot(n)