In [None]:
import numpy as np
import pickle as pk
import math
import sys, os
import tqdm
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from glob import glob
from sklearn import preprocessing

# manually specify the GPUs to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
# dataset
with open("data.p", "rb") as fd:
    charges, ini_pos, ini_P, Es, total_charges = pk.load(fd)

# all parameters
min_charge = charges.min()
max_charge = charges.max()
min_tCharge = total_charges.min()
max_tCharge = total_charges.max()
min_pos = ini_pos.min()
max_pos = ini_pos.max()
min_P = ini_P.min()
max_P = ini_P.max()
min_E = Es.min()
max_E = Es.max()
range_pos = max_pos-min_pos
range_P = max_P-min_P
  
source_range = (-1, 1)
target_range = (0, 1)

min_charge_new = 0
max_charge_new = charges.max()

m_proton = 938.27208816 # mass of the proton in MeV

# dataset
class VAdataset(Dataset):
    def __init__(self, root, shuffle=False, **kwargs):
        '''Initialiser for ProtonVertexDataset class'''
        
        self.root = root
        self.data_files = self.processed_file_names
        if shuffle:
            random.shuffle(self.data_files) 
        self.total_events = len(self.data_files)
    
    @property
    def processed_dir(self):
        return f'{self.root}'
    
    @property
    def processed_file_names(self):
        return sorted(glob(f'{self.processed_dir}/*.npz'))
    
    def __len__(self):
        return self.total_events
    
    def __getitem__(self, idx):
        data = np.load(self.data_files[idx])
        
        # retrieve input
        sparse_image = data['sparse_image'] # array of shape (Nx2) [points vs (1d pos, charge)]
        initPosition = data['initPosition'] # array with initial position (x1, y1, z1)
        finalPosition = data['finalPosition'] # array with initial position (xN, yN, zN)
        initMomentum = data['initMomentum'] # array momentum vector (Px, Py, Pz, E)
        
        if sparse_image.shape[0] == 0:
            del data
            return { 'full_image': None,\
                     'initPosition': None,\
                     'finalPosition': None,\
                     'initMomentum': None,\
                     'totalCharge': None}
        
        # reconstruct the image from sparse points to a 7x7x7 volume
        full_image = np.zeros(shape=(7*7*7))
        full_image[sparse_image[:,0].astype(int)] = sparse_image[:,1]
        totalCharge = full_image.sum()
        
        # normalise
        full_image = np.interp(full_image.ravel(), (min_charge_new, max_charge_new), target_range).reshape(full_image.shape)
        initPosition = np.interp(initPosition.ravel(), (min_pos, max_pos), source_range).reshape(initPosition.shape)
        initMomentum[:3] = np.interp(initMomentum[:3].ravel(), (min_P, max_P), source_range).reshape(initMomentum[:3].shape)
        initMomentum[3] = np.interp(initMomentum[3].ravel(), (min_E, max_E), source_range).reshape(initMomentum[3].shape)
        totalCharge = np.interp(totalCharge.ravel(), (min_tCharge, max_tCharge), target_range).reshape(totalCharge.shape)
        
        del data
        return { 'full_image': full_image,\
                 'initPosition': initPosition,\
                 'finalPosition': finalPosition,\
                 'initMomentum': initMomentum,\
                 'totalCharge': totalCharge }

In [None]:
# generate dataset
dataset = VAdataset("images")
ini_pos = ini_pos.reshape(-1)
ini_P = ini_P.reshape(-1)

In [None]:
# function to collate data samples into batch tensors
def collate_fn(batch):     
    img_batch = np.array([event['full_image'] for event in batch if event['full_image'] is not None])
    params_batch = np.array([np.concatenate([event['initPosition'], event['initMomentum']])\
                             for event in batch if event['initPosition'] is not None])
    charges = np.array([event['totalCharge'] for event in batch if event['totalCharge'] is not None])
    
    img_batch = torch.tensor(img_batch).float()
    params_batch = torch.tensor(params_batch).float()
    charges = torch.tensor(charges).float()
    
    return img_batch, params_batch, charges

