# Analysis of ATLAS-GAN results

Let's take a look at the results of the DCGAN trained to generate ATLAS RPV SUSY events.

In [1]:
# Convenient fudge for python path
import sys
sys.path.append('..')

In [25]:
# Compatibility
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

# System
import os
import json

# Externals
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import torch
from torch.autograd import Variable

# Locals
from atlasgan import gan
from atlasgan.dataset import RPVImages, inverse_transform_data, generate_noise
from atlasgan.reco import compute_physics_variables

# Magics
%matplotlib notebook

In [3]:
def draw_image(x, vmin=None, vmax=None, xlabel='$\eta$', ylabel='$\phi$',
               draw_cbar=True, ax=None, figsize=(5,4), **kwargs):
    if ax is None:
        _, ax = plt.subplots(figsize=figsize)
    cax = ax.imshow(x.T, extent=[-2.5, 2.5, -3.15, 3.15],
                    norm=LogNorm(vmin, vmax), aspect='auto',
                    origin='lower', **kwargs)
    if draw_cbar:
        cbar = plt.gcf().colorbar(cax, ax=ax)
        cbar.set_label('Energy')
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

In [4]:
def draw_reco(sample):
    """Apply reconstruction to one sample and visualize the clusters"""
    jets = build_jets(sample)
    
    # Draw the provided jet image on the left
    fig, axs = plt.subplots(1, 2, figsize=(8,4))
    draw_image(sample, vmin=1e-4, ax=axs[0], draw_cbar=False)

    # Draw the jets on the right
    for j in jets:
        s = axs[1].scatter(j.eta, j.phi, s=200, marker='+')
        cl = j.constituents_array()
        axs[1].scatter(cl['eta'], cl['phi'], s=5, marker=',', c=s.get_edgecolor())
        axs[1].set_xlim((-2.5, 2.5))
        axs[1].set_ylim((-np.pi, np.pi))
        axs[1].set_xlabel('$\eta$')
        axs[1].set_ylabel('$\phi$')

    plt.tight_layout()

## Load validation results for all HP sets

In [5]:
def load_model_config(train_dir):
    config_file = os.path.join(train_dir, 'config.json')
    with open(config_file, 'r') as f:
        config = json.load(f)
    return config

def load_validation_metrics(train_dir):
    metrics_file = os.path.join(train_dir, 'validation_metrics.npz')
    return pd.DataFrame(dict(np.load(metrics_file)))

In [6]:
# Directory containing all of the training directories we want to process
results_dir = '/global/cscratch1/sd/sfarrell/atlas_gan/AtlasDCGAN_180629'

In [7]:
# Find all the training directories that have validation results avalailable
train_dirs = [os.path.join(results_dir, d) for d in os.listdir(results_dir)]
train_dirs = [d for d in train_dirs if os.path.exists(os.path.join(d, 'validation_metrics.npz'))]

# Load the results for each HP set
results = []
for i, train_dir in enumerate(train_dirs):
    config = load_model_config(train_dir)
    metrics = load_validation_metrics(train_dir)
    results.append(metrics.assign(hp=i, **config))
# Combined into single DataFrame
results = pd.concat(results, ignore_index=True)

In [8]:
# Add the combined KS metric
ks_comb = results.ks_jetPt + results.ks_nJet + results.ks_sumMass
results = results.assign(ks_comb=ks_comb)

## Find the best models

In [9]:
# Summarize the hyper-parameter sets explored
results[['hp', 'n_filters', 'noise_dim', 'lr', 'flip_rate']].drop_duplicates()

Unnamed: 0,hp,n_filters,noise_dim,lr,flip_rate
0,0,32,128,0.0002,0.021051
64,1,8,128,0.005,0.186247
128,2,8,16,0.001,0.047855
192,3,64,128,1e-05,0.012984
256,4,8,64,0.001,0.067979
320,5,32,32,0.005,0.175582
384,6,8,64,1e-05,0.07765
448,7,16,16,0.005,0.087252
512,8,64,128,0.001,0.02606
576,9,128,16,1e-05,0.115164


