# Load modules

In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns

from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from keras.layers import Input, Dense, Dropout
from keras.models import Model
from keras.optimizers import adadelta

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MaxAbsScaler

from scipy.stats import percentileofscore

from livelossplot import PlotLossesKeras

from helper import ae_errors

In [None]:
from IPython.core.display import HTML
style = """
<style>
div.output_area {
    overflow-y: scroll;
}
div.output_area img {
    max-width: unset;
}
</style>
"""

def make_cell_scrollable():
    HTML(style)
    
from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))

# Input preparation

In [None]:
%%time 
df_orig = pd.read_csv('data/trending_merged_LHC18q_withGraphs.csv')

target_col = 'alias_global_Warning'
#----------

df = df_orig[[c for c in df_orig.columns if 
              ('gr' not in c and 'alias' not in c and 'Unnamed' not in c)
              and c != 'dataType.fString'
              or c == target_col
             ]]
rename = lambda c: c if c != target_col else 'bad'
df.columns = [rename(c) for c in df.columns]

In [None]:
df.columns.tolist()
nonphysical_cols = ['run', 'chunkID', 'time', 
                   'year', 'period.fString', 'pass.fString', 'runType.fString', 
                   'startTimeGRP', 'stopTimeGRP', 'duration', 
                   'iroc_A_side', 'oroc_A_side', 'iroc_C_side', 'oroc_C_side',
                   'chunkStart', 'chunkStop', 'chunkMean', 'chunkMedian', 'chunkRMS', 'chunkDuration']

no_variance_cols = df.std()[(df.std() < 1e-6).tolist()].index.tolist()
cols_exclude = nonphysical_cols + no_variance_cols

for c in df.columns:
    if c not in cols_exclude:
        print(c)

In [None]:
input_data = df[[c for c in df.columns if c not in cols_exclude]].drop('bad', axis=1)

for i in range(5):
    input_data[f'random{i}'] = np.random.randn(len(input_data))

x = input_data.to_numpy()
y = df['bad'].to_numpy()

x_test_bad = x[y == 1]
x_train_val_good, x_test_good = train_test_split(x[y == 0], test_size=0.1)  
x_train, x_val = train_test_split(x_train_val_good, test_size=0.1)  # x_val are GOOD samples used to monitor overfitting

scaler = StandardScaler()
# scaler = MaxAbsScaler()
scaler.fit(x_train)
x_train     = scaler.transform(x_train)
x_val       = scaler.transform(x_val)
x_test_good = scaler.transform(x_test_good)
x_test_bad  = scaler.transform(x_test_bad)
x_all       = scaler.transform(x)

## InteractionRate

In [None]:
jitter_y = 1500
#-------------
plt.figure(figsize=(25,5))
plt.plot(df['chunkMean'], df['interactionRate']+np.random.random(len(df))*jitter_y, '.', ms=3)
# plt.plot(df.query('bad == 0')['chunkMean'], df.query('bad == 0')['interactionRate']+np.random.random(len(df.query('bad == 0')))*jitter_y, '.', ms=2, color='b')
# plt.plot(df.query('bad == 1')['chunkMean'], df.query('bad == 1')['interactionRate']+np.random.random(len(df.query('bad == 1')))*jitter_y, '.', ms=8, marker='x', color='r')
plt.xlabel('chunk mean time');
plt.ylabel('interactionRate');

In [None]:
# plt.hist(df.query('bad == 0')['interactionRate'], bins=20, histtype='step', color='b', density=1)
# plt.hist(df.query('bad == 1')['interactionRate'], bins=20, histtype='step', color='r', density=1)

plt.hist(df['interactionRate'], bins=50, histtype='step', color='b', density=1);

# Model training

In [None]:
# this is our input placeholder
input_size = x_train.shape[1]
coding_layers_sizes = [64,32]
bottleneck_size = 16

ae_input = Input(shape=(input_size,))
encoded = Dense(coding_layers_sizes[0], activation='relu')(ae_input)
for lsize in coding_layers_sizes[1:]:
    encoded = Dense(lsize, activation='relu')(encoded)
#     encoded = Dropout(0.2)(encoded)
encoded = Dense(bottleneck_size, activation='relu')(encoded)
    
# encoded = Dense(8, activation='relu')(encoded)
# encoded = Dense(4, activation='relu')(encoded)
# encoded = Dense(8, activation='relu')(encoded)

# decoded = Dense(input_size, activation='linear')(encoded)

decoded = Dense(coding_layers_sizes[-1], activation='relu')(encoded)
for lsize in reversed(coding_layers_sizes[:-1]):
    decoded = Dense(lsize, activation='relu')(decoded)
# decoded = Dense(32, activation='relu')(decoded)
# decoded = Dense(16, activation='relu')(decoded)
decoded = Dense(input_size, activation='linear')(decoded)

# this model maps an input to its reconstruction
autoencoder = Model(ae_input, decoded)

In [None]:
autoencoder = Model(ae_input, decoded)
autoencoder.compile(optimizer=adadelta(lr=0.2), loss='mean_squared_error')

fit = autoencoder.fit(x_train, x_train, 
                epochs=20,
                batch_size=32,
                verbose=2,
                shuffle=True,
                validation_data=(x_val, x_val),
                callbacks=[PlotLossesKeras()])
PlotLossesKeras()

In [None]:
loss = fit.history['loss']
val_loss = fit.history['val_loss']
epochs = fit.epoch

plt.plot(epochs, loss, 'bx--', label='train loss', color='blue')
plt.plot(epochs, val_loss, 'rx--', label='val loss', color='green')
plt.legend()
plt.show()

In [None]:
fit.model.summary()

# Compute predictions and errors

In [None]:
from sklearn.metrics import mean_squared_error

x_pred_train     = autoencoder.predict(x_train)
x_pred_val       = autoencoder.predict(x_val)
x_pred_test_good = autoencoder.predict(x_test_good)
x_pred_test_bad  = autoencoder.predict(x_test_bad)
x_pred_all       = autoencoder.predict(x_all)

mse_train     = mean_squared_error(x_train, x_pred_train)
mse_val       = mean_squared_error(x_val, x_pred_val)
mse_test_good = mean_squared_error(x_test_good, x_pred_test_good)
mse_test_bad  = mean_squared_error(x_test_bad, x_pred_test_bad)
mse_all       = mean_squared_error(x_all, x_pred_all)

In [None]:
print(f'MSE:\n\t all = {mse_all:.3f}\n\t {"-"*10}\n\t train = {mse_train:.3f}\n\t val = {mse_val:.3f}\n\t test_good = {mse_test_good:.3f}\n\t test_bad = {mse_test_bad:.3f}')

In [None]:
mse_distr_train     = ((x_train - x_pred_train)**2).mean(axis=1)
mse_distr_val       = ((x_val - x_pred_val)**2).mean(axis=1)
mse_distr_test_good = ((x_test_good - x_pred_test_good)**2).mean(axis=1)
mse_distr_test_bad  = ((x_test_bad - x_pred_test_bad)**2).mean(axis=1)
mse_distr           = ((x_all - x_pred_all)**2).mean(axis=1)

# plot histos
bins = np.linspace(np.quantile(np.log10(mse_distr), 0), np.quantile(np.log10(mse_distr), 1), 20)
plt.hist(np.log10(mse_distr_train), bins=bins, density=1, lw=2, ls='-.', histtype='step', label='train', color='y')
plt.hist(np.log10(mse_distr_test_good), bins=bins, density=1, lw=2, histtype='step', label='test good', color='blue')
plt.hist(np.log10(mse_distr_test_bad),  bins=bins, density=1, lw=2, histtype='step', label='test bad', color='red')
plt.legend()
plt.xlabel('log (MSE)');

**MSE by column:**

In [None]:
mse_columns_train = ((x_train - x_pred_train)**2).mean(axis=0)
mse_columns_val = ((x_val - x_pred_val)**2).mean(axis=0)
mse_columns_test_good = ((x_test_good - x_pred_test_good)**2).mean(axis=0)
mse_columns_test_bad = ((x_test_bad - x_pred_test_bad)**2).mean(axis=0)
mse_columns_all = ((x_all - x_pred_all)**2).mean(axis=0)

