In [None]:
import os
import re
from typing import List
from dataclasses import dataclass

import jinja2

SURVIVAL_ANALYSIS_LIST = [
    'normal',
    'skew_normal',
    'cauchy',
    'student_t',
    'logistic',
    'gumbel',
    'inv_chi_square',
    'scaled_inv_chi_square',
    'inv_gamma',
    'lognormal',
    'weibull',
    'frechet',
    'rayleigh',
    'wiener',
    'pareto',
    'pareto_type_2',
    'exponential',
    'exp_mod_normal',
    'double_exponential',
    'gamma',
    'digamma'
]

@dataclass
class StanFunction:
    py_name: str
    stan_name: str
    arguments: List[str]
    argument_types: List[str]
    return_type: str

def full_paths_in_dir(directory):
    return [os.path.join(directory, path) for path in os.listdir(directory)]


def rename_builtins(name):
    if name == 'lambda':
        return 'lambda_'
    else:
        return name


def extract_functions_from_page_text(text):
    function_tuples = re.findall(r'<!--\s*([^;]+)\s*;\s*([^;]+)\s*;\s*\(([^)]+)\);\s*-->', text)
    functions = []

    for function_tuple in function_tuples:
        (return_type, fun, args_text) = function_tuple
        
        typed_args = [a.strip().rsplit(' ', 1) for a in re.split(r"[,\|]", args_text)]

        if ['...'] not in typed_args and ['array['] not in typed_args and fun != 'map_rect':
            arg_types = [typed_arg[0] for typed_arg in typed_args]
            args = [typed_arg[1] for typed_arg in typed_args]
            args = [rename_builtins(arg) for arg in args]

            function = StanFunction(
                py_name=fun,
                stan_name=fun,
                arguments=args,
                argument_types=arg_types,
                return_type=return_type
            )

            functions.append(function)

    return functions
    


def extract_functions_from_page(path):
    with open(path) as f:
        contents = f.read()
    
    return extract_functions_from_page_text(contents)

pages = [path for path in full_paths_in_dir('stan-docs/functions-reference') if path.endswith('.Rmd')]

template = jinja2.Template("""
from py_expr import PyExpr, python_to_ast
from stan_ast import FunctionCall

class StanFunctionsLibrary:
    def _function(self, function_name, args):
        arguments_as_ast = [python_to_ast(arg) for arg in args]
        function_call_ast = FunctionCall(function_name, arguments_as_ast)
        self.functions[function_name] = self.functions.get(function_name, 0) + 1

        return PyExpr(self, function_call_ast)
              
{% for f in functions %}
    def {{ f.py_name }}({% for arg in ['self'] + f.arguments %}{{ arg }}{{ ", " if not loop.last else "" }}{% endfor %}):
        return self._function('{{ f.stan_name }}', [{% for arg in f.arguments %}{{ arg }}{{ ", " if not loop.last else "" }}{% endfor %}])          
{% endfor %}
""")

def to_camel_case(snake_str):
    return "".join(x.capitalize() for x in snake_str.lower().split("_"))

