Skip to content

Commit

Permalink
SNOW-832885 Add arrow_batches example
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Jun 14, 2023
1 parent 4730ed9 commit 9d30810
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmd/arrow/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
arrow_batches
16 changes: 16 additions & 0 deletions cmd/arrow/batches/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
include ../../../gosnowflake.mak
CMD_TARGET=arrow_batches

## Install
install: cinstall

## Run
run: crun

## Lint
lint: clint

## Format source codes
fmt: cfmt

.PHONY: install run lint fmt
156 changes: 156 additions & 0 deletions cmd/arrow/batches/arrow_batches.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package main

import (
"context"
"database/sql"
"database/sql/driver"
"flag"
"fmt"
"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/apache/arrow/go/v12/arrow/memory"
"log"
"os"
"strconv"
"sync"

sf "github.com/snowflakedb/gosnowflake"
)

func getDSN() (string, *sf.Config, error) {
env := func(key string, failOnMissing bool) string {
if value := os.Getenv(key); key != "" {
return value
}
if failOnMissing {
log.Fatalf("%v environment variable is not set", key)
}
return ""
}

account := env("SNOWFLAKE_TEST_ACCOUNT", true)
user := env("SNOWFLAKE_TEST_USER", true)
password := env("SNOWFLAKE_TEST_PASSWORD", true)
host := env("SNOWFLAKE_TEST_HOST", false)
portStr := env("SNOWFLAKE_TEST_PORT", false)
protocol := env("SNOWFLAKE_TEST_PROTOCOL", false)

var err error
port := 443
if portStr != "" {
port, err = strconv.Atoi(portStr)
if err != nil {
return "", nil, err
}
}

cfg := &sf.Config{
Account: account,
User: user,
Password: password,
Host: host,
Port: port,
Protocol: protocol,
}

dsn, err := sf.DSN(cfg)
return dsn, cfg, err
}

type sampleRecord struct {
batchID int
workerID int
number int32
string string
}

func (s sampleRecord) String() string {
return fmt.Sprintf("batchID: %v, workerID: %v, number: %v, string: %v", s.batchID, s.workerID, s.number, s.string)
}

func main() {
if !flag.Parsed() {
flag.Parse()
}

dsn, cfg, err := getDSN()
if err != nil {
log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err)
}

ctx := sf.WithArrowAllocator(sf.WithArrowBatches(context.Background()), memory.DefaultAllocator)
query := "SELECT SEQ4(), 'example ' || (SEQ4() * 2) FROM TABLE(GENERATOR(ROWCOUNT=>30000))"

db, err := sql.Open("snowflake", dsn)
if err != nil {
log.Fatalf("failed to connect. %v, err: %v", dsn, err)
}
defer db.Close()

conn, _ := db.Conn(ctx)
defer conn.Close()

var rows driver.Rows
err = conn.Raw(func(x interface{}) error {
rows, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil)
return err
})
if err != nil {
log.Fatalf("unable to run the query. err: %v", err)
}
defer rows.Close()

batches, err := rows.(sf.SnowflakeRows).GetArrowBatches()
batchIds := make(chan int, 1)
maxWorkers := len(batches)
sampleRecordsPerBatch := make([][]sampleRecord, len(batches))

var waitGroup sync.WaitGroup
for workerID := 0; workerID < maxWorkers; workerID++ {
waitGroup.Add(1)
go func(waitGroup *sync.WaitGroup, batchIDs chan int, workerId int) {
defer waitGroup.Done()

for batchID := range batchIDs {
records, err := batches[batchID].Fetch()
if err != nil {
log.Fatalf("Error while fetching batch %v: %v", batchID, err)
}
sampleRecordsPerBatch[batchID] = make([]sampleRecord, batches[batchID].GetRowCount())
totalRowID := 0
convertFromColumnsToRows(records, sampleRecordsPerBatch, batchID, workerId, totalRowID)
}
}(&waitGroup, batchIds, workerID)
}

for batchID := 0; batchID < len(batches); batchID++ {
batchIds <- batchID
}
close(batchIds)
waitGroup.Wait()

for _, batchSampleRecords := range sampleRecordsPerBatch {
for _, sampleRecord := range batchSampleRecords {
fmt.Println(sampleRecord)
}
}
for batchID, batch := range batches {
fmt.Printf("BatchId: %v, number of records: %v\n", batchID, batch.GetRowCount())
}
}

func convertFromColumnsToRows(records *[]arrow.Record, sampleRecordsPerBatch [][]sampleRecord, batchID int,
workerID int, totalRowID int) {
for _, record := range *records {
for rowID, intColumn := range record.Column(0).(*array.Int32).Int32Values() {
sampleRecord := sampleRecord{
batchID: batchID,
workerID: workerID,
number: intColumn,
string: record.Column(1).(*array.String).Value(rowID),
}
sampleRecordsPerBatch[batchID][totalRowID] = sampleRecord
totalRowID++
}
}
}

0 comments on commit 9d30810

Please sign in to comment.