In [None]:
from torch.utils.tensorboard import SummaryWriter
import MinkowskiEngine as ME
from torch import nn
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

## Tell pytorch we have a GPU if we do
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)

SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)
writer = SummaryWriter("log")

In [None]:
## Use the common dataset loader
from ME_dataset_libs import SingleModuleImage2D_MultiHDF5_ME, triple_ME_collate_fn
from ME_dataset_libs import make_dense, make_dense_from_tensor

In [None]:
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
import random

## This is a transformation for the nominal image
class CenterCrop:
    def __init__(self):
        self.orig_y = 280
        self.orig_x = 140
        self.new_y = 256
        self.new_x = 128
        self.pad_y = (self.orig_y - self.new_y)/2
        self.pad_x = (self.orig_x - self.new_x)/2
        
    def __call__(self, coords, feats):
        
        coords = coords - np.array([self.pad_y, self.pad_x])
        mask = (coords[:,0] > 0) & (coords[:,0] < (self.new_y)) \
             & (coords[:,1] > 0) & (coords[:,1] < (self.new_x))
        
        return coords[mask], feats[mask]

    
## This just takes a 256x128 subimage from the original 280x140 block
class RandomCrop:
    def __init__(self):
        self.orig_y = 280
        self.orig_x = 140
        self.new_y = 256
        self.new_x = 128       

    def __call__(self, coords, feats):
        ## Need to copy the array
        new_coords = coords.copy()
        new_feats = feats.copy()
        
        shift_y = random.randint(0, self.orig_y - self.new_y)
        shift_x = random.randint(0, self.orig_x - self.new_x)
        
        new_coords = new_coords - np.array([shift_x, shift_y])
        mask = (new_coords[:,0] > 0) & (new_coords[:,0] < (self.new_y)) \
             & (new_coords[:,1] > 0) & (new_coords[:,1] < (self.new_x))
        
        # print(new_coords[mask])
        
        return new_coords[mask], new_feats[mask]
    
    
class RandomHorizontalFlip:
    def __init__(self, p=0.5):
        self.p = p
        self.ncols = 128
        
    def __call__(self, coords, feats):
        
        ## Need to copy the array
        new_coords = coords.copy()
        
        if torch.rand(1) < self.p:
            new_coords[:,1] = self.ncols - 1 - new_coords[:,1]
        return new_coords,feats
    
    
## Need to define a fairly standard functions that work for ME tensors
class RandomRotation2D:
    def __init__(self, min_angle, max_angle):
        self.min_angle = min_angle
        self.max_angle = max_angle

    def _M(self, theta):
        """Generate a 2D rotation matrix for a given angle theta."""
        return np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta),  np.cos(theta)]
        ])

    def __call__(self, coords, feats):
        """Apply a random rotation to 2D coordinates and return the rotated coordinates with features."""
        # Generate a random rotation angle
        angle = np.deg2rad(torch.FloatTensor(1).uniform_(self.min_angle, self.max_angle).item())

        # Get the 2D rotation matrix
        R = self._M(angle)
        # Apply the rotation
        rotated_coords = coords @ R
        return rotated_coords, feats
    
class RandomShear2D:
    def __init__(self, max_shear_x, max_shear_y):
        self.max_shear_x = max_shear_x
        self.max_shear_y = max_shear_y

    def __call__(self, coords, feats):
        """Apply a random rotation to 2D coordinates and return the rotated coordinates with features."""
        # Generate a random rotation angle
        shear_x = np.random.uniform(-self.max_shear_x, self.max_shear_x)
        shear_y = np.random.uniform(-self.max_shear_y, self.max_shear_y)

        shear_matrix = np.array([
            [1, shear_x],
            [shear_y, 1]
        ])
        
        rotated_coords = coords @ shear_matrix
        return rotated_coords, feats
    
    
## A function to randomly remove some number of blocks of size
## This has to be called before the cropping as it uses the original image size
class RandomBlockZero:
    def __init__(self, max_blocks=4, block_size=6):
        self.max_blocks = max_blocks
        self.block_size = block_size
        self.xmax = 140
        self.ymax = 280

    def __call__(self,  coords, feats):

        combined_mask = np.full(feats.size, True, dtype=bool)
        
        num_blocks_removed = random.randint(0, self.max_blocks)
        for _ in range(num_blocks_removed):
            this_size = self.block_size
            block_x = random.randint(0, self.xmax - this_size - 1)
            block_y = random.randint(0, self.ymax - this_size - 1)
            
            mask = ~((coords[:,0] > block_y) & (coords[:,0] < (block_y+this_size)) \
                   & (coords[:,1] > block_x) & (coords[:,1] < (block_x+this_size)))
            combined_mask = np.logical_and(combined_mask, mask)
            
        ## Need to copy the array
        new_coords = coords.copy()
        new_feats = feats.copy()
        
        return new_coords[combined_mask], new_feats[combined_mask]

## Apply a Gaussian jitter to all values
class RandomJitterCharge:
    def __init__(self, width=0.1):
        self.width = width

    def __call__(self,  coords, feats):
        scale_factors = np.random.normal(loc=1.0, scale=self.width, size=feats.shape)
        new_feats = feats*scale_factors
        return coords, new_feats
    
# Scale the entire feature vector by a single scaling factor
class RandomScaleCharge:
    def __init__(self, width=0.1):
        self.width = width

    def __call__(self,  coords, feats):
        scale_factor = np.random.normal(loc=1.0, scale=self.width)
        # scale_factor = torch.normal(mean=1.0, std=self.width, size=(1,))
        new_feats = feats*scale_factor
        return coords, new_feats
    
    
from scipy.ndimage import gaussian_filter, map_coordinates

class RandomElasticDistortion2D:
    def __init__(self, alpha_range, sigma):
        self.alpha_range = alpha_range
        self.sigma = sigma
        self.height = 280
        self.width = 140
        
    def __call__(self, coords, feats):
        """       
       # Arguments
       image: Numpy array with shape (height, width, channels). 
       alpha_range: Float for fixed value or [lower, upper] for random value from uniform distribution.
           Controls intensity of deformation.
       sigma: Float, sigma of gaussian filter that smooths the displacement fields.
       random_state: `numpy.random.RandomState` object for generating displacement fields.
        """
    
        alpha = np.random.uniform(low=self.alpha_range[0], high=self.alpha_range[1])

        # Create random displacement fields
        displacement_shape = (self.height, self.width)
        dx = gaussian_filter((np.random.rand(*displacement_shape) * 2 - 1), self.sigma) * alpha
        dy = gaussian_filter((np.random.rand(*displacement_shape) * 2 - 1), self.sigma) * alpha

        # Normalize coords to the grid size
        norm_x = coords[:, 0] / self.width * (displacement_shape[1] - 1)
        norm_y = coords[:, 1] / self.height * (displacement_shape[0] - 1)

        # Interpolate displacement fields at coordinate positions
        distorted_x = norm_x + map_coordinates(dx, [norm_y, norm_x], order=1, mode='reflect')
        distorted_y = norm_y + map_coordinates(dy, [norm_y, norm_x], order=1, mode='reflect')

        # Denormalize back to original coordinate scale
        new_coords = np.stack((distorted_x * self.width / (displacement_shape[1] - 1),
                            distorted_y * self.height / (displacement_shape[0] - 1)), axis=-1)
        return new_coords, feats

    
## Apply distortions in a regular grid, with random strength at each point up to some maximum, smoothed by some amount
class RandomGridDistortion2D:
    def __init__(self, grid_size, distortion_strength, sigma=1):
        """
        Initializes the GridDistortion2D transformation.

        :param grid_size: Size of the grid (number of grid points along each axis).
        :param distortion_strength: Maximum displacement for the grid points.
        """
        self.grid_size = grid_size
        self.distortion_strength = distortion_strength
        self.smoothing_sigma = sigma
        self.height = 280
        self.width = 140

    def __call__(self, coords, feats):

        # Create a grid of points
        grid_x, grid_y = np.meshgrid(
            np.linspace(0, self.width, self.grid_size),
            np.linspace(0, self.height, self.grid_size)
        )
                     
        # Create random displacements for the grid points
        displacement_x = np.random.uniform(-self.distortion_strength, self.distortion_strength, grid_x.shape)
        displacement_y = np.random.uniform(-self.distortion_strength, self.distortion_strength, grid_y.shape)
        
        # Apply the displacements to the grid points
        distorted_map_x = (grid_x + displacement_x)/ self.width * (self.grid_size - 1)
        distorted_map_y = (grid_y + displacement_y)/ self.height * (self.grid_size - 1)
        
        #displacement_x = gaussian_filter(displacement_x, sigma=self.smoothing_sigma)
        #displacement_y = gaussian_filter(displacement_y, sigma=self.smoothing_sigma)
        
        # Normalize coords to the grid size
        norm_y = coords[:, 0] / self.height * (self.grid_size - 1)
        norm_x = coords[:, 1] / self.width * (self.grid_size - 1)
                
        # Interpolate the distorted coordinates using map_coordinates
        distorted_x = map_coordinates(distorted_map_x, [norm_y, norm_x], order=1, mode='reflect')
        distorted_y = map_coordinates(distorted_map_y, [norm_y, norm_x], order=1, mode='reflect')
        
        # Combine distorted coordinates
        distorted_coords = np.stack((distorted_y * self.height / (self.grid_size - 1),
                                    distorted_x * self.width / (self.grid_size - 1)), axis=-1)
        
        return distorted_coords, feats
    
