Skip to content
This repository was archived by the owner on Jan 9, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
ruby_snowflake_client (1.0.2)
ruby_snowflake_client (1.1.0)

GEM
remote: https://rubygems.org/
Expand Down
4 changes: 2 additions & 2 deletions ext/c-decl.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ func goobj_log(obj unsafe.Pointer) {
}

//export goobj_retain
func goobj_retain(obj unsafe.Pointer) {
func goobj_retain(obj unsafe.Pointer, x *C.char) {
if LOG_LEVEL > 0 {
fmt.Printf("retain obj %v - currently keeping %d\n", obj, len(objects))
fmt.Printf("retain obj [%v] %v - currently keeping %d\n", C.GoString(x), obj, len(objects))
}
objects[obj] = true
marked[obj] = 0
Expand Down
109 changes: 50 additions & 59 deletions ext/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@ VALUE funcall0param(VALUE obj, ID id);
import "C"

import (
"errors"
"fmt"
"io"
"math/big"
"strings"
"time"

gopointer "github.com/mattn/go-pointer"
Expand All @@ -28,18 +25,44 @@ func wrapRbRaise(err error) {
}

func getResultStruct(self C.VALUE) *SnowflakeResult {
ivar := C.rb_ivar_get(self, RESULT_IDENTIFIER)
return resultMap[self]
}

str := GetGoStruct(ivar)
ptr := gopointer.Restore(str)
sr, ok := ptr.(*SnowflakeResult)
if !ok || sr.rows == nil {
err := errors.New("Empty result; please run a query via `client.fetch(\"SQL\")`")
wrapRbRaise(err)
return nil
//export GetRowsNoEnum
func GetRowsNoEnum(self C.VALUE) C.VALUE {
res := getResultStruct(self)
rows := res.rows

i := 0
t1 := time.Now()
var arr []C.VALUE

for rows.Next() {
if i%5000 == 0 {
if LOG_LEVEL > 0 {
fmt.Println("scanning row: ", i)
}
}
x := res.ScanNextRow(false)
objects[x] = true
gopointer.Save(x)
if LOG_LEVEL > 1 {
// This is VERY noisy
fmt.Printf("alloced %v\n", &x)
}
arr = append(arr, x)
i = i + 1
}
if LOG_LEVEL > 0 {
fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1))
}

return sr
rbArr := C.rb_ary_new2(C.long(len(arr)))
for idx, elem := range arr {
C.rb_ary_store(rbArr, C.long(idx), elem)
}

return rbArr
}

//export GetRows
Expand Down Expand Up @@ -69,11 +92,6 @@ func GetRows(self C.VALUE) C.VALUE {
fmt.Printf("done with rows.next: %s\n", time.Now().Sub(t1))
}

//empty for GC
res.rows = nil
res.keptHash = C.Qnil
res.cols = []C.VALUE{}

return self
}

Expand All @@ -89,10 +107,6 @@ func ObjNextRow(self C.VALUE) C.VALUE {
if rows.Next() {
r := res.ScanNextRow(false)
return r
} else if rows.Err() == io.EOF {
res.rows = nil // free up for gc
res.keptHash = C.Qnil // free up for gc
res.cols = []C.VALUE{}
}
return C.Qnil
}
Expand All @@ -104,8 +118,8 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
fmt.Printf("column types: %+v; %+v\n", cts[0], cts[0].ScanType())
}

rawResult := make([]any, len(res.cols))
rawData := make([]any, len(res.cols))
rawResult := make([]any, len(res.columns))
rawData := make([]any, len(res.columns))
for i := range rawResult {
rawData[i] = &rawResult[i]
}
Expand All @@ -117,10 +131,15 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
}

// trick from postgres; keep hash: pg_result.c:1088
hash := C.rb_hash_dup(res.keptHash)
//hash := C.rb_hash_dup(res.keptHash)
hash := C.rb_hash_new()
if LOG_LEVEL > 1 {
// This is very noisy
fmt.Println("alloc'ed new hash", &hash)
}

for idx, raw := range rawResult {
raw := raw
col_name := res.cols[idx]

var rbVal C.VALUE

Expand Down Expand Up @@ -151,40 +170,12 @@ func (res SnowflakeResult) ScanNextRow(debug bool) C.VALUE {
wrapRbRaise(err)
}
}
C.rb_hash_aset(hash, col_name, rbVal)
}
return hash
}

