In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import pathlib
import numpy as np
import pandas as pd

seed_ = 42

dim = 16
directory = pathlib.Path("../events/MG3")

df_3b = pd.read_hdf(directory / "dataframes" / "symmetrized_bbbj.h5")
df_bg4b = pd.read_hdf(directory / "dataframes" / "symmetrized_bbbb_large.h5")
df_hh4b = pd.read_hdf(directory / "dataframes" / "symmetrized_HH4b.h5")

df_3b["signal"] = False
df_bg4b["signal"] = False
df_hh4b["signal"] = True

print("3b-jet events: ", len(df_3b))
print("4b-jet events: ", len(df_bg4b))
print("HH4b-jet events: ", len(df_hh4b))

# shuffle the data
df_3b = df_3b.sample(frac=1, random_state=seed_).reset_index(drop=True)
df_bg4b = df_bg4b.sample(frac=1, random_state=seed_).reset_index(drop=True)
df_hh4b = df_hh4b.sample(frac=1, random_state=seed_).reset_index(drop=True)

3b-jet events:  275508
4b-jet events:  382108
HH4b-jet events:  28656


In [3]:
n_3b = 250000
n_all4b = 250000
# signal_ratio = 0.05

# for signal_ratio in [0, 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1]:
for signal_ratio in [0.03, 0.04, 0.06, 0.07, 0.08, 0.09]:
    ###########################################################################################
    ###########################################################################################

    import pytorch_lightning as pl
    from torch.utils.data import TensorDataset

    np.random.seed(seed_)

    features = ["sym_canJet0_pt", "sym_canJet1_pt", "sym_canJet2_pt", "sym_canJet3_pt",
                "sym_canJet0_eta", "sym_canJet1_eta", "sym_canJet2_eta", "sym_canJet3_eta",
                "sym_canJet0_phi", "sym_canJet1_phi", "sym_canJet2_phi", "sym_canJet3_phi",
                "sym_canJet0_m", "sym_canJet1_m", "sym_canJet2_m", "sym_canJet3_m"]

    pl.seed_everything(seed_)
    np.random.seed(seed_)

    test_ratio = 0.5

    n_3b_train = int(n_3b * (1 - test_ratio))
    n_all4b_train = int(n_all4b * (1 - test_ratio))
    n_bg4b_train = n_all4b_train - int(n_all4b_train * signal_ratio)
    n_hh4b_train = int(n_all4b_train * signal_ratio)


    df_3b_train = df_3b.iloc[:n_3b_train]
    df_bg4b_train = df_bg4b.iloc[:n_bg4b_train]
    df_hh4b_train = df_hh4b.iloc[:n_hh4b_train]
    # reweight to match signal_ratio
    df_hh4b_train.loc[:, "weight"] = (signal_ratio / (1 - signal_ratio)) * (
                                np.sum(df_bg4b_train["weight"]) / np.sum(df_hh4b_train["weight"])) * df_hh4b_train["weight"]
    df_train = pd.concat([df_3b_train, df_bg4b_train, df_hh4b_train])
    # shuffle the data
    df_train = df_train.sample(frac=1, random_state=seed_).reset_index(drop=True)

    n_3b_test = n_3b - n_3b_train
    n_all4b_test = n_all4b - n_all4b_train
    n_bg4b_test = n_all4b_test - int(n_all4b_test * signal_ratio)
    n_hh4b_test = int(n_all4b_test * signal_ratio)

    df_3b_test = df_3b.iloc[n_3b_train:n_3b_train+n_3b_test]
    df_bg4b_test = df_bg4b.iloc[n_bg4b_train:n_bg4b_train+n_bg4b_test]
    df_hh4b_test = df_hh4b.iloc[n_hh4b_train:n_hh4b_train+n_hh4b_test]
    df_hh4b_test.loc[:, "weight"] = (signal_ratio / (1 - signal_ratio)) * (
                                np.sum(df_bg4b_test["weight"]) / np.sum(df_hh4b_test["weight"])) * df_hh4b_test["weight"]
    df_test = pd.concat([df_3b_test, df_bg4b_test, df_hh4b_test])
    df_test = df_test.sample(frac=1, random_state=seed_).reset_index(drop=True)

    # reduce number of 4b samples to 1/8
    print("4b ratio: ", df_train.loc[df_train["fourTag"], "weight"].sum() / df_train["weight"].sum())
    print("Signal ratio: ", df_train.loc[df_train["signal"], "weight"].sum() / df_train.loc[df_train["fourTag"], "weight"].sum())

    # For ghostbatch, let len(train_indices) be a multiple of 32
    split_at = 1024 * (int((2/3) * df_train.index.size) // 1024)
    end_at = 1024 * (df_train.index.size // 1024)

    X_train = torch.tensor(df_train[features].values, dtype=torch.float32)[:split_at]
    w_train = torch.tensor(df_train["weight"].values, dtype=torch.float32)[:split_at]
    y_train = torch.tensor(df_train["fourTag"].values, dtype=torch.long)[:split_at]
    is_signal_train = torch.tensor(df_train["signal"].values, dtype=torch.long)[:split_at]

    X_val = torch.tensor(df_train[features].values, dtype=torch.float32)[split_at:end_at]
    w_val = torch.tensor(df_train["weight"].values, dtype=torch.float32)[split_at:end_at]
    y_val = torch.tensor(df_train["fourTag"].values, dtype=torch.long)[split_at:end_at]
    is_signal_val = torch.tensor(df_train["signal"].values, dtype=torch.long)[split_at:end_at]

    train_dataset = TensorDataset(X_train, y_train, w_train)
    val_dataset = TensorDataset(X_val, y_val, w_val)

    ###########################################################################################
    ###########################################################################################

    from fvt_classifier import FvTClassifier

    num_classes = 2
    dim_input_jet_features = 4
    dim_dijet_features = 6
    dim_quadjet_features = 6
    max_epochs = 30
    run_name = "_".join(["fvt_classifier_toy_signal_ratio", 
                        f"signal_ratio={signal_ratio}", 
                        f"dijet={dim_dijet_features}", 
                        f"quadjet={dim_quadjet_features}", 
                        f"n_3b={n_3b}",
                        f"n_all4b={n_all4b}",])
    lr = 1e-3

    pl.seed_everything(42)

    model = FvTClassifier(num_classes, 
                        dim_input_jet_features, 
                        dim_dijet_features, 
                        dim_quadjet_features, 
                        run_name=run_name,
                        device=torch.device("cuda:0"),
                        lr=lr)

    model.fit(train_dataset, val_dataset, batch_size=1024, max_epochs=max_epochs)


Seed set to 42


4b ratio:  0.4989991
Signal ratio:  0.03


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
-----------------------------

Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 49.61it/s, v_num=0, val_loss=0.692, train_loss=0.691]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 49.52it/s, v_num=0, val_loss=0.692, train_loss=0.691]