In [None]:
batch_size = 64

fulllen = len(dataset)

train_len = int(fulllen*0.6)
val_len = int(fulllen*0.1)
test_len = fulllen-train_len-val_len
train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len],
                                            generator=torch.Generator().manual_seed(7))

train_loader = DataLoader(train_set, collate_fn=collate_fn, batch_size=batch_size, num_workers=2, shuffle=True)
valid_loader = DataLoader(val_set, collate_fn=collate_fn, batch_size=batch_size, num_workers=2, shuffle=False)
test_loader = DataLoader(test_set, collate_fn=collate_fn, batch_size=batch_size, num_workers=2, shuffle=False)

In [None]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer

vol_idx = np.moveaxis(np.indices((7, 7, 7)), 0, -1) # indexes volume
vol_idx = np.interp(vol_idx.ravel(), (0, 6), source_range).reshape(vol_idx.shape)
vol_idx = torch.FloatTensor(vol_idx)

from torch.nn import TransformerEncoder, TransformerEncoderLayer

class Generator(nn.Module):
    def __init__(self, label_size, noise_size):
        super(Generator, self).__init__()
        self.label_size = label_size
        self.noise_size = noise_size
        
        # linear projection for the pos + noise + labels
        self.dropout = nn.Dropout(dropout)
        
        self.d_model1 = 32
        self.d_model2 = self.d_model1//2
        self.d_model3 = self.d_model2//2
        
        # linear mappings
        self.project_px = nn.Linear(1, self.d_model1-3)
        self.project_py = nn.Linear(1, self.d_model1-3)
        self.project_pz = nn.Linear(1, self.d_model1-3)
        self.project_e = nn.Linear(1, self.d_model1-3)
        self.project_x = nn.Linear(1, self.d_model1-3)
        self.project_y = nn.Linear(1, self.d_model1-3)
        self.project_z = nn.Linear(1, self.d_model1-3)
        self.map1 = nn.Linear(self.d_model1*2, 7*(self.d_model2-3))
        self.map2 = nn.Linear(self.d_model1*2, 7*(self.d_model2-3))
        
        self.map_enc1 = nn.Linear(self.d_model1, self.d_model1)
        self.map_enc2 = nn.Linear(self.d_model2, self.d_model1)
        self.map_enc3 = nn.Linear(self.d_model2, self.d_model1)
           
        # Tranformer encoders
        encoder_layers1 = TransformerEncoderLayer(d_model=self.d_model1,
                                                 nhead=self.d_model1//8,
                                                 dim_feedforward=self.d_model1*2,
                                                 dropout=dropout,
                                                 batch_first=True # very important
                                                 )
        self.encoder1 = TransformerEncoder(encoder_layers1, 2) # model
        
        encoder_layers2 = TransformerEncoderLayer(d_model=self.d_model1,
                                                 nhead=self.d_model1//8,
                                                 dim_feedforward=self.d_model1*2,
                                                 dropout=dropout,
                                                 batch_first=True # very important
                                                 )
        self.encoder2 = TransformerEncoder(encoder_layers2, 4) # model
        
        encoder_layers3 = TransformerEncoderLayer(d_model=self.d_model1,
                                                 nhead=self.d_model1//8,
                                                 dim_feedforward=self.d_model1*2,
                                                 dropout=dropout,
                                                 batch_first=True # very important
                                                 )
        self.encoder3 = TransformerEncoder(encoder_layers3, 8) # model
        
        # Final decoder
        self.decoder = nn.Sequential(
                                     #nn.LeakyReLU(0.2),
                                     nn.Linear(self.d_model1, 1),
                                     #nn.Tanh()
                                     #nn.ReLU()
                                     nn.Sigmoid()
                                    )
    
    def forward(self, labels, noise):
        
        # Reshape labels
        labels = labels.view(-1, 7, 1)
        
        # project labels
        proj_px = self.project_px(labels[:, 0])
        proj_py = self.project_px(labels[:, 1])
        proj_pz = self.project_px(labels[:, 2])
        proj_e = self.project_px(labels[:, 3])
        proj_x = self.project_px(labels[:, 4])
        proj_y = self.project_px(labels[:, 5])
        proj_z = self.project_px(labels[:, 6])
        
        # stack labels back + pos encoding
        proj_labels = torch.stack([proj_px, proj_py, proj_pz, proj_e,\
                                   proj_x, proj_y, proj_z], dim=1)
        x = proj_labels

        # Transformer 1
        x = self.dropout(x)
        pos = vol_idx[3,3,:,:].repeat(x.shape[0],1,1).view(-1,7,3).to(device)
        x = torch.cat([pos, x], dim=2) # add pos
        x = self.map_enc1(x)
        x = self.encoder1(x)
        
        # add noise
        z = torch.normal(0, 0.5, size=x.shape).to(device)        
        xz = torch.cat([x,z], dim=2)
        x = self.map1(xz).view(-1, 7*7, self.d_model2-3)
        
        # Transformer 2
        x = self.dropout(x)
        pos = vol_idx[3,:,:,:].repeat(x.shape[0],1,1,1).view(-1,7*7,3).to(device)
        x = torch.cat([pos, x], dim=2) # add pos
        x = self.map_enc2(x)
        x = self.encoder2(x)
              
        # add noise
        z = torch.normal(0, 0.5, size=x.shape).to(device)
        xz = torch.cat([x,z], dim=2)        
        x = self.map2(xz).view(-1, 7*7*7, self.d_model2-3)
        
        # Transformer 3
        x = self.dropout(x)
        pos = vol_idx.repeat(x.shape[0],1,1,1,1).view(-1,7*7*7,3).to(device)
        x = torch.cat([pos, x], dim=2) # add pos
        x = self.map_enc3(x)
        x = self.encoder3(x)
        
        # Generator out
        out = self.decoder(x)
    
class PoissonLikelihood_loss(nn.Module):
    def __init__(self):
        '''
        Poisson Likelihood loss function
        Email: jchen245@jhmi.edu
        Date: 02/21/2021
        :param max_val: the maximum value of the target.
        '''
        super(PoissonLikelihood_loss, self).__init__()

    def forward(self, y_pred, y_true):
        y_pred = y_pred.view(y_pred.shape[0], -1)
        y_true = y_true.view(y_true.shape[0], -1)

        """Custom loss function for Poisson model."""
        loss=torch.mean(y_pred-y_true*torch.log(y_pred+eps))
        return loss

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('torch version:',torch.__version__)
print('device:', device)

img_size = 7
label_size = 6
noise_size = 3
lambda_gp = 10
dropout = 0.2
eps = 1e-6
noise_type = "normal"

generator = Generator(label_size, noise_size).to(device)

generator_total_params = sum(p.numel() for p in generator.parameters() if p.requires_grad)
print(generator)
print("total trainable params: {} (generator).".format(generator_total_params))

lossf = PoissonLikelihood_loss()

# Training
epochs = 10  # Train epochs
learning_rate = 0.0002#1e-4
betas = (0.9, 0.98)#(0.5, 0.999)

# optmiisers
#g_optimizer = torch.optim.RMSprop(generator.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
# train Generator with Wasserstein Loss
def generator_train_step(batch_size, generator, g_optimizer, real_images, labels, charges):
    
    # init gradient
    g_optimizer.zero_grad()
    
    # labels
    real_labels = labels.to(device)
    charges = charges.to(device)
    
    # fake
    z = torch.normal(0, 0.5, size=(batch_size, img_size, img_size, img_size, noise_size)).to(device)
    #z = torch.distributions.uniform.Uniform(-1,1).sample([batch_size, img_size, img_size, img_size, noise_size]).to(device)
    fake_data = generator(real_labels, z)
    
    # real
    real_data = real_images.to(device)
    
    # Generator loss
    w_loss = lossf(fake_data, real_data)
    w_loss.backward()
    g_optimizer.step()
    
    return w_loss

