# Export of DNN model trained with Lumin into ONNX usable by NBA DNNInference task
All 10 models with their weights are embedded into one model. Also feature scaling is embedded as well. The two ensembles (even/odd) stay separate

In [1]:
import torch
from torch import nn
import lumin
from lumin.nn.ensemble.ensemble import Ensemble
from lumin.nn.models.model import Model
import pickle
from sklearn.preprocessing import StandardScaler
import shutil, os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LuminEnsembleLoader:
    def __init__(self, input_dir:str, ensemble_nb):
        self.input_dir = input_dir
        self.ensemble_nb = ensemble_nb
        self.run_name = 0
        self.load()
    
    def load(self):
        self.ensemble = Ensemble.from_save(self.input_dir + f'/selected_set_{self.ensemble_nb}_{self.run_name}')
        self.n_continuous_features = self.ensemble.models[0].model_builder.n_cont_in
        self.n_categorical_features = self.ensemble.models[0].model_builder.cat_embedder.n_cat_in

        with open(self.input_dir + f"/selected_set_{self.ensemble_nb}_{self.run_name}_input_pipe.pkl", "rb") as f:
            self.input_pipe = pickle.load(f)
    
    def makeDummyInput(self):
        """ Random input to serve as placeholder for the export. Same dimension as real input, except batch_size which is picked randomly to 5 (because it is ignored anyway) """
        return (torch.rand(5, self.n_continuous_features+self.n_categorical_features))


class MergingModelModule(nn.Module):
    """ Torch module """
    def __init__(self, loader:LuminEnsembleLoader) -> None:
        super().__init__()
        self.models = nn.ModuleList([lumin_model.model for lumin_model in loader.ensemble.models])
        self.weights = torch.tensor(loader.ensemble.weights)
        assert len(self.models) == self.weights.shape[0]

        scaler = loader.input_pipe[0]
        # Only continuous features should be scaled. They are the first features in order in the inputs
        assert(len(scaler.scale_) == loader.n_continuous_features)
        self.scaler_scale = torch.tensor(list(scaler.scale_) + [1.]*loader.n_categorical_features, dtype=torch.float32)
        self.scaler_mean = torch.tensor(list(scaler.mean_) + [0.]*loader.n_categorical_features, dtype=torch.float32)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        # scaling input
        input_scaled = (input - self.scaler_mean) / self.scaler_scale
        preds = []#torch.zeros(len(self.models))
        for model in self.models:
            preds.append(torch.squeeze(model(input_scaled), dim=1))
        return torch.sum(torch.stack(preds, dim=1) * self.weights, dim=1)


def doExport(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    loader_0 = LuminEnsembleLoader(input_dir, 0)
    merged_model_0 = MergingModelModule(loader_0)
    ensemble_nb = 0
    torch.onnx.export(merged_model_0, loader_0.makeDummyInput(), f"{output_dir}/model_merged_{ensemble_nb}.onnx",
    verbose=False, input_names=["input"], output_names=["output"], dynamic_axes={"input":{0: "batch_size"}, "output":{0:"batch_size"}}, do_constant_folding=True, opset_version=11)

    loader_1 = LuminEnsembleLoader(input_dir, 1)
    merged_model_1 = MergingModelModule(loader_1)
    ensemble_nb = 1
    torch.onnx.export(merged_model_1, loader_1.makeDummyInput(), f"{output_dir}/model_merged_{ensemble_nb}.onnx",
    verbose=False, input_names=["input"], output_names=["output"], dynamic_axes={"input":{0: "batch_size"}, "output":{0:"batch_size"}}, do_constant_folding=True, opset_version=11)

    

In [None]:
# April 2025 new non-res DNN 2025-04-17 version (v3) split in res and boosted
#doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_17/DNNWeight_ZZbbtt_FullRun2_Res/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_17/ZZbbtt_Resolved-0")
#shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/hhbbtt-analysis/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025-04-17/ZZbbttRes-0/features.txt ", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_17/ZZbbtt_Resolved-0/features.txt")

# boosted
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_17/DNNWeight_ZZbbtt_FullRun2_Boosted/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_17/ZZbbtt_Boosted-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/hhbbtt-analysis/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025-04-17/ZZbbttBoost-0/features.txt ", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_17/ZZbbtt_Boosted-0/features.txt")


'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_17/ZZbbtt_Boosted-0/features.txt'

In [None]:
# April 2025 new non-res DNN with more stat  (full Run2):  v1

#ZZ non resonant
# doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_07/DNNWeight_ZZbbtt_FullRun2_0/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_07/ZZbbtt-0")
# shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/hhbbtt-analysis/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt//2025-04-07/ZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_07/ZZbbtt-0/features.txt")

