# Toki - Backend Example

In [1]:
from __future__ import annotations
from typing import NewType, Union
import dataclasses

import metadsl
import metadsl_rewrite
import metadsl_core
import pandas as pd

from toki import datatypes as dtypes
from toki.types import Expr
from toki.backend import Backend, BackendTranslator

In [2]:
class RegisterStrategy:
    _inner_strategy = []
    
    def register(self, fn):
        self._inner_strategy.append(fn)
        
    def get_rules(self):
        inner_strategy = metadsl_rewrite.StrategySequence(*self._inner_strategy)
        return metadsl_rewrite.StrategyRepeat(
            metadsl_rewrite.StrategyFold(
                inner_strategy
            )
        )
    
strategy = RegisterStrategy()

def compile(expr: Expr, strategies):
    return metadsl_rewrite.execute(expr, strategies)


def translate(expr: Expr, strategies):
    return compile(expr, strategies)


# numeric binary operation

BIN_OPS = {
    'add': '{} + {}', 
    'truediv': '{} / {}',
    'floordiv': '{} // {}',
    'mod': '{} % {}',
    'mul': '{} * {}',
    'sub': '{} - {}',
    'pow': '{} ** {}',
}

FN_MAP = {}

def op_num_builder(op, tp_x, tp_y):
    dunder_op = '__{}__'.format(op)
    _tpx = 'i' if tp_x.startswith('int') else 'f' if tp_x.startswith('float') else None
    _tpy = 'i' if tp_y.startswith('int') else 'f' if tp_y.startswith('float') else None
    
    if _tpx is None:
        raise Exception('X types not recognized.')
        
    if _tpy is None:
        raise Exception('X types not recognized.')
    
    def __fn (x: Union[int, float], y: Union[int, float]) -> str: 
        _dtype_x = getattr(dtypes, tp_x)
        _dtype_y = getattr(dtypes, tp_y)

        return (
            getattr(_dtype_x(x), dunder_op)(_dtype_y(y)),
            lambda: BIN_OPS[op].format(x, y),
        )
    
    if _tpx == 'i':
        if _tpy == 'i':
            def _fn (x: int, y: int) -> str: 
                return __fn(x, y)
        else:
            def _fn (x: int, y: float) -> str: 
                return __fn(x, y)
    else:
        if _tpy == 'i':
            def _fn (x: float, y: int) -> str: 
                return __fn(x, y)
        else:
            def _fn (x: float, y: float) -> str: 
                return __fn(x, y)
        
    
    _fn.__name__ = dunder_op
    return _fn

int_types = ('int8', 'int16', 'int32', 'int64')
float_types = ('float16', 'float16', 'float64')

number_types = int_types + float_types

for tp_x in number_types:
    for tp_y in number_types:
        for op in BIN_OPS:
            FN_MAP[op] = op_num_builder(op, tp_x, tp_y)
            strategy.register(
                metadsl_rewrite.rule(
                    FN_MAP[op]
                )
            )
        


class TokiExample(Backend):
    # translator: BackendTranslator = TokiExampleTranslator
    
    def __init__(self, strategy):
        self.strategy = strategy

    def connect(self) -> None:
        ...

    def compile(self, expr) -> str:
        return translate(expr, self.strategy.get_rules())

    def execute(self, expr) -> pd.DataFrame:
        request_str = self.compile(expr)
        print(request_str)
        return pd.DataFrame()


con = TokiExample(strategy)

for tp_x in number_types:
    _dtype_x = getattr(dtypes, tp_x)
    _dtype_primitive_x = int if tp_x.startswith('int') else float
    x = _dtype_primitive_x(1)
    x_expr = _dtype_x(x)
    
    for tp_y in number_types:
        # print('tp_x:{}, tp_y:{}'.format(tp_x, tp_y))
        _dtype_y = getattr(dtypes, tp_y)
        _dtype_primitive_y = int if tp_y.startswith('int') else float
        
        y = _dtype_primitive_y(2)
        y_expr = _dtype_y(y)
        
        assert con.compile(x_expr + y_expr) == '{} + {}'.format(x, y)
        assert con.compile(x_expr - y_expr) == '{} - {}'.format(x, y)
        assert con.compile(x_expr * y_expr) == '{} * {}'.format(x, y)
        assert con.compile(x_expr / y_expr) == '{} / {}'.format(x, y)
        assert con.compile(x_expr // y_expr) == '{} // {}'.format(x, y)
        assert con.compile(x_expr ** y_expr) == '{} ** {}'.format(x, y)
        assert con.compile(x_expr % y_expr) == '{} % {}'.format(x, y)
print('[II] Done!')

[II] Done!