def maybe_add_sampling_distribution_for_function(functions, function):
    if function.stan_name.endswith('_lpdf') or function.stan_name.endswith('_lpmf'):
        if function.stan_name.endswith('_lpdf'):
            suffix = '_lpdf'
        elif function.stan_name.endswith('_lpmf'):
            suffix = '_lpmf'

        base_name = function.stan_name[:-5]
        dist_name = to_camel_case(base_name)

        f_sampling = StanFunction(
            py_name=base_name,
            stan_name=base_name,
            arguments=function.arguments[1:],
            argument_types=function.argument_types[1:],
            return_type=None
        )

        # Right censoring (for survival analysis)

        f_right_censored = StanFunction(
            py_name=dist_name + "_with_right_censoring",
            stan_name=base_name + "_with_right_censoring",
            arguments=function.arguments[1:],
            argument_types=function.argument_types[1:],
            return_type=None
        )

        f_right_censored_lpdf = StanFunction(
            py_name=base_name + "_with_right_censoring" + suffix,
            stan_name=base_name + "_with_right_censoring" + suffix,
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types + ['reals'],
            return_type='reals'
        )
        
        f_right_censored_lupdf = StanFunction(
            py_name=base_name + "_with_right_censoring" + suffix,
            stan_name=base_name + "_with_right_censoring" + suffix,
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types + ['reals'],
            return_type='reals'
        )

        f_right_censored_rng = StanFunction(
            stan_name=base_name + "_with_right_censoring_rng",
            py_name=base_name + "_with_right_censoring_rng",
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types[1:] + ['reals'],
            return_type='reals'
        )

        # Left censoring (for survival analysis)
        
        f_left_censored = StanFunction(
            py_name=dist_name + "_with_left_censoring",
            stan_name=base_name + "_with_left_censoring",
            arguments=function.arguments[1:] + ['event'],
            argument_types=function.argument_types[1:] + ['reals'],
            return_type=None
        )

        f_left_censored_lpdf = StanFunction(
            py_name=base_name + "_with_left_censoring" + suffix,
            stan_name=base_name + "_with_left_censoring" + suffix,
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types + ['reals', 'reals'],
            return_type='reals'
        )
        
        f_left_censored_lupdf = StanFunction(
            py_name=base_name + "_with_left_censoring" + suffix,
            stan_name=base_name + "_with_left_censoring" + suffix,
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types + ['reals', 'reals'],
            return_type='reals'
        )

        f_left_censored_rng = StanFunction(
            py_name=base_name + "_with_left_censoring_rng",
            stan_name=base_name + "_with_left_censoring_rng",
            arguments=function.arguments + ['event'],
            argument_types=function.argument_types[1:] + ['reals'],
            return_type='reals'
        )

        # Right and left censoring (for survival analysis)
        
        f_left_and_right_censored = StanFunction(
            py_name=dist_name + "_with_left_and_right_censoring",
            stan_name=base_name + "_with_left_and_right_censoring",
            arguments=function.arguments[1:] + ['event_left', 'event_right'],
            argument_types=function.argument_types[1:] + ['reals', 'reals'],
            return_type=None
        )

        f_left_and_right_censored_lpdf = StanFunction(
            py_name=base_name + "_with_left_and_right_censoring" + suffix,
            stan_name=base_name + "_with_left_and_right_censoring" + suffix,
            arguments=function.arguments + ['event_left', 'event_right'],
            argument_types=function.argument_types + ['reals', 'reals'],
            return_type='reals'
        )
        
        f_left_and_right_censored_lupdf = StanFunction(
            py_name=base_name + "_with_left_and_right_censoring" + suffix,
            stan_name=base_name + "_with_left_and_right_censoring" + suffix,
            arguments=function.arguments + ['event_left', 'event_right'],
            argument_types=function.argument_types + ['reals', 'reals'],
            return_type='reals'
        )

        f_left_and_right_censored_rng = StanFunction(
            py_name=base_name + "_with_left_and_right_censoring_rng",
            stan_name=base_name + "_with_left_and_right_censoring_rng",
            arguments=function.arguments[1:] + ['event_left', 'event_right'],
            argument_types=function.argument_types[1:] + ['reals', 'reals'],
            return_type='reals'
        )


        if function.stan_name in SURVIVAL_ANALYSIS_LIST:
            new_functions = [
                function, f_sampling,
                f_left_censored, f_left_censored_lpdf, f_left_censored_lupdf, f_left_censored_rng,
                f_right_censored, f_right_censored_lpdf, f_right_censored_lupdf, f_right_censored_rng,
                f_left_and_right_censored, f_left_and_right_censored_lpdf, f_left_and_right_censored_lupdf,
                f_left_and_right_censored_rng
            ]
        
        else:
            new_functions = [function, f_sampling]

        functions.extend(new_functions)

    else:
        functions.append(function)

all_functions = []
for page in pages:
    functions = extract_functions_from_page(page)
    all_functions.extend(functions)

all_functions = [function for function in all_functions if not function.stan_name.startswith('operator')]

functions = []
for function in all_functions:
    maybe_add_sampling_distribution_for_function(functions, function)

function_dict = dict()
for function in functions:
    if function.stan_name in function_dict:
        pass
    else:
        function_dict[function.stan_name] = function

functions = sorted(function_dict.values(), key=lambda f: f.stan_name.lower())

open('stan_functions_library.py', 'w').write(template.render(functions=functions))