# Analysis of ATLAS-GAN results

Let's take a look at the results of the DCGAN trained to generate ATLAS RPV SUSY events.

In [2]:
# Convenient fudge for python path
import sys
sys.path.append('..')

In [3]:
# System
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import os
from datetime import datetime

# Externals
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import torch
from torch.autograd import Variable

# Locals
from atlasgan import gan
from atlasgan.dataset import RPVImages, inverse_transform_data, generate_noise
from atlasgan.reco import compute_physics_variables

%matplotlib notebook

In [4]:
def draw_image(x, vmin=None, vmax=None, xlabel='$\eta$', ylabel='$\phi$',
               draw_cbar=True, ax=None, figsize=(5,4), **kwargs):
    if ax is None:
        _, ax = plt.subplots(figsize=figsize)
    cax = ax.imshow(x.T, extent=[-2.5, 2.5, -3.15, 3.15],
                    norm=LogNorm(vmin, vmax), aspect='auto',
                    origin='lower', **kwargs)
    if draw_cbar:
        cbar = plt.gcf().colorbar(cax, ax=ax)
        cbar.set_label('Energy')
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

In [5]:
def draw_reco(sample):
    """Apply reconstruction to one sample and visualize the clusters"""
    jets = build_jets(sample)
    
    # Draw the provided jet image on the left
    fig, axs = plt.subplots(1, 2, figsize=(8,4))
    draw_image(sample, vmin=1e-4, ax=axs[0], draw_cbar=False)

    # Draw the jets on the right
    for j in jets:
        s = axs[1].scatter(j.eta, j.phi, s=200, marker='+')
        cl = j.constituents_array()
        axs[1].scatter(cl['eta'], cl['phi'], s=5, marker=',', c=s.get_edgecolor())
        axs[1].set_xlim((-2.5, 2.5))
        axs[1].set_ylim((-np.pi, np.pi))
        axs[1].set_xlabel('$\eta$')
        axs[1].set_ylabel('$\phi$')

    plt.tight_layout()

## Load training results and test data

We don't currently have a separate validation and test dataset, so for now we will do the model selection here on a test dataset.

In [6]:
results_dir = '/data0/sfarrell/atlas_gan/AtlasDCGAN_008_n64_h64/'

In [7]:
ls $results_dir