for i_c, (c, train, test_g, test_b) in enumerate(zip(input_data.columns, mse_columns_train, mse_columns_test_good, mse_columns_test_bad)):
    print(f'{i_c:3.0f}. {c:<30s}: {train:.3f}, \t {test_g:.3f}, {test_b:6.3f}, \t {test_b/test_g:.2f}')


# Visualization

In [None]:
ae_errors.plot_AE_error(mse_columns_all, input_data.columns);

## NEW

In [None]:
isinstance(list(mse_columns_all), list)

In [None]:
ae_errors.plot_AE_error([mse_columns_train, mse_columns_val, mse_columns_test_good, mse_columns_test_bad], 
                        ylabels=[   'train',           'val',           'test_good',           'test_bad'],
                        columns=input_data.columns);

## Single instance

In [None]:
for index in df.index:
    if df.iloc[index]['bad'] == 1: continue
    mse_instance = (x_all[index,:]-x_pred_all[index,:])**2 
    log_mse = np.log10(mse_instance.mean())
    arrow = '\t\t<------' if log_mse > 0.5 else ''
    print(f'{index:5d}: {log_mse:7.4f} {arrow}')

In [None]:
instance_index = 145

row_orig = df_orig.iloc[instance_index]
row = x_all[instance_index,:]
global_warning_flag = df_orig['alias_global_Warning'].iloc[instance_index]
mse_instance_number = mse_distr[instance_index]
mse_percentile = percentileofscore(mse_distr, mse_instance_number)

status_str =  f"chunk {instance_index} [ {row_orig['period.fString']} / {row_orig['run']} / chunk {row_orig['chunkID']} ]:  \n - _globalWarning_ flag set to: **{bool(global_warning_flag)}**  \n  MSE = **{mse_instance_number:.3f}**  \n  log(MSE) = **{np.log10(mse_instance_number):.3f}**"
printmd(status_str)

In [None]:
mse_instance = (x_all[instance_index,:]-x_pred_all[instance_index,:])**2 

bins = np.linspace(np.quantile(np.log10(mse_distr), 0), np.quantile(np.log10(mse_distr), 1), 20)

# plt.hist(np.log10(mse_distr_test_good), bins=bins, density=1, lw=2, alpha=0.1, color='red')
# plt.hist(np.log10(mse_distr_test_bad), bins=bins, density=1, lw=2,  alpha=0.1, color='k')
# plt.hist(np.log10(mse_distr_train),     bins=bins, density=1, lw=2, histtype='step', label='train', color='blue')
# plt.hist(np.log10(mse_distr_val),       bins=bins, density=1, lw=2, histtype='step', label='val', color='c')
plt.hist(np.log10(mse_distr_test_good), bins=bins, density=1, lw=2, histtype='step', label='test good', color='blue')
plt.hist(np.log10(mse_distr_test_bad),  bins=bins, density=1, lw=2, histtype='step', label='test bad', color='red')
plt.legend()
plt.xlabel('log (MSE)');

xrange = plt.xlim()[1] - plt.xlim()[0]
yrange = plt.ylim()[1] - plt.ylim()[0]
plt.arrow(np.log10(mse_instance.mean()), yrange*0.95, 0, -0.2*yrange, 
            width=0.01*xrange, 
#             length_includes_head=True, head_length=0.1*yrange, head_width=0.02*xrange,
                fc='k')
plt.text(np.log10(mse_instance.mean())-0.1, yrange*0.9, f'{mse_instance.mean():.2f}', horizontalalignment='right', fontdict=dict(fontsize=14));

In [None]:
mse_instance_relative = mse_instance / mse_columns_all

ae_errors.plot_AE_error([mse_instance, mse_instance_relative], 
                        ylabels=[f'squared errors\ninstance={instance_index}', 'squared errors relative\nto aver. (column) error'],
                        columns=input_data.columns);

