In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib.ticker as ticker
from trainlib.ConfigFileHandler import ConfigFileHandler
from trainlib.ConfigFileUtils import ConfigFileUtils

In [2]:
# plots the finally selected input variables, irrespective of their weighting

In [3]:
def convert_model_label(raw):
    retval = raw.replace('D_', '')
    retval = retval.replace('_ML', '')
    retval = retval.replace('_', '-')
    return retval

In [16]:
def make_input_plot(input_file):
    confhandler = ConfigFileHandler()
    confhandler.load_configuration(input_file)
    models = confhandler.get_sections()
    
    df = pd.DataFrame()

    for model in models:
        cur_sect = confhandler.get_section(model)

        used_nonperiodic_vars = filter(None, ConfigFileUtils.parse_list(cur_sect["nonperiodic_columns"], lambda x: x))
        used_periodic_vars = filter(None, ConfigFileUtils.parse_list(cur_sect["periodic_columns"], lambda x: x))

        used_vars = used_nonperiodic_vars + used_periodic_vars
        var_dict = {col: [1.0] for col in used_vars}
        var_dict["model"] = model

        row_df = pd.DataFrame.from_dict(var_dict)

        df = pd.concat([df, row_df], axis = 0)

    df = df.fillna(0.0)
    
    datacols = [col for col in df.columns if col is not "model"]
    plot_data = df[datacols].as_matrix()
    
    y_label = np.array(datacols)
    x_label = [convert_model_label(label) for label in df["model"].as_matrix()]
    
    fig = plt.figure(figsize = (12, 10))
    
    ax = fig.add_subplot(111)
    
    cax = ax.matshow(plot_data.transpose(), cmap = 'Blues', vmin = 0, vmax = 1)
    ax.set_xticklabels(np.concatenate([[''], x_label]), rotation = 'vertical')
    ax.set_yticklabels(np.concatenate([[''], y_label]))
    ax.xaxis.set_label_position("top")
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.tight_layout()
    
    return fig

In [80]:
def make_fscore_plot_combined(input_file_inclusive, input_file_exclusive):
    df_exclusive = pd.DataFrame.from_csv(input_file_exclusive)
    df_inclusive = pd.DataFrame.from_csv(input_file_inclusive)
    
    cols_inclusive = [col for col in df_inclusive["discriminant"] if "D_VBF_ggH" not in col]
    cols_exclusive = ["D_VBF_ggH_2j_ML", "D_VBF_ggH_1j_ML", "D_VBF_ggH_0j_ML"]
    
    df_exclusive = df_exclusive.loc[df_exclusive["discriminant"].isin(cols_exclusive)]
    df_inclusive = df_inclusive.loc[df_inclusive["discriminant"].isin(cols_inclusive)]
    
    df = pd.concat([df_exclusive, df_inclusive])
    
    # now plot the data contained in the table to have a global picture of the relevant input variables
    datacol_labels = [col for col in df.columns.tolist() if col != "discriminant"]
    variable_data = df[datacol_labels].as_matrix().transpose()
    datacol_labels = np.concatenate([[''], np.array(datacol_labels)])
    
    discriminant_labels = [convert_model_label(col) for col in df["discriminant"]]
    
    discriminant_labels = np.concatenate([[''], discriminant_labels])
    
    fig = plt.figure(figsize = (12, 10))
    ax = fig.add_subplot(111)
    cax = ax.matshow(variable_data, interpolation = 'nearest', cmap = 'Blues', vmin = np.min(variable_data), vmax = np.max(variable_data))
    ax.set_xticklabels(discriminant_labels, rotation = 'vertical')
    ax.set_yticklabels(datacol_labels)
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.tight_layout()
    
    return fig

In [20]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/inclusive_99.conf")

In [21]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/inclusive_variables_selected.pdf")

In [22]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/exclusive_99.conf")

In [23]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/exclusive_variables_selected.pdf")

In [24]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/inclusive_99_fullmassrange.conf")

In [25]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/inclusive_variables_selected_fullmassrange.pdf")

In [26]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/exclusive_99_fullmassrange.conf")

In [27]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/exclusive_variables_selected_fullmassrange.pdf")

In [28]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/combined_99.conf")

In [29]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/combined_variables_selected.pdf")

In [30]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/combined_99_fullmassrange.conf")

In [31]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/combined_variables_selected_fullmassrange.pdf")

In [17]:
%%capture
fig = make_input_plot("/data_CMS/cms/wind/InputConfigurations/combined_99_fullmassrange_ZZMask.conf")

In [18]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/combined_variables_selected_fullmassrange_ZZMask.pdf")

In [81]:
%%capture
fig = make_fscore_plot_combined("/data_CMS/cms/wind/InputConfigurations/input_parameters_table_inclusive_ZZMask.csv",
                               "/data_CMS/cms/wind/InputConfigurations/input_parameters_table_ZZMask.csv")

In [82]:
plt.savefig("/data_CMS/cms/wind/InputConfigurations/combined_variables_fullmassrange_ZZMask.pdf")