Seed set to 42
Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
----------------------------------------
895       Trainable params
49        Non-trainable params
944       Total params
0.004     Total estimated model params size (MB)


4b ratio:  0.49900398
Signal ratio:  0.04
Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 63.50it/s, v_num=0, val_loss=0.692, train_loss=0.691]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 63.39it/s, v_num=0, val_loss=0.692, train_loss=0.691]


Seed set to 42


4b ratio:  0.4990017
Signal ratio:  0.060000002


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
----------------------------------------
895       Trainable params
49        Non-trainable params
944       Total params
0.004     Total estimated model params size (MB)


Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 60.98it/s, v_num=0, val_loss=0.692, train_loss=0.691]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 60.27it/s, v_num=0, val_loss=0.692, train_loss=0.691]


Seed set to 42


4b ratio:  0.4989975
Signal ratio:  0.07000001


Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
----------------------------------------
895       Trainable params
49        Non-trainable params
944       Total params
0.004     Total estimated model params size (MB)


Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 51.39it/s, v_num=0, val_loss=0.691, train_loss=0.691]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 51.30it/s, v_num=0, val_loss=0.691, train_loss=0.691]

Seed set to 42





Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
----------------------------------------
895       Trainable params
49        Non-trainable params
944       Total params
0.004     Total estimated model params size (MB)


4b ratio:  0.49899468
Signal ratio:  0.07999999
Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 51.33it/s, v_num=0, val_loss=0.691, train_loss=0.690]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:03<00:00, 51.01it/s, v_num=0, val_loss=0.691, train_loss=0.690]


Seed set to 42
Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/soheuny/miniconda3/envs/coffea_torch/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /home/soheuny/HH4bsim/playground/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type       | Params
----------------------------------------
0 | encoder  | FvTEncoder | 920   
1 | select_q | conv1d     | 8     
2 | out      | conv1d     | 16    
----------------------------------------
895       Trainable params
49        Non-trainable params
944       Total params
0.004     Total estimated model params size (MB)


4b ratio:  0.49899265
Signal ratio:  0.08999999
Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 55.49it/s, v_num=0, val_loss=0.690, train_loss=0.689]

`Trainer.fit` stopped: `max_epochs=30` reached.


Epoch 29: 100%|██████████| 162/162 [00:02<00:00, 55.12it/s, v_num=0, val_loss=0.690, train_loss=0.689]


In [None]:
## SvB

# pl.seed_everything(42)

# num_classes = 2
# dim_input_jet_features = 4
# dim_dijet_features = 6
# dim_quadjet_features = 6
# max_epochs = 80
# svb_run_name = "_".join(["svb_classifier_toy_signal_ratio", 
#                     f"signal_ratio={signal_ratio}", 
#                     f"dijet={dim_dijet_features}", 
#                     f"quadjet={dim_quadjet_features}", 
#                     f"n_3b={n_3b}",
#                     f"n_all4b={n_all4b}",])
# lr = 1e-3