def generator_test_step(batch_size, generator, real_images, labels):
    
    generator.eval()
    
    # init gradient
    g_optimizer.zero_grad()
    
    # labels
    real_labels = labels.to(device)
    
    # fake
    z = torch.normal(0, 0.5, size=(batch_size, img_size, img_size, img_size, noise_size)).to(device)
    fake_data = generator(real_labels, z)
    
    # real
    real_data = real_images.to(device)
    
    # Generator loss
    w_loss = lossf(fake_data, real_data)
    
    return w_loss

In [None]:
def plot_event(X, Y, labels, elev=20, azim=20, add_projs=False):
    
    nevents = len(X)
        
    # start plot
    fig = plt.figure(figsize=(nevents*18, nevents*17))
    
    fig.patch.set_facecolor('white')

    color = np.array(['#7A88CCC0'])
    edgecolor = '1.0'
        
    for event in range(nevents):
        x = X[event].reshape(img_size,img_size,img_size)
        y = Y[event].reshape(img_size,img_size,img_size)
        x[x<min_charge_plot] = 0
        y[y<min_charge_plot] = 0

        # fill the detector with reco hits and shrink
        detector1 = np.zeros((7, 7, 7), dtype=bool)
        detector_hitcharges1 = x
        colors1 = np.empty((7,7,7,4), dtype=object)

        detector2 = np.zeros((7, 7, 7), dtype=bool)
        detector_hitcharges2 = y
        colors2 = np.empty((7,7,7,4), dtype=object)

        # set colors based on hit charge
        norm = matplotlib.colors.Normalize(vmin=min_charge, vmax=max_charge)
        cmap = cm.YlOrRd
        m = cm.ScalarMappable(norm=norm, cmap=cmap)
        for i in range(detector1.shape[0]):
            for j in range(detector1.shape[1]):
                for k in range(detector1.shape[2]):
                    colors1[i,j,k] = m.to_rgba(detector_hitcharges1[i,j,k])
                    colors2[i,j,k] = m.to_rgba(detector_hitcharges2[i,j,k])
    
        to_plot = [x, y]
        for i in range(len(to_plot)):
            ax = fig.add_subplot(nevents, nevents, i*nevents+ event+1, projection='3d')

            # voxels volume
            sc = ax.voxels(to_plot[i], facecolors=colors1, edgecolor=edgecolor, alpha=1.0)

            #ax.tick_params(axis='both', which='minor', labelsize=20, length=0)
            ax.set_xlabel('X [cube]', labelpad=50, fontsize=45)
            ax.set_ylabel('Z [cube]', labelpad=50, fontsize=45)
            ax.set_zlabel('Y [cube]', labelpad=50, fontsize=45)
            
            ini_KE = np.interp(labels[event,-1], source_range, (min_E, max_E)) - m_proton

            # ticks
            ax.set_title("Total charge: {0:.2f} p.e.\n[KE of {1:.2f} MeV]".format(to_plot[i].sum(),\
                                                                                 ini_KE), fontsize=50)
            ax.set_xticks(np.arange(0.5, 7, 1.), minor=True, length=0, width=0, grid_alpha=0)
            ax.set_xticklabels([str(x) for x in range(1,8)], minor=True, size=25)
            ax.set_xticklabels([], minor=False)
            ax.set_yticklabels([], minor=False)
            ax.set_yticks(np.arange(0.5, 7, 1.), minor=True, length=0, width=0)
            ax.set_yticklabels([str(x) for x in range(1,8)], minor=True, size=25)
            ax.set_zticklabels([], minor=False)
            ax.set_zticks(np.arange(0.5, 7, 1.), minor=True)
            ax.set_zticklabels([str(x) for x in range(1,8)], minor=True, size=25)

            # change camera angle
            ax.view_init(elev=elev, azim=azim)

            ax.grid(False)
            
    # colorbar
    fig.subplots_adjust(right=0.925)
    cbar_ax = fig.add_axes([0.95, 0.70, 0.01, 0.15]) # left, botton, width, height
    cbar = fig.colorbar(m, cax=cbar_ax, fraction=0.020)
    #cbar = plt.colorbar(m, fraction=0.020, pad=0.2)
    cbar.set_label('# of p.e.', rotation=90, labelpad=19, fontsize=40)
    cbar.ax.tick_params(labelsize=35)
            
    #fig.tight_layout()
    #plt.subplots_adjust(top=0.85)
    plt.show()