class BilinearInterpolation:
    def __init__(self, threshold):
        self.height=280
        self.width=140
        self.threshold=threshold
        
    def __call__(self, coords, feats):
        """
        Apply bilinear interpolation to sparse image data represented by coordinates and features.
    
        Arguments:
        coords: Numpy array of shape (N, 2), where each row is (x, y) coordinate.
        feats: Numpy array of shape (N,), containing feature values for each coordinate.
        height: Integer, maximum height of the output grid.
        width: Integer, maximum width of the output grid.

        Returns:
        interpolated_coords: Numpy array of shape (M, 2), with interpolated integer coordinates.
        interpolated_feats: Numpy array of shape (M,), containing interpolated feature values.
        """
        
        feats = np.squeeze(feats)  # Remove single-dimensional entries from shape
        
        # Floor and ceil coordinates for each point
        x0, y0 = np.floor(coords[:, 0]).astype(int), np.floor(coords[:, 1]).astype(int)
        x1, y1 = np.ceil(coords[:, 0]).astype(int), np.ceil(coords[:, 1]).astype(int)
    
        # Calculate the weights for bilinear interpolation
        wx1 = coords[:, 0] - x0
        wx0 = 1 - wx1
        wy1 = coords[:, 1] - y0
        wy0 = 1 - wy1
        
        # Coordinates for the four corners
        coords00 = np.stack([x0, y0], axis=-1)
        coords01 = np.stack([x0, y1], axis=-1)
        coords10 = np.stack([x1, y0], axis=-1)
        coords11 = np.stack([x1, y1], axis=-1)
    
        # Calculate interpolated feature values for each of the four corners
        f00 = feats * (wx0 * wy0)
        f01 = feats * (wx0 * wy1)
        f10 = feats * (wx1 * wy0)
        f11 = feats * (wx1 * wy1)
    
        # Combine coordinates and features
        coords_combined = np.vstack([coords00,coords01,coords10,coords11])
        features_combined = np.concatenate([f00, f01, f10, f11])
            
        # Round coordinates to nearest integers and clip them
        coords_combined = np.round(coords_combined).astype(int)

        ## This seems to clip values to exactly the limits... I don't want that, I need to crop values outside the range...

        mask = (coords_combined[:,0] > 0) \
             & (coords_combined[:,0] < (self.height-1)) \
             & (coords_combined[:,1] > 0) \
             & (coords_combined[:,1] < (self.width-1))
        coords_combined = coords_combined[mask]
        features_combined = features_combined[mask]
        # np.clip(coords_combined, [0, 0], [self.height-1, self.width-1])
        
        print(coords_combined)

    
        # Consolidate features at unique coordinates
        unique_coords, indices = np.unique(coords_combined, axis=0, return_inverse=True)
        summed_feats = np.zeros(len(unique_coords))    
        np.add.at(summed_feats, indices, features_combined)

        # Create a mask for values above the threshold
        mask = summed_feats >= self.threshold
    
        # Apply the mask to filter features and coordinates
        unique_coords = unique_coords[mask]
        summed_feats = summed_feats[mask]
        
        # Reshape summed_feats to (N, 1)
        summed_feats = summed_feats.reshape(-1, 1)
        
        return unique_coords, summed_feats


In [None]:
import time

aug_transform = transforms.Compose([
    # RandomElasticDistortion2D([0,20],5),
    RandomGridDistortion2D(5,5,1),
    RandomShear2D(0.1, 0.1),
    RandomHorizontalFlip(),
    RandomRotation2D(-10,10),
    RandomBlockZero(5, 6),
    # BilinearInterpolation(0.05),
    RandomScaleCharge(0.05),
    RandomJitterCharge(0.05),
    CenterCrop()
])

#aug_transform2 = transforms.Compose([
#    RandomGridDistortion2D(5,5,1),
#    CenterCrop()
#])


## Get a concrete dataset and data loader
inDir = "/pscratch/sd/c/cwilk/h5_inputs/"
start = time.process_time()
train_dataset = SingleModuleImage2D_MultiHDF5_ME(inDir, nom_transform=CenterCrop(), aug_transform=aug_transform, max_events=10)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

In [None]:
## Randomly chosen batching
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=triple_ME_collate_fn,
                                           batch_size=1,
                                           shuffle=True, 
                                           num_workers=1,
                                           drop_last=True,
                                           pin_memory=False,
                                           prefetch_factor=1)

In [None]:
## Visualise data
# Access a specific instance
aug1_bcoords, aug1_bfeats, aug2_bcoords, aug2_bfeats, orig_bcoords, orig_bfeats = next(iter(train_loader))

# Visualize the image
plt.figure(figsize=(15,5))
ax = plt.subplot(1,3,1)

gr1 = plt.imshow(make_dense(aug1_bcoords, aug1_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr1)
ax = plt.subplot(1,3,2)
gr2 = plt.imshow(make_dense(aug2_bcoords, aug2_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr2)
ax = plt.subplot(1,3,3)
gr3 = plt.imshow(make_dense(orig_bcoords, orig_bfeats, 'cpu'), origin='lower')
plt.colorbar(gr3)
