diff --git a/ext/symengine/ruby_basic.c b/ext/symengine/ruby_basic.c index ea3f7cd..2fbc403 100644 --- a/ext/symengine/ruby_basic.c +++ b/ext/symengine/ruby_basic.c @@ -18,12 +18,13 @@ VALUE cbasic_alloc(VALUE klass) return Data_Wrap_Struct(klass, NULL, cbasic_free_heap, struct_ptr); } -VALUE cbasic_binary_op(VALUE self, VALUE operand2, - void (*cwfunc_ptr)(basic_struct *, const basic_struct *, - const basic_struct *)) +VALUE cbasic_binary_op( + VALUE self, VALUE operand2, + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *, + const basic_struct *)) { basic_struct *this, *cresult; - VALUE result; + VALUE result = Qnil; basic cbasic_operand2; basic_new_stack(cbasic_operand2); @@ -32,27 +33,37 @@ VALUE cbasic_binary_op(VALUE self, VALUE operand2, sympify(operand2, cbasic_operand2); cresult = basic_new_heap(); - cwfunc_ptr(cresult, this, cbasic_operand2); - result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, cbasic_free_heap, - cresult); - basic_free_stack(cbasic_operand2); + symengine_exceptions_t error_code + = cwfunc_ptr(cresult, this, cbasic_operand2); + if (error_code == SYMENGINE_NO_EXCEPTION) { + result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, + cbasic_free_heap, cresult); + basic_free_stack(cbasic_operand2); + } else { + basic_free_stack(cbasic_operand2); + raise_exception(error_code); + } return result; } -VALUE cbasic_unary_op(VALUE self, - void (*cwfunc_ptr)(basic_struct *, const basic_struct *)) +VALUE cbasic_unary_op(VALUE self, symengine_exceptions_t (*cwfunc_ptr)( + basic_struct *, const basic_struct *)) { basic_struct *this, *cresult; - VALUE result; + VALUE result = Qnil; Data_Get_Struct(self, basic_struct, this); cresult = basic_new_heap(); - cwfunc_ptr(cresult, this); - result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, cbasic_free_heap, - cresult); + symengine_exceptions_t error_code = cwfunc_ptr(cresult, this); + if (error_code == SYMENGINE_NO_EXCEPTION) { + result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, + cbasic_free_heap, cresult); + } else { + raise_exception(error_code); + } return result; } @@ -93,8 +104,8 @@ VALUE cbasic_diff(VALUE self, VALUE operand2) sympify(operand2, cbasic_operand2); cresult = basic_new_heap(); - int status = basic_diff(cresult, this, cbasic_operand2); - if (status == 0) { + symengine_exceptions_t status = basic_diff(cresult, this, cbasic_operand2); + if (status == SYMENGINE_RUNTIME_ERROR) { basic_free_stack(cbasic_operand2); basic_free_heap(cresult); return Qnil; diff --git a/ext/symengine/ruby_basic.h b/ext/symengine/ruby_basic.h index dc5304e..3653387 100644 --- a/ext/symengine/ruby_basic.h +++ b/ext/symengine/ruby_basic.h @@ -12,12 +12,13 @@ void cbasic_free_heap(void *ptr); VALUE cbasic_alloc(VALUE klass); -VALUE cbasic_binary_op(VALUE self, VALUE operand2, - void (*cwfunc_ptr)(basic_struct *, const basic_struct *, - const basic_struct *)); +VALUE cbasic_binary_op( + VALUE self, VALUE operand2, + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *, + const basic_struct *)); -VALUE cbasic_unary_op(VALUE self, - void (*cwfunc_ptr)(basic_struct *, const basic_struct *)); +VALUE cbasic_unary_op(VALUE self, symengine_exceptions_t (*cwfunc_ptr)( + basic_struct *, const basic_struct *)); VALUE cbasic_add(VALUE self, VALUE operand2); diff --git a/ext/symengine/ruby_utils.c b/ext/symengine/ruby_utils.c index 9973eb1..7fc948a 100644 --- a/ext/symengine/ruby_utils.c +++ b/ext/symengine/ruby_utils.c @@ -17,15 +17,20 @@ VALUE cutils_sympify(VALUE self, VALUE operand) VALUE cutils_evalf(VALUE self, VALUE operand, VALUE prec, VALUE real) { - VALUE result; + VALUE result = Qnil; basic_struct *cresult; cresult = basic_new_heap(); sympify(operand, cresult); - basic_evalf(cresult, cresult, NUM2INT(prec), (real == Qtrue ? 1 : 0)); - result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, cbasic_free_heap, - cresult); - + symengine_exceptions_t error_code + = basic_evalf(cresult, cresult, NUM2INT(prec), (real == Qtrue ? 1 : 0)); + + if (error_code == SYMENGINE_NO_EXCEPTION) { + result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, + cbasic_free_heap, cresult); + } else { + raise_exception(error_code); + } return result; } diff --git a/ext/symengine/symengine_utils.c b/ext/symengine/symengine_utils.c index 43dfdb3..60c8d92 100644 --- a/ext/symengine/symengine_utils.c +++ b/ext/symengine/symengine_utils.c @@ -199,31 +199,37 @@ VALUE Klass_of_Basic(const basic_struct *basic_ptr) } } -VALUE function_onearg(void (*cwfunc_ptr)(basic_struct *, const basic_struct *), - VALUE operand1) +VALUE function_onearg( + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *), + VALUE operand1) { basic_struct *cresult; - VALUE result; + VALUE result = Qnil; basic cbasic_operand1; basic_new_stack(cbasic_operand1); sympify(operand1, cbasic_operand1); cresult = basic_new_heap(); - cwfunc_ptr(cresult, cbasic_operand1); - result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, cbasic_free_heap, - cresult); - basic_free_stack(cbasic_operand1); - + symengine_exceptions_t error_code = cwfunc_ptr(cresult, cbasic_operand1); + if (error_code == SYMENGINE_NO_EXCEPTION) { + result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, + cbasic_free_heap, cresult); + basic_free_stack(cbasic_operand1); + } else { + basic_free_stack(cbasic_operand1); + raise_exception(error_code); + } return result; } -VALUE function_twoarg(void (*cwfunc_ptr)(basic_struct *, const basic_struct *, +VALUE function_twoarg( + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *, const basic_struct *), - VALUE operand1, VALUE operand2) + VALUE operand1, VALUE operand2) { basic_struct *cresult; - VALUE result; + VALUE result = Qnil; basic cbasic_operand1; basic_new_stack(cbasic_operand1); @@ -234,11 +240,43 @@ VALUE function_twoarg(void (*cwfunc_ptr)(basic_struct *, const basic_struct *, sympify(operand2, cbasic_operand2); cresult = basic_new_heap(); - cwfunc_ptr(cresult, cbasic_operand1, cbasic_operand2); - result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, cbasic_free_heap, - cresult); - basic_free_stack(cbasic_operand1); - basic_free_stack(cbasic_operand2); + symengine_exceptions_t error_code + = cwfunc_ptr(cresult, cbasic_operand1, cbasic_operand2); + if (error_code == SYMENGINE_NO_EXCEPTION) { + result = Data_Wrap_Struct(Klass_of_Basic(cresult), NULL, + cbasic_free_heap, cresult); + basic_free_stack(cbasic_operand1); + basic_free_stack(cbasic_operand2); + } else { + basic_free_stack(cbasic_operand1); + basic_free_stack(cbasic_operand2); + raise_exception(error_code); + } return result; } + +void raise_exception(symengine_exceptions_t error_code) +{ + char *str = ""; + switch (error_code) { + case SYMENGINE_NO_EXCEPTION: + return; + case SYMENGINE_RUNTIME_ERROR: + str = "Runtime Error"; + break; + case SYMENGINE_DIV_BY_ZERO: + str = "Division by Zero"; + break; + case SYMENGINE_NOT_IMPLEMENTED: + str = "Not Implemented"; + break; + case SYMENGINE_UNDEFINED: + str = "Undefined"; + break; + case SYMENGINE_PARSE_ERROR: + str = "Parse Error"; + break; + } + rb_raise(rb_eRuntimeError, "%s", str); +} diff --git a/ext/symengine/symengine_utils.h b/ext/symengine/symengine_utils.h index b2baeac..6924f34 100644 --- a/ext/symengine/symengine_utils.h +++ b/ext/symengine/symengine_utils.h @@ -17,12 +17,15 @@ void get_symintfromval(VALUE operand2, basic_struct *cbasic_operand2); VALUE Klass_of_Basic(const basic_struct *basic_ptr); // Returns the result from the function pointed by cwfunc_ptr: for one argument // functions -VALUE function_onearg(void (*cwfunc_ptr)(basic_struct *, const basic_struct *), - VALUE operand1); +VALUE function_onearg( + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *), + VALUE operand1); // Returns the result from the function pointed by cwfunc_ptr: for two argument // functions -VALUE function_twoarg(void (*cwfunc_ptr)(basic_struct *, const basic_struct *, +VALUE function_twoarg( + symengine_exceptions_t (*cwfunc_ptr)(basic_struct *, const basic_struct *, const basic_struct *), - VALUE operand1, VALUE operand2); + VALUE operand1, VALUE operand2); +void raise_exception(symengine_exceptions_t error_code); #endif // SYMENGINE_UTILS_H_ diff --git a/spec/functions_spec.rb b/spec/functions_spec.rb index 24f4e10..adf9dae 100644 --- a/spec/functions_spec.rb +++ b/spec/functions_spec.rb @@ -66,8 +66,11 @@ expect(SymEngine::cos(pi)).to eq(-1) expect(SymEngine::tan(pi)).to eq(0) expect(SymEngine::csc(pi/2)).to eq(1) + expect { SymEngine::csc(pi) }.to raise_error(RuntimeError) expect(SymEngine::sec(pi)).to eq(-1) + expect { SymEngine::sec(pi/2) }.to raise_error(RuntimeError) expect(SymEngine::cot(pi/4)).to eq(1) + expect { SymEngine::cot(pi) }.to raise_error(RuntimeError) expect(SymEngine::asin(1)).to eq(pi/2) expect(SymEngine::acos(1)).to eq(0) diff --git a/spec/integer_spec.rb b/spec/integer_spec.rb index 850bb1f..455613b 100644 --- a/spec/integer_spec.rb +++ b/spec/integer_spec.rb @@ -67,4 +67,10 @@ end end end + + describe 'errors' do + it 'raises an exception on division by zero' do + expect { SymEngine(1)/SymEngine(0) }.to raise_error(RuntimeError) + end + end end diff --git a/symengine_version.txt b/symengine_version.txt index ade392a..1f95871 100644 --- a/symengine_version.txt +++ b/symengine_version.txt @@ -1 +1 @@ -ab5f5da3a9aa501a7a5bf7889a68f67c21a3b9c8 +d674459efdfb2871643a2ae4616715ba98323ada