Skip to content
This repository was archived by the owner on Jan 9, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
### limits ssh access and adds the ssh public key for the user which triggered the workflow
#limit-access-to-actor: true
- name: Install gem
run: cd pkg && gem install --local *.gem --verbose
run: cd pkg && gem install --local *.gem
- name: Run tests
run: ruby -rruby_snowflake_client -S rspec spec/**/*_spec.rb -cfdoc
env: # Or as an environment variable
Expand Down
54 changes: 0 additions & 54 deletions ext/c-decl.go

This file was deleted.

98 changes: 98 additions & 0 deletions ext/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package main

/*
#include <stdlib.h>
#include "ruby/ruby.h"

void RbGcGuard(VALUE ptr);
VALUE ReturnEnumerator(VALUE cls);
VALUE RbNumFromDouble(double v);
*/
import "C"

import (
"context"
"database/sql"
"fmt"
"strings"
"time"

sf "github.com/snowflakedb/gosnowflake"
)

type SnowflakeClient struct {
db *sql.DB
}

func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
t1 := time.Now()

if LOG_LEVEL > 0 {
fmt.Println("statement", RbGoString(statement))
}
rows, err := x.db.QueryContext(sf.WithHigherPrecision(context.Background()), RbGoString(statement))
if err != nil {
result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
errStr := fmt.Sprintf("Query error: '%s'", err.Error())
C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr))
return result
}

duration := time.Now().Sub(t1).Seconds()
if LOG_LEVEL > 0 {
fmt.Printf("Query duration: %s\n", time.Now().Sub(t1))
}
if err != nil {
result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
errStr := fmt.Sprintf("Query error: '%s'", err.Error())
C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr))
return result
}

result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
cols, _ := rows.Columns()
for idx, col := range cols {
col := col
cols[idx] = strings.ToLower(col)
}
rs := SnowflakeResult{rows, cols}
resultMap[result] = &rs
C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration)))
return result
}

//export Connect
func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE, schema C.VALUE, user C.VALUE, password C.VALUE, role C.VALUE) {
// other optional parms: Application, Host, and alt auth schemes
cfg := &sf.Config{
Account: RbGoString(account),
Warehouse: RbGoString(warehouse),
Database: RbGoString(database),
Schema: RbGoString(schema),
User: RbGoString(user),
Password: RbGoString(password),
Role: RbGoString(role),
Port: int(443),
}

dsn, err := sf.DSN(cfg)
if err != nil {
errStr := fmt.Sprintf("Snowflake Config Creation Error: '%s'", err.Error())
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}

db, err := sql.Open("snowflake", dsn)
if err != nil {
errStr := fmt.Sprintf("Connection Error: '%s'", err.Error())
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}
rs := SnowflakeClient{db}
clientRef[self] = &rs
}

//export ObjFetch
func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE {
x, _ := clientRef[self]

return x.Fetch(statement)
}
8 changes: 6 additions & 2 deletions ext/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ package main
#include "ruby/ruby.h"

VALUE ReturnEnumerator(VALUE cls);
VALUE createRbString(char* str);
VALUE funcall0param(VALUE obj, ID id);
*/
import "C"

import (
"database/sql"
"fmt"
"math/big"
"time"

gopointer "github.com/mattn/go-pointer"
)

type SnowflakeResult struct {
rows *sql.Rows
columns []string
}