In [10]:
# Show the best model epochs across all HPs using the combined KS metric
top_k = 10
idx = results.ks_comb.sort_values().index[:top_k]
topk_results = results.loc[idx]
topk_results

Unnamed: 0,ks_nJet,ks_sumMass,ks_jetPt,ks_jetEta,ks_jetPhi,epoch,hp,flip_rate,noise_dim,image_norm,n_filters,lr,beta2,beta1,threshold,ks_comb
1141,5.434601,9.926581,1.342309,23.171104,57.07768,53,17,0.038879,128,4000000.0,32,0.0001,0.999,0.5,0.000125,16.703491
1372,5.512413,1.962132,12.543727,41.040549,4.455274,28,21,0.181805,128,4000000.0,128,0.0002,0.999,0.5,0.000125,20.018272
1682,1.477262,8.355615,11.968152,94.938016,24.188031,18,26,0.013889,64,4000000.0,32,0.0001,0.999,0.5,0.000125,21.801029
1110,4.186911,15.772782,2.311557,22.182108,18.764115,22,17,0.038879,128,4000000.0,32,0.0001,0.999,0.5,0.000125,22.27125
1695,8.450119,3.142228,12.002732,46.099971,55.722384,31,26,0.013889,64,4000000.0,32,0.0001,0.999,0.5,0.000125,23.595079
1130,2.667043,18.551183,6.26724,25.096557,49.877545,42,17,0.038879,128,4000000.0,32,0.0001,0.999,0.5,0.000125,27.485465
622,9.125388,12.518674,7.221948,18.255958,107.365506,46,9,0.115164,16,4000000.0,128,1e-05,0.999,0.5,0.000125,28.86601
1704,5.35728,6.484448,17.073575,26.512215,38.368916,40,26,0.013889,64,4000000.0,32,0.0001,0.999,0.5,0.000125,28.915303
1384,16.156442,4.467721,9.130697,4.888904,49.275106,40,21,0.181805,128,4000000.0,128,0.0002,0.999,0.5,0.000125,29.75486
1702,4.905969,7.664468,17.671288,53.083806,44.267865,38,26,0.013889,64,4000000.0,32,0.0001,0.999,0.5,0.000125,30.241725


In [11]:
# Show the best few models from each HP
n_best = 1
hpbest_results = (results.groupby('hp', as_index=False)
                  .apply(lambda x: x.loc[x.ks_comb.sort_values().index[:n_best]]))
hpbest_results

Unnamed: 0,Unnamed: 1,ks_nJet,ks_sumMass,ks_jetPt,ks_jetEta,ks_jetPhi,epoch,hp,flip_rate,noise_dim,image_norm,n_filters,lr,beta2,beta1,threshold,ks_comb
0,36,7.084115,16.28531,9.779407,73.201768,36.23299,36,0,0.021051,128,4000000.0,32,0.0002,0.999,0.5,0.000125,33.148832
1,64,621.083835,224.043975,18.914444,72.031426,700.0,0,1,0.186247,128,4000000.0,8,0.005,0.999,0.5,0.000125,864.042254
2,133,13.210881,44.904377,10.640436,56.41069,175.553686,5,2,0.047855,16,4000000.0,8,0.001,0.999,0.5,0.000125,68.755694
3,242,8.075051,18.141013,26.55388,166.577253,133.77991,50,3,0.012984,128,4000000.0,64,1e-05,0.999,0.5,0.000125,52.769944
4,294,290.941905,3.08111,158.479666,194.008035,506.080419,38,4,0.067979,64,4000000.0,8,0.001,0.999,0.5,0.000125,452.502682
5,356,18.551183,700.0,700.0,233.961082,156.337861,36,5,0.175582,32,4000000.0,32,0.005,0.999,0.5,0.000125,1418.551183
6,440,15.519464,236.495426,42.393115,452.327167,441.34594,56,6,0.07765,64,4000000.0,8,1e-05,0.999,0.5,0.000125,294.408005
7,472,177.908224,700.0,700.0,122.041101,81.109943,24,7,0.087252,16,4000000.0,16,0.005,0.999,0.5,0.000125,1577.908224
8,516,29.033451,168.406582,9.94256,26.344795,40.131058,4,8,0.02606,128,4000000.0,64,0.001,0.999,0.5,0.000125,207.382593
9,622,9.125388,12.518674,7.221948,18.255958,107.365506,46,9,0.115164,16,4000000.0,128,1e-05,0.999,0.5,0.000125,28.86601


