In [50]:
import numpy as np
import sympy as sp
from sympy import Symbol
from IPython.display import display



sp.init_printing(use_latex='mathjax')



In [51]:
def partition1(max_range, S):
    max_range = np.asarray(max_range, dtype = int)
    a = np.indices(max_range + 1)
    b = a.sum(axis = 0) <= S
    return (a[:,b].T)

poly_basis = partition1(np.array([1,2,3,4]), 2)
print(poly_basis)

[[0 0 0 0]
 [0 0 0 1]
 [0 0 0 2]
 [0 0 1 0]
 [0 0 1 1]
 [0 0 2 0]
 [0 1 0 0]
 [0 1 0 1]
 [0 1 1 0]
 [0 2 0 0]
 [1 0 0 0]
 [1 0 0 1]
 [1 0 1 0]
 [1 1 0 0]]


In [52]:
z = Symbol('z')
dz = Symbol('zd')
rot = Symbol('thetad')
drot = Symbol('h')
action = Symbol('a')

state = [z, dz, rot, drot]
display(state)

extra_basis = np.array([])


for p in poly_basis:
    curr_basis = (state[0]**p[0]) * (state[1]**p[1]) * (state[2]**p[2]) * (state[3]**p[3]) * action
    extra_basis = np.append(extra_basis, curr_basis)
    
basis = np.hstack((state, extra_basis))



basis_sp = sp.Matrix(basis)
state_sp = sp.Matrix(state)
display(basis_sp)
display(state_sp)

dbasis_dx = basis_sp.jacobian(state_sp)
display(dbasis_dx)

dbasis_du = basis_sp.jacobian(sp.Matrix([action]))
display(dbasis_du)

[z, zd, thetad, h]

⎡     z     ⎤
⎢           ⎥
⎢    zd     ⎥
⎢           ⎥
⎢  thetad   ⎥
⎢           ⎥
⎢     h     ⎥
⎢           ⎥
⎢     a     ⎥
⎢           ⎥
⎢    a⋅h    ⎥
⎢           ⎥
⎢      2    ⎥
⎢   a⋅h     ⎥
⎢           ⎥
⎢ a⋅thetad  ⎥
⎢           ⎥
⎢a⋅h⋅thetad ⎥
⎢           ⎥
⎢         2 ⎥
⎢ a⋅thetad  ⎥
⎢           ⎥
⎢   a⋅zd    ⎥
⎢           ⎥
⎢  a⋅h⋅zd   ⎥
⎢           ⎥
⎢a⋅thetad⋅zd⎥
⎢           ⎥
⎢       2   ⎥
⎢   a⋅zd    ⎥
⎢           ⎥
⎢    a⋅z    ⎥
⎢           ⎥
⎢   a⋅h⋅z   ⎥
⎢           ⎥
⎢a⋅thetad⋅z ⎥
⎢           ⎥
⎣  a⋅z⋅zd   ⎦

⎡  z   ⎤
⎢      ⎥
⎢  zd  ⎥
⎢      ⎥
⎢thetad⎥
⎢      ⎥
⎣  h   ⎦

⎡   1         0          0          0    ⎤
⎢                                        ⎥
⎢   0         1          0          0    ⎥
⎢                                        ⎥
⎢   0         0          1          0    ⎥
⎢                                        ⎥
⎢   0         0          0          1    ⎥
⎢                                        ⎥
⎢   0         0          0          0    ⎥
⎢                                        ⎥
⎢   0         0          0          a    ⎥
⎢                                        ⎥
⎢   0         0          0        2⋅a⋅h  ⎥
⎢                                        ⎥
⎢   0         0          a          0    ⎥
⎢                                        ⎥
⎢   0         0         a⋅h      a⋅thetad⎥
⎢                                        ⎥
⎢   0         0      2⋅a⋅thetad     0    ⎥
⎢                                        ⎥
⎢   0         a          0          0    ⎥
⎢                                        ⎥
⎢   0        a⋅h         0         a⋅zd  ⎥
⎢          

