In [96]:
import rootpy.stl as stl
stl.generate("Peak","Peak.h", True) # can be done for all elements of the data_model

In [371]:
import types
import inspect
import traceback
from rootpy.tree import IntCol, FloatCol, BoolCol, ObjectCol, LongCol, IntArrayCol, BoolArrayCol
import numpy as np
from collections import OrderedDict
from rootpy.tree import TreeBuffer
from pax.data_model import StrictModel,Model,ListField
from rootpy.tree import TreeModel
from pax.utils import Memoize
import json
import bson
import rootpy.stl as stl
import cppyy
import ROOT

# ROOT.gROOT.ProcessLine('.L pax_event_class.cpp+')
# ROOT.gInterpreter.GenerateDictionary("Peak","pax_event_class.h")
# ROOT.gInterpreter.GenerateDictionary("Interaction","pax_event_class.h")

casting_allowed_for = {
    int:    ['int16', 'int32', 'int64', 'Int64', 'Int32'],
    float:  ['int', 'float32', 'float64', 'int16', 'int32', 'int64', 'Int64', 'Int32'],
    bool:   ['int16', 'int32', 'int64', 'Int64', 'Int32'],
    stl.string:    ['str'],
    ROOT.vector(int): ['list'],
    ROOT.vector(float): ['list'],
    ROOT.vector(bool): ['list'],
    ROOT.vector(stl.string): ['list'],
    ROOT.vector(ROOT.Peak): ['list']
}

    
class PaxTreeBuffer(TreeBuffer):
    def __post_init__(self, kwargs_dict=None, quick_init=False, **kwargs):
        if quick_init:
            self.__dict__.update(kwargs_dict)
            self.__dict__.update(kwargs)
            return

        # Initialize the collection fields to empty lists
        # super() is needed to bypass type checking in StrictModel
        list_field_info = self.get_list_field_info()
        for field_name in list_field_info:
            super().__setattr__(field_name, [])

        # Initialize all attributes from kwargs and kwargs_dict
        kwargs.update(kwargs_dict or {})
        for k, v in kwargs.items():
            if k in list_field_info:
                # User gave a value to initialize a list field. Hopefully an iterable!
                # Let's check if the types are correct
                desired_type = list_field_info[k]
                temp_list = []
                for el in v:
                    if isinstance(el, desired_type):
                        # Good, pass through unmolested
                        temp_list.append(el)
                    elif isinstance(el, dict):
                        # Dicts are fine too, we can use them to init the desired type
                        temp_list.append(desired_type(**el))
                    else:
                        raise ValueError("Attempt to initialize list field %s with type %s, "
                                         "but you promised type %s in class declaration." % (k,
                                                                                             type(el),
                                                                                             desired_type))
                # This has to be a list of dictionaries
                # suitable to be passed to __init__ of the list field's element type
                setattr(self, k, temp_list)
            else:
                default_value = getattr(self, k)
                if type(default_value) == np.ndarray:
                    if isinstance(v, np.ndarray):
                        pass
                    elif isinstance(v, bytes):
                        # Numpy arrays can be also initialized from a 'string' of bytes...
                        v = np.fromstring(v, dtype=default_value.dtype)
                    elif hasattr(v, '__iter__'):
                        # ... or an iterable
                        v = np.array(v, dtype=default_value.dtype) # TODO: rootpy.tree.treetypes.IntArray should pass here
                    else:
                        raise ValueError("Can't initialize field %s: "
                                         "don't know how to make a numpy array from a %s" % (k, type(v)))
                elif isinstance(default_value, Model):
                    v = default_value.__class__(**v)
                elif isinstance(v, stl.string): 
                    v = str(v)                    
                elif isinstance(v, stl.vector(stl.string)):
                    if not v.empty():
                        v = [str(el) for el in v]
                    else:
                        v = []  
                elif isinstance(v, (stl.vector(int),
                                   stl.vector(float),
                                   stl.vector(bool),
                                   stl.vector(Peak)
                               )):
                    if not v.empty():
                        v = list(v)
                    else:
                        v = []

                setattr(self, k, v)

    def to_json(self, fields_to_ignore=None):
        return json.dumps(self.to_dict(convert_numpy_arrays_to='list',
                                       fields_to_ignore=fields_to_ignore)
#                           , cls=ROOTJSONEncoder
                         )

    def to_bson(self, fields_to_ignore=None):
        return bson.BSON.encode(self.to_dict(convert_numpy_arrays_to='bytes',
                                             fields_to_ignore=fields_to_ignore))

    def __setattr__(self, key, value): 
        # Get the old attr.
        # #Will raise AttributeError if doesn't exists, which is what we want
        if(key.startswith('_')): # TreeBuffer attributes for the ROOT tree
            super().__setattr__(key, value)
            
        # model fields    
        old_val = getattr(self, key)
        old_type = type(old_val)
        new_type = type(value)
        # Check for attempted type change
        if old_type != new_type:

            # Are we allowed to cast the type?
            if old_type in casting_allowed_for \
                    and value.__class__.__name__ in casting_allowed_for[old_type]:
                if(old_type in [stl.vector(int), stl.vector(bool), stl.vector(float), 
                                stl.vector(stl.string),stl.vector(Peak)]):
                    prev_value = value
                    value=old_type()
                    for el in prev_value: # TODO: more efficient way to populate the stl.vector
                          value.push_back(str(el) if(old_type==stl.vector(stl.string)) else el)
                else:    
                    value = old_type(value)
            else:
                raise TypeError('Attribute %s of class %s should be a %s, not a %s. '
                                % (key,
                                   self.__class__.__name__,
                                   old_val.__class__.__name__,
                                   value.__class__.__name__))

        # Check for attempted dtype change
        if isinstance(old_val, np.ndarray):
            if old_val.dtype != value.dtype:
                raise TypeError('Attribute %s of class %s should have numpy dtype %s, not %s' % (
                    key, self.__class__.__name__, old_val.dtype, value.dtype))
                
        super().__setattr__(key, value)
    
                