func wrapRbRaise(err error) {
fmt.Printf("[ruby-snowflake-client] Error encountered: %s\n", err.Error())
fmt.Printf("[ruby-snowflake-client] Will call `rb_raise`\n")
Expand Down
92 changes: 0 additions & 92 deletions ext/ruby_snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,16 @@ VALUE Inspect(VALUE);
VALUE GetRows(VALUE);
VALUE GetRowsNoEnum(VALUE);

VALUE NewGoStruct(VALUE klass, char* reason, void *p);
VALUE GoRetEnum(VALUE,int,VALUE);
void* GetGoStruct(VALUE obj);
void RbGcGuard(VALUE ptr);
VALUE ReturnEnumerator(VALUE cls);
VALUE RbNumFromDouble(double v);
*/
import "C"

import (
"context"
"database/sql"
"fmt"
"strings"
"time"

sf "github.com/snowflakedb/gosnowflake"
)

type SnowflakeResult struct {
rows *sql.Rows
//keptHash C.VALUE
columns []string
//cols []C.VALUE
}
type SnowflakeClient struct {
db *sql.DB
}

var rbSnowflakeClientClass C.VALUE
var rbSnowflakeResultClass C.VALUE
var rbSnowflakeModule C.VALUE
Expand All @@ -54,79 +35,6 @@ var clientRef = make(map[C.VALUE]*SnowflakeClient)
var LOG_LEVEL = 0
var empty C.VALUE = C.Qnil

//export Connect
func Connect(self C.VALUE, account C.VALUE, warehouse C.VALUE, database C.VALUE, schema C.VALUE, user C.VALUE, password C.VALUE, role C.VALUE) {
// other optional parms: Application, Host, and alt auth schemes
cfg := &sf.Config{
Account: RbGoString(account),
Warehouse: RbGoString(warehouse),
Database: RbGoString(database),
Schema: RbGoString(schema),
User: RbGoString(user),
Password: RbGoString(password),
Role: RbGoString(role),
Port: int(443),
}

dsn, err := sf.DSN(cfg)
if err != nil {
errStr := fmt.Sprintf("Snowflake Config Creation Error: '%s'", err.Error())
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}

db, err := sql.Open("snowflake", dsn)
if err != nil {
errStr := fmt.Sprintf("Connection Error: '%s'", err.Error())
C.rb_ivar_set(self, ERROR_IDENT, RbString(errStr))
}
rs := SnowflakeClient{db}
clientRef[self] = &rs
}

func (x SnowflakeClient) Fetch(statement C.VALUE) C.VALUE {
t1 := time.Now()

if LOG_LEVEL > 0 {
fmt.Println("statement", RbGoString(statement))
}
rows, err := x.db.QueryContext(sf.WithHigherPrecision(context.Background()), RbGoString(statement))
if err != nil {
result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
errStr := fmt.Sprintf("Query error: '%s'", err.Error())
C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr))
return result
}

duration := time.Now().Sub(t1).Seconds()
if LOG_LEVEL > 0 {
fmt.Printf("Query duration: %s\n", time.Now().Sub(t1))
}
if err != nil {
result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
errStr := fmt.Sprintf("Query error: '%s'", err.Error())
C.rb_ivar_set(result, ERROR_IDENT, RbString(errStr))
return result
}

result := C.rb_class_new_instance(0, &empty, rbSnowflakeResultClass)
cols, _ := rows.Columns()
for idx, col := range cols {
col := col
cols[idx] = strings.ToLower(col)
}
rs := SnowflakeResult{rows, cols}
resultMap[result] = &rs
C.rb_ivar_set(result, RESULT_DURATION, RbNumFromDouble(C.double(duration)))
return result
}

//export ObjFetch
func ObjFetch(self C.VALUE, statement C.VALUE) C.VALUE {
x, _ := clientRef[self]

return x.Fetch(statement)
}

//export Inspect
func Inspect(self C.VALUE) C.VALUE {
x := clientRef[self]
Expand Down
47 changes: 0 additions & 47 deletions ext/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,14 @@ VALUE RbNumFromLong(long v) {
return LONG2NUM(v);
}

void goobj_retain(void *, char*);
void goobj_free(void *);
void goobj_log(void *);
void goobj_mark(void *);
void goobj_compact(void *);

static const rb_data_type_t go_type = {
"GoStruct",
{
goobj_mark,
goobj_free,
NULL,
(goobj_compact),
},
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
};

VALUE
NewGoStruct(VALUE klass, char* reason, void *p)
{
goobj_retain(p, reason);
return TypedData_Wrap_Struct(klass, &go_type, p);
}

VALUE ReturnEnumerator(VALUE cls) {
RETURN_ENUMERATOR(cls, 0, NULL);
return Qnil;
}

void *
GetGoStruct(VALUE obj)
{
void *val;
return TypedData_Get_Struct(obj, void *, &go_type, val);
}

void RbGcGuard(VALUE ptr) {
RB_GC_GUARD(ptr);
}

VALUE createRbString(char* str) {
volatile VALUE rbStr;
rbStr = rb_tainted_str_new_cstr(str);
return rbStr;
}

VALUE funcall0param(VALUE obj, ID id) {
RB_GC_GUARD(obj);
return rb_funcall(obj, id, 0);
}

*/
import "C"
import (
Expand All @@ -90,10 +47,6 @@ func RbNumFromDouble(v C.double) C.VALUE {
return C.RbNumFromDouble(v)
}

func GetGoStruct(obj C.VALUE) unsafe.Pointer {
return C.GetGoStruct(obj)
}

func returnEnum(cls C.VALUE) C.VALUE {
return C.ReturnEnumerator(cls)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/ruby_snowflake_client/version.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module RubySnowflakeClient
VERSION = '1.1.1'
VERSION = '1.2.0'
end
Loading