# Analysis of ATLAS-GAN results

Let's take a look at the results of the DCGAN trained to generate ATLAS RPV SUSY events.

In [1]:
# Convenient fudge for python path
import sys
sys.path.append('..')

In [2]:
# System
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import
import os
from datetime import datetime

# Externals
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import torch
from torch.autograd import Variable

# Locals
from atlasgan import gan
from atlasgan.dataset import RPVImages, inverse_transform_data, generate_noise
from atlasgan.reco import compute_physics_variables

%matplotlib notebook

In [3]:
def draw_image(x, vmin=None, vmax=None, xlabel='$\eta$', ylabel='$\phi$',
               draw_cbar=True, ax=None, figsize=(5,4), **kwargs):
    if ax is None:
        _, ax = plt.subplots(figsize=figsize)
    cax = ax.imshow(x.T, extent=[-2.5, 2.5, -3.15, 3.15],
                    norm=LogNorm(vmin, vmax), aspect='auto',
                    origin='lower', **kwargs)
    if draw_cbar:
        cbar = plt.gcf().colorbar(cax, ax=ax)
        cbar.set_label('Energy')
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)

In [4]:
def draw_reco(sample):
    """Apply reconstruction to one sample and visualize the clusters"""
    jets = build_jets(sample)
    
    # Draw the provided jet image on the left
    fig, axs = plt.subplots(1, 2, figsize=(8,4))
    draw_image(sample, vmin=1e-4, ax=axs[0], draw_cbar=False)

    # Draw the jets on the right
    for j in jets:
        s = axs[1].scatter(j.eta, j.phi, s=200, marker='+')
        cl = j.constituents_array()
        axs[1].scatter(cl['eta'], cl['phi'], s=5, marker=',', c=s.get_edgecolor())
        axs[1].set_xlim((-2.5, 2.5))
        axs[1].set_ylim((-np.pi, np.pi))
        axs[1].set_xlabel('$\eta$')
        axs[1].set_ylabel('$\phi$')

    plt.tight_layout()

## Load training results and test data

We don't currently have a separate validation and test dataset, so for now we will do the model selection here on a test dataset.

In [5]:
results_dir = '/data0/sfarrell/atlas_gan/AtlasDCGAN_003'

In [6]:
ls $results_dir