# model = FvTClassifier(num_classes, 
#                        dim_input_jet_features, 
#                        dim_dijet_features, 
#                        dim_quadjet_features, 
#                        run_name=svb_run_name,
#                        device=torch.device("cuda:0"),
#                        lr=lr)


# svb_train_dataset = TensorDataset(X_train, is_signal_train, w_train)
# svb_val_dataset = TensorDataset(X_val, is_signal_val, w_val)

# model.fit(svb_train_dataset, svb_val_dataset, batch_size=1024, max_epochs=max_epochs)

In [None]:
from torch.utils.data import DataLoader


# model = FvTClassifier.load_from_checkpoint(f"./checkpoints/{svb_run_name}_best.ckpt")
model = FvTClassifier.load_from_checkpoint(f"./checkpoints/{run_name}_best.ckpt")
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
device = torch.device("cuda:0")
model = model.to(device)

q_repr_val = []

for batch in val_loader:
    x, y, w = batch
    x = x.to(device)
    q = model.encoder(x)
    q_repr_val.append(q.detach().cpu().numpy())

q_repr_val = np.concatenate(q_repr_val, axis=0)
labels_4b_val = y_val.cpu().numpy()
probs_4b_val = model.predict(X_val)[:, 1].cpu().numpy()
weights_val = w_val.cpu().numpy()
is_signal_val = is_signal_val.cpu().numpy()

In [None]:
from plots import plot_prob_weighted_histogram1d, calibration_plot
%matplotlib inline
plot_prob_weighted_histogram1d(probs_4b_val, probs_4b_val, labels_4b_val,
                               n_bins=50, 
                               sample_weights=weights_val,
                               ylim=(0.5, 1.5))
calibration_plot(probs_4b_val, labels_4b_val,
                 bins=50,
                 sample_weights=weights_val)

# plot_prob_weighted_histogram1d(probs_4b_val, probs_4b_val, is_signal_val,
#                                n_bins=50, 
#                                sample_weights=weights_val,
#                                ylim=(0.5, 1.5))
# calibration_plot(probs_4b_val, is_signal_val,
#                  bins=50,
#                  sample_weights=weights_val)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 3.5 * 2 * dim_quadjet_features))
outer = gridspec.GridSpec(2*dim_quadjet_features, 3, hspace=1.5, wspace=0.3)

is_3b_val = labels_4b_val == 0
is_4b_val = labels_4b_val == 1
is_bg4b_val = (labels_4b_val == 1) & (is_signal_val == 0)
is_hh4b_val = (labels_4b_val == 1) & (is_signal_val == 1)
plot_density = True

