In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
writer = SummaryWriter("log")

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from torchvision import datasets, transforms

%matplotlib inline
mpl.rcParams['figure.figsize'] = [8, 6]
mpl.rcParams['font.size'] = 16
mpl.rcParams['axes.grid'] = True

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.device(device)
import numpy as np
SEED=12345
_=np.random.seed(SEED)
_=torch.manual_seed(SEED)

In [None]:
from torch.utils.data import Dataset
import os
import time
import numpy as np
import joblib

class SingleModuleImage2D_augpair(Dataset):

    def __init__(self, infilename, transform=None):
        self._data = joblib.load(infilename)
        self._length = len(self._data)
        self._transform = transform

    def __len__(self):
        return self._length
    
    def __getitem__(self,idx):

        ## Convert the raw data to a dense pytorch tensor...
        raw_data = torch.Tensor(self._data[idx].toarray())
        
        ## Apply transforms to augment the data
        if not self._transform:
            img1 = raw_data
            img2 = raw_data
        else:
            img1 = self._transform(raw_data)
            img2 = self._transform(raw_data)
        
        return img1, img2

def collate_pair(batch):
    img1_batch = torch.stack([item[0] for item in batch])
    img2_batch = torch.stack([item[1] for item in batch])
    return img1_batch, img2_batch

In [None]:
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as F
import random

## Need to define a RandomRotation function that works for Tensors
class RandomTensorRotation:
    def __init__(self, min_angle, max_angle):
        self.min_angle = min_angle
        self.max_angle = max_angle

    def __call__(self, img):
        angle = torch.FloatTensor(1).uniform_(self.min_angle, self.max_angle).item()
        return F.rotate(img.unsqueeze(0), angle).squeeze()

## A function to randomly remove some number of blocks of size
class RandomBlockZero:
    def __init__(self, max_blocks=5, block_size=4):
        self.max_blocks = max_blocks
        self.block_size = block_size

    def __call__(self, img):
        # Randomly zero out blocks of 4x4 pixels
        num_blocks_removed = random.randint(0, self.max_blocks)
        for _ in range(num_blocks_removed):
            this_size = self.block_size
            block_x = random.randint(0, img.size(1) // this_size - 1) * this_size
            block_y = random.randint(0, img.size(0) // this_size - 1) * this_size
            img[block_y:block_y+4, block_x:block_x+4] = 0
        return img    

## A function to randomly shift the image by some number of pixels and crop it
class RandomShiftTensor:
    def __init__(self, max_shift=10):
        self.max_shift = max_shift

    def __call__(self, img):
        
        height, width = img.shape

        shift_x = random.randint(-self.max_shift, self.max_shift)
        shift_y = random.randint(-self.max_shift, self.max_shift)

        new_img = torch.zeros_like(img)

        src_x1 = max(0, -shift_x)
        src_y1 = max(0, -shift_y)
        src_x2 = min(width, width - shift_x)
        src_y2 = min(height, height - shift_y)

        tgt_x1 = max(0, shift_x)
        tgt_y1 = max(0, shift_y)
        tgt_x2 = tgt_x1 + (src_x2 - src_x1)
        tgt_y2 = tgt_y1 + (src_y2 - src_y1)

        new_img[tgt_y1:tgt_y2, tgt_x1:tgt_x2] = img[src_y1:src_y2, src_x1:src_x2]

        return new_img

    
aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    RandomBlockZero(),
    RandomTensorRotation(-10, 10),
    RandomShiftTensor()
])

## Get a concrete dataset and data loader
# inFile = "/global/cfs/cdirs/dune/users/cwilk/single_module_images/sparse_joblib_fixdupes_pluscuts_noneg_transform/training_images_200k.joblib"
inFile = "/global/cfs/cdirs/dune/users/cwilk/single_module_images/sparse_joblib_fixdupes_pluscuts_noneg_transform/packet_2022_02_11_11_39_26_CET_0cd913fb_20220211_113926.data.module1_flow_images.joblib"
start = time.process_time() 
train_dataset = SingleModuleImage2D_augpair(inFile, transform=aug_transform)
print("Time taken to load", train_dataset.__len__(),"images:", time.process_time() - start)