[0m[01;34mcheckpoints[0m/  out.log  README  summaries.npz


In [7]:
# Load the summary data
summaries = np.load(os.path.join(results_dir, 'summaries.npz'))
print(summaries.keys())

['dis_output_fake', 'gen_samples', 'dis_output_real', 'epoch', 'gen_loss', 'dis_loss']


In [11]:
input_data = '/data0/sfarrell/atlas_rpv_data/RPV10_1400_850_01.npz'

scale = 4e6 #6072947 #FIXME
threshold = 500
n_samples = 4096

In [12]:
dataset = RPVImages(input_data, n_samples=n_samples, scale=scale, from_back=True)

## Training loss and mean outputs

In [8]:
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9, 4))

# Plot training loss
ax0.semilogy(summaries['dis_loss'], label='discriminator')
ax0.semilogy(summaries['gen_loss'], label='generator')
ax0.set_xlabel('Epochs')
ax0.set_ylabel('Loss')
ax0.legend(loc=0)

# Plot average discriminator outputs
ax1.semilogy(summaries['dis_output_fake'], label='fake samples')
ax1.semilogy(summaries['dis_output_real'], label='real samples')
ax1.legend(loc=0)
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Average discriminator output')
plt.tight_layout()

<IPython.core.display.Javascript object>

## Compute metrics for every epoch

Loop over epochs, load the checkpoint, generate the samples, compute the metrics.

Oh, but compute the real sample stuff first, of course!

In [9]:
def load_model(results_dir, checkpoint_id, noise_dim=64):
    """Load generator and discriminator from checkpoint"""
    checkpoint_file = os.path.join(
        results_dir, 'checkpoints', 'model_checkpoint_%03i.pth.tar' % checkpoint_id
    )
    checkpoint = torch.load(checkpoint_file)
    generator = gan.Generator(noise_dim)
    discriminator = gan.Discriminator()
    generator.load_state_dict(checkpoint['generator'])
    discriminator.load_state_dict(checkpoint['discriminator'])
    return generator, discriminator

def ks_metric(x, y):
    return -np.log(scipy.stats.ks_2samp(x, y).pvalue)

In [13]:
# Real images
test_real = Variable(dataset.data, volatile=True)
# Random noise; input for generator
test_noise = Variable(generate_noise(n_samples, noise_dim), volatile=True)
# Get the numpy format of real images
test_images = inverse_transform_data(dataset.data.numpy().squeeze(1), scale, threshold)
# Compute reconstructed physics variables from real images
real_reco_vars = compute_physics_variables(test_images)

In [None]:
%%time

# Initially store all results in an epoch list
results = []

# Loop over epoch checkpoints
for i in summaries['epoch']:
    print(str(datetime.now()), 'Epoch', i)
    
    # Dictionary of results for this epoch
    epoch_results = {}
    epoch_results['epoch'] = i
    
    # Load the model
    generator, discriminator = load_model(results_dir, checkpoint_id=i, noise_dim=noise_dim)
    generator.eval(), discriminator.eval()
    
    # Generated images
    test_fake = generator(test_noise)
    gen_images = inverse_transform_data(test_fake.data.numpy().squeeze(1), scale, threshold)
    epoch_results['gen_images'] = gen_images

    # Apply discriminator to real and fake samples
    dis_real = discriminator(test_real)
    dis_fake = discriminator(test_fake)

    # Save the average discriminator scores
    epoch_results['dis_real'] = dis_real.mean().data[0]
    epoch_results['dis_fake'] = dis_fake.mean().data[0]

    # Compute physics variables on the generated images
    fake_reco_vars = compute_physics_variables(gen_images)
    epoch_results.update(fake_reco_vars)
    
    # Compute KS metrics
    for key in fake_reco_vars.keys():
        epoch_results['ks_' + key] = ks_metric(real_reco_vars[key], fake_reco_vars[key])    
    # Pixel intensity KS metric
    epoch_results['ks_pixel'] = ks_metric(test_images.flatten(), gen_images.flatten())
    
    results.append(epoch_results)

# Pack final results into one dict of arrays
final_results = {}
for key in results[0].keys():
    final_results[key] = np.array([r[key] for r in results])
results = final_results

In [None]:
# Sum of neg log ks pvalues metric
combined_keys = ['ks_nJet', 'ks_sumMass', 'ks_jetPt']
results['ks_sum'] = sum([results[k] for k in combined_keys])

In [36]:
# Plot the physics variable metrics
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(9,4))

keys_labels = {
    'ks_pixel': 'pixel',
    'ks_nJet': 'nJet',
    'ks_sumMass': 'sumMass',
    #'ks_jetPt': 'jetPt',
    #'ks_jetEta': 'jetEta',
    #'ks_jetPhi': 'jetPhi',
}
for key, label in keys_labels.items():
    ax0.plot(results[key], label=label)
ax0.set_xlabel('Epoch')
ax0.set_ylabel('-log(KS p-value)')
ax0.set_title('Individual metrics')
ax0.legend(loc=0)

# Plot the summed metric separately
ax1.plot(results['ks_sum'])
ax1.set_xlabel('Epoch')
ax1.set_ylabel('-log(KS p-value)')
ax1.set_title('Combined metric')

plt.tight_layout()

<IPython.core.display.Javascript object>

### Which epochs were the best for certain metrics?

In [37]:
for key in ['ks_nJet', 'ks_pixel', 'ks_sumMass', 'ks_jetPt', 'ks_jetEta', 'ks_jetPhi', 'ks_sum']:
    print(key, results[key].argmin())

ks_nJet 74
ks_pixel 94
ks_sumMass 17
ks_jetPt 10
ks_jetEta 18
ks_jetPhi 28
ks_sum 24


## Plot kinematics for select epochs

