diff --git a/ext/ruby_snowflake.go b/ext/ruby_snowflake.go index e9b6130..68c4db7 100644 --- a/ext/ruby_snowflake.go +++ b/ext/ruby_snowflake.go @@ -45,6 +45,7 @@ var rbSnowflakeModule C.VALUE var DB_IDENTIFIER = C.rb_intern(C.CString("db")) var RESULT_IDENTIFIER = C.rb_intern(C.CString("rows")) var RESULT_DURATION = C.rb_intern(C.CString("@query_duration")) +var RESULT_ERROR = C.rb_intern(C.CString("@error")) var objects = make(map[interface{}]bool) @@ -91,12 +92,22 @@ func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE { fmt.Println("statement", RbGoString(statement)) } rows, err := x.db.QueryContext(sf.WithHigherPrecision(context.Background()), RbGoString(statement)) + if err != nil { + result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) + errStr := fmt.Sprintf("Query error: '%s'", err.Error()) + C.rb_ivar_set(result, RESULT_ERROR, RbString(errStr)) + return result + } + duration := time.Now().Sub(t1).Seconds() if LOG_LEVEL > 0 { fmt.Printf("Query duration: %s\n", time.Now().Sub(t1)) } if err != nil { - rb_raise(C.rb_eArgError, "Query error: '%s'", err) + result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) + errStr := fmt.Sprintf("Query error: '%s'", err.Error()) + C.rb_ivar_set(result, RESULT_ERROR, RbString(errStr)) + return result } result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) @@ -156,7 +167,7 @@ func Init_ruby_snowflake_client_ext() { C.rb_define_private_method(rbSnowflakeClientClass, C.CString("_connect"), (*[0]byte)(C.Connect), 7) C.rb_define_method(rbSnowflakeClientClass, C.CString("inspect"), (*[0]byte)(C.Inspect), 0) C.rb_define_method(rbSnowflakeClientClass, C.CString("to_s"), (*[0]byte)(C.Inspect), 0) - C.rb_define_method(rbSnowflakeClientClass, C.CString("fetch"), (*[0]byte)(C.ObjFetch), 1) + C.rb_define_method(rbSnowflakeClientClass, C.CString("_fetch"), (*[0]byte)(C.ObjFetch), 1) if LOG_LEVEL > 0 { fmt.Println("init ruby snowflake client") diff --git a/lib/ruby_snowflake_client.rb b/lib/ruby_snowflake_client.rb index d5bf22b..d129032 100644 --- a/lib/ruby_snowflake_client.rb +++ b/lib/ruby_snowflake_client.rb @@ -11,11 +11,21 @@ def connect(account:"", warehouse:"", database:"", schema: "", user: "", passwor _connect(account, warehouse, database, schema, user, password, role) true end + + def fetch(sql) + result = _fetch(sql) + return result if result.valid? + raise(result.error) + end end class Result - attr_reader :query_duration + attr_reader :query_duration, :error + + def valid? + error == nil + end def get_all_rows(&blk) GC.disable