#     @classmethod        # Use only in initialization (or if attributes are fixed, as for StrictModel)
#     @Memoize            # Caching decorator, improves performance if a model is initialized often
    def get_list_field_info(self):
        """Return dict with fielname => type of elements in collection fields in this class
        """
        list_field_info = {}
        for k, v in self.get_fields_data():
            if isinstance(v, ListField):
                list_field_info[k] = v.element_type
        return list_field_info
    
    def __str__(self): #TODO change
        return str(self.__dict__['_OrderedDict__map'])
    
    def get_fields_data(self):
        """Iterator over (key, value) tuples of all user-specified fields
        Returns keys in lexical order
        """
        # TODO: increase performance by pre-sorting keys?
        # self.__dict__.items() does not return default values set in class declaration
        # Hence we need something more complicated
        class_dict = self.__dict__['_OrderedDict__map']
        self_dict = self.__dict__
        
        for field_name in sorted(class_dict.keys()):
            if field_name in self_dict:
                # The instance has a value for this field: return it
                yield (field_name, self_dict[field_name])
            else:
                # ... it doesnt. Should we return its value?
                if field_name.startswith('_'):
                    continue    # No, is internal
                value_in_class = self.__getattr__(field_name) #TODO: or __getitem (returns wrapper object)
                if callable(value_in_class):
                    continue    # No, is a method
                if isinstance(value_in_class, (property, classmethod)):
                    continue    # No, is a property or classmethod
                # Yes, yield the class-level value
                yield (field_name, value_in_class)
    
    def to_dict(self, convert_numpy_arrays_to=None, fields_to_ignore=None):
        # TODO deal with rootpy.tree.treetypes.IntArray etc types
        result = {}
        if fields_to_ignore is None:
            fields_to_ignore = tuple()
        for k, v in self.get_fields_data():
            if k in fields_to_ignore:
                continue
            if isinstance(v, PaxTreeBuffer):
                result[k] = v.to_dict(convert_numpy_arrays_to=convert_numpy_arrays_to,
                                      fields_to_ignore=fields_to_ignore)
            elif isinstance(v, list):
                result[k] = [el.to_dict(convert_numpy_arrays_to=convert_numpy_arrays_to,
                                        fields_to_ignore=fields_to_ignore) for el in v]
            # dealing with ROOT stl.string fields
            elif isinstance(v, stl.string):
                result[k] = str(v)
            # dealing with ROOT.vectors
            elif isinstance(v, (stl.vector(int),
                               stl.vector(float),
                               stl.vector(bool)
#                                stl.vector(Peak)
                               )):
                if not v.empty():
                    result[k] = list(v)
                else:
                    result[k] = []
            elif isinstance(v, stl.vector(stl.string)):
                if not v.empty():
                    result[k] = [str(el) for el in v]
                else:
                    result[k] = []
            elif isinstance(v, ROOT.ObjectProxy): 
                if hasattr(v, 'value_type'): # Its stl.vector of ROOT.* object
                    list_result = []
                    for el in v:
                        dict_el = {}
                        for key in eval(v.value_type).__dict__.keys():
                            if key not in ['__doc__', '__module__']:
                                dict_el[key] = el.__getattribute__(key)
                        list_result.append(dict_el)
                        result[k] = list_result
                        
#                       result[k] = [key for key in eval(v.value_type).__dict__.keys()]
            elif isinstance(v, np.ndarray) and convert_numpy_arrays_to is not None:
                if convert_numpy_arrays_to == 'list':
                    result[k] = v.tolist()
                elif convert_numpy_arrays_to == 'bytes':
                    result[k] = v.tostring()
                else:
                    raise ValueError('convert_numpy_arrays_to must be "list" or "bytes"')
            else:
                result[k] = v
        return result
    
class ROOTModel(StrictModel):
    def __new__(cls, kwargs_dict=None, quick_init=False, **kwargs):
        treebuffer = PaxTreeBuffer()
        for name, attr in cls.get_attrs():
            treebuffer[name] = attr()

        treebuffer.__post_init__(kwargs_dict, quick_init, **kwargs) # We need TreeBuffer to have to_dict, to_json...   
        return treebuffer
        
EventModel = type("EventModel", (TreeModel,), {'__new__':ROOTModel.__new__})

INT_NAN = -99999

# from ROOT import TObject,TClonesArray,AddressOf, TRef
# from rootpy.tree.model import TreeModelMeta

class Peak(ROOT.Peak):
    n_hits = IntCol(0)
    hits_fraction_top = FloatCol(1.1)
    
class Event(EventModel):
    dataset_name = stl.string
    event_number = IntCol(0)
    n_channels = IntCol(INT_NAN)
    start_time = LongCol(0)
    stop_time = LongCol(0)
    sample_duration = IntCol(0)
    peaks = stl.vector(Peak)
#     is_channel_suspicious = BoolArrayCol(2)

In [372]:
peak1 = ROOT.Peak()
peak1.n_hits = 33
peak2 = ROOT.Peak()
peak2.n_hits = 44

In [373]:
# or don't populate
event = Event(dataset_name="dataset1", event_number=999)
event.peaks.push_back(peak1)
event.peaks.push_back(peak2)

In [377]:
event.to_dict()

{'dataset_name': 'dataset1',
 'event_number': 999,
 'n_channels': -99999,
 'peaks': [{'hits_fraction_top': 11.100000381469727, 'n_hits': 11},
  {'hits_fraction_top': 22.200000762939453, 'n_hits': 22}],
 'sample_duration': 0,
 'start_time': 0,
 'stop_time': 0}

In [378]:
event.peaks[0].hits_fraction_top = 11.1
event.peaks[0].n_hits = 11
event.peaks[1].hits_fraction_top = 22.2
event.peaks[1].n_hits = 22

In [379]:
event.to_dict()

{'dataset_name': 'dataset1',
 'event_number': 999,
 'n_channels': -99999,
 'peaks': [{'hits_fraction_top': 11.100000381469727, 'n_hits': 11},
  {'hits_fraction_top': 22.200000762939453, 'n_hits': 22}],
 'sample_duration': 0,
 'start_time': 0,
 'stop_time': 0}

In [367]:
peak1.something = 4 # NOPE, should not be allowed

In [383]:
event.peaks[1].n_hits

22

In [368]:
# usse the model to write a ROOT tree
from rootpy.tree import Tree, Ntuple, TreeModel, TreeChain
from rootpy.tree.treetypes import *
from rootpy.tree.tree import *
from rootpy.tree.model import *
import ROOT
import rootpy.stl as stl
from ROOT import Peak 
from rootpy.io import root_open, TemporaryFile
f = root_open("test_pax_new.root", "recreate")

tree_event = Tree("events", "events", model=Event)
tree_event.event_number = 55
tree_event.dataset_name = "sdsd"

peak1 = Peak()
peak1.n_hits = 22
peak1.hits_fraction_top = 22.22
peak2 = Peak()
peak2.n_hits = 33
peak2.hits_fraction_top = 33.33
tree_event.peaks.push_back(peak1)
tree_event.peaks.push_back(peak2)
tree_event.fill() 
# tree_event.write()

f.write()

for branch in tree_event.GetListOfBranches():
    print(branch)



<ROOT.TBranchElement object ("dataset_name") at 0x765fce0>
<ROOT.TBranchElement object ("peaks") at 0x7514560>
<ROOT.TBranch object ("event_number") at 0x74b9480>
<ROOT.TBranch object ("n_channels") at 0x7d97910>
<ROOT.TBranch object ("start_time") at 0x7599970>
<ROOT.TBranch object ("stop_time") at 0x75ba4c0>
<ROOT.TBranch object ("sample_duration") at 0x74fe020>
