In [None]:
import argparse
import torch
import os, sys
import math
from tqdm import tqdm
import numpy as np

from model import BERT
from dataset import VAdatasetFinal
from torch.utils.data import DataLoader
from torch import nn
from trainer import plot_event

# manually specify the GPUs to use
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# parameters

parser = argparse.ArgumentParser()
parser.add_argument("-c", "--dataset_path", required=True, type=str, help="train dataset path")
parser.add_argument("-im", "--img_size", type=int, default=5, help="image size (one dimension)")
parser.add_argument("-is", "--input_size", type=int, default=1, help="input dimension (per cube)")
parser.add_argument("-ls", "--label_size", type=int, default=7, help="number of labels")
parser.add_argument("-ns", "--noise_size", type=int, default=100, help="size of the noise")
parser.add_argument("-hs", "--hidden", type=int, default=256, help="hidden size of transformer model")
parser.add_argument("-dr", "--dropout", type=float, default=0.1, help="dropout of the model")
parser.add_argument("-l", "--layers", type=int, default=8, help="number of layers")
parser.add_argument("-a", "--attn_heads", type=int, default=8, help="number of attention heads")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="number of batch_size")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int, default=5, help="dataloader worker size")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-3, help="learning rate of adam")
parser.add_argument("-wd", "--weight_decay", type=float, default=0.01, help="weight_decay of adam")
parser.add_argument("-b1", "--beta1", type=float, default=0.9, help="adam first beta value")
parser.add_argument("-b2", "--beta2", type=float, default=0.999, help="adam second beta value")   
parser.add_argument("--eps", type=float, default=1e-9, help="value to prevent division by zero")
parser.add_argument("-dis", "--disable", type=bool, default=False, help="whether to show the training progress")
parser.add_argument("-lo", "--load", type=bool, default=False, help="whether to load a pretrained model")
parser.add_argument("-s", "--save", type=bool, default=False, help="whether to save the model after every epoch")
parser.add_argument("-sp", "--save_path", type=str, default=".", help="whether to save the model after every epoch")
parser.add_argument("-es", "--early_stopping", type=int, default=10, help="early stopping count")
parser.add_argument('-sr','--source_range', nargs='+', type=int, default=[-1,1], help='source range')
parser.add_argument('-tr','--target_range', nargs='+', type=int, default=[0,1], help='source range')
parser.add_argument('-tc','--total_charge', type=bool, default=False, help='Use total charge in discriminator')
parser.add_argument('-cr','--crit_repeats', type=int, default=5, help='Critic repetitions')

args = parser.parse_args(["-c", "/scratch2/salonso/vertex_activity/images",
                          "-lo", 0,
                          "-b", "32",
                          "-ls", "6",
                          "-wd", "0",
                          "-dis", "True",
                          "-s", "True",
                          "-sp", "pretrained/gan_tran_rms_final_{}_{}",
                          "-e", "100",
                          "-ns", "512",
                          "-l", "2",
                          "-w", "8",
                          "-hs", "64",
                          "-is", "1",
                          "-lr", "0.00005",
                          "--eps", "1e-8",
                          "-tr", "-1", "1",
                          "-tc", 0,
                          "-cr", "10",
                         ]
                        )

In [None]:
# generate datasets

dataset_original = VAdatasetFinal("/scratch2/salonso/vertex_activity/images", source_range=args.source_range, 
                    target_range=args.target_range)
dataset_10Mev = VAdatasetFinal("/scratch2/salonso/vertex_activity/images_test/10MeV", source_range=args.source_range, 
                    target_range=args.target_range)
dataset_20Mev = VAdatasetFinal("/scratch2/salonso/vertex_activity/images_test/20MeV", source_range=args.source_range, 
                    target_range=args.target_range)
dataset_30Mev = VAdatasetFinal("/scratch2/salonso/vertex_activity/images_test/30MeV", source_range=args.source_range, 
                    target_range=args.target_range)
dataset_40Mev = VAdatasetFinal("/scratch2/salonso/vertex_activity/images_test/40MeV", source_range=args.source_range, 
                    target_range=args.target_range)
dataset_50Mev = VAdatasetFinal("/scratch2/salonso/vertex_activity/images_test/50MeV", source_range=args.source_range, 
                    target_range=args.target_range)

datasets = [dataset_10Mev, dataset_20Mev, dataset_30Mev, dataset_40Mev, dataset_50Mev]

In [None]:
class Generator(nn.Module):
    """
    BERT-based Generator
    """

    def __init__(self, bert: BERT):
        """
        :param bert: BERT model which should be trained
        """

        super().__init__()
        self.bert = bert
        self.decoder = Decoder(self.bert.hidden, 1, activation=True)

    def forward(self, label, noise):
        x = self.bert(input=None, label=label, noise=noise)        
        return self.decoder(x).view(x.shape[0], -1)

class Decoder(nn.Module):

    def __init__(self, hidden, outsize, activation="sigmoid"):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, outsize)
        if activation is not None:
            if args.target_range==[0,1]:
                self.activation = nn.Sigmoid()
            elif args.target_range==[-1,1]:
                self.activation = nn.Tanh()
            else:
                assert False
        else:
            self.activation = None

    def forward(self, x):
        x = self.linear(x)
        if self.activation is not None:
            x = self.activation(x)
        return x

In [None]:
# ini model

print("Building BERT model")
bert_gen = BERT(input_size=args.input_size, label_size=args.label_size, noise_size=args.noise_size, hidden=args.hidden,
            n_layers=args.layers, attn_heads=args.attn_heads, total_charge=args.total_charge, dropout=args.dropout, device=device)
model = Generator(bert_gen).to(device)

model_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(model)
print("total trainable params: {} (model).".format(model_total_params))

# load saved model
epoch, iteration = 40, 12000
checkpoint = torch.load(args.save_path.format(epoch, iteration))
model.load_state_dict(checkpoint['g_state_dict'], strict=False)
model.eval();

In [None]:
'''
Run the network on a single image!
'''

# physics parameters (arbitrary)
ini_x = 3.9
ini_y = 2.1
ini_z = -4.0
ke = 12.1
theta = 2.14
phi = -1.03
params = np.array([ini_x, ini_y, ini_z, ke, theta, phi])

# normalise (range [0,1])
params[:3] = np.interp(params[:3].ravel(), (dataset_original.min_pos, dataset_original.max_pos), 
                       dataset_original.source_range).reshape(params[:3].shape)
params[3] = np.interp(params[3], (dataset_original.min_KE, dataset_original.max_KE), 
                      dataset_original.source_range).reshape(1)
params[4] = np.interp(theta, (dataset_original.min_theta, dataset_original.max_theta), 
                      dataset_original.source_range).reshape(1)
params[5] = np.interp(phi, (dataset_original.min_phi, dataset_original.max_phi), 
                      dataset_original.source_range).reshape(1)

# tensors needed to run the network
params = torch.tensor([params]).float().to(device)
noise = torch.normal(0, 1, size=(len(params), 1, args.noise_size)).to(device) # normal noise!

# run the network!
sample_image = model(params, noise).data.cpu()

# sample image to standard range
sample_image = np.interp(sample_image.ravel(), dataset_original.target_range, 
                         (dataset_original.min_charge_new, 
                          dataset_original.max_charge_new)).reshape(sample_image.shape)

# plot the event!
plot_event(sample_image, elev=20, azim=30, img_size=5)