In [56]:
import glob
import os
import json
import numpy as np
import torch
import pandas as pd

import utils.plotting as plot_utils

from operator import itemgetter
from copy import deepcopy
import matplotlib.pyplot as plt

In [2]:
# Set directories for loading data and saving plots
from data.data_loaders import load_curtains_pd
from data.physics_datasets import preprocess_method
data_directory = '/srv/beegfs/scratch/groups/dpnc/atlas/CURTAINS/'
# save_directory = '/srv/beegfs/scratch/groups/dpnc/atlas/CURTAINS/images'
save_directory = '/home/users/k/kleins/MLproject/CURTAINS/images/test_images_paper/'
os.makedirs(save_directory, exist_ok=True)


def get_files(directory_wildcard):
    """
    Returns a list of directories under which data can be loaded and a list of dictionaries for each directory in each
    file.
    """
    directories = glob.glob(os.path.join(data_directory, f'{directory_wildcard}*'))
    info = [glob.glob(os.path.join(dd, '*.json'))[0] for dd in directories]
    loaded_info = []
    for exp_info in info:
        with open(exp_info, "r") as file_name:
            json_dict = json.load(file_name)
        loaded_info += [json.loads(json_dict)]
    return directories, loaded_info

In [3]:
CATHODE_directories, CATHODE_info = get_files('fixed_widths_cathode_features_CATHODE_features_')
CURTAINS_directories, CURTAINS_info = get_files('fixed_widths_cathode_features_CATHODE_features_')

In [27]:
%%capture
# Load the data and dope appropriately, load feature type 12 as that encompasses all the others
feature_type = 16
sm = load_curtains_pd(feature_type=feature_type)
sm = sm.sample(frac=1).dropna() 
ad = load_curtains_pd(sm='WZ_allhad_pT', feature_type=feature_type)
ad = ad.sample(frac=1).dropna()

In [None]:
# This dictionary can be used to access the corresponding figures
feature_keys = {3: ['mj1', 'mj2-mj1', '$\\tau_{21}^{j_1}$', '$\\tau_{21}^{j_2}$', '$dR_{jj}$', 'mjj'],
                10: ['mj1', 'mj2-mj1', '$\\tau_{21}^{j_1}$', '$\\tau_{21}^{j_2}$', '$dR_{jj}$', '$p_t^{j_1}$', 
                     '$p_t^{j_2}$', 'mjj'],
                11:['mj1', 'mj2-mj1', '$\\tau_{21}^{j_1}$', '$\\tau_{21}^{j_2}$', '$dR_{jj}$', 
                    '$p_t^{j_1}$', '$p_t^{j_2}$', 'delEta', 'mjj'], 
                12: ['mj1', 'mj2', 'mj2-mj1', '$\\tau_{21}^{j_1}$', '$\\tau_{32}^{j_1}$', '$\\tau_{21}^{j_2}$', 
                     '$\\tau_{32}^{j_2}$', '$p_t^{j_1}$', '$p_t^{j_2}$', '$dR_{jj}$', 'delPhi', 'delEta', 'mjj']}

# Assign the current headers to the correct latex version
proper_keys = {'mjj': r'$m_{JJ}$',
               'mj1': r'$m_{J_1}$',
               'mj2': r'$m_{J_2}$',
               'mj2-mj1': r'$\Delta m_J$', 
               '$\\tau_{21}^{j_1}$': r'$\tau_{21}^{J_1}$', 
               '$\\tau_{21}^{j_2}$': r'$\tau_{21}^{J_2}$', 
               '$dR_{jj}$': r'$\Delta R_{JJ}$',
               '$\\tau_{32}^{j_1}$': r'$\tau_{32}^{J_1}$',
               '$\\tau_{32}^{j_2}$': r'$\tau_{32}^{J_2}$',
               '$p_t^{j_1}$': r'$p_\mathrm{T}^{J_1}$',
               '$p_t^{j_2}$': r'$p_\mathrm{T}^{J_2}$',
               'delPhi': r'$\Delta \phi $',
               'delEta': r'$\Delta \eta$',
              }

In [60]:
from utils.plotting import get_bins, add_hist, add_error_hist, add_contour, add_off_diagonal, get_weights
from utils.torch_utils import shuffle_tensor, tensor2numpy


def getFeaturePlot(original, sampled, nm, savedir, region, feature_names, input_sample=None, nbins=20, contour=True,
                   n_sample_for_plot=-1, summary_writer=None, x_bounds=None, show=False, include_mass=True):
    if x_bounds is None:
        x_bounds = [-1.2, 1.2]
    if n_sample_for_plot > 0:
        original = shuffle_tensor(original)
        sampled = shuffle_tensor(sampled)

    nfeatures = len(feature_names)
    if not include_mass:
        nfeatures -= 1
    fig, axes = plt.subplots(nfeatures, nfeatures, figsize=(2 * nfeatures + 2, 2 * nfeatures + 1),
                             gridspec_kw={'wspace': 0.03, 'hspace': 0.03})
    for i in range(nfeatures):
        if (i != 0) or (not contour):
            axes[i, 0].set_ylabel(feature_names[i])
        else:
            axes[0, 0].set_ylabel('Normalised Entries', horizontalalignment='right', y=1.0)
        for j in range(nfeatures):
            if not contour:
                axes[0, j].set_title(feature_names[j])
            else:
                axes[-1, j].set_xlabel(feature_names[j])
                if i != nfeatures - 1:
                    axes[i, j].tick_params(axis='x', which='both', direction='in', labelbottom=False)

                axes[i, j].set_yticks([-1, 0, 1])
                if i == j == 0:
                    axes[i, j].tick_params(axis='y', colors='w')
                elif j > 0:
                    axes[i, j].tick_params(axis='y', which='both', direction='in', labelleft=False)

            if i == j:
                og = original[:, i]
                bin = get_bins(og[(og > x_bounds[0]) & (og < x_bounds[1])], nbins=nbins)
                add_hist(axes[i, j], tensor2numpy(original[:, i]), bin, 'red', 'Original')
                add_hist(axes[i, j], tensor2numpy(sampled[:, i]), bin, 'blue', 'Transformed')
                add_error_hist(axes[i, j], tensor2numpy(original[:, i]), bins=bin, color='red')
                add_error_hist(axes[i, j], tensor2numpy(sampled[:, i]), bins=bin, color='blue')
                if input_sample is not None:
                    data = tensor2numpy(input_sample[:, i])
                    axes[i, j].hist(data, density=False, bins=bin, histtype='step',
                                    color='black', linestyle='dashed', label='Input Sample', weights=get_weights(data))
                axes[i, j].set_xlim(x_bounds)

            if contour:
                if i > j:
                    add_contour(axes, i, j, original[:n_sample_for_plot], sampled[:n_sample_for_plot],
                                x_bounds=x_bounds)
                elif i < j:
                    axes[i, j].set_visible(False)
            else:
                if i < j:
                    add_off_diagonal(axes, i, j, original, 'Reds')

                if i > j:
                    add_off_diagonal(axes, i, j, sampled, 'Blues')

    fig.suptitle(region)
    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(0.32, 0.89), frameon=False)
    if summary_writer is not None:
        summary_writer.add_figure(tag=f'featurespread_{region}', figure=fig)
    fig.savefig(savedir + '/featurespread_{}_{}_{}.png'.format(region, nm, 'transformed_data'), bbox_inches="tight")
    if show:
        return fig
    else:
        fig.clf()

In [61]:
# Bin the data
def mx_data(data, bins, features):
    context_df = data['mjj']
    mx = (context_df >= bins[0]) & (context_df < bins[1])
    if 'mjj' != features[-1]:
        features = deepcopy(features) + ['mjj']
    return data[features].loc[mx], data.loc[~mx]

# Plot a feature spread
n_sample_for_plot = 1000

to_plot = {'Cathode': [CATHODE_directories, CATHODE_info],
          'Curtains': [CURTAINS_directories, CURTAINS_info]}

for nm in to_plot.keys():

    directory, info  = to_plot[nm]
    for ind in range(len(info)):
    
        info_dict = info[ind]

        # First grab the true data for the signal region
        bins = [int(b) for b in info_dict['bins'].split(',')]
        # Here we only want the SR
        bins = [bins[2], bins[3]]
        doping = info_dict['doping']
        feature_nms = feature_keys[info_dict['feature_type']]

        target_qcd, _ = mx_data(sm, bins, feature_nms)
        ad_to_use = ad[:doping] 
        target_ad, _ = mx_data(ad_to_use, bins, feature_nms)
        target_sample = torch.tensor(pd.concat((target_qcd, target_ad)).sample(frac=1).to_numpy(), dtype=torch.float32)

        samples = np.load(os.path.join(directory[ind], 'SB2_to_SR_samples.npy'))
        samples = torch.tensor(samples, dtype=torch.float32)

        # Apply the data preprocessing to have it plotted on an easy to view range
        ts, facts, _ = preprocess_method(target_sample)
        ss, _, _ = preprocess_method(samples, facts)

        # Set the tags for saving
        plt_nm = f'SB2_to_SR_{nm}'
        tag = f"{plt_nm}_{doping}_{info_dict['feature_type']}"
        x_bounds= [-1.2, 1.2]

        fig = getFeaturePlot(ts, ss, nm, save_directory, tag, itemgetter(*feature_nms)(proper_keys), 
                                        input_sample=None, n_sample_for_plot=n_sample_for_plot, 
                                        x_bounds=x_bounds, show=False)

<Figure size 1296x1224 with 0 Axes>

<Figure size 2016x1944 with 0 Axes>

<Figure size 864x792 with 0 Axes>

<Figure size 1440x1368 with 0 Axes>

<Figure size 1296x1224 with 0 Axes>

<Figure size 1440x1368 with 0 Axes>

<Figure size 2016x1944 with 0 Axes>

<Figure size 864x792 with 0 Axes>

<Figure size 1296x1224 with 0 Axes>

<Figure size 2016x1944 with 0 Axes>

<Figure size 864x792 with 0 Axes>

<Figure size 1440x1368 with 0 Axes>

<Figure size 1296x1224 with 0 Axes>

<Figure size 1440x1368 with 0 Axes>

<Figure size 2016x1944 with 0 Axes>

<Figure size 864x792 with 0 Axes>