⎡    0    ⎤
⎢         ⎥
⎢    0    ⎥
⎢         ⎥
⎢    0    ⎥
⎢         ⎥
⎢    0    ⎥
⎢         ⎥
⎢    1    ⎥
⎢         ⎥
⎢    h    ⎥
⎢         ⎥
⎢    2    ⎥
⎢   h     ⎥
⎢         ⎥
⎢ thetad  ⎥
⎢         ⎥
⎢h⋅thetad ⎥
⎢         ⎥
⎢       2 ⎥
⎢ thetad  ⎥
⎢         ⎥
⎢   zd    ⎥
⎢         ⎥
⎢  h⋅zd   ⎥
⎢         ⎥
⎢thetad⋅zd⎥
⎢         ⎥
⎢     2   ⎥
⎢   zd    ⎥
⎢         ⎥
⎢    z    ⎥
⎢         ⎥
⎢   h⋅z   ⎥
⎢         ⎥
⎢thetad⋅z ⎥
⎢         ⎥
⎣  z⋅zd   ⎦

In [53]:
from sympy.printing import ccode
def cpp_generator(mat, name, label):
    update_cpp = '\nvoid {0}::{1}_update()'.format(label, name) + '{'
    for n in range(mat.shape[1]):
        for m in range(mat.shape[0]):
            expr = mat[m, n]
            symbs = expr.free_symbols
            c = ccode(expr)
            update_cpp += '\n{0}({1}, {2}) = {3};'.format(name, m, n, c)
    update_cpp += '\n};'
    return update_cpp

In [54]:

mat = basis_sp
name = 'extra_basis'
code = cpp_generator(mat, name, name)    # Print C++ code
print(code)


void extra_basis::extra_basis_update(){
extra_basis(0, 0) = z;
extra_basis(1, 0) = zd;
extra_basis(2, 0) = thetad;
extra_basis(3, 0) = h;
extra_basis(4, 0) = a;
extra_basis(5, 0) = a*h;
extra_basis(6, 0) = a*pow(h, 2);
extra_basis(7, 0) = a*thetad;
extra_basis(8, 0) = a*h*thetad;
extra_basis(9, 0) = a*pow(thetad, 2);
extra_basis(10, 0) = a*zd;
extra_basis(11, 0) = a*h*zd;
extra_basis(12, 0) = a*thetad*zd;
extra_basis(13, 0) = a*pow(zd, 2);
extra_basis(14, 0) = a*z;
extra_basis(15, 0) = a*h*z;
extra_basis(16, 0) = a*thetad*z;
extra_basis(17, 0) = a*z*zd;
};


In [55]:
mat = dbasis_dx
name = 'dbdx'
code = cpp_generator(mat, name, name)    # Print C++ code
print(code)


void dbdx::dbdx_update(){
dbdx(0, 0) = 1;
dbdx(1, 0) = 0;
dbdx(2, 0) = 0;
dbdx(3, 0) = 0;
dbdx(4, 0) = 0;
dbdx(5, 0) = 0;
dbdx(6, 0) = 0;
dbdx(7, 0) = 0;
dbdx(8, 0) = 0;
dbdx(9, 0) = 0;
dbdx(10, 0) = 0;
dbdx(11, 0) = 0;
dbdx(12, 0) = 0;
dbdx(13, 0) = 0;
dbdx(14, 0) = a;
dbdx(15, 0) = a*h;
dbdx(16, 0) = a*thetad;
dbdx(17, 0) = a*zd;
dbdx(0, 1) = 0;
dbdx(1, 1) = 1;
dbdx(2, 1) = 0;
dbdx(3, 1) = 0;
dbdx(4, 1) = 0;
dbdx(5, 1) = 0;
dbdx(6, 1) = 0;
dbdx(7, 1) = 0;
dbdx(8, 1) = 0;
dbdx(9, 1) = 0;
dbdx(10, 1) = a;
dbdx(11, 1) = a*h;
dbdx(12, 1) = a*thetad;
dbdx(13, 1) = 2*a*zd;
dbdx(14, 1) = 0;
dbdx(15, 1) = 0;
dbdx(16, 1) = 0;
dbdx(17, 1) = a*z;
dbdx(0, 2) = 0;
dbdx(1, 2) = 0;
dbdx(2, 2) = 1;
dbdx(3, 2) = 0;
dbdx(4, 2) = 0;
dbdx(5, 2) = 0;
dbdx(6, 2) = 0;
dbdx(7, 2) = a;
dbdx(8, 2) = a*h;
dbdx(9, 2) = 2*a*thetad;
dbdx(10, 2) = 0;
dbdx(11, 2) = 0;
dbdx(12, 2) = a*zd;
dbdx(13, 2) = 0;
dbdx(14, 2) = 0;
dbdx(15, 2) = 0;
dbdx(16, 2) = a*z;
dbdx(17, 2) = 0;
dbdx(0, 3) = 0;
dbdx(1, 3) = 0;
dbdx(2, 3