diff --git a/ext/symengine/ruby_function.c b/ext/symengine/ruby_function.c index e8be33a..017a8b7 100644 --- a/ext/symengine/ruby_function.c +++ b/ext/symengine/ruby_function.c @@ -1,5 +1,7 @@ #include "ruby_function.h" +typedef struct CVecBasic CVecBasic; + #define IMPLEMENT_ONE_ARG_FUNC(func) \ VALUE cfunction_ ## func(VALUE self, VALUE operand1) { \ return function_onearg(basic_ ## func, operand1); \ @@ -37,3 +39,36 @@ IMPLEMENT_ONE_ARG_FUNC(gamma); #undef IMPLEMENT_ONE_ARG_FUNC +VALUE cfunction_functionsymbol_init(VALUE self, VALUE args) +{ + int argc = RARRAY_LEN(args); + if (argc == 0) { + rb_raise(rb_eTypeError, "Arguments Expected"); + } + + VALUE first = rb_ary_shift(args); + if (TYPE(first) != T_STRING) { + rb_raise(rb_eTypeError, "String expected as first argument"); + } + char *name = StringValueCStr(first); + + CVecBasic *cargs = vecbasic_new(); + + basic x; + basic_new_stack(x); + int i; + for (i = 0; i < argc-1; i++) { + sympify(rb_ary_shift(args), x); + vecbasic_push_back(cargs, x); + } + + basic_struct *this; + Data_Get_Struct(self, basic_struct, this); + function_symbol_set(this, name, cargs); + + vecbasic_free(cargs); + basic_free_stack(x); + + return self; +} + diff --git a/ext/symengine/ruby_function.h b/ext/symengine/ruby_function.h index e55b63b..b271353 100644 --- a/ext/symengine/ruby_function.h +++ b/ext/symengine/ruby_function.h @@ -37,4 +37,6 @@ VALUE cfunction_dirichlet_eta(VALUE self, VALUE operand1); VALUE cfunction_zeta(VALUE self, VALUE operand1); VALUE cfunction_gamma(VALUE self, VALUE operand1); +VALUE cfunction_functionsymbol_init(VALUE self, VALUE args); + #endif //RUBY_FUNCTION_H_ diff --git a/ext/symengine/symengine.c b/ext/symengine/symengine.c index 3be164b..f86969e 100644 --- a/ext/symengine/symengine.c +++ b/ext/symengine/symengine.c @@ -110,6 +110,9 @@ void Init_symengine() { rb_define_const(m_symengine, "I", cconstant_i()); rb_define_const(m_symengine, "HAVE_MPFR", cconstant_have_mpfr()); rb_define_const(m_symengine, "HAVE_MPC", cconstant_have_mpc()); + + //Subs class + c_subs = rb_define_class_under(m_symengine, "Subs", c_basic); //Add class c_add = rb_define_class_under(m_symengine, "Add", c_basic); @@ -130,10 +133,12 @@ void Init_symengine() { c_dirichlet_eta = rb_define_class_under(m_symengine, "Dirichlet_eta", c_function); c_zeta = rb_define_class_under(m_symengine, "Zeta", c_function); c_gamma = rb_define_class_under(m_symengine, "Gamma", c_function); - - //Abs Class c_abs = rb_define_class_under(m_symengine, "Abs", c_function); + //FunctionSymbol Class + c_function_symbol = rb_define_class_under(m_symengine, "FunctionSymbol", c_function); + rb_define_method(c_function_symbol, "initialize", cfunction_functionsymbol_init, -2); + //TrigFunction SubClasses c_sin = rb_define_class_under(m_symengine, "Sin", c_trig_function); c_cos = rb_define_class_under(m_symengine, "Cos", c_trig_function); diff --git a/ext/symengine/symengine.h b/ext/symengine/symengine.h index 20dfcc0..f850d57 100644 --- a/ext/symengine/symengine.h +++ b/ext/symengine/symengine.h @@ -21,10 +21,12 @@ VALUE c_real_mpfr; VALUE c_complex_mpc; #endif //HAVE_SYMENGINE_MPC VALUE c_constant; +VALUE c_subs; VALUE c_add; VALUE c_mul; VALUE c_pow; VALUE c_function; +VALUE c_function_symbol; VALUE c_trig_function; VALUE c_hyperbolic_function; VALUE c_lambertw; diff --git a/ext/symengine/symengine_utils.c b/ext/symengine/symengine_utils.c index cf4f015..e4385f1 100644 --- a/ext/symengine/symengine_utils.c +++ b/ext/symengine/symengine_utils.c @@ -8,7 +8,6 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) { VALUE a, b; double f; char *c; - rb_cBigDecimal = CLASS_OF(rb_eval_string("BigDecimal.new('0.0001')")); switch(TYPE(operand2)) { case T_FIXNUM: @@ -63,7 +62,7 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) { case T_DATA: c = rb_obj_classname(operand2); #ifdef HAVE_SYMENGINE_MPFR - if(CLASS_OF(operand2) == rb_cBigDecimal){ + if (strcmp(c, "BigDecimal") == 0) { c = RSTRING_PTR( rb_funcall(operand2, rb_intern("to_s"), 1, rb_str_new2("F")) ); real_mpfr_set_str(cbasic_operand2, c, 200); break; @@ -113,12 +112,16 @@ VALUE Klass_of_Basic(const basic_struct *basic_ptr) { #endif //HAVE_SYMENGINE_MPFR case SYMENGINE_CONSTANT: return c_constant; + case SYMENGINE_SUBS: + return c_subs; case SYMENGINE_ADD: return c_add; case SYMENGINE_MUL: return c_mul; case SYMENGINE_POW: return c_pow; + case SYMENGINE_FUNCTIONSYMBOL: + return c_function_symbol; case SYMENGINE_ABS: return c_abs; case SYMENGINE_SIN: diff --git a/lib/symengine.rb b/lib/symengine.rb index 21d4c7d..e4d3477 100644 --- a/lib/symengine.rb +++ b/lib/symengine.rb @@ -1,6 +1,5 @@ module SymEngine class << self - # Defines a shortcut for SymEngine::Symbol.new() allowing multiple symbols # to be created all at once. # @@ -27,7 +26,10 @@ def symbols ary_or_string, *params ary_or_string.map do |symbol_or_string| SymEngine::Symbol.new(symbol_or_string) end - end + end + def Function(n) + return SymEngine::UndefFunction.new(n) + end end end @@ -37,3 +39,4 @@ def symbols ary_or_string, *params require 'symengine/integer' require 'symengine/complex' require 'symengine/complex_double' +require 'symengine/undef_function' diff --git a/lib/symengine/undef_function.rb b/lib/symengine/undef_function.rb new file mode 100644 index 0000000..8982b8b --- /dev/null +++ b/lib/symengine/undef_function.rb @@ -0,0 +1,13 @@ +module SymEngine + class UndefFunction + + def initialize(n) + @name = n + end + + def call(*args) + SymEngine::FunctionSymbol.new(@name, *args) + end + + end +end diff --git a/spec/function_symbol_spec.rb b/spec/function_symbol_spec.rb new file mode 100644 index 0000000..f73d2eb --- /dev/null +++ b/spec/function_symbol_spec.rb @@ -0,0 +1,40 @@ +describe SymEngine::FunctionSymbol do + + let(:x) { sym('x') } + let(:y) { sym('y') } + let(:z) { sym('z') } + + describe '.new' do + context 'with symbols' do + subject { SymEngine::FunctionSymbol.new('f', x, y, z) } + it { is_expected.to be_a SymEngine::FunctionSymbol } + end + + context 'with compound arguments' do + subject { SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z)) } + it { is_expected.to be_a SymEngine::FunctionSymbol } + end + end + + context '#diff' do + let(:fun) { (SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z))) } + context 'by variable' do + subject { fun.diff(x)/2 } + it { is_expected.to be_a SymEngine::Subs } + end + end + + context 'Initializing with UndefFunctions' do + let(:fun) { SymEngine::Function('f') } + context 'UndefFunction' do + subject { fun } + it { is_expected.to be_a SymEngine::UndefFunction } + end + context 'using call method for UndefFunction' do + subject { fun.(x, y, z) } + it { is_expected.to be_a SymEngine::FunctionSymbol } + end + end +end + +