In [None]:
from IPython.display import Markdown as md
available_plots = [('Event Information', 'TPC_event_info.png'), ('Cluster Occupancy', 'cluster_occupancy.png'), ('#eta, #phi and pt', 'eta_phi_pt.png'), ('Number of clusters in #eta and #phi', 'cluster_in_detail.png'), ('DCAs vs #eta', 'dca_in_detail.png'), ('TPC dEdx', 'TPC_dEdx_track_info.png'), ('DCAs vs #phi', 'dca_and_phi.png'), ('TPC-ITS matching', 'TPC-ITS.png'), ('dcar vs pT', 'dcar_pT.png'), ('Tracking parameter phi', 'pullPhiConstrain.png'), ('Raw QA Information', 'rawQAInformation.png'), ('Canvas ROC Status OCDB', 'canvasROCStatusOCDB.png'), ('Resolution vs pT and 1/pT', 'res_pT_1overpT.png'), ('Efficiency all charged + findable', 'eff_all+all_findable.png'), ('Efficiency #pi, K, p', 'eff_Pi_K_P.png'), ('Efficiency findable #pi, K, p', 'eff_Pi_K_P_findable.png')]
file_names_mapping = dict(available_plots)

row_orig = df_orig.iloc[instance_index]
path = '/'.join([str(el) for el in row_orig[['year', 'period.fString', 'pass.fString', 'run']].to_list()])
path = path.replace('pass1/', 'pass1/000')

def show_qa_plot(plot_name, path=path, file_names_mapping=file_names_mapping):
    plot_file_name = file_names_mapping[plot_name]
    src = f"http://aliqatpceos.web.cern.ch/aliqatpceos/data/{path}/{plot_file_name}"
    html = f'<img src={src} width="1200" height="1200">'
    print(src)
    return md(html)
    

interact(show_qa_plot, plot_name=file_names_mapping.keys(), 
         path=fixed(path), file_names_mapping=fixed(file_names_mapping));

# Permutation Importance

In [None]:
%%time

def permutation_importances_custom(score_func, X):
    baseline = score_func(X)
    imp = []
    imp_abs = []
    for col in X.columns:
        save = X[col].copy()
        X[col] = np.random.permutation(X[col])
        m = score_func(X)
        X[col] = save
#         imp.append( np.mean(baseline - m) )
#         imp_abs.append( np.mean(abs(baseline - m)) )
        imp.append(m - baseline)
#     return np.array(imp), np.array(imp_abs), X.columns.to_numpy()
    return np.array(imp), X.columns.to_numpy()



def score(X):
    X_pred = autoencoder.predict(X)
    return mean_squared_error(X, X_pred)
    

X = pd.DataFrame(x_all, columns=input_data.columns)

fimps_multirun = []
for i in range(3):
    fimps, fnames = permutation_importances_custom(score, X)
    fimps_multirun.append(fimps)
fimps_multirun = np.array(fimps_multirun)
fimps_means = fimps_multirun.mean(axis=0)
fimps_stds  = fimps_multirun.std(axis=0)

In [None]:
idx = np.argsort(fimps_means)[::-1]
for name, imp_mean, imp_std in zip(fnames[idx], fimps_means[idx], fimps_stds[idx]): 
    print(f'{name:>25s}:  {imp_mean*100:>10.5f} +/- {imp_std*100:>6.5f}')

In [None]:
plt.hist(fimps_means, 30, histtype='step', lw=2);

In [None]:
ae_errors.plot_AE_error(fimps_means, 
                        ylabels='feat. importances',
                        columns=input_data.columns);

# TODO:

1. train **basic** AE (feature scaling, training only on good time intervals? ) - DONE
2. try to viz. it - ?
3. check dependence: 
    - on overall performance w/ and w/o _bad_ timeIntervals, 
    - performance on _bad_ and _good_ timeIntervals - DONE
__________________
4. Compare with AE trained on both _good_ and _bad_
5. Check correlations of MSE of columns
6. Try to viz. columns sq. errors as a ratio to aver. sq. error of this column (of train / test_all / test_bad) - DONE
    - double barplot - upper just sq. errors, lower - divided by column aver.

* SHARE GROUPING FUNC. WITH WARSAW

____________________
As a func. of IRate:  
IRate binning: 0-2-4-6-7-8k OR  0-4-7-8k
1. Compare AE's scores to: 1) whole distr. (of course) 2) distr. of similar (in terms of IRate) chunks
2. Add third barplot - divided by column aver amoung similar chunks

____________________
Find justification for bad chunks of being rejected!  
Then also look at apropriate QA control plots and see what can be wrong
____________________
Train AE without matching eff. and use it as flags