[0m[01;34mcheckpoints[0m/  out.log  README  summaries.npz


In [8]:
cat $results_dir/README

Testing some new HPs


In [9]:
# Load the summary data
summaries = np.load(os.path.join(results_dir, 'summaries.npz'))
print(summaries.keys())

['gen_samples', 'd_train_loss', 'd_train_output_real', 'epoch', 'd_train_output_fake', 'g_train_loss']


In [10]:
input_data = '/data0/sfarrell/atlas_rpv_data/RPV10_1400_850_01.npz'

scale = 4e6 #6072947
threshold = 500/scale
n_samples = 4096
noise_dim = 64
n_filters = 64

In [11]:
dataset = RPVImages(input_data, n_samples=n_samples, scale=scale, from_back=True)

## Training loss and mean outputs

In [12]:
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 4))

# Plot training loss
ax0.semilogy(summaries['d_train_loss'], label='discriminator')
ax0.semilogy(summaries['g_train_loss'], label='generator')
ax0.set_xlabel('Epochs')
ax0.set_ylabel('Loss')
ax0.legend(loc=0)

# Plot average discriminator outputs
ax1.semilogy(summaries['d_train_output_fake'], label='fake samples')
ax1.semilogy(summaries['d_train_output_real'], label='real samples')
ax1.legend(loc=0)
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Average discriminator output')
plt.tight_layout()

<IPython.core.display.Javascript object>

## Compute metrics for every epoch

Loop over epochs, load the checkpoint, generate the samples, compute the metrics.

Oh, but compute the real sample stuff first, of course!

In [13]:
def load_model(results_dir, checkpoint_id, noise_dim=64, threshold=0, n_filters=16):
    """Load generator and discriminator from checkpoint"""
    checkpoint_file = os.path.join(
        results_dir, 'checkpoints', 'model_checkpoint_%03i.pth.tar' % checkpoint_id
    )
    # Load the checkpoint and map onto CPU
    checkpoint = torch.load(checkpoint_file, map_location=lambda storage, loc: storage)
    generator = gan.Generator(noise_dim, threshold=threshold, n_filters=n_filters)
    discriminator = gan.Discriminator(n_filters=n_filters)
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    return generator, discriminator

def ks_metric(x, y):
    return -np.log(scipy.stats.ks_2samp(x, y).pvalue)

In [14]:
# Real images
test_real = Variable(dataset.data, volatile=True)
# Random noise; input for generator
test_noise = Variable(generate_noise(n_samples, noise_dim), volatile=True)
# Get the numpy format of real images
test_images = inverse_transform_data(dataset.data.numpy().squeeze(1), scale)
# Compute reconstructed physics variables from real images
real_reco_vars = compute_physics_variables(test_images)

In [15]:
def evaluate_model(epoch, train_dir, noise_dim, n_filters,
                   test_real, test_noise, real_reco_vars,
                   test_images, scale, threshold):
    print(str(datetime.now()), 'Epoch', epoch)

    # Dictionary of results for this epoch
    results = {}
    results['epoch'] = epoch
    
    # Load the model
    generator, discriminator = load_model(results_dir, checkpoint_id=epoch,
                                          noise_dim=noise_dim, n_filters=n_filters,
                                          threshold=threshold)
    generator.eval(), discriminator.eval()
    
    # Generated images
    test_fake = generator(test_noise)
    gen_images = inverse_transform_data(test_fake.data.numpy().squeeze(1), scale, threshold)
    results['gen_images'] = gen_images

    # Apply discriminator to real and fake samples
    dis_real = discriminator(test_real)
    dis_fake = discriminator(test_fake)

    # Save the average discriminator scores
    results['dis_real'] = dis_real.mean().data[0]
    results['dis_fake'] = dis_fake.mean().data[0]

    # Compute physics variables on the generated images
    fake_reco_vars = compute_physics_variables(gen_images)
    results.update(fake_reco_vars)
    
    # Compute KS metrics
    for key in fake_reco_vars.keys():
        results['ks_' + key] = ks_metric(real_reco_vars[key], fake_reco_vars[key])    
    # Pixel intensity KS metric
    results['ks_pixel'] = ks_metric(test_images.flatten(), gen_images.flatten())

    return results

In [16]:
%%time

# Initially store all results in an epoch list
results = []

# Loop over epoch checkpoints
for i in summaries['epoch']:
    # Dictionary of results for this epoch
    epoch_results = evaluate_model(epoch=i, train_dir=results_dir,
                                   noise_dim=noise_dim, n_filters=n_filters,
                                   test_real=test_real, test_noise=test_noise,
                                   real_reco_vars=real_reco_vars, test_images=test_images,
                                   scale=scale, threshold=threshold)
    results.append(epoch_results)

# Pack final results into one dict of arrays
final_results = {}
for key in results[0].keys():
    final_results[key] = np.array([r[key] for r in results])
results = final_results

2018-06-28 15:25:07.391721 Epoch 0




2018-06-28 15:26:01.836177 Epoch 1
2018-06-28 15:26:56.674892 Epoch 2
2018-06-28 15:27:48.925437 Epoch 3
2018-06-28 15:28:42.791522 Epoch 4
2018-06-28 15:29:36.151875 Epoch 5
2018-06-28 15:30:30.258110 Epoch 6
2018-06-28 15:31:24.117108 Epoch 7
2018-06-28 15:32:17.983890 Epoch 8
2018-06-28 15:33:13.395573 Epoch 9
2018-06-28 15:34:09.545081 Epoch 10
2018-06-28 15:35:01.938633 Epoch 11
2018-06-28 15:35:55.664373 Epoch 12
2018-06-28 15:36:49.321474 Epoch 13
2018-06-28 15:37:44.035423 Epoch 14
2018-06-28 15:38:38.400261 Epoch 15
2018-06-28 15:39:32.575025 Epoch 16
2018-06-28 15:40:24.845446 Epoch 17
2018-06-28 15:41:20.194892 Epoch 18
2018-06-28 15:42:12.903038 Epoch 19
2018-06-28 15:43:05.469155 Epoch 20
2018-06-28 15:43:59.574181 Epoch 21
2018-06-28 15:44:53.516296 Epoch 22
2018-06-28 15:45:48.504767 Epoch 23
2018-06-28 15:46:43.276889 Epoch 24
2018-06-28 15:47:35.938687 Epoch 25
2018-06-28 15:48:31.985449 Epoch 26
2018-06-28 15:49:26.385400 Epoch 27
2018-06-28 15:50:19.371623 Epoch 28
2

In [17]:
# Sum of neg log ks pvalues metric
combined_keys = ['ks_nJet', 'ks_sumMass', 'ks_jetPt']
results['ks_sum'] = sum([results[k] for k in combined_keys])

In [19]:
# Plot the physics variable metrics
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9,4))