func SafeMakeHash(lenght int, cols []C.VALUE) C.VALUE {
var hash C.VALUE
hash = C.rb_hash_new()

if LOG_LEVEL > 0 {
fmt.Println("starting make hash")
}
for _, col := range cols {
C.rb_hash_aset(hash, col, C.Qnil)
}
if LOG_LEVEL > 0 {
fmt.Println("end make hash", hash)
colstr := C.rb_str_new2(C.CString(res.columns[idx]))
if LOG_LEVEL > 1 {
// This is very noisy
fmt.Printf("alloc string: %+v; rubyVal: %+v\n", &colstr, &rbVal)
}
C.rb_hash_aset(hash, colstr, rbVal)
}
return hash
}

func (res *SnowflakeResult) Initialize() {
columns, _ := res.rows.Columns()
rbArr := C.rb_ary_new2(C.long(len(columns)))

cols := make([]C.VALUE, len(columns))
for idx, colName := range columns {
str := strings.ToLower(colName)
sym := C.rb_str_new2(C.CString(str))
sym = C.rb_str_freeze(sym)
cols[idx] = sym
C.rb_ary_store(rbArr, C.long(idx), sym)
}

res.cols = cols
res.keptHash = SafeMakeHash(len(columns), cols)
}
72 changes: 30 additions & 42 deletions ext/ruby_snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ VALUE ObjFetch(VALUE,VALUE);
VALUE ObjNextRow(VALUE);
VALUE Inspect(VALUE);
VALUE GetRows(VALUE);
VALUE GetRowsNoEnum(VALUE);

VALUE NewGoStruct(VALUE klass, void *p);
VALUE NewGoStruct(VALUE klass, char* reason, void *p);
VALUE GoRetEnum(VALUE,int,VALUE);
void* GetGoStruct(VALUE obj);
void RbGcGuard(VALUE ptr);
Expand All @@ -21,18 +22,18 @@ import "C"
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"

gopointer "github.com/mattn/go-pointer"
sf "github.com/snowflakedb/gosnowflake"
)

type SnowflakeResult struct {
rows *sql.Rows
keptHash C.VALUE
cols []C.VALUE
rows *sql.Rows
//keptHash C.VALUE
columns []string
//cols []C.VALUE
}
type SnowflakeClient struct {
db *sql.DB
Expand All @@ -42,12 +43,13 @@ var rbSnowflakeClientClass C.VALUE
var rbSnowflakeResultClass C.VALUE
var rbSnowflakeModule C.VALUE

var DB_IDENTIFIER = C.rb_intern(C.CString("db"))
var RESULT_IDENTIFIER = C.rb_intern(C.CString("rows"))
var RESULT_DURATION = C.rb_intern(C.CString("@query_duration"))
var ERROR_IDENT = C.rb_intern(C.CString("@error"))

var objects = make(map[interface{}]bool)
var resultMap = make(map[C.VALUE]*SnowflakeResult)
var clientRef = make(map[C.VALUE]*SnowflakeClient)

var LOG_LEVEL = 0
var empty C.VALUE = C.Qnil
Expand Down Expand Up @@ -78,13 +80,7 @@ func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE,
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}
rs := SnowflakeClient{db}
ptr := gopointer.Save(&rs)
rbStruct := C.NewGoStruct(
rbSnowflakeClientClass,
ptr,
)

C.rb_ivar_set(self, DB_IDENTIFIER, rbStruct)
clientRef[self] = &rs
}

func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
Expand Down Expand Up @@ -113,46 +109,28 @@ func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
}

result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
rs := SnowflakeResult{rows, C.Qnil, []C.VALUE{}}
rs.Initialize()
ptr := gopointer.Save(&rs)
rbStruct := C.NewGoStruct(
rbSnowflakeClientClass,
ptr,
)
C.RbGcGuard(rbStruct)
C.RbGcGuard(rbSnowflakeResultClass)
C.rb_ivar_set(result, RESULT_IDENTIFIER, rbStruct)
cols, _ := rows.Columns()
for idx, col := range cols {
col := col
cols[idx] = strings.ToLower(col)
}
rs := SnowflakeResult{rows, cols}
resultMap[result] = &rs
C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration)))
return result
}