for i in range(dim_quadjet_features):
    for j in range(3):
        
        repr_min, repr_max = np.min(q_repr_val[:, i, j]), np.max(q_repr_val[:, i, j])
        bins_range = np.linspace(repr_min, repr_max, 50)
        
        inner = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=outer[2*i:2*(i+1), j], hspace=0.2, height_ratios=[1, 1, 1, 1])
        current_ax = plt.Subplot(fig, inner[0])
        current_ax.hist(q_repr_val[is_3b_val, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=True, weights=weights_val[is_3b_val])
        current_ax.hist(q_repr_val[is_bg4b_val, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=True, weights=weights_val[is_bg4b_val])
        current_ax.hist(q_repr_val[is_hh4b_val, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=True, weights=weights_val[is_hh4b_val])
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[1])

        current_ax.hist(q_repr_val[is_3b_val, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_val[is_3b_val])
        current_ax.hist(q_repr_val[is_bg4b_val, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_bg4b_val])
        current_ax.hist(q_repr_val[is_hh4b_val, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_hh4b_val])
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        # current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[2])
        current_ax.hist(q_repr_val[is_3b_val, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=True, weights=weights_val[is_3b_val])
        current_ax.hist(q_repr_val[is_4b_val, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=True, weights=weights_val[is_4b_val])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()

        fig.add_subplot(current_ax)
        current_ax = plt.Subplot(fig, inner[3])
        current_ax.hist(q_repr_val[is_3b_val, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=False, weights=weights_val[is_3b_val])
        current_ax.hist(q_repr_val[is_4b_val, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_4b_val])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.legend()
        fig.add_subplot(current_ax)


plt.show()
plt.close()

bins_range = np.linspace(0, 1, 50)

fig, ax = plt.subplots()
ax.hist(probs_4b_val[is_3b_val], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_val[is_3b_val])
ax.hist(probs_4b_val[is_bg4b_val], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_bg4b_val])
ax.hist(probs_4b_val[is_hh4b_val], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_hh4b_val])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()

fig, ax = plt.subplots()
ax.hist(probs_4b_val[is_3b_val], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_val[is_3b_val])
ax.hist(probs_4b_val[is_4b_val], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_val[is_4b_val])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()
plt.close()


# Test dataset (Not validation dataset)

In [None]:
print("Test Data")

from torch.utils.data import DataLoader

end_at = 1024 * (df_test.index.size // 1024)

X_test = torch.tensor(df_test[features].values[:end_at], dtype=torch.float32)
w_test = torch.tensor(df_test["weight"].values[:end_at], dtype=torch.float32)
y_test = torch.tensor(df_test["fourTag"].values[:end_at], dtype=torch.long)
is_signal_test = torch.tensor(df_test["signal"].values[:end_at], dtype=torch.long)
svb_test_dataset = TensorDataset(X_test, y_test, w_test)

fvt_model = FvTClassifier.load_from_checkpoint(f"./checkpoints/{run_name}_best.ckpt")
test_loader = DataLoader(svb_test_dataset, batch_size=1024, shuffle=False)
device = torch.device("cuda:0")
fvt_model = fvt_model.to(device)

q_repr_test = []

for batch in test_loader:
    x, y, w = batch
    x = x.to(device)
    q = fvt_model.encoder(x)
    q_repr_test.append(q.detach().cpu().numpy())

q_repr_test = np.concatenate(q_repr_test, axis=0)
labels_4b_test = y_test.cpu().numpy()
probs_4b_test = fvt_model.predict(X_test)[:, 1].cpu().numpy()
weights_test = w_test.cpu().numpy()
is_signal_test = is_signal_test.cpu().numpy()

In [None]:
print("**Test Data**")


plot_prob_weighted_histogram1d(probs_4b_test, probs_4b_test, labels_4b_test,
                               n_bins=50, 
                               sample_weights=weights_test,
                               ylim=(0.5, 1.5))
calibration_plot(probs_4b_test, labels_4b_test,
                 bins=50,
                 sample_weights=weights_test)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 3.5 * 2 * dim_quadjet_features))
outer = gridspec.GridSpec(2*dim_quadjet_features, 3, hspace=1.5, wspace=0.3)

is_3b_test = labels_4b_test == 0
is_4b_test = labels_4b_test == 1
is_bg4b_test = (labels_4b_test == 1) & (is_signal_test == 0)
is_hh4b_test = (labels_4b_test == 1) & (is_signal_test == 1)
plot_density = True

for i in range(dim_quadjet_features):
    for j in range(3):
        
        repr_min, repr_max = np.min(q_repr_test[:, i, j]), np.max(q_repr_test[:, i, j])
        bins_range = np.linspace(repr_min, repr_max, 50)
        
        inner = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=outer[2*i:2*(i+1), j], hspace=0.2, height_ratios=[1, 1, 1, 1])
        current_ax = plt.Subplot(fig, inner[0])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_test[is_hh4b_test, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_hh4b_test])
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[1])

        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_test[is_hh4b_test, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_hh4b_test])
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        # current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[2])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_4b_test])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()

        fig.add_subplot(current_ax)
        current_ax = plt.Subplot(fig, inner[3])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.legend()
        fig.add_subplot(current_ax)


plt.show()
plt.close()

bins_range = np.linspace(0, 1, 50)

fig, ax = plt.subplots()
ax.hist(probs_4b_test[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
ax.hist(probs_4b_test[is_bg4b_test], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
ax.hist(probs_4b_test[is_hh4b_test], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_hh4b_test])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()

fig, ax = plt.subplots()
ax.hist(probs_4b_test[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
ax.hist(probs_4b_test[is_4b_test], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()
plt.close()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 3.5 * 2 * dim_quadjet_features))
outer = gridspec.GridSpec(2*dim_quadjet_features, 3, hspace=1.5, wspace=0.3)

for i in range(dim_quadjet_features):
    for j in range(3):
        
        repr_min, repr_max = np.min(q_repr_test[:, i, j]), np.max(q_repr_test[:, i, j])
        bins_range = np.linspace(repr_min, repr_max, 50)
        
        inner = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=outer[2*i:2*(i+1), j], hspace=0.2, height_ratios=[1, 1, 1, 1])
        current_ax = plt.Subplot(fig, inner[0])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_test[is_hh4b_test, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_hh4b_test])
        current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[1])

        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_test[is_hh4b_test, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_hh4b_test])
        # current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[2])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_4b_test])
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()

        fig.add_subplot(current_ax)
        current_ax = plt.Subplot(fig, inner[3])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.legend()
        fig.add_subplot(current_ax)


plt.show()
plt.close()

In [None]:
import plotly.graph_objects as go

n_points = min(np.sum(is_3b_test), np.sum(is_bg4b_test), np.sum(is_hh4b_test))

is_3b_plot = np.random.choice(np.where(is_3b_test)[0], n_points, replace=False)
is_bg4b_plot = np.random.choice(np.where(is_bg4b_test)[0], n_points, replace=False)
is_hh4b_plot = np.random.choice(np.where(is_hh4b_test)[0], n_points, replace=False)

for i in range(dim_quadjet_features):
    fig = go.Figure()
    fig.update_layout(width=600, height=600)
    fig.update_layout(title=f"Feature {i}")
    fig.update_layout(hovermode=False)
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_3b_plot, i, 0], y=q_repr_test[is_3b_plot, i, 1], z=q_repr_test[is_3b_plot, i, 2], mode='markers', name='bg 3b', marker=dict(size=3, color="blue", opacity=0.4)))
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_bg4b_plot, i, 0], y=q_repr_test[is_bg4b_plot, i, 1], z=q_repr_test[is_bg4b_plot, i, 2], mode='markers', name='bg 4b', marker=dict(size=3, color="orange", opacity=0.4)))
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_hh4b_plot, i, 0], y=q_repr_test[is_hh4b_plot, i, 1], z=q_repr_test[is_hh4b_plot, i, 2], mode='markers', name='HH 4b', marker=dict(size=3, color="green"))
                    )
    fig.update_layout(scene=dict(xaxis_title=f"View 0", yaxis_title=f"View 1", zaxis_title=f"View 2"))
    fig.show()

In [None]:
# Pair plots with prob4b threshold

fig, ax = plt.subplots()
ax.hist(probs_4b_test[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
ax.hist(probs_4b_test[is_bg4b_test], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
ax.hist(probs_4b_test[is_hh4b_test], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_hh4b_test])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()

probs_4b_threshold = 0.6
probs_4b_exceeded = probs_4b_test > probs_4b_threshold

n_points = min(np.sum(is_3b_test & probs_4b_exceeded), np.sum(is_bg4b_test & probs_4b_exceeded), np.sum(is_hh4b_test & probs_4b_exceeded))
print(f"Plotting {n_points} per class with prob4b > {probs_4b_threshold}")

is_3b_plot = np.random.choice(np.where(is_3b_test & probs_4b_exceeded)[0], n_points, replace=False)
is_bg4b_plot = np.random.choice(np.where(is_bg4b_test & probs_4b_exceeded)[0], n_points, replace=False)
is_hh4b_plot = np.random.choice(np.where(is_hh4b_test & probs_4b_exceeded)[0], n_points, replace=False)

for i in range(dim_quadjet_features):
    fig = go.Figure()
    fig.update_layout(width=600, height=600)
    fig.update_layout(title=f"Feature {i}")
    fig.update_layout(hovermode=False)
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_3b_plot, i, 0], y=q_repr_test[is_3b_plot, i, 1], z=q_repr_test[is_3b_plot, i, 2], mode='markers', name='bg 3b', marker=dict(size=3, color="blue")))
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_bg4b_plot, i, 0], y=q_repr_test[is_bg4b_plot, i, 1], z=q_repr_test[is_bg4b_plot, i, 2], mode='markers', name='bg 4b', marker=dict(size=3, color="orange")))
    fig.add_trace(go.Scatter3d
                    (x=q_repr_test[is_hh4b_plot, i, 0], y=q_repr_test[is_hh4b_plot, i, 1], z=q_repr_test[is_hh4b_plot, i, 2], mode='markers', name='HH 4b', marker=dict(size=3, color="green"))
                    )
    fig.update_layout(scene=dict(xaxis_title=f"View 0", yaxis_title=f"View 1", zaxis_title=f"View 2"))
    fig.show()

In [None]:
# TSNE
from sklearn.manifold import TSNE
np.random.seed(seed_)

q_repr_and_probs_4b = np.concatenate([q_repr_test.reshape(-1, 3 * dim_quadjet_features), probs_4b_test.reshape(-1, 1)], axis=1)

n_components = 1
for probs_4b_threshold in [0.0, 0.4, 0.5, 0.6, 0.7]:
    probs_4b_exceeded = probs_4b_test > probs_4b_threshold
    n_sample_clustering = min(np.sum(probs_4b_exceeded), 10000)
    np.random.seed(seed_)
    idx_clustering = np.random.choice(np.where(probs_4b_exceeded)[0], n_sample_clustering, replace=False)

    is_3b_cluster = is_3b_test[idx_clustering]
    is_bg4b_cluster = is_bg4b_test[idx_clustering]
    is_hh4b_cluster = is_hh4b_test[idx_clustering]

    # Initialize t-SNE
    tsne = TSNE(n_components=n_components, random_state=seed_)  # n_components is the dimension of the embedded space

    # Apply t-SNE to your data
    embedded_data = tsne.fit_transform(q_repr_and_probs_4b[idx_clustering])
    # save embedded data
    np.save(f"data/tsne_embedding_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy", embedded_data)
    np.save(f"data/tsne_labels_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy", np.stack([is_3b_cluster, is_bg4b_cluster, is_hh4b_cluster], axis=1))

In [None]:
n_components = 3
probs_4b_threshold = 0.0
seed_ = 42

embedded_data = np.load(f"data/tsne_embedding_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy")
labels = np.load(f"data/tsne_labels_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy")

is_3b_cluster = labels[:, 0]
is_bg4b_cluster = labels[:, 1]
is_hh4b_cluster = labels[:, 2]

# 3d plot of TSNE

fig = go.Figure()
fig.update_layout(width=600, height=600)
fig.update_layout(hovermode=False)
fig.add_trace(go.Scatter3d
                (x=embedded_data[is_3b_cluster, 0], y=embedded_data[is_3b_cluster, 1], z=embedded_data[is_3b_cluster, 2], mode='markers', name='bg 3b', marker=dict(size=3, color="blue", opacity=0.2)))
fig.add_trace(go.Scatter3d
                (x=embedded_data[is_bg4b_cluster, 0], y=embedded_data[is_bg4b_cluster, 1], z=embedded_data[is_bg4b_cluster, 2], mode='markers', name='bg 4b', marker=dict(size=3, color="orange", opacity=0.2)))
fig.add_trace(go.Scatter3d
                (x=embedded_data[is_hh4b_cluster, 0], y=embedded_data[is_hh4b_cluster, 1], z=embedded_data[is_hh4b_cluster, 2], mode='markers', name='HH 4b', marker=dict(size=3, color="green"))
                )
fig.show()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 5))

ax[0].hist(embedded_data[is_3b_cluster, 0], label="3b", histtype="step")
ax[0].hist(embedded_data[is_bg4b_cluster, 0], label="bg4b", histtype="step")
ax[0].hist(embedded_data[is_hh4b_cluster, 0], label="hh4b", histtype="step")
ax[0].set_xlabel('Component 1')
ax[0].legend()

ax[1].hist(embedded_data[is_3b_cluster, 1], label="3b", histtype="step")
ax[1].hist(embedded_data[is_bg4b_cluster, 1], label="bg4b", histtype="step")
ax[1].hist(embedded_data[is_hh4b_cluster, 1], label="hh4b", histtype="step")
ax[1].set_xlabel('Component 2')
ax[1].legend()

ax[2].hist(embedded_data[is_3b_cluster, 2], label="3b", histtype="step")
ax[2].hist(embedded_data[is_bg4b_cluster, 2], label="bg4b", histtype="step")
ax[2].hist(embedded_data[is_hh4b_cluster, 2], label="hh4b", histtype="step")
ax[2].set_xlabel('Component 3')
ax[2].legend()

plt.show()
plt.close()

In [None]:
n_components = 1
probs_4b_threshold = 0.5
seed_ = 42

embedded_data = np.load(f"data/tsne_embedding_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy")
labels = np.load(f"data/tsne_labels_n_components_{n_components}_probs_4b_thr_{probs_4b_threshold}_seed_{seed_}.npy")

is_3b_cluster = labels[:, 0]
is_bg4b_cluster = labels[:, 1]
is_hh4b_cluster = labels[:, 2]

fig, ax = plt.subplots(figsize=(5, 5))

bins_range = np.linspace(np.min(embedded_data), np.max(embedded_data), 30)
ax.hist(embedded_data[is_3b_cluster, 0], label="3b", histtype="step", bins=bins_range)
ax.hist(embedded_data[is_bg4b_cluster, 0], label="bg4b", histtype="step", bins=bins_range)
ax.hist(embedded_data[is_hh4b_cluster, 0], label="hh4b", histtype="step", bins=bins_range)
ax.set_xlabel('Component 1')
ax.legend()
plt.show()
plt.close()

In [None]:
# Refit Classifier with thresholded data

from sklearn.metrics import roc_auc_score, roc_curve


probs_4b_threshold = 0.6

probs_4b_exceeded_val = probs_4b_val > probs_4b_threshold
q_repr_thresholded_val = q_repr_val[probs_4b_exceeded_val]
labels_4b_thresholded_val = labels_4b_val[probs_4b_exceeded_val]
weights_thresholded_val = weights_val[probs_4b_exceeded_val]
probs_4b_thresholded_val = probs_4b_val[probs_4b_exceeded_val]

probs_4b_exceeded_test = probs_4b_test > probs_4b_threshold
q_repr_thresholded_test = q_repr_test[probs_4b_exceeded_test]
labels_4b_thresholded_test = labels_4b_test[probs_4b_exceeded_test]
weights_thresholded_test = weights_test[probs_4b_exceeded_test]
probs_4b_thresholded_test = probs_4b_test[probs_4b_exceeded_test]