keys_labels = {
    'ks_pixel': 'pixel',
    'ks_nJet': 'nJet',
    'ks_sumMass': 'sumMass',
    'ks_jetPt': 'jetPt',
    #'ks_jetEta': 'jetEta',
    #'ks_jetPhi': 'jetPhi',
}
for key, label in keys_labels.items():
    ax0.plot(results[key], label=label)
ax0.set_xlabel('Epoch')
ax0.set_ylabel('-log(KS p-value)')
ax0.set_title('Individual metrics')
ax0.legend(loc=0)

# Plot the summed metric separately
ax1.plot(results['ks_sum'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('-log(KS p-value)')
ax1.set_title('Combined metric')

plt.tight_layout()

<IPython.core.display.Javascript object>

## Plot kinematics for select epochs

In [20]:
def plot_physics(real_pixels, fake_pixels,
                 real_nJet, fake_nJet,
                 real_sumMass, fake_sumMass,
                 real_pt, fake_pt,
                 real_eta, fake_eta,
                 real_phi, fake_phi,
                 figsize=(9, 6), linewidth=1):
    
    fig1, axs = plt.subplots(2, 3, figsize=(9, 6))
    (ax0, ax1, ax2), (ax3, ax4, ax5) = axs
    
    # Pixel energy
    hist_args = dict(bins=100, range=(0, 1e3), log=True, histtype='step', linewidth=linewidth)
    ax0.hist(real_pixels.flatten()*1e-3, label='real', **hist_args)
    ax0.hist(fake_pixels.flatten()*1e-3, label='fake', **hist_args)
    ax0.set_xlabel('Pixel energy [GeV]')
    ax0.legend(loc=0)
    
    # Jet multiplicity
    hist_args = dict(bins=10, range=(0, 10), log=False, histtype='step', linewidth=linewidth)
    ax1.hist(real_nJet, label='real', **hist_args)
    ax1.hist(fake_nJet, label='fake', **hist_args)
    ax1.set_xlabel('Number of jets')
    ax1.legend(loc=0)

    # Sum of jet mass
    hist_args = dict(bins=20, range=(0, 2e3), log=False, histtype='step', linewidth=linewidth)
    ax2.hist(real_sumMass*1e-3, label='real', **hist_args)
    ax2.hist(fake_sumMass*1e-3, label='fake', **hist_args)
    ax2.set_xlabel('Sum of jet mass [GeV]')
    ax2.legend(loc=0)
    
    # Jet PT
    hist_args = dict(bins=20, range=(0, 2e3), histtype='step', linewidth=linewidth, log=False)
    ax3.hist(real_pt*1e-3, label='real', **hist_args)
    ax3.hist(fake_pt*1e-3, label='fake', **hist_args)
    ax3.set_xlabel('Jet $p_T$ [GeV]')
    ax3.legend(loc=0)

    # Jet eta
    hist_args = dict(bins=20, range=(-2, 2), histtype='step', linewidth=linewidth)
    ax4.hist(real_eta, label='real', **hist_args)
    ax4.hist(fake_eta, label='fake', **hist_args)
    ax4.set_xlabel('Jet $\eta$')
    ax4.legend(loc=0)

    # Jet phi
    hist_args = dict(bins=20, range=(-np.pi, np.pi), histtype='step', linewidth=linewidth)
    ax5.hist(real_phi, label='real', **hist_args)
    ax5.hist(fake_phi, label='fake', **hist_args)
    ax5.set_xlabel('Jet $\phi$')
    ax5.legend(loc=0)

    plt.tight_layout()

In [33]:
# Which epochs were the best for certain metrics?
for key in ['ks_nJet', 'ks_pixel', 'ks_sumMass', 'ks_jetPt', 'ks_jetEta', 'ks_jetPhi', 'ks_sum']:
    ibest = results[key].argmin()
    print('%10s  %2i  %.3f' % (key, ibest, results[key][ibest]))

   ks_nJet   9  3.142
  ks_pixel  53  216.276
ks_sumMass  63  1.298
  ks_jetPt  47  1.428
 ks_jetEta   3  8.790
 ks_jetPhi  15  7.871
    ks_sum  47  20.661


In [33]:
# Manually including the 'best' ones for the above variables
select_epochs = [0, 14, 15, 41, 63, 47]
for i in select_epochs:
    print('Epoch', i)
    plot_physics(test_images, results['gen_images'][i],
             real_reco_vars['nJet'], results['nJet'][i],
             real_reco_vars['sumMass'], results['sumMass'][i],
             real_reco_vars['jetPt'], results['jetPt'][i],
             real_reco_vars['jetEta'], results['jetEta'][i],
             real_reco_vars['jetPhi'], results['jetPhi'][i])

Epoch 0


<IPython.core.display.Javascript object>

Epoch 14


<IPython.core.display.Javascript object>

Epoch 15


<IPython.core.display.Javascript object>

Epoch 41


<IPython.core.display.Javascript object>

Epoch 63


<IPython.core.display.Javascript object>

Epoch 47


<IPython.core.display.Javascript object>

## View sample images from a select epochs

In [34]:
# Real images
fig, axs = plt.subplots(2, 3, figsize=(9.5, 5))
for i, ax in enumerate(axs.flatten()):
    draw_image(test_images[i]*1e-3, ax=ax)
plt.tight_layout()

<IPython.core.display.Javascript object>

In [35]:
# Fake images from select epochs
for epoch in select_epochs:
    print('Epoch', epoch)
    gen_images = results['gen_images'][epoch]

    fig, axs = plt.subplots(2, 3, figsize=(9.5, 5))
    for i, ax in enumerate(axs.flatten()):
        draw_image(gen_images[i]*1e-3, ax=ax)
    plt.tight_layout()

Epoch 0


<IPython.core.display.Javascript object>

Epoch 14


<IPython.core.display.Javascript object>

Epoch 15


<IPython.core.display.Javascript object>

Epoch 41


<IPython.core.display.Javascript object>

Epoch 63


<IPython.core.display.Javascript object>

Epoch 47


<IPython.core.display.Javascript object>

## Average images

In [36]:
# Plot the average real image

In [39]:
test_images.sum(axis=0).shape

(64, 64)

In [40]:
draw_image(test_images.sum(axis=0)*1e-3)

<IPython.core.display.Javascript object>

In [42]:
# Fake images from select epochs
for epoch in select_epochs:
    print('Epoch', epoch)
    gen_images = results['gen_images'][epoch]
    draw_image(gen_images.sum(axis=0)*1e-3)

Epoch 0


<IPython.core.display.Javascript object>

Epoch 14


<IPython.core.display.Javascript object>

Epoch 15


<IPython.core.display.Javascript object>

Epoch 41


<IPython.core.display.Javascript object>

Epoch 63




<IPython.core.display.Javascript object>

Epoch 47


<IPython.core.display.Javascript object>

## Discussion

The KS metric seems to be a reasonable way to pick out the best distributions. And the summed metric can probably be used to pick out the best model epoch, but one has to choose which distributions to include in it.

The GAN clearly struggles to capture certain things, especially the geometric distribution of the jets.

Also, one can observe funny artifacts in the generated images such as common pixel patterns.