diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 60dc212..461d28c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,7 +32,7 @@ jobs: ### limits ssh access and adds the ssh public key for the user which triggered the workflow #limit-access-to-actor: true - name: Install gem - run: cd pkg && gem install --local *.gem --verbose + run: cd pkg && gem install --local *.gem - name: Run tests run: ruby -rruby_snowflake_client -S rspec spec/**/*_spec.rb -cfdoc env: # Or as an environment variable diff --git a/ext/c-decl.go b/ext/c-decl.go deleted file mode 100644 index f3126e9..0000000 --- a/ext/c-decl.go +++ /dev/null @@ -1,54 +0,0 @@ -package main - -/* -#include -#include "ruby/ruby.h" -*/ -import "C" - -import ( - "fmt" - "unsafe" -) - -var marked = make(map[unsafe.Pointer]int) - -//export goobj_mark -func goobj_mark(obj unsafe.Pointer) { - if LOG_LEVEL > 0 { - marked[obj] = marked[obj] + 1 - fmt.Printf("MARK log obj %v; counter: %d; total number of MARKED objects: %d\n", obj, marked[obj], len(marked)) - } -} - -//export goobj_log -func goobj_log(obj unsafe.Pointer) { - if LOG_LEVEL > 0 { - fmt.Println("log obj", obj) - } -} - -//export goobj_retain -func goobj_retain(obj unsafe.Pointer, x *C.char) { - if LOG_LEVEL > 0 { - fmt.Printf("retain obj [%v] %v - currently keeping %d\n", C.GoString(x), obj, len(objects)) - } - objects[obj] = true - marked[obj] = 0 -} - -//export goobj_free -func goobj_free(obj unsafe.Pointer) { - if LOG_LEVEL > 0 { - fmt.Printf("CALLED GOOBJ FREE %v - CURRENTLY %d objects left\n", obj, len(objects)) - } - delete(objects, obj) - delete(marked, obj) -} - -//export goobj_compact -func goobj_compact(obj unsafe.Pointer) { - if LOG_LEVEL > 0 { - fmt.Printf("CALLED GOOBJ COMPACT %v", obj) - } -} diff --git a/ext/client.go b/ext/client.go new file mode 100644 index 0000000..7fb1b6d --- /dev/null +++ b/ext/client.go @@ -0,0 +1,98 @@ +package main + +/* +#include +#include "ruby/ruby.h" + +void RbGcGuard(VALUE ptr); +VALUE ReturnEnumerator(VALUE cls); +VALUE RbNumFromDouble(double v); +*/ +import "C" + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + sf "github.com/snowflakedb/gosnowflake" +) + +type SnowflakeClient struct { + db *sql.DB +} + +func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE { + t1 := time.Now() + + if LOG_LEVEL > 0 { + fmt.Println("statement", RbGoString(statement)) + } + rows, err := x.db.QueryContext(sf.WithHigherPrecision(context.Background()), RbGoString(statement)) + if err != nil { + result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) + errStr := fmt.Sprintf("Query error: '%s'", err.Error()) + C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr)) + return result + } + + duration := time.Now().Sub(t1).Seconds() + if LOG_LEVEL > 0 { + fmt.Printf("Query duration: %s\n", time.Now().Sub(t1)) + } + if err != nil { + result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) + errStr := fmt.Sprintf("Query error: '%s'", err.Error()) + C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr)) + return result + } + + result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) + cols, _ := rows.Columns() + for idx, col := range cols { + col := col + cols[idx] = strings.ToLower(col) + } + rs := SnowflakeResult{rows, cols} + resultMap[result] = &rs + C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration))) + return result +} + +//export Connect +func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE, schema C.VALUE, user C.VALUE, password C.VALUE, role C.VALUE) { + // other optional parms: Application, Host, and alt auth schemes + cfg := &sf.Config{ + Account: RbGoString(account), + Warehouse: RbGoString(warehouse), + Database: RbGoString(database), + Schema: RbGoString(schema), + User: RbGoString(user), + Password: RbGoString(password), + Role: RbGoString(role), + Port: int(443), + } + + dsn, err := sf.DSN(cfg) + if err != nil { + errStr := fmt.Sprintf("Snowflake Config Creation Error: '%s'", err.Error()) + C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr)) + } + + db, err := sql.Open("snowflake", dsn) + if err != nil { + errStr := fmt.Sprintf("Connection Error: '%s'", err.Error()) + C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr)) + } + rs := SnowflakeClient{db} + clientRef[self] = &rs +} + +//export ObjFetch +func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE { + x, _ := clientRef[self] + + return x.Fetch(statement) +} diff --git a/ext/result.go b/ext/result.go index 46bd56a..dc85569 100644 --- a/ext/result.go +++ b/ext/result.go @@ -5,12 +5,11 @@ package main #include "ruby/ruby.h" VALUE ReturnEnumerator(VALUE cls); -VALUE createRbString(char* str); -VALUE funcall0param(VALUE obj, ID id); */ import "C" import ( + "database/sql" "fmt" "math/big" "time" @@ -18,6 +17,11 @@ import ( gopointer "github.com/mattn/go-pointer" ) +type SnowflakeResult struct { + rows *sql.Rows + columns []string +} + func wrapRbRaise(err error) { fmt.Printf("[ruby-snowflake-client] Error encountered: %s\n", err.Error()) fmt.Printf("[ruby-snowflake-client] Will call `rb_raise`\n") diff --git a/ext/ruby_snowflake.go b/ext/ruby_snowflake.go index c5cf388..66a0c44 100644 --- a/ext/ruby_snowflake.go +++ b/ext/ruby_snowflake.go @@ -10,9 +10,6 @@ VALUE Inspect(VALUE); VALUE GetRows(VALUE); VALUE GetRowsNoEnum(VALUE); -VALUE NewGoStruct(VALUE klass, char* reason, void *p); -VALUE GoRetEnum(VALUE,int,VALUE); -void* GetGoStruct(VALUE obj); void RbGcGuard(VALUE ptr); VALUE ReturnEnumerator(VALUE cls); VALUE RbNumFromDouble(double v); @@ -20,25 +17,9 @@ VALUE RbNumFromDouble(double v); import "C" import ( - "context" - "database/sql" "fmt" - "strings" - "time" - - sf "github.com/snowflakedb/gosnowflake" ) -type SnowflakeResult struct { - rows *sql.Rows - //keptHash C.VALUE - columns []string - //cols []C.VALUE -} -type SnowflakeClient struct { - db *sql.DB -} - var rbSnowflakeClientClass C.VALUE var rbSnowflakeResultClass C.VALUE var rbSnowflakeModule C.VALUE @@ -54,79 +35,6 @@ var clientRef = make(map[C.VALUE]*SnowflakeClient) var LOG_LEVEL = 0 var empty C.VALUE = C.Qnil -//export Connect -func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE, schema C.VALUE, user C.VALUE, password C.VALUE, role C.VALUE) { - // other optional parms: Application, Host, and alt auth schemes - cfg := &sf.Config{ - Account: RbGoString(account), - Warehouse: RbGoString(warehouse), - Database: RbGoString(database), - Schema: RbGoString(schema), - User: RbGoString(user), - Password: RbGoString(password), - Role: RbGoString(role), - Port: int(443), - } - - dsn, err := sf.DSN(cfg) - if err != nil { - errStr := fmt.Sprintf("Snowflake Config Creation Error: '%s'", err.Error()) - C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr)) - } - - db, err := sql.Open("snowflake", dsn) - if err != nil { - errStr := fmt.Sprintf("Connection Error: '%s'", err.Error()) - C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr)) - } - rs := SnowflakeClient{db} - clientRef[self] = &rs -} - -func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE { - t1 := time.Now() - - if LOG_LEVEL > 0 { - fmt.Println("statement", RbGoString(statement)) - } - rows, err := x.db.QueryContext(sf.WithHigherPrecision(context.Background()), RbGoString(statement)) - if err != nil { - result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) - errStr := fmt.Sprintf("Query error: '%s'", err.Error()) - C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr)) - return result - } - - duration := time.Now().Sub(t1).Seconds() - if LOG_LEVEL > 0 { - fmt.Printf("Query duration: %s\n", time.Now().Sub(t1)) - } - if err != nil { - result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) - errStr := fmt.Sprintf("Query error: '%s'", err.Error()) - C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr)) - return result - } - - result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass) - cols, _ := rows.Columns() - for idx, col := range cols { - col := col - cols[idx] = strings.ToLower(col) - } - rs := SnowflakeResult{rows, cols} - resultMap[result] = &rs - C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration))) - return result -} - -//export ObjFetch -func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE { - x, _ := clientRef[self] - - return x.Fetch(statement) -} - //export Inspect func Inspect(self C.VALUE) C.VALUE { x := clientRef[self] diff --git a/ext/wrapper.go b/ext/wrapper.go index d91e8dc..11581bb 100644 --- a/ext/wrapper.go +++ b/ext/wrapper.go @@ -24,57 +24,14 @@ VALUE RbNumFromLong(long v) { return LONG2NUM(v); } -void goobj_retain(void *, char*); -void goobj_free(void *); -void goobj_log(void *); -void goobj_mark(void *); -void goobj_compact(void *); - -static const rb_data_type_t go_type = { - "GoStruct", - { - goobj_mark, - goobj_free, - NULL, - (goobj_compact), - }, - 0, 0, RUBY_TYPED_FREE_IMMEDIATELY -}; - -VALUE -NewGoStruct(VALUE klass, char* reason, void *p) -{ - goobj_retain(p, reason); - return TypedData_Wrap_Struct(klass, &go_type, p); -} - VALUE ReturnEnumerator(VALUE cls) { RETURN_ENUMERATOR(cls, 0, NULL); return Qnil; } -void * -GetGoStruct(VALUE obj) -{ - void *val; - return TypedData_Get_Struct(obj, void *, &go_type, val); -} - void RbGcGuard(VALUE ptr) { RB_GC_GUARD(ptr); } - -VALUE createRbString(char* str) { - volatile VALUE rbStr; - rbStr = rb_tainted_str_new_cstr(str); - return rbStr; -} - -VALUE funcall0param(VALUE obj, ID id) { - RB_GC_GUARD(obj); - return rb_funcall(obj, id, 0); -} - */ import "C" import ( @@ -90,10 +47,6 @@ func RbNumFromDouble(v C.double) C.VALUE { return C.RbNumFromDouble(v) } -func GetGoStruct(obj C.VALUE) unsafe.Pointer { - return C.GetGoStruct(obj) -} - func returnEnum(cls C.VALUE) C.VALUE { return C.ReturnEnumerator(cls) } diff --git a/lib/ruby_snowflake_client/version.rb b/lib/ruby_snowflake_client/version.rb index d092f16..19b8a02 100644 --- a/lib/ruby_snowflake_client/version.rb +++ b/lib/ruby_snowflake_client/version.rb @@ -1,3 +1,3 @@ module RubySnowflakeClient - VERSION = '1.1.1' + VERSION = '1.2.0' end diff --git a/spec/snowflake/client_spec.rb b/spec/snowflake/client_spec.rb index c2acd9b..8ad3524 100644 --- a/spec/snowflake/client_spec.rb +++ b/spec/snowflake/client_spec.rb @@ -164,7 +164,6 @@ let(:limit) { 150_000 } it "should work" do 100.times do |idx| - puts "on #{idx}" client = described_class.new client.connect( account: ENV["SNOWFLAKE_ACCOUNT"], @@ -174,7 +173,6 @@ ) result = client.fetch(query) rows = result.get_all_rows - puts "Done with get all rows" GC.start expect(rows.length).to eq 150000 expect((-50000...50000)).to include(rows[0]["id"].to_i) @@ -188,7 +186,6 @@ t = [] 10.times do |idx| t << Thread.new do - puts "on #{idx}" client = described_class.new client.connect( account: ENV["SNOWFLAKE_ACCOUNT"], @@ -198,7 +195,6 @@ ) result = client.fetch(query) rows = result.get_all_rows - puts "Done with get all rows" expect(rows.length).to eq 150000 expect((-50000...50000)).to include(rows[0]["id"].to_i) end