In [23]:
def plot_physics(real_pixels, fake_pixels,
                 real_nJet, fake_nJet,
                 real_sumMass, fake_sumMass,
                 real_pt, fake_pt,
                 real_eta, fake_eta,
                 real_phi, fake_phi,
                 figsize=(9, 6), linewidth=1):
    
    fig1, axs = plt.subplots(2, 3, figsize=(9, 6))
    (ax0, ax1, ax2), (ax3, ax4, ax5) = axs
    
    # Pixel energy
    hist_args = dict(bins=100, range=(0, 1e3), log=True, histtype='step', linewidth=linewidth)
    ax0.hist(real_pixels.flatten()*1e-3, label='real', **hist_args)
    ax0.hist(fake_pixels.flatten()*1e-3, label='fake', **hist_args)
    ax0.set_xlabel('Pixel energy [GeV]')
    ax0.legend(loc=0)
    
    # Jet multiplicity
    hist_args = dict(bins=10, range=(0, 10), log=False, histtype='step', linewidth=linewidth)
    ax1.hist(real_nJet, label='real', **hist_args)
    ax1.hist(fake_nJet, label='fake', **hist_args)
    ax1.set_xlabel('Number of jets')
    ax1.legend(loc=0)

    # Sum of jet mass
    hist_args = dict(bins=20, range=(0, 2e3), log=False, histtype='step', linewidth=linewidth)
    ax2.hist(real_sumMass*1e-3, label='real', **hist_args)
    ax2.hist(fake_sumMass*1e-3, label='fake', **hist_args)
    ax2.set_xlabel('Sum of jet mass [GeV]')
    ax2.legend(loc=0)
    
    # Jet PT
    hist_args = dict(bins=20, range=(0, 2e3), histtype='step', linewidth=linewidth, log=False)
    ax3.hist(real_pt*1e-3, label='real', **hist_args)
    ax3.hist(fake_pt*1e-3, label='fake', **hist_args)
    ax3.set_xlabel('Jet $p_T$ [GeV]')
    ax3.legend(loc=0)

    # Jet eta
    hist_args = dict(bins=20, range=(-2, 2), histtype='step', linewidth=linewidth)
    ax4.hist(real_eta, label='real', **hist_args)
    ax4.hist(fake_eta, label='fake', **hist_args)
    ax4.set_xlabel('Jet $\eta$')
    ax4.legend(loc=0)

    # Jet phi
    hist_args = dict(bins=20, range=(-np.pi, np.pi), histtype='step', linewidth=linewidth)
    ax5.hist(real_phi, label='real', **hist_args)
    ax5.hist(fake_phi, label='fake', **hist_args)
    ax5.set_xlabel('Jet $\phi$')
    ax5.legend(loc=0)

    plt.tight_layout()

In [49]:
# Manually including the 'best' ones for the above variables
select_epochs = [0, 10, 17, 18, 24, 28, 74, 94]
for i in select_epochs:
    print('Epoch', i)
    plot_physics(test_images, results['gen_images'][i],
             real_reco_vars['nJet'], results['nJet'][i],
             real_reco_vars['sumMass'], results['sumMass'][i],
             real_reco_vars['jetPt'], results['jetPt'][i],
             real_reco_vars['jetEta'], results['jetEta'][i],
             real_reco_vars['jetPhi'], results['jetPhi'][i])

Epoch 0


<IPython.core.display.Javascript object>

Epoch 10


<IPython.core.display.Javascript object>

Epoch 17


<IPython.core.display.Javascript object>

Epoch 18


<IPython.core.display.Javascript object>

Epoch 24


<IPython.core.display.Javascript object>

Epoch 28


<IPython.core.display.Javascript object>

Epoch 74


<IPython.core.display.Javascript object>

Epoch 94


<IPython.core.display.Javascript object>

## View sample images from a select epochs

In [47]:
# Real images
fig, axs = plt.subplots(2, 3, figsize=(9.5, 5))
for i, ax in enumerate(axs.flatten()):
    draw_image(test_images[i]*1e-3, ax=ax)
plt.tight_layout()

<IPython.core.display.Javascript object>

In [54]:
# Fake images from select epochs
select_epochs = [0, 1, 2, 24, 127]
for epoch in select_epochs:
    print('Epoch', epoch)
    gen_images = results['gen_images'][epoch]

    fig, axs = plt.subplots(2, 3, figsize=(9.5, 5))
    for i, ax in enumerate(axs.flatten()):
        draw_image(gen_images[i]*1e-3, ax=ax)
    plt.tight_layout()

Epoch 0


<IPython.core.display.Javascript object>

Epoch 1


<IPython.core.display.Javascript object>

Epoch 2


<IPython.core.display.Javascript object>

Epoch 24


<IPython.core.display.Javascript object>

Epoch 127


<IPython.core.display.Javascript object>

## Discussion

The KS metric seems to be a reasonable way to pick out the best distributions. And the summed metric can probably be used to pick out the best model epoch, but one has to choose which distributions to include in it.

The GAN clearly struggles to capture certain things, especially the geometric distribution of the jets.

Also, one can observe funny artifacts in the generated images such as common pixel patterns.