# Dependencies

In [3]:
! pip install starknet-py

Looking in indexes: https://pypi.douban.com/simple/


# Converter

In [2]:
import numpy as np
from starknet_py.hash.utils import pedersen_hash

def float_to_fixed_point(value, integer_bits, fractional_bits):
    scale_factor = 2**fractional_bits
    return f"FP{integer_bits}x{fractional_bits} {{ mag: {abs(int(value * scale_factor))}, sign: {str(value < 0).lower()} }}"

def cairo_array(arr, type_name="usize", fixed_point_type="FP16x16"):
    if type_name.startswith("FP"):
        integer_bits, fractional_bits = map(int, fixed_point_type[2:].split('x'))
        return ", ".join([float_to_fixed_point(x, integer_bits, fractional_bits) for x in arr])
    return ", ".join(map(str, arr))

class TreeEnsembleAttributes:
    def __init__(self):
        self._names = []

    def add(self, name, value):
        if not name.endswith("_as_tensor"):
            self._names.append(name)
        if isinstance(value, list):
            dtype = np.float32 if name in {"base_values", "class_weights", "nodes_values", "nodes_hitrates"} else None
            value = np.array(value, dtype=dtype)
        setattr(self, name, value)

class TreeEnsemble:
    def __init__(self, fixed_point_type="FP16x16", **kwargs):
        self.atts = TreeEnsembleAttributes()
        self.fixed_point_type = fixed_point_type
        for name, value in kwargs.items():
            self.atts.add(name, value)

        self.tree_ids = sorted(set(self.atts.nodes_treeids))
        self._initialize_indices()

    def _initialize_indices(self):
        self.root_index = {tid: len(self.atts.nodes_treeids) for tid in self.tree_ids}
        for index, tree_id in enumerate(self.atts.nodes_treeids):
            self.root_index[tree_id] = min(self.root_index[tree_id], index)
        self.node_index = {(tid, nid): i for i, (tid, nid) in enumerate(zip(self.atts.nodes_treeids, self.atts.nodes_nodeids))}

    def generate_cairo_code(self):
        tree_ids_cairo = f"let tree_ids: Span<usize> = array![{cairo_array(self.tree_ids)}].span();"
        root_index_cairo = self._generate_root_index_cairo()
        node_index_cairo = self._generate_node_index_cairo()

        return f"{tree_ids_cairo}\n{root_index_cairo}\n{node_index_cairo}"

    def _generate_root_index_cairo(self):
        root_index_lines = [f"    root_index.insert({tid}, {self.root_index[tid]});" for tid in self.tree_ids]
        return "let mut root_index: Felt252Dict<usize> = Default::default();\n" + "\n".join(root_index_lines)

    def _generate_node_index_cairo(self):
        node_index_lines = [f"    node_index.insert({pedersen_hash(int(tid), int(nid))}, {index});"
                            for (tid, nid), index in self.node_index.items()]
        return "let mut node_index: Felt252Dict<usize> = Default::default();\n" + "\n".join(node_index_lines)