fpr, tpr, _ = roc_curve(labels_4b_thresholded_test, probs_4b_thresholded_test, sample_weight=weights_thresholded_test)
roc_auc = roc_auc_score(labels_4b_thresholded_test, probs_4b_thresholded_test, sample_weight=weights_thresholded_test)

# refit

from sklearn.neural_network import MLPClassifier

refit_clf = MLPClassifier(hidden_layer_sizes=(100, 100), max_iter=1000, random_state=seed_)
refit_clf.fit(q_repr_thresholded_val.reshape(-1, 3*dim_quadjet_features), labels_4b_thresholded_val)

probs_4b_refit = refit_clf.predict_proba(q_repr_thresholded_test.reshape(-1, 3*dim_quadjet_features))[:, 1]
auc_refit = roc_auc_score(labels_4b_thresholded_test, probs_4b_refit, sample_weight=weights_thresholded_test)
fpr_refit, tpr_refit, _ = roc_curve(labels_4b_thresholded_test, probs_4b_refit, sample_weight=weights_thresholded_test)

fig, ax = plt.subplots()
ax.plot(fpr, tpr, label=f"FvT ROC AUC: {roc_auc:.3f}")
ax.plot(fpr_refit, tpr_refit, label=f"Refit ROC AUC: {auc_refit:.3f}")
ax.plot([0, 1], [0, 1], linestyle="--")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
plt.show()
plt.close()

# bins_range = np.linspace(0, 1, 50)

# fig, ax = plt.subplots()
# ax.hist(probs_4b_refit[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
# ax.hist(probs_4b_refit[is_bg4b_test], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
# ax.hist(probs_4b_refit[is_hh4b_test], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_hh4b_test])
# ax.legend()
# ax.set_xlabel("FvT output")
# plt.show()

# fig, ax = plt.subplots()
# ax.hist(probs_4b_refit[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
# ax.hist(probs_4b_refit[is_4b_test], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
# ax.legend()
# ax.set_xlabel("FvT output")
# plt.show()
# plt.close()

In [None]:
print("HH4B Data")
from torch.utils.data import DataLoader

end_at = 1024 * (df_hh4b.index.size // 1024)

X_hh4b = torch.tensor(df_hh4b[features].values[:end_at], dtype=torch.float32)
w_hh4b = torch.tensor(df_hh4b["weight"].values[:end_at], dtype=torch.float32)
y_hh4b = torch.tensor(df_hh4b["fourTag"].values[:end_at], dtype=torch.long)
is_signal_hh4b = torch.tensor(df_hh4b["signal"].values[:end_at], dtype=torch.long)
hh4b_dataset = TensorDataset(X_hh4b, y_hh4b, w_hh4b)

fvt_model = FvTClassifier.load_from_checkpoint(f"./checkpoints/{run_name}_best.ckpt")
hh4b_loader = DataLoader(hh4b_dataset, batch_size=1024, shuffle=False)
device = torch.device("cuda:0")
fvt_model = fvt_model.to(device)

q_repr_hh4b = []

for batch in hh4b_loader:
    x, y, w = batch
    x = x.to(device)
    q = fvt_model.encoder(x)
    q_repr_hh4b.append(q.detach().cpu().numpy())

q_repr_hh4b = np.concatenate(q_repr_hh4b, axis=0)
labels_4b_hh4b = y_hh4b.cpu().numpy()
probs_4b_hh4b = fvt_model.predict(X_hh4b)[:, 1].cpu().numpy()
weights_hh4b = w_hh4b.cpu().numpy()
is_signal_hh4b = is_signal_hh4b.cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 3.5 * 2 * dim_quadjet_features))
outer = gridspec.GridSpec(2*dim_quadjet_features, 3, hspace=1.5, wspace=0.3)

is_3b_test = labels_4b_test == 0
is_4b_test = labels_4b_test == 1
is_bg4b_test = (labels_4b_test == 1) & (is_signal_test == 0)
is_hh4b_test = (labels_4b_test == 1) & (is_signal_test == 1)
plot_density = True

