# Dependencies

In [2]:
! pip install starknet-py

Looking in indexes: https://pypi.douban.com/simple/


# Converter

In [3]:
import numpy as np
from starknet_py.hash.utils import pedersen_hash

def float_to_fixed_point(value, integer_bits, fractional_bits):
    scale_factor = 2**fractional_bits
    return f"FP{integer_bits}x{fractional_bits} {{ mag: {abs(int(value * scale_factor))}, sign: {str(value < 0).lower()} }}"

def cairo_array(arr, type_name="usize", fixed_point_type="FP16x16"):
    if type_name.startswith("FP"):
        integer_bits, fractional_bits = map(int, fixed_point_type[2:].split('x'))
        return ", ".join([float_to_fixed_point(x, integer_bits, fractional_bits) for x in arr])
    return ", ".join(map(str, arr))

class TreeEnsembleAttributes:
    def __init__(self):
        self._names = []

    def add(self, name, value):
        if not name.endswith("_as_tensor"):
            self._names.append(name)
        if isinstance(value, list):
            dtype = np.float32 if name in {"base_values", "class_weights", "nodes_values", "nodes_hitrates"} else None
            value = np.array(value, dtype=dtype)
        setattr(self, name, value)

class TreeEnsemble:
    def __init__(self, fixed_point_type="FP16x16", **kwargs):
        self.atts = TreeEnsembleAttributes()
        self.fixed_point_type = fixed_point_type
        for name, value in kwargs.items():
            self.atts.add(name, value)

        self.tree_ids = sorted(set(self.atts.nodes_treeids))
        self._initialize_indices()

    def _initialize_indices(self):
        self.root_index = {tid: len(self.atts.nodes_treeids) for tid in self.tree_ids}
        for index, tree_id in enumerate(self.atts.nodes_treeids):
            self.root_index[tree_id] = min(self.root_index[tree_id], index)
        self.node_index = {(tid, nid): i for i, (tid, nid) in enumerate(zip(self.atts.nodes_treeids, self.atts.nodes_nodeids))}

    def generate_cairo_code(self):
        tree_ids_cairo = f"let tree_ids: Span<usize> = array![{cairo_array(self.tree_ids)}].span();"
        root_index_cairo = self._generate_root_index_cairo()
        node_index_cairo = self._generate_node_index_cairo()

        return f"{tree_ids_cairo}\n{root_index_cairo}\n{node_index_cairo}"

    def _generate_root_index_cairo(self):
        root_index_lines = [f"    root_index.insert({tid}, {self.root_index[tid]});" for tid in self.tree_ids]
        return "let mut root_index: Felt252Dict<usize> = Default::default();\n" + "\n".join(root_index_lines)

    def _generate_node_index_cairo(self):
        node_index_lines = [f"    node_index.insert({pedersen_hash(int(tid), int(nid))}, {index});"
                            for (tid, nid), index in self.node_index.items()]
        return "let mut node_index: Felt252Dict<usize> = Default::default();\n" + "\n".join(node_index_lines)