def generate_full_cairo_code(params, fixed_point_type="FP16x16"):
    ensemble = TreeEnsemble(fixed_point_type=fixed_point_type, **params)
    tree_specific_code = ensemble.generate_cairo_code()

    # Check for base_values content
    if params['base_values']:
        base_values_cairo = f"let base_values: Option<Span<{fixed_point_type}>> = Option::Some(array![{cairo_array(params['base_values'], fixed_point_type, fixed_point_type)}].span());"
    else:
        base_values_cairo = f"let base_values: Option<Span<{fixed_point_type}>> = Option::None;"

    return f"""
use orion::numbers::{fixed_point_type};
use orion::operators::tensor::{{Tensor, TensorTrait, {fixed_point_type}Tensor, U32Tensor}};
use orion::operators::ml::tree_ensemble::core::{{NODE_MODES, TreeEnsembleAttributes, TreeEnsemble}};
use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{{TreeEnsembleClassifier, POST_TRANSFORM, TreeEnsembleClassifierTrait}};
use orion::operators::matrix::{{MutMatrix, MutMatrixImpl}};

fn main(X: Tensor<{fixed_point_type}>) {{
    let class_ids: Span<usize> = array![{cairo_array(params['class_ids'])}].span();
    let class_nodeids: Span<usize> = array![{cairo_array(params['class_nodeids'])}].span();
    let class_treeids: Span<usize> = array![{cairo_array(params['class_treeids'])}].span();
    let class_weights: Span<{fixed_point_type}> = array![{cairo_array(params['class_weights'], fixed_point_type, fixed_point_type)}].span();
    let classlabels: Span<usize> = array![{cairo_array(params['classlabels'])}].span();
    let nodes_falsenodeids: Span<usize> = array![{cairo_array(params['nodes_falsenodeids'])}].span();
    let nodes_featureids: Span<usize> = array![{cairo_array(params['nodes_featureids'])}].span();
    let nodes_missing_value_tracks_true: Span<usize> = array![{cairo_array(params['nodes_missing_value_tracks_true'])}].span();
    let nodes_modes: Span<NODE_MODES> = array![{', '.join(['NODE_MODES::' + x for x in params['nodes_modes']])}].span();
    let nodes_nodeids: Span<usize> = array![{cairo_array(params['nodes_nodeids'])}].span();
    let nodes_treeids: Span<usize> = array![{cairo_array(params['nodes_treeids'])}].span();
    let nodes_truenodeids: Span<usize> = array![{cairo_array(params['nodes_truenodeids'])}].span();
    let nodes_values: Span<{fixed_point_type}> = array![{cairo_array(params['nodes_values'], fixed_point_type, fixed_point_type)}].span();
    {base_values_cairo}
    let post_transform = POST_TRANSFORM::{params['post_transform']};

    {tree_specific_code}

    let atts = TreeEnsembleAttributes {{
        nodes_falsenodeids,
        nodes_featureids,
        nodes_missing_value_tracks_true,
        nodes_modes,
        nodes_nodeids,
        nodes_treeids,
        nodes_truenodeids,
        nodes_values
    }};

    let mut ensemble: TreeEnsemble<{fixed_point_type}> = TreeEnsemble {{
        atts, tree_ids, root_index, node_index
    }};

    let mut classifier: TreeEnsembleClassifier<{fixed_point_type}> = TreeEnsembleClassifier {{
        ensemble,
        class_ids,
        class_nodeids,
        class_treeids,
        class_weights,
        classlabels,
        base_values,
        post_transform
    }};

    let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X);
}}
    """

# Usage

In [4]:
params = {
    "class_ids": [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
    "class_nodeids": [2, 2, 2, 2, 3, 3, 3, 3, 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, 10, 10, 10, 10],
    "class_treeids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    "class_weights": [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
    "classlabels": [0, 1, 2, 3],
    "nodes_falsenodeids": [4, 3, 0, 0, 6, 0, 8, 0, 10, 0, 0],
    "nodes_featureids": [2, 0, 0, 0, 0, 0, 4, 0, 3, 0, 0],
    "nodes_hitrates": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
    "nodes_missing_value_tracks_true": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    "nodes_modes": ['BRANCH_LEQ', 'BRANCH_LEQ', 'LEAF', 'LEAF', 'BRANCH_LEQ', 'LEAF', 'BRANCH_LEQ', 'LEAF', 'BRANCH_LEQ', 'LEAF', 'LEAF'],
    "nodes_nodeids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    "nodes_treeids": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    "nodes_truenodeids": [1, 2, 0, 0, 5, 0, 7, 0, 9, 0, 0],
    "nodes_values": [0.5, 0.574999988079071, 0.0, 0.0, 0.40833333134651184, 0.0, 0.5, 0.0, 0.5, 0.0, 0.0],
    "base_values": [],
    "post_transform": "NONE",
}

fixed_point_type = "FP16x16"  # You can change this to FP8x23, FP32x32, FP64x64
full_cairo_code = generate_full_cairo_code(params, fixed_point_type)
print(full_cairo_code)


use orion::numbers::FP16x16;
use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor};
use orion::operators::ml::tree_ensemble::core::{NODE_MODES, TreeEnsembleAttributes, TreeEnsemble};
use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{TreeEnsembleClassifier, POST_TRANSFORM, TreeEnsembleClassifierTrait};
use orion::operators::matrix::{MutMatrix, MutMatrixImpl};

fn main(X: Tensor<FP16x16>) {
    let class_ids: Span<usize> = array![0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3].span();
    let class_nodeids: Span<usize> = array![2, 2, 2, 2, 3, 3, 3, 3, 5, 5, 5, 5, 7, 7, 7, 7, 9, 9, 9, 9, 10, 10, 10, 10].span();
    let class_treeids: Span<usize> = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span();
    let class_weights: Span<FP16x16> = array![FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { m