/
__init__.py
199 lines (144 loc) · 6.61 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Authors:
# * Harshal Shende 04/2022
# * Lorenzo Moneta 04/2022
################################################################################
# Copyright (C) 1995-2020, Rene Brun and Fons Rademakers. #
# All rights reserved. #
# #
# For the licensing terms see $ROOTSYS/LICENSE. #
# For the list of contributors see $ROOTSYS/README/CREDITS. #
################################################################################
import sys
import cppyy
from cppyy.gbl import gSystem
from .. import pythonization
from ._factory import Factory
from ._dataloader import DataLoader
from ._crossvalidation import CrossValidation
from ._rbdt import Compute, pythonize_rbdt
if sys.version_info >= (3, 8):
from ._batchgenerator import (
CreateNumPyGenerators,
CreateTFDatasets,
CreatePyTorchGenerators,
)
python_batchgenerator_functions = [
CreateNumPyGenerators,
CreateTFDatasets,
CreatePyTorchGenerators,
]
def inject_rbatchgenerator(ns):
for python_func in python_batchgenerator_functions:
func_name = python_func.__name__
setattr(ns.Experimental, func_name, python_func)
return ns
from ._gnn import RModel_GNN, RModel_GraphIndependent
hasRDF = gSystem.GetFromPipe("root-config --has-dataframe") == "yes"
if hasRDF:
from ._rtensor import get_array_interface, add_array_interface_property, RTensorGetitem, pythonize_rtensor
#this should be available only when xgboost is there ?
# We probably don't need a protection here since the code is run only when there is xgboost
from ._tree_inference import SaveXGBoost, pythonize_tree_inference
# list of python classes that are used to pythonize TMVA classes
python_classes = [Factory, DataLoader, CrossValidation]
# create a dictionary for convenient access to python classes
python_classes_dict = dict()
for python_class in python_classes:
python_classes_dict[python_class.__name__] = python_class
def get_defined_attributes(klass, consider_base_classes=False):
"""
Get all class attributes that are defined in a given class or optionally in
any of its base classes (except for `object`).
"""
blacklist = ["__dict__", "__doc__", "__hash__", "__module__", "__weakref__"]
if not consider_base_classes:
return sorted([attr for attr in klass.__dict__.keys() if attr not in blacklist])
# get a list of this class and all its base classes, excluding `object`
method_resolution_order = klass.mro()
if object in method_resolution_order:
method_resolution_order.remove(object)
def is_defined(funcname):
if funcname in blacklist:
return False
in_any_dict = False
for mro_class in method_resolution_order:
if funcname in mro_class.__dict__:
in_any_dict = True
return in_any_dict
return sorted([attr for attr in dir(klass) if is_defined(attr)])
def is_staticmethod_py2(klass, func_name):
"""Check if the function with name `func_name` of a class is a static method in Python 2."""
return type(getattr(klass, func_name)).__name__ == "function"
def is_classmethod(klass, func):
if hasattr(func, "__self__"):
return func.__self__ == klass
return False
def rebind_attribute(to_class, from_class, func_name):
"""
Bind the instance method `from_class.func_name` also to class `to_class`.
"""
import sys
from_method = getattr(from_class, func_name)
if is_classmethod(from_class, from_method):
# the @classmethod case
to_method = classmethod(from_method.__func__)
elif sys.version_info >= (3, 0):
# any other case in Python 3 is trivial
to_method = from_method
elif isinstance(from_method, property):
# the @property case in Python 2
to_method = from_method
elif is_staticmethod_py2(from_class, func_name):
# the @staticmethod case in Python 2
to_method = staticmethod(from_method)
else:
# the instance method case in Python 2
import new
to_method = new.instancemethod(from_method.__func__, None, to_class)
setattr(to_class, func_name, to_method)
def make_func_name_orig(func_name):
"""Return the name that we will give to the original cppyy function."""
# special treatment of magic functions, e.g.: __getitem__ > _getitem
if func_name.startswith("__") and func_name.endswith("__"):
func_name = func_name[2:-2]
return "_" + func_name
@pythonization(class_name=["Factory", "DataLoader", "CrossValidation"], ns="TMVA")
def pythonize_tmva(klass, name):
# Parameters:
# klass: class to pythonize
# name: string containing the name of the class
# need to strip the TMVA namespace
ns_prefix = "TMVA::"
name = name[len(ns_prefix) : len(name)]
if not name in python_classes_dict:
print("Error - class ", name, "is not in the pythonization list")
return
python_klass = python_classes_dict[name]
# list of functions to pythonize, which are assumed to be all functions in
# that are manually defined in the Python classes or their superclasses
func_names = get_defined_attributes(python_klass)
for func_name in func_names:
# if the TMVA class already has a function with the same name as our
# pythonization, we rename it and prefix it with an underscore
if hasattr(klass, func_name):
# new name for original function
func_name_orig = make_func_name_orig(func_name)
func_orig = getattr(klass, func_name)
func_new = getattr(python_klass, func_name)
import inspect
import sys
if sys.version_info < (3, 0):
func_new = func_new.__func__
if func_new.__doc__ is None:
func_new.__doc__ = func_orig.__doc__
elif not func_orig.__doc__ is None:
python_docstring = func_new.__doc__
func_new.__doc__ = "Pythonization info\n"
func_new.__doc__ += "==============\n\n"
func_new.__doc__ += inspect.cleandoc(python_docstring) + "\n\n"
func_new.__doc__ += "Documentation of original cppyy.CPPOverload object\n"
func_new.__doc__ += "==================================================\n\n"
func_new.__doc__ += func_orig.__doc__
setattr(klass, func_name_orig, func_orig)
rebind_attribute(klass, python_klass, func_name)
return