//export ObjFetch
func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE {
var q C.VALUE
q = C.rb_ivar_get(self, DB_IDENTIFIER)

req := C.GetGoStruct(q)
f := gopointer.Restore(req)
x, ok := f.(*SnowflakeClient)
if !ok {
wrapRbRaise((errors.New("cannot convert SnowflakeClient pointer in ObjFetch")))
}
x, _ := clientRef[self]

return x.Fetch(statement)
}

//export Inspect
func Inspect(self C.VALUE) C.VALUE {
q := C.rb_ivar_get(self, DB_IDENTIFIER)
if q == C.Qnil {
return RbString("Object is not instantiated")
}

req := C.GetGoStruct(q)
f := gopointer.Restore(req)
x := f.(*SnowflakeClient)
return RbString(fmt.Sprintf("%+v", x))
x := clientRef[self]
return RbString(fmt.Sprintf("Snowflake::Client <%+v>", x))
}

//export Init_ruby_snowflake_client_ext
Expand All @@ -161,10 +139,20 @@ func Init_ruby_snowflake_client_ext() {
rbSnowflakeClientClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Client"), C.rb_cObject)
rbSnowflakeResultClass = C.rb_define_class_under(rbSnowflakeModule, C.CString("Result"), C.rb_cObject)

objects[rbSnowflakeResultClass] = true
objects[rbSnowflakeClientClass] = true
objects[rbSnowflakeModule] = true
objects[RESULT_DURATION] = true
objects[ERROR_IDENT] = true
C.RbGcGuard(RESULT_DURATION)
//C.RbGcGuard(RESULT_IDENTIFIER)
C.RbGcGuard(ERROR_IDENT)

C.rb_define_method(rbSnowflakeResultClass, C.CString("next_row"), (*[0]byte)(C.ObjNextRow), 0)
// `get_rows` is private as this can lead to SEGFAULT errors if not invoked
// with GC.disable due to undetermined issues caused by the Ruby GC.
C.rb_define_private_method(rbSnowflakeResultClass, C.CString("_get_rows"), (*[0]byte)(C.GetRows), 0)
C.rb_define_method(rbSnowflakeResultClass, C.CString("get_rows_no_enum"), (*[0]byte)(C.GetRowsNoEnum), 0)

C.rb_define_private_method(rbSnowflakeClientClass, C.CString("_connect"), (*[0]byte)(C.Connect), 7)
C.rb_define_method(rbSnowflakeClientClass, C.CString("inspect"), (*[0]byte)(C.Inspect), 0)
Expand Down
9 changes: 5 additions & 4 deletions ext/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ VALUE RbNumFromLong(long v) {
return LONG2NUM(v);
}

void goobj_retain(void *);
void goobj_retain(void *, char*);
void goobj_free(void *);
void goobj_log(void *);
void goobj_mark(void *);
Expand All @@ -42,9 +42,9 @@ static const rb_data_type_t go_type = {
};

VALUE
NewGoStruct(VALUE klass, void *p)
NewGoStruct(VALUE klass, char* reason, void *p)
{
goobj_retain(p);
goobj_retain(p, reason);
return TypedData_Wrap_Struct(klass, &go_type, p);
}

Expand Down Expand Up @@ -125,7 +125,8 @@ func RbString(str string) C.VALUE {
if len(str) == 0 {
return C.rb_utf8_str_new(nil, C.long(0))
}
cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0]))
//cstr := (*C.char)(unsafe.Pointer(&(*(*[]byte)(unsafe.Pointer(&str)))[0]))
cstr := C.CString(str)
return C.rb_utf8_str_new(cstr, C.long(len(str)))
}

Expand Down
18 changes: 15 additions & 3 deletions lib/ruby_snowflake_client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

module Snowflake
require "ruby_snowflake_client_ext" # build bundle of the go files
LOG_LEVEL = 0

class Error < StandardError
attr_reader :details
Expand Down Expand Up @@ -51,13 +52,24 @@ def valid?
def get_all_rows(&blk)
GC.disable
if blk
_get_rows(&blk)
while r = next_row do
yield r
end
else
_get_rows.to_a
get_rows_array
end
ensure
GC.enable
GC.start
end

private
def get_rows_array
arr = []
while r = next_row do
puts "at #{arr.length}" if arr.length % 15000 == 0 && LOG_LEVEL > 0
arr << r
end
arr
end
end
end
Loading