<a href="https://colab.research.google.com/github/scardenol/DecisionTrees/blob/main/symbolic_taylor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Functions

In [125]:
import sympy as sp
from sympy import symbols, init_printing, Function, sympify, Derivative, simplify, ordered
from sympy import factorial, Matrix, prod
init_printing(use_latex='mathjax') # Printing preferences
import itertools


def Taylor_polynomial_sympy(function_expression: 'symbolic',
                            variable_list: list,
                            evaluation_point: list,
                            degree: int) -> 'symbolic':

  n_var = len(variable_list)
  point_coordinates = [(i, j) for i, j in (zip(variable_list, evaluation_point))]  # list of tuples with variables and their evaluation_point coordinates, to later perform substitution

  deriv_orders = list(itertools.product(range(degree + 1), repeat=n_var))  # list with exponentials of the partial derivatives
  deriv_orders = [deriv_orders[i] for i in range(len(deriv_orders)) if sum(deriv_orders[i]) <= degree]  # Discarding some higher-order terms
  n_terms = len(deriv_orders)
  deriv_orders_as_input = [list(sum(list(zip(variable_list, deriv_orders[i])), ())) for i in range(n_terms)]  # Individual degree of each partial derivative, of each term

  polynomial = 0
  for i in range(n_terms):
    partial_derivatives_at_point = function_expression.diff(*deriv_orders_as_input[i]).subs(point_coordinates)  # e.g. df/(dx*dy**2)
    denominator = prod([factorial(j) for j in deriv_orders[i]])  # e.g. (1! * 2!)
    distances_powered = prod([(Matrix(variable_list) - Matrix(evaluation_point))[j] ** deriv_orders[i][j] for j in range(n_var)])  # e.g. (x-x0)*(y-y0)**2
    polynomial += partial_derivatives_at_point / denominator * distances_powered
  return polynomial


def Model_taylor_sympy(taylor_serie: 'symbolic') -> '(symbolic, symbolic)':

  # Convert Serie to list
  S = list(sympify((sympify(taylor_serie, evaluate=False)).args))
  M = S.copy() # Copy the list to use for the model

  # Check for derivatives
  # Substitue derivatives from Serie List with parameters (cleaner than from Serie)
  d = [list(i.atoms(Derivative)) for i in M]
  d_unpacked = [x for l in d for x in l] # this unpacks it but gets messy if a list had more than 1 element
  
  # if d_unpacked is empty it means there are no derivatives in the serie
  A = [];

  if len(d_unpacked) != 0:  # if there are derivatives
    indexes = [idx for idx in range(len(d)) if len(d[idx])] # Get indexes

    for i in indexes:
      A_aux = sympify(['a' + str(_) for _ in range(len(A), len(A) + 1)])
      A += A_aux
      M[i] = simplify(M[i].subs(list(ordered(d[i]))[0], A_aux[0]))
    
    # Clean the f(c) termn at the end
    A_aux = sympify(['a' + str(_) for _ in range(len(A), len(A) + 1)])
    A += A_aux
    M[-1] = A_aux[0]
  
  return M, A

# Implementación

In [124]:
# Model F(x1, x2, ..., xn) where n is n_vars

n_vars = 2 # Num of variables = num of columns
vars = sp.symbols('x0:'+str(n_vars)) # List of variables
center = [0]*len(vars)

f = sp.Function('f')
fun = f(sum(vars)) # f(x1+x2+...+xn)
deg = 4 # max degree of taylor poly.
F = Taylor_polynomial_sympy(fun, vars, center, deg)

# Generate the model
M, A = Model_taylor_sympy(F)
model = sum(M)
model

                      2            2   2            3         3               
                a₁₀⋅x₀ ⋅x₁   a₁₁⋅x₀ ⋅x₁    a₁₂⋅x₀⋅x₁    a₁₃⋅x₀ ⋅x₁         a₂⋅
a₀⋅x₀ + a₁⋅x₁ + ────────── + ─────────── + ────────── + ────────── + a₁₄ + ───
                    2             4            6            6                2

  2        2        3        3        4        4                      2
x₀    a₃⋅x₁    a₄⋅x₀    a₅⋅x₁    a₆⋅x₀    a₇⋅x₁               a₉⋅x₀⋅x₁ 
─── + ────── + ────── + ────── + ────── + ────── + a₈⋅x₀⋅x₁ + ─────────
        2        6        6        24       24                    2    

In [121]:
par = list(vars) + A
display(par) # Order of function parameters
fun = sp.lambdify(par, model, 'numpy')

[x₀, x₁, a₀, a₁, a₂, a₃, a₄, a₅, a₆, a₇, a₈, a₉, a₁₀, a₁₁, a₁₂, a₁₃, a₁₄]