diff --git a/tutorials/machine_learning/TMVA_CNN_Classification.py b/tutorials/machine_learning/TMVA_CNN_Classification.py index 658f373b517bd..501d8e4e36bd8 100644 --- a/tutorials/machine_learning/TMVA_CNN_Classification.py +++ b/tutorials/machine_learning/TMVA_CNN_Classification.py @@ -27,43 +27,13 @@ import os import importlib.util -useKerasCNN = False - -if ROOT.gSystem.GetFromPipe("root-config --has-tmva-pymva") == "yes": - useKerasCNN = True - opt = [1, 1, 1, 1, 1] useTMVACNN = opt[0] if len(opt) > 0 else False -useKerasCNN = opt[1] if len(opt) > 1 else useKerasCNN +useKerasCNN = opt[1] if len(opt) > 1 else False useTMVADNN = opt[2] if len(opt) > 2 else False useTMVABDT = opt[3] if len(opt) > 3 else False usePyTorchCNN = opt[4] if len(opt) > 4 else False -if useKerasCNN: - try: - import tensorflow - except: - ROOT.Warning("TMVA_CNN_Classification", "Skip using Keras since tensorflow cannot be imported") - useKerasCNN = False - -# PyTorch has to be imported before ROOT to avoid crashes because of clashing -# std::regexp symbols that are exported by cppyy. -# See also: https://github.com/wlav/cppyy/issues/227 -torch_spec = importlib.util.find_spec("torch") -if torch_spec is None: - usePyTorchCNN = False - print("TMVA_CNN_Classificaton","Skip using PyTorch since torch is not installed") -else: - try: - import torch - except: - ROOT.Warning("TMVA_CNN_Classification", "Skip using PyTorch since it cannot be imported") - usePyTorchCNN = False - - -import ROOT - - TMVA = ROOT.TMVA TFile = ROOT.TFile