Skip to content

Commit

Permalink
Add Scan method for ResultSet to scan the rows into the given slice (…
Browse files Browse the repository at this point in the history
…Part 1) (#298)

* Add Scan for ResultSet

* drop supports for Go 1.16 and 1.17

Update test.yaml

* Revert "drop supports for Go 1.16 and 1.17"

This reverts commit 2478f44.

* Supports Go 1.16+
  • Loading branch information
Xin Hao committed Jan 11, 2024
1 parent 6eb07e9 commit 3fe0bb1
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_label.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
- labeled
- unlabeled
- closed

env:
GH_PAT: ${{ secrets.GITHUB_TOKEN }}
EVENT: ${{ toJSON(github.event)}}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ on:
- cron: "0 6 * * *"

jobs:
lint:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Lint
run: |
make lint
go-client:
runs-on: ubuntu-latest
strategy:
Expand Down
74 changes: 74 additions & 0 deletions result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ package nebula_go
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"reflect"
"regexp"
"sort"
"strings"
Expand Down Expand Up @@ -295,6 +297,78 @@ func (res ResultSet) GetRowValuesByIndex(index int) (*Record, error) {
}, nil
}

// Scan scans the rows into the given value.
func (res ResultSet) Scan(v interface{}) error {
size := res.GetRowSize()
if size == 0 {
return nil
}

rv := reflect.ValueOf(v)
switch {
case rv.Kind() != reflect.Ptr:
if t := reflect.TypeOf(v); t != nil {
return fmt.Errorf("scan: Scan(non-pointer %s)", t)
}
fallthrough
case rv.IsNil():
return fmt.Errorf("scan: Scan(nil)")
}
rv = reflect.Indirect(rv)
if k := rv.Kind(); k != reflect.Slice {
return fmt.Errorf("scan: invalid type %s. expected slice as an argument", k)
}

colNames := res.GetColNames()
rows := res.GetRows()

t := reflect.TypeOf(v).Elem().Elem()
for _, row := range rows {
vv, err := res.scanRow(row, colNames, t)
if err != nil {
return err
}
rv.Set(reflect.Append(rv, vv))
}

return nil
}

// Scan scans the rows into the given value.
func (res ResultSet) scanRow(row *nebula.Row, colNames []string, t reflect.Type) (reflect.Value, error) {
rowValues := row.GetValues()

val := reflect.New(t).Elem()

for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get("nebula")

if tag == "" {
continue
}

i := IndexOf(colNames, tag)
if i == -1 {
// It is possible that the tag is not in the result set
continue
}

rowVal := rowValues[i]

switch f.Type.Kind() {
case reflect.Int64:
val.Field(i).SetInt(rowVal.GetIVal())
case reflect.String:
val.Field(i).SetString(string(rowVal.GetSVal()))
default:
return val, errors.New("scan: not support type")
}
}

return val, nil
}

// Returns the number of total rows
func (res ResultSet) GetRowSize() int {
if res.resp.Data == nil {
Expand Down
32 changes: 32 additions & 0 deletions result_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,38 @@ func TestAsStringTable(t *testing.T) {
}
}

func TestScan(t *testing.T) {
resp := &graph.ExecutionResponse{
nebula.ErrorCode_SUCCEEDED,
1000,
getDateset(),
[]byte("test_space"),
[]byte("test"),
graph.NewPlanDescription(),
[]byte("test_comment")}
resultSet, err := genResultSet(resp, testTimezone)
if err != nil {
t.Error(err)
}

type testStruct struct {
Col0 int64 `nebula:"col0_int"`
Col1 string `nebula:"col1_string"`
// Col2 Node `nebula:"col2_vertex"`
// Col3 Relationship `nebula:"col3_edge"`
// Col4 PathWrapper `nebula:"col4_path"`
}

var testStructList []testStruct
err = resultSet.Scan(&testStructList)
if err != nil {
t.Error(err)
}
assert.Equal(t, 1, len(testStructList))
assert.Equal(t, int64(1), testStructList[0].Col0)
assert.Equal(t, "value1", testStructList[0].Col1)
}

func TestIntVid(t *testing.T) {
vertex := getVertexInt(101, 3, 5)
node, err := genNode(vertex, testTimezone)
Expand Down
11 changes: 11 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package nebula_go

func IndexOf(collection []string, element string) int {
for i, item := range collection {
if item == element {
return i
}
}

return -1
}
15 changes: 15 additions & 0 deletions util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package nebula_go

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestUtil_IndexOf(t *testing.T) {
collection := []string{"a", "b", "c"}
assert.Equal(t, IndexOf(collection, "a"), 0)
assert.Equal(t, IndexOf(collection, "b"), 1)
assert.Equal(t, IndexOf(collection, "c"), 2)
assert.Equal(t, IndexOf(collection, "d"), -1)
}

0 comments on commit 3fe0bb1

Please sign in to comment.