In [None]:
disable = False
load = False
save = True

min_charge_plot = min_charge
#min_charge_plot = 30

g_losses = []
g_losses_real = []
g_losses_fake = []

epoch = 0

np.set_printoptions(suppress=True)

if load:
    print("Loading saved model...")
    epoch, iteration = 1, 500
    checkpoint = torch.load("models/gen_{}_{}".format(epoch, iteration))
    generator.load_state_dict(checkpoint['g_state_dict'])
    g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
    epoch = checkpoint['epoch']+1
    g_losses = checkpoint['g_losses'].tolist()

for epoch in range(epoch, epochs):
    
    print('Starting epoch {}...'.format(epoch))
    
    train_loss, val_loss = 0., 0.
    
    batch_size = train_loader.batch_size
    n_batches = int(math.ceil(len(train_loader.dataset)/batch_size))
    t = tqdm.tqdm(enumerate(train_loader), total=n_batches, disable=disable)
    
    for ite, (images, labels, charges) in t:
        
        # Train data
        real_images = images
        
        p = float(ite + (epoch) * len(train_loader)) / (epochs) / len(train_loader)
        lambda_pe = 2. / (1. + np.exp(-10 * p)) - 1
        
        # Set generator train
        generator.train()
        
        # Train generator
        g_loss = generator_train_step(len(real_images),\
                                       generator, g_optimizer,\
                                       real_images, labels, charges)
        
        t.set_description("g_loss = {0:.5f}".format(g_loss.item()))
        
        train_loss += g_loss.item()
        
        if ite%50==0:
            g_losses.append(g_loss.item())
            #g_losses_real.append(real_loss.item())
            #g_losses_fake.append(fake_loss.item())
            
        if ite%500==0 and ite>0:
            if not disable:    
                # Set generator eval
                generator.eval()

                true_images = []
                pred_images = []

                for i, (images, labels, charges) in enumerate(valid_loader):

                    print(i)

                    # Building z
                    z = torch.normal(0, 0.5, size=(len(images), img_size, img_size, img_size, noise_size)).to(device)
                    #z = torch.distributions.uniform.Uniform(-1,1).sample([len(images), img_size, img_size, img_size, noise_size]).to(device)
                    sample_images = generator(labels.to(device), z).data.cpu()

                    for j in range(len(images)):
                        true_image = images[j].numpy()
                        pred_image = sample_images[j].numpy()
                        true_image = np.interp(true_image.ravel(), target_range, (min_charge_new, max_charge_new)).reshape(true_image.shape)
                        pred_image = np.interp(pred_image.ravel(), target_range, (min_charge_new, max_charge_new)).reshape(pred_image.shape)
                        #true_image = prepro.inverse_transform(true_image.reshape(-1,1))
                        #pred_image = prepro.inverse_transform(pred_image.reshape(-1,1))
                        true_images.append(true_image)
                        pred_images.append(pred_image)

                    break

                plot_event(true_images[:7], pred_images[:7], labels[:7])
                
            if save:
                torch.save({
                           'epoch': epoch,
                           'g_state_dict': generator.state_dict(),
                           'g_optimizer_state_dict': g_optimizer.state_dict(),
                           'g_losses': np.array(g_losses),
                           }, "models/gen_{}_{}".format(epoch, ite))