for i in range(dim_quadjet_features):
    for j in range(3):
        
        repr_min, repr_max = np.min(q_repr_test[:, i, j]), np.max(q_repr_test[:, i, j])
        bins_range = np.linspace(repr_min, repr_max, 50)
        
        inner = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=outer[2*i:2*(i+1), j], hspace=0.2, height_ratios=[1, 1, 1, 1])
        current_ax = plt.Subplot(fig, inner[0])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_hh4b[:, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=True, weights=weights_hh4b)
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[1])

        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_bg4b_test, i, j], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
        current_ax.hist(q_repr_hh4b[:, i, j], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_hh4b)
        # calculate distance between two histograms (total variation distance)
        # tvd = 0.5 * np.sum(np.abs(repr_hist_3b - repr_hist_4b))
        # current_ax.set_title(f"View {j}, Feature {i}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()
        fig.add_subplot(current_ax)

        current_ax = plt.Subplot(fig, inner[2])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=True, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=True, weights=weights_test[is_4b_test])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.set_xticks([])
        current_ax.legend()

        fig.add_subplot(current_ax)
        current_ax = plt.Subplot(fig, inner[3])
        current_ax.hist(q_repr_test[is_3b_test, i, j], bins=bins_range, label="3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
        current_ax.hist(q_repr_test[is_4b_test, i, j], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
        # calculate distance between two histograms (total variation distance)
        # current_ax.set_title(f"View {j}, Feature {i}, TV: {tvd:.2f}")
        # remove x labels
        current_ax.legend()
        fig.add_subplot(current_ax)


plt.show()
plt.close()

bins_range = np.linspace(0, 1, 50)

fig, ax = plt.subplots()
ax.hist(probs_4b_test[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
ax.hist(probs_4b_test[is_bg4b_test], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_bg4b_test])
ax.hist(probs_4b_hh4b, bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_hh4b)
ax.legend()
ax.set_xlabel("FvT output")
plt.show()

fig, ax = plt.subplots()
ax.hist(probs_4b_test[is_3b_test], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_test[is_3b_test])
ax.hist(probs_4b_test[is_4b_test], bins=bins_range, label="all 4b", linewidth=1, histtype="step", density=False, weights=weights_test[is_4b_test])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()
plt.close()


# Re-Fit Classifier

In [None]:
# fit a classifier on q_repr_test to predict 4b vs 3b

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.model_selection import train_test_split

# split the data
(q_repr_refit_train, q_repr_refit_val, 
 labels_4b_refit_train, labels_4b_refit_val, 
 is_3b_refit_train, is_3b_refit_val, 
 is_bg4b_refit_train, is_bg4b_refit_val, 
 is_hh4b_refit_train, is_hh4b_refit_val,
 weights_refit_train, weights_refit_val) = train_test_split(q_repr_test, 
                                                            labels_4b_test, 
                                                            is_3b_test, 
                                                            is_bg4b_test, 
                                                            is_hh4b_test, 
                                                            weights_test, 
                                                            test_size=0.25, random_state=seed_)
                                                                                         

refit_clf = GradientBoostingClassifier(n_estimators=10, max_depth=3, learning_rate=0.1)
refit_clf.fit(q_repr_refit_train.reshape(-1, 3*dim_quadjet_features), labels_4b_refit_train, sample_weight=weights_refit_train)

refit_probs_val = refit_clf.predict_proba(q_repr_refit_val.reshape(-1, 3*dim_quadjet_features))[:, 1]
refit_auc_val = roc_auc_score(labels_4b_refit_val, refit_probs_val)
fpr, tpr, _ = roc_curve(labels_4b_refit_val, refit_probs_val)

In [None]:
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label=f"ROC AUC: {refit_auc_val:.3f}")
ax.set_xlabel("FPR")
ax.set_ylabel("TPR")
ax.legend()
plt.show()
plt.close()

min_prob = refit_probs_val.min()
max_prob = refit_probs_val.max()
bins_range = np.linspace(min_prob, max_prob, 50)

fig, ax = plt.subplots()
ax.hist(refit_probs_val[is_3b_refit_val], bins=bins_range, label="bg 3b", linewidth=1, histtype="step", density=False, weights=weights_refit_val[is_3b_refit_val])
ax.hist(refit_probs_val[is_bg4b_refit_val], bins=bins_range, label="bg 4b", linewidth=1, histtype="step", density=False, weights=weights_refit_val[is_bg4b_refit_val])
ax.hist(refit_probs_val[is_hh4b_refit_val], bins=bins_range, label="HH 4b", linewidth=1, histtype="step", density=False, weights=weights_refit_val[is_hh4b_refit_val])
ax.legend()
ax.set_xlabel("FvT output")
plt.show()

refit_probs_val_3b_hist, _ = np.histogram(refit_probs_val[is_3b_refit_val], bins=bins_range, density=False, weights=weights_refit_val[is_3b_refit_val])
refit_probs_val_4b_hist, _ = np.histogram(refit_probs_val[~is_3b_refit_val], bins=bins_range, density=False, weights=weights_refit_val[~is_3b_refit_val])


fig, ax = plt.subplots()
refit_probs_val_hist_ratio = refit_probs_val_4b_hist / refit_probs_val_3b_hist 
ax.step(bins_range[:-1], refit_probs_val_hist_ratio, label="4b / 3b", linewidth=1)
ax.legend()
ax.set_xlabel("FvT output")
plt.show()
plt.close()