### Scatter plots of metrics vs. HP values

In [12]:
%matplotlib notebook

In [13]:
# Metrics vs. noise-dim for all models
fig, axs = plt.subplots(2, 2, figsize=(9, 7))
axs[0,0].scatter(results.noise_dim, results.ks_comb, s=2)
axs[0,0].set_xlabel('Noise dim')
axs[0,1].scatter(results.n_filters, results.ks_comb, s=2)
axs[0,1].set_xlabel('Number of filters')
axs[1,0].scatter(results.flip_rate, results.ks_comb, s=2)
axs[1,0].set_xlabel('Label flip rate')
axs[1,1].scatter(results.lr, results.ks_comb, s=2)
axs[1,1].set_xlabel('Learning rate')
plt.tight_layout()

<IPython.core.display.Javascript object>

In [14]:
# Metrics vs. noise-dim for best models per HP set
fig, axs = plt.subplots(2, 2, figsize=(9, 7))
axs[0,0].scatter(hpbest_results.noise_dim, hpbest_results.ks_comb, s=10)
axs[0,0].set_xlabel('Noise dim')
axs[0,1].scatter(hpbest_results.n_filters, hpbest_results.ks_comb, s=10)
axs[0,1].set_xlabel('Number of filters')
axs[1,0].scatter(hpbest_results.flip_rate, hpbest_results.ks_comb, s=10)
axs[1,0].set_xlabel('Label flip rate')
axs[1,1].scatter(hpbest_results.lr, hpbest_results.ks_comb, s=10)
axs[1,1].set_xlabel('Learning rate')
plt.tight_layout()

<IPython.core.display.Javascript object>

## Evaluate best model on test set

In [15]:
def load_model(train_dir, checkpoint_id, model_config):
    """Load generator and discriminator from checkpoint"""
    checkpoint_file = os.path.join(
        train_dir, 'checkpoints', 'model_checkpoint_%03i.pth.tar' % checkpoint_id
    )
    # Load the checkpoint and map onto CPU
    checkpoint = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)
    generator = gan.Generator(model_config['noise_dim'],
                              threshold=model_config['threshold'],
                              n_filters=model_config['n_filters'])
    discriminator = gan.Discriminator(n_filters=model_config['n_filters'])
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    # Ensure the model is in eval mode
    return generator.eval(), discriminator.eval()

### Load the test set

In [16]:
input_data = '/global/cscratch1/sd/sfarrell/atlas_gan/data/RPV10_1400_850_01.npz'

scale = 4e6
n_test = 4096

In [17]:
dataset = RPVImages(input_data, n_samples=n_test, scale=scale, from_back=True)

### Load the best model

In [18]:
best_hp = topk_results.hp.iloc[0]
best_train_dir = train_dirs[best_hp]
best_epoch = topk_results.epoch.iloc[0]
best_config = load_model_config(best_train_dir)
generator, discriminator = load_model(best_train_dir, best_epoch, best_config)

### Evaluate the test set

In [19]:
# Real images and reconstructed physics variables
test_real = Variable(dataset.data, volatile=True)
real_images = inverse_transform_data(dataset.data.numpy().squeeze(1), scale)
real_vars = compute_physics_variables(real_images)

