Skip to content
Merged
35 changes: 35 additions & 0 deletions ext/symengine/ruby_function.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ruby_function.h"

typedef struct CVecBasic CVecBasic;

#define IMPLEMENT_ONE_ARG_FUNC(func) \
VALUE cfunction_ ## func(VALUE self, VALUE operand1) { \
return function_onearg(basic_ ## func, operand1); \
Expand Down Expand Up @@ -37,3 +39,36 @@ IMPLEMENT_ONE_ARG_FUNC(gamma);

#undef IMPLEMENT_ONE_ARG_FUNC

VALUE cfunction_functionsymbol_init(VALUE self, VALUE args)
{
int argc = RARRAY_LEN(args);
if (argc == 0) {
rb_raise(rb_eTypeError, "Arguments Expected");
}

VALUE first = rb_ary_shift(args);
if (TYPE(first) != T_STRING) {
rb_raise(rb_eTypeError, "String expected as first argument");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also raise error, if no arguments are passed.

}
char *name = StringValueCStr(first);

CVecBasic *cargs = vecbasic_new();

basic x;
basic_new_stack(x);
int i;
for (i = 0; i < argc-1; i++) {
sympify(rb_ary_shift(args), x);
vecbasic_push_back(cargs, x);
}

basic_struct *this;
Data_Get_Struct(self, basic_struct, this);
function_symbol_set(this, name, cargs);

vecbasic_free(cargs);
basic_free_stack(x);

return self;
}

2 changes: 2 additions & 0 deletions ext/symengine/ruby_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ VALUE cfunction_dirichlet_eta(VALUE self, VALUE operand1);
VALUE cfunction_zeta(VALUE self, VALUE operand1);
VALUE cfunction_gamma(VALUE self, VALUE operand1);

VALUE cfunction_functionsymbol_init(VALUE self, VALUE args);

#endif //RUBY_FUNCTION_H_
9 changes: 7 additions & 2 deletions ext/symengine/symengine.c
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ void Init_symengine() {
rb_define_const(m_symengine, "I", cconstant_i());
rb_define_const(m_symengine, "HAVE_MPFR", cconstant_have_mpfr());
rb_define_const(m_symengine, "HAVE_MPC", cconstant_have_mpc());

//Subs class
c_subs = rb_define_class_under(m_symengine, "Subs", c_basic);

//Add class
c_add = rb_define_class_under(m_symengine, "Add", c_basic);
Expand All @@ -130,10 +133,12 @@ void Init_symengine() {
c_dirichlet_eta = rb_define_class_under(m_symengine, "Dirichlet_eta", c_function);
c_zeta = rb_define_class_under(m_symengine, "Zeta", c_function);
c_gamma = rb_define_class_under(m_symengine, "Gamma", c_function);

//Abs Class
c_abs = rb_define_class_under(m_symengine, "Abs", c_function);

//FunctionSymbol Class
c_function_symbol = rb_define_class_under(m_symengine, "FunctionSymbol", c_function);
rb_define_method(c_function_symbol, "initialize", cfunction_functionsymbol_init, -2);

//TrigFunction SubClasses
c_sin = rb_define_class_under(m_symengine, "Sin", c_trig_function);
c_cos = rb_define_class_under(m_symengine, "Cos", c_trig_function);
Expand Down
2 changes: 2 additions & 0 deletions ext/symengine/symengine.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ VALUE c_real_mpfr;
VALUE c_complex_mpc;
#endif //HAVE_SYMENGINE_MPC
VALUE c_constant;
VALUE c_subs;
VALUE c_add;
VALUE c_mul;
VALUE c_pow;
VALUE c_function;
VALUE c_function_symbol;
VALUE c_trig_function;
VALUE c_hyperbolic_function;
VALUE c_lambertw;
Expand Down
7 changes: 5 additions & 2 deletions ext/symengine/symengine_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) {
VALUE a, b;
double f;
char *c;
rb_cBigDecimal = CLASS_OF(rb_eval_string("BigDecimal.new('0.0001')"));

switch(TYPE(operand2)) {
case T_FIXNUM:
Expand Down Expand Up @@ -63,7 +62,7 @@ void sympify(VALUE operand2, basic_struct *cbasic_operand2) {
case T_DATA:
c = rb_obj_classname(operand2);
#ifdef HAVE_SYMENGINE_MPFR
if(CLASS_OF(operand2) == rb_cBigDecimal){
if (strcmp(c, "BigDecimal") == 0) {
c = RSTRING_PTR( rb_funcall(operand2, rb_intern("to_s"), 1, rb_str_new2("F")) );
real_mpfr_set_str(cbasic_operand2, c, 200);
break;
Expand Down Expand Up @@ -113,12 +112,16 @@ VALUE Klass_of_Basic(const basic_struct *basic_ptr) {
#endif //HAVE_SYMENGINE_MPFR
case SYMENGINE_CONSTANT:
return c_constant;
case SYMENGINE_SUBS:
return c_subs;
case SYMENGINE_ADD:
return c_add;
case SYMENGINE_MUL:
return c_mul;
case SYMENGINE_POW:
return c_pow;
case SYMENGINE_FUNCTIONSYMBOL:
return c_function_symbol;
case SYMENGINE_ABS:
return c_abs;
case SYMENGINE_SIN:
Expand Down
7 changes: 5 additions & 2 deletions lib/symengine.rb
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module SymEngine
class << self

# Defines a shortcut for SymEngine::Symbol.new() allowing multiple symbols
# to be created all at once.
#
Expand All @@ -27,7 +26,10 @@ def symbols ary_or_string, *params
ary_or_string.map do |symbol_or_string|
SymEngine::Symbol.new(symbol_or_string)
end
end
end
def Function(n)
return SymEngine::UndefFunction.new(n)
end
end
end

Expand All @@ -37,3 +39,4 @@ def symbols ary_or_string, *params
require 'symengine/integer'
require 'symengine/complex'
require 'symengine/complex_double'
require 'symengine/undef_function'
13 changes: 13 additions & 0 deletions lib/symengine/undef_function.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module SymEngine
class UndefFunction

def initialize(n)
@name = n
end

def call(*args)
SymEngine::FunctionSymbol.new(@name, *args)
end

end
end
40 changes: 40 additions & 0 deletions spec/function_symbol_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
describe SymEngine::FunctionSymbol do

let(:x) { sym('x') }
let(:y) { sym('y') }
let(:z) { sym('z') }

describe '.new' do
context 'with symbols' do
subject { SymEngine::FunctionSymbol.new('f', x, y, z) }
it { is_expected.to be_a SymEngine::FunctionSymbol }
end

context 'with compound arguments' do
subject { SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z)) }
it { is_expected.to be_a SymEngine::FunctionSymbol }
end
end

context '#diff' do
let(:fun) { (SymEngine::FunctionSymbol.new('f', 2*x, y, SymEngine::sin(z))) }
context 'by variable' do
subject { fun.diff(x)/2 }
it { is_expected.to be_a SymEngine::Subs }
end
end

context 'Initializing with UndefFunctions' do
let(:fun) { SymEngine::Function('f') }
context 'UndefFunction' do
subject { fun }
it { is_expected.to be_a SymEngine::UndefFunction }
end
context 'using call method for UndefFunction' do
subject { fun.(x, y, z) }
it { is_expected.to be_a SymEngine::FunctionSymbol }
end
end
end