# new version splitting resolved & boosted DNNs 2025_04_10
# resolved
# doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_10/DNNWeight_ZZbbtt_FullRun2_Res/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_10/ZZbbtt_Resolved-0")
# shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_10/DNNWeight_ZZbbtt_FullRun2_Res/ZZbbtt_Resolved-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_10/ZZbbtt_Resolved-0/features.txt")

# boosted
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_10/DNNWeight_ZZbbtt_FullRun2_Boosted/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_10/ZZbbtt_Boosted-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_04_10/DNNWeight_ZZbbtt_FullRun2_Boosted/ZZbbtt_Boosted-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_10/ZZbbtt_Boosted-0/features.txt")


'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework9/nanoaod_base_analysis/data/cmssw/CMSSW_15_0_3/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_04_10/ZZbbtt_Boosted-0/features.txt'

In [3]:
# January 2025 new DNN with boostedTau  (full Run2): /grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/hhbbtt-analysis/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt//2025-01-09/ResZbbHtt-0

#ZbbHtt resonant
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZbbHtt_0/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZbbHtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZbbHtt_0/ResZbbHtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZbbHtt-0/features.txt")




'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZbbHtt-0/features.txt'

In [4]:
#ZZbbtt resonant
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZZbbtt_0/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZZbbtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZZbbtt_0/ResZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZZbbtt-0/features.txt")


'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZZbbtt-0/features.txt'

In [5]:
#ZttHbb resonant
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZttHbb_0/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZttHbb-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2025_01_09/DNNWeight_ZttHbb_0/ResZttHbb-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZttHbb-0/features.txt")


'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_01_09/ResZttHbb-0/features.txt'

In [5]:
# non-resonant networks March 2025
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZZbbtt_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZZbbtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZZbbtt_FullRun2_0/ZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZZbbtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZbbHtt_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZbbHtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZbbHtt_FullRun2_0/ZbbHtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZbbHtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZttHbb_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZttHbb-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2025_03_04/DNNWeight_ZttHbb_FullRun2_0/ZttHbb-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZttHbb-0/features.txt")

'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2025_03_04/ZttHbb-0/features.txt'

In [6]:
# DNN with boostedTau : /grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/hhbbtt-analysis/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-11-16/ResZZbbtt-0
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_11_16/DNNWeight_ZZbbtt_0/ensemble/", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-11-16/ResZZbbtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_11_16/DNNWeight_ZZbbtt_0/ZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-11-16/ResZZbbtt-0/features.txt")


'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/frameworkJobs/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-11-16/ResZZbbtt-0/features.txt'

In [16]:
# non-resonant networks
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZZbbtt_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZZbbtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZZbbtt_FullRun2_0/ZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZZbbtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZbbHtt_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZbbHtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZbbHtt_FullRun2_0/ZbbHtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZbbHtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZttHbb_FullRun2_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZttHbb-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/NonResDNN/2024_03_26/DNNWeight_ZttHbb_FullRun2_0/ZttHbb-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZttHbb-0/features.txt")

'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ZttHbb-0/features.txt'

In [15]:
# resonant networks
doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZZbbtt_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZZbbtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZZbbtt_0/ResZZbbtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZZbbtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZbbHtt_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZbbHtt-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZbbHtt_0/ResZbbHtt-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZbbHtt-0/features.txt")