def generate_full_cairo_code(params, fixed_point_type="FP16x16"):
    ensemble = TreeEnsemble(fixed_point_type=fixed_point_type, **params)
    tree_specific_code = ensemble.generate_cairo_code()

    # Check for base_values content
    if params['base_values']:
        base_values_cairo = f"let base_values: Option<Span<{fixed_point_type}>> = Option::Some(array![{cairo_array(params['base_values'], fixed_point_type, fixed_point_type)}].span());"
    else:
        base_values_cairo = f"let base_values: Option<Span<{fixed_point_type}>> = Option::None;"

    return f"""
use orion::numbers::{fixed_point_type};
use orion::operators::tensor::{{Tensor, TensorTrait, {fixed_point_type}Tensor, U32Tensor}};
use orion::operators::ml::tree_ensemble::core::{{NODE_MODES, TreeEnsembleAttributes, TreeEnsemble}};
use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{{TreeEnsembleClassifier, POST_TRANSFORM, TreeEnsembleClassifierTrait}};
use orion::operators::matrix::{{MutMatrix, MutMatrixImpl}};

fn main(X: Tensor<{fixed_point_type}>) {{
    let class_ids: Span<usize> = array![{cairo_array(params['class_ids'])}].span();
    let class_nodeids: Span<usize> = array![{cairo_array(params['class_nodeids'])}].span();
    let class_treeids: Span<usize> = array![{cairo_array(params['class_treeids'])}].span();
    let class_weights: Span<{fixed_point_type}> = array![{cairo_array(params['class_weights'], fixed_point_type, fixed_point_type)}].span();
    let classlabels: Span<usize> = array![{cairo_array(params['classlabels'])}].span();
    let nodes_falsenodeids: Span<usize> = array![{cairo_array(params['nodes_falsenodeids'])}].span();
    let nodes_featureids: Span<usize> = array![{cairo_array(params['nodes_featureids'])}].span();
    let nodes_missing_value_tracks_true: Span<usize> = array![{cairo_array(params['nodes_missing_value_tracks_true'])}].span();
    let nodes_modes: Span<NODE_MODES> = array![{', '.join(['NODE_MODES::' + x for x in params['nodes_modes']])}].span();
    let nodes_nodeids: Span<usize> = array![{cairo_array(params['nodes_nodeids'])}].span();
    let nodes_treeids: Span<usize> = array![{cairo_array(params['nodes_treeids'])}].span();
    let nodes_truenodeids: Span<usize> = array![{cairo_array(params['nodes_truenodeids'])}].span();
    let nodes_values: Span<{fixed_point_type}> = array![{cairo_array(params['nodes_values'], fixed_point_type, fixed_point_type)}].span();
    {base_values_cairo}
    let post_transform = POST_TRANSFORM::{params['post_transform']};

    {tree_specific_code}

    let atts = TreeEnsembleAttributes {{
        nodes_falsenodeids,
        nodes_featureids,
        nodes_missing_value_tracks_true,
        nodes_modes,
        nodes_nodeids,
        nodes_treeids,
        nodes_truenodeids,
        nodes_values
    }};

    let mut ensemble: TreeEnsemble<{fixed_point_type}> = TreeEnsemble {{
        atts, tree_ids, root_index, node_index
    }};

    let mut classifier: TreeEnsembleClassifier<{fixed_point_type}> = TreeEnsembleClassifier {{
        ensemble,
        class_ids,
        class_nodeids,
        class_treeids,
        class_weights,
        classlabels,
        base_values,
        post_transform
    }};

    let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X);
}}
    """

# Usage