In [20]:
# Generated images and reconstructed physics variables
test_noise = Variable(generate_noise(n_test, best_config['noise_dim']), volatile=True)
test_fake = generator(test_noise)
fake_images = inverse_transform_data(test_fake.data.numpy().squeeze(1), scale)
fake_vars = compute_physics_variables(fake_images)

In [21]:
def plot_physics(real_pixels, fake_pixels,
                 real_vars, fake_vars,
                 figsize=(9, 6), linewidth=1):
    
    fig1, axs = plt.subplots(2, 3, figsize=(9, 6))
    (ax0, ax1, ax2), (ax3, ax4, ax5) = axs
    
    # Pixel energy
    hist_args = dict(bins=100, range=(0, 1e3), log=True, histtype='step', linewidth=linewidth)
    ax0.hist(real_pixels.flatten()*1e-3, label='real', **hist_args)
    ax0.hist(fake_pixels.flatten()*1e-3, label='fake', **hist_args)
    ax0.set_xlabel('Pixel energy [GeV]')
    ax0.legend(loc=0)
    
    # Jet multiplicity
    hist_args = dict(bins=10, range=(0, 10), log=False, histtype='step', linewidth=linewidth)
    ax1.hist(real_vars['nJet'], label='real', **hist_args)
    ax1.hist(fake_vars['nJet'], label='fake', **hist_args)
    ax1.set_xlabel('Number of jets')
    ax1.legend(loc=0)

    # Sum of jet mass
    hist_args = dict(bins=20, range=(0, 2e3), log=False, histtype='step', linewidth=linewidth)
    ax2.hist(real_vars['sumMass']*1e-3, label='real', **hist_args)
    ax2.hist(fake_vars['sumMass']*1e-3, label='fake', **hist_args)
    ax2.set_xlabel('Sum of jet mass [GeV]')
    ax2.legend(loc=0)
    
    # Jet PT
    hist_args = dict(bins=20, range=(0, 2e3), histtype='step', linewidth=linewidth, log=False)
    ax3.hist(real_vars['jetPt']*1e-3, label='real', **hist_args)
    ax3.hist(fake_vars['jetPt']*1e-3, label='fake', **hist_args)
    ax3.set_xlabel('Jet $p_T$ [GeV]')
    ax3.legend(loc=0)

    # Jet eta
    hist_args = dict(bins=20, range=(-2, 2), histtype='step', linewidth=linewidth)
    ax4.hist(real_vars['jetEta'], label='real', **hist_args)
    ax4.hist(fake_vars['jetEta'], label='fake', **hist_args)
    ax4.set_xlabel('Jet $\eta$')
    ax4.legend(loc=0)

    # Jet phi
    hist_args = dict(bins=20, range=(-np.pi, np.pi), histtype='step', linewidth=linewidth)
    ax5.hist(real_vars['jetPhi'], label='real', **hist_args)
    ax5.hist(fake_vars['jetPhi'], label='fake', **hist_args)
    ax5.set_xlabel('Jet $\phi$')
    ax5.legend(loc=0)

    plt.tight_layout()

In [22]:
plot_physics(real_images, fake_images, real_vars, fake_vars)

<IPython.core.display.Javascript object>

## Average images

In [39]:
# Plot the average real image
draw_image(real_images.sum(axis=0)*1e-3 / n_test)

<IPython.core.display.Javascript object>

In [40]:
# Plot the average real image
draw_image(fake_images.sum(axis=0)*1e-3 / n_test)

<IPython.core.display.Javascript object>

## Compute metrics for every epoch

Loop over epochs, load the checkpoint, generate the samples, compute the metrics.

Oh, but compute the real sample stuff first, of course!

## Plot kinematics for select epochs

## View sample images from a select epochs

## Average images

## Discussion