doExport("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZttHbb_0/ensemble", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZttHbb-0")
shutil.copyfile("/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZttHbb_0/ResZttHbb-0/features.txt", "/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZttHbb-0/features.txt")

'/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZttHbb-0/features.txt'

In [34]:
weight_dir = "/grid_mnt/data__data.polcms/cms/vernazza/FrameworkNanoAOD/DNNTraining/ResDNN/2024_04_29/DNNWeight_ZbbHtt_0/ensemble"
ensemble_nb = 0
run_name = 0
ensemble = Ensemble.from_save(weight_dir + f'/selected_set_{ensemble_nb}_{run_name}')
n_continuous_features = ensemble.models[0].model_builder.n_cont_in
n_categorical_features = ensemble.models[0].model_builder.cat_embedder.n_cat_in



In [35]:
with open(weight_dir + f"/selected_set_{ensemble_nb}_{run_name}_input_pipe.pkl", "rb") as f:
    input_pipe = pickle.load(f)
input_pipe

Pipeline(memory=None,
         steps=[('norm_in',
                 StandardScaler(copy=True, with_mean=True, with_std=True))],
         verbose=False)

In [36]:
input_pipe[0].var_

array([1.09056358e+06, 9.73338957e+03, 3.04601092e+04, 2.27463657e+03,
       6.81578084e+03, 2.43383514e+03, 1.10336170e+03, 5.28646168e-01,
       8.82251517e-01, 1.12800290e+03, 8.15063767e-02, 2.27849065e+04,
       6.73152745e-01, 6.09149199e+03, 5.41055765e-01, 1.24320763e+03,
       3.59210958e+03, 6.36534991e-01, 2.87261347e-01, 1.85255106e+00,
       1.17523499e+00])

In [38]:
dummy_input = (torch.rand(5, ensemble.models[0].model_builder.n_cont_in+ensemble.models[0].model_builder.cat_embedder.n_cat_in))
dummy_input.dtype

torch.float32

In [39]:
mergedModel.scaler_mean

tensor([1.9224e+03, 4.4228e+01, 3.5362e+02, 1.2861e+02, 1.9908e+02, 6.8297e+01,
        4.8398e+01, 1.8598e+00, 1.3292e+00, 9.4287e+01, 8.6428e-01, 3.5160e+02,
        2.2462e+00, 1.0110e+02, 2.5600e+00, 5.8949e+01, 8.9182e+01, 1.5841e+00,
        2.7105e-02, 1.4117e+00, 8.2376e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00])

In [40]:
mergedModel(dummy_input)

tensor([0.0037, 0.0018, 0.0016, 0.0009, 0.0032], dtype=torch.float64,
       grad_fn=<SumBackward1>)

In [41]:
torch.onnx.export(mergedModel, dummy_input, f"/grid_mnt/data__data.polcms/cms/cuisset/ZHbbtautau/framework/nanoaod_base_analysis/data/cmssw/CMSSW_12_3_0_pre6/src/cms_runII_dnn_models/models/arc_checks/zz_bbtt/2024-05-10/ResZbbHtt-0/model_merged_{ensemble_nb}.onnx",
    verbose=False, input_names=["input"], output_names=["output"], dynamic_axes={"input":{0: "batch_size"}, "output":{0:"batch_size"}}, do_constant_folding=True, opset_version=11)

In [15]:
mergedModel.weights

tensor([0.1154, 0.1116, 0.1082, 0.1036, 0.1023, 0.1009, 0.1006, 0.0996, 0.0967,
        0.0611], dtype=torch.float64)

In [24]:
preds = []
for model in mergedModel.models:
    preds.append(torch.squeeze(model(dummy_input), dim=1))
preds

[tensor([5.2964e-10, 2.6302e-07, 4.3536e-08, 2.7058e-09, 3.4302e-06],
        grad_fn=<SqueezeBackward1>),
 tensor([0.0251, 0.0741, 0.0160, 0.0055, 0.1671], grad_fn=<SqueezeBackward1>),
 tensor([9.7916e-07, 4.9248e-08, 7.9337e-08, 2.1923e-09, 4.2870e-08],
        grad_fn=<SqueezeBackward1>),
 tensor([3.3172e-11, 3.7592e-09, 3.8905e-11, 7.9050e-10, 3.8310e-08],
        grad_fn=<SqueezeBackward1>),
 tensor([9.4735e-07, 7.5788e-08, 5.7191e-08, 8.2962e-08, 2.1950e-06],
        grad_fn=<SqueezeBackward1>),
 tensor([2.6189e-07, 1.6012e-10, 6.3453e-10, 3.8667e-11, 1.5070e-08],
        grad_fn=<SqueezeBackward1>),
 tensor([2.5753e-11, 5.7733e-13, 1.1907e-12, 8.2175e-14, 1.8263e-09],
        grad_fn=<SqueezeBackward1>),
 tensor([0.0032, 0.0098, 0.0062, 0.0007, 0.0077], grad_fn=<SqueezeBackward1>),
 tensor([1.5932e-09, 4.4555e-07, 1.0649e-07, 4.6203e-09, 1.4602e-07],
        grad_fn=<SqueezeBackward1>),
 tensor([2.7580e-04, 2.4793e-03, 1.4698e-01, 8.8709e-05, 8.3632e-03],
        grad_fn=<Squeez

In [23]:
ensemble.models[0].model(dummy_input)

tensor([[1.2324e-06],
        [6.2032e-08],
        [2.0313e-08],
        [3.7742e-07],
        [2.3359e-06]], grad_fn=<SigmoidBackward0>)

In [72]:
mergedModel(dummy_input).detach().numpy()

array([0.11136092, 0.09851962, 0.10070108, 0.06736492, 0.11155293])

In [75]:
ensemble.predict(dummy_input)[:, 0] - mergedModel(dummy_input).detach().numpy()

array([1.46670333e-09, 2.83946335e-09, 3.35595635e-10, 5.69570668e-09,
       2.51254241e-09])