In [5]:
params = {
    "class_ids": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3],
    "class_nodeids": [2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 4, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 4, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 4, 5, 6],
    "class_treeids": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 14, 14, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17, 18, 18, 18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 21, 22, 22, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 38, 38, 39, 39, 39, 39],
    "class_weights": [0.529411792755127, -0.15789474546909332, -0.19223301112651825, 0.5596638917922974, -0.14482758939266205, -0.1902439147233963, -0.18987342715263367, -0.14482758939266205, 0.5186440944671631, 0.47368425130844116, -0.16363637149333954, -0.19304348528385162, -0.15789474546909332, 0.4909090995788574, 0.3363681435585022, -0.1447599232196808, -0.17823059856891632, 0.3431580364704132, -0.13070012629032135, -0.17663049697875977, -0.17551077902317047, -0.13070012629032135, 0.3307386338710785, 0.31947022676467896, -0.15036650002002716, -0.17892931401729584, -0.1447599232196808, 0.32325631380081177, 0.25804343819618225, -0.13300007581710815, -0.16763946413993835, 0.26184943318367004, -0.11850599944591522, -0.16595004498958588, -0.16456960141658783, -0.11850599944591522, 0.25308361649513245, 0.2477485090494156, -0.1386074423789978, -0.16838385164737701, -0.13300007581710815, 0.2489330917596817, 0.21561512351036072, -0.12224949151277542, -0.15908876061439514, 0.21918702125549316, -0.10746564716100693, -0.15714383125305176, -0.15556851029396057, -0.10746564716100693, 0.2106853425502777, 0.20655255019664764, -0.12789729237556458, -0.15995457768440247, -0.12224949151277542, 0.2071279138326645, 0.18905676901340485, -0.11225060373544693, -0.15180641412734985, 0.19316790997982025, -0.09727886319160461, -0.1495160609483719, -0.14771580696105957, -0.09727886319160461, 0.1839185208082199, 0.17971396446228027, -0.11792486160993576, -0.15284207463264465, -0.11225060373544693, 0.18020886182785034, 0.1706760823726654, -0.10285065323114395, -0.14523299038410187, 0.17563529312610626, -0.0878157690167427, -0.14254821836948395, -0.14045031368732452, -0.0878157690167427, 0.16516728699207306, 0.16054150462150574, -0.10849703848361969, -0.1464727818965912, -0.10285065323114395, 0.1610797941684723, 0.1568574607372284, -0.09398943930864334, -0.13895311951637268, 0.16103516519069672, 0.013162195682525635, -0.13586124777793884, -0.13336969912052155, 0.052140023559331894, 0.11902859807014465, -0.14041899144649506, -0.09398943930864334, 0.14634716510772705, 0.14717331528663635, -0.0889815092086792, -0.1330275535583496, 0.15070605278015137, 0.011658085510134697, -0.12982763350009918, -0.12655934691429138, 0.046255260705947876, 0.10517136007547379, -0.13469882309436798, -0.0889815092086792, 0.1374071091413498, 0.13890844583511353, -0.0842740461230278, -0.12701822817325592, 0.1418359875679016, 0.010371334850788116, -0.12379498034715652, -0.11960335820913315, 0.10921790450811386, 0.07832330465316772, 0.10457883030176163, -0.09656156599521637, -0.12887467443943024, -0.0842740461230278, 0.12980103492736816, 0.10798902064561844, 0.03683459386229515, -0.12080138921737671, 0.036573074758052826, 0.1315067857503891, -0.11818119138479233, -0.1146732047200203, -0.026138588786125183, 0.12369491904973984, 0.11829245090484619, -0.08825768530368805, -0.12276258319616318, 0.03258060663938522, 0.0030335404444485903, 0.08653463423252106],
    "classlabels": [0, 1, 2, 3],
    "nodes_falsenodeids": [4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 4, 0, 0, 4, 3, 0, 0, 0, 4, 3, 0, 0, 0, 2, 0, 4, 0, 6, 0, 8, 0, 0, 2, 0, 6, 5, 0, 0, 0],
    "nodes_featureids": [3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 5, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 5, 0, 0, 1, 0, 3, 0, 0, 3, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 3, 0, 1, 0, 0, 1, 0, 5, 0, 0, 0, 0],
    "nodes_hitrates": [],
    "nodes_missing_value_tracks_true": [1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0],
    "nodes_modes": ['BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'LEAF', 'LEAF', 'BRANCH_LT', 'LEAF', 'BRANCH_LT', 'BRANCH_LT', 'LEAF', 'LEAF', 'LEAF'],
    "nodes_nodeids": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7, 8, 0, 1, 2, 3, 4, 5, 6],
    "nodes_treeids": [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39],
    "nodes_truenodeids": [1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 0, 0, 1, 2, 0, 0, 0, 1, 2, 0, 0, 0, 1, 0, 3, 0, 5, 0, 7, 0, 0, 1, 0, 3, 4, 0, 0, 0],
    "nodes_values": [0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.574999988079071, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.3916666507720947, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.3916666507720947, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.5, 0.3916666507720947, 0.0, 0.0, 0.0, 0.40833333134651184, 0.0, 0.7250000238418579, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.5666666626930237, 0.0, 0.0, 0.0, 0.40833333134651184, 0.5, 0.0, 0.0, 0.0, 0.5, 0.0, 0.42500001192092896, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0, 0.5, 0.8416666984558105, 0.0, 0.0, 0.0],
    "base_values": [0.5, 0.5, 0.5, 0.5],
    "post_transform": "SOFTMAX",
}

fixed_point_type = "FP64x64"  # You can change this to FP8x23, FP32x32, FP64x64
full_cairo_code = generate_full_cairo_code(params, fixed_point_type)
print(full_cairo_code)


use orion::numbers::FP64x64;
use orion::operators::tensor::{Tensor, TensorTrait, FP64x64Tensor, U32Tensor};
use orion::operators::ml::tree_ensemble::core::{NODE_MODES, TreeEnsembleAttributes, TreeEnsemble};
use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{TreeEnsembleClassifier, POST_TRANSFORM, TreeEnsembleClassifierTrait};
use orion::operators::matrix::{MutMatrix, MutMatrixImpl};

fn main(X: Tensor<FP64x64>) {
    let class_ids: Span<usize> = array![0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3].span();
    let class_nodeids: Span<usize> = array![2, 3, 4, 2, 3, 4, 1, 3, 5, 7, 8, 1, 3, 4, 2, 3, 4, 2, 3, 4, 1,