In [None]:
## Randomly chosen batching
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           collate_fn=collate_pair,
                                           batch_size=128,
                                           shuffle=True, 
                                           num_workers=4,
                                           drop_last=True)

In [None]:
start = time.process_time()
minVal = 0
maxVal = 0
for img1_batch, img2_batch in train_loader:
    if torch.max(img1_batch) > maxVal:
        maxVal = torch.max(img1_batch)
    if torch.min(img1_batch) < minVal:
        minVal = torch.min(img1_batch)
    if torch.max(img2_batch) > maxVal:
        maxVal = torch.max(img2_batch)
    if torch.min(img2_batch) < minVal:
        minVal = torch.min(img2_batch)
        
print("Time taken to loop:", time.process_time() - start)
print("Found a minimum value of:", minVal)
print("Found a maximum value of:", maxVal)

In [None]:
## Visualise data
# Access a specific instance
img1, img2 = train_dataset[11]
print(img1.size(), img2.size())

# Visualize the image
gr = plt.imshow(img1, origin='lower')
plt.colorbar(gr)
plt.show()
gr = plt.imshow(img2, origin='lower')
plt.colorbar(gr)
plt.show()

In [None]:
def plot_ae_outputs(encoder,decoder,n=10):  
    
    plt.figure(figsize=(12,6))

    encoder.eval()
    decoder.eval()
    ## Loop over figures
    for i in range(n):
        ax = plt.subplot(3,n,i+1)
        ## This is not working the way I expect when shuffle is not on...
        ## It always gives the first image...
        img = next(iter(train_loader))
        with torch.no_grad():
            img = img.to(device)
            # temp=encoder(img)
            rec_img  = decoder(encoder(img))
        this_input  = img[0].cpu().numpy().squeeze()
        this_output = rec_img[0].cpu().numpy().squeeze()
        
        ## Input row
        plt.imshow(this_input, cmap='viridis', origin='lower')            
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Original images')
        
        ## Reconstructed row
        ax = plt.subplot(3, n, i + 1 + n)
        plt.imshow(this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Reconstructed images')
        
        ## In - rec row
        ax = plt.subplot(3, n, i + 1 + 2*n)
        plt.imshow(this_input-this_output, cmap='viridis', origin='lower')  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2: ax.set_title('Input - reco images')
        
    plt.show()   

In [None]:
def plot_distribution_from_dataloader(data_loader, encoder, decoder):
    encoder.eval()
    decoder.eval()
    
    # Initialize empty lists to store histogram counts
    num_bins=50
    input_hist = np.zeros(num_bins, dtype=int)
    output_hist = np.zeros(num_bins, dtype=int)

    # Create logarithmically spaced bins
    # bins = np.logspace(0, np.log10(1000), num=num_bins+1)
    bins = np.linspace(0.1, 3.1, num=num_bins+1)
    with torch.no_grad():
        for image_batch in data_loader:

            image_batch = image_batch.to(device)
            # Encode data
            encoded_batch = encoder(image_batch)
            # Decode data
            decoded_batch = decoder(encoded_batch)
            
            # Flatten input and output tensors to 1D arrays
            # Update input histogram
            input_hist += np.histogram(image_batch.cpu().numpy(), bins=bins)[0]

            # Update output histogram
            output_hist += np.histogram(decoded_batch.cpu().numpy(), bins=bins)[0]

    # Plot distribution of input and output values
    plt.figure(figsize=(10, 5))
    plt.plot(bins[:-1], input_hist, color='blue', label='Input')
    plt.plot(bins[:-1], output_hist, color='orange', label='Output')
    plt.title('Distribution of Input and Output Values')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    # plt.xscale('log')
    plt.legend()
    plt.show()