From f8652daad3aa2bcdffe4f1c6a86492d127afb96d Mon Sep 17 00:00:00 2001 From: MicBun Date: Thu, 21 Mar 2024 01:42:44 +0700 Subject: [PATCH] test: increase unit test coverage on extensions --- go.sum | 3 + .../extensions/basestream/basestream_test.go | 183 +++++++++++++++++- .../compose_streams/compose_streams.go | 6 +- .../compose_streams/compose_streams_test.go | 172 ++++++++++++++++ internal/extensions/mathutil/mathutil_test.go | 66 ++++++- internal/extensions/stream/stream_test.go | 4 +- .../extensions/whitelist/whitelist_test.go | 30 +++ internal/utils/utils.go | 8 +- internal/utils/utils_test.go | 8 +- 9 files changed, 460 insertions(+), 20 deletions(-) diff --git a/go.sum b/go.sum index 8c0edea86..a65d4f1ea 100644 --- a/go.sum +++ b/go.sum @@ -287,6 +287,7 @@ github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3v github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= @@ -503,6 +504,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= @@ -526,6 +528,7 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/internal/extensions/basestream/basestream_test.go b/internal/extensions/basestream/basestream_test.go index 7bae256dd..035f74516 100644 --- a/internal/extensions/basestream/basestream_test.go +++ b/internal/extensions/basestream/basestream_test.go @@ -32,7 +32,8 @@ func Test_Index(t *testing.T) { b.sqlGetRangeValue("2024-01-01", "2024-01-02"): mockDateScalar("value", []utils.ValueWithDate{ {Date: "2024-01-01", Value: 150000}, {Date: "2024-01-02", Value: 300000}, - }), // 150.000, 300.000 + }), // 150.000, 300.000 + b.sqlGetLastBefore("2024-01-01"): mockDateScalar("value", []utils.ValueWithDate{{Date: "2024-01-01", Value: 266666}}), // 266.666 } app := &common.App{ @@ -52,6 +53,9 @@ func Test_Index(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []utils.ValueWithDate{{Date: "2024-01-01", Value: 200000}, {Date: "2024-01-02", Value: 400000}}, returned) // 200.000 * 1000, 400.000 * 1000 + + returned, err = b.index(scope, app, "2024-01-01", nil) + assert.NoError(t, err) } func Test_Value(t *testing.T) { @@ -127,3 +131,180 @@ func (m *mockQuerier) Execute(ctx context.Context, tx sql.DB, dbid, query string } return res, nil } + +type baseStreamTest struct { + ctx *precompiles.DeploymentContext + scope *precompiles.ProcedureContext + app *common.App + baseStream *BaseStreamExt +} + +func newBaseStreamTest() *baseStreamTest { + return &baseStreamTest{ + ctx: &precompiles.DeploymentContext{ + Schema: &common.Schema{ + Tables: []*common.Table{ + { + Name: "price", + Columns: []*common.Column{ + { + Name: "date", + Type: common.TEXT, + }, + { + Name: "value", + Type: common.INT, + }, + }, + }, + }, + }, + }, + scope: &precompiles.ProcedureContext{}, + app: &common.App{}, + baseStream: &BaseStreamExt{}, + } +} + +func TestInitializeBasestream(t *testing.T) { + metadata := map[string]string{ + "table_name": "price", + "date_column": "date", + "value_column": "value", + } + + instance := newBaseStreamTest() + t.Run("success - it should initialize the basestream", func(t *testing.T) { + _, err := InitializeBasestream(instance.ctx, nil, metadata) + assert.NoError(t, err) + }) + + t.Run("validation - it should return an error if the table does not exist", func(t *testing.T) { + wrongMetadata := map[string]string{ + "wrong_table_name": "price", + } + _, err := InitializeBasestream(instance.ctx, nil, wrongMetadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing table") + }) + + t.Run("validation - it should return date type must be text", func(t *testing.T) { + wrongInstance := newBaseStreamTest() + wrongInstance.ctx.Schema.Tables[0].Columns[0].Type = common.INT + _, err := InitializeBasestream(wrongInstance.ctx, nil, metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "date column date must be of type TEXT") + }) + + t.Run("validation - it should return value type must be int", func(t *testing.T) { + wrongInstance := newBaseStreamTest() + wrongInstance.ctx.Schema.Tables[0].Columns[1].Type = common.TEXT + _, err := InitializeBasestream(wrongInstance.ctx, nil, metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "value column value must be of type INT") + }) + + t.Run("validation - it should return an error if the date column does not exist", func(t *testing.T) { + wrongMetadata := map[string]string{ + "table_name": "price", + "date_column": "wrong_date", + "value_column": "value", + } + _, err := InitializeBasestream(instance.ctx, nil, wrongMetadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("validation - it should return an error if the value column does not exist", func(t *testing.T) { + wrongMetadata := map[string]string{ + "table_name": "price", + "date_column": "date", + "value_column": "wrong_value", + } + _, err := InitializeBasestream(instance.ctx, nil, wrongMetadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("validation - it should return an error if the table does not exist", func(t *testing.T) { + wrongMetadata := map[string]string{ + "table_name": "wrong_table", + "date_column": "date", + } + _, err := InitializeBasestream(instance.ctx, nil, wrongMetadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestBaseStreamExt_Call(t *testing.T) { + instance := newBaseStreamTest() + mockEngine := mocks.NewEngine(t) + instance.app.Engine = mockEngine + //instance.scope.SetValue("caller", "caller") + //instance.scope.SetValue("args", "args") + + t.Run("success - it should return the index", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Execute(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mockDateScalar("value", []utils.ValueWithDate{{Date: "2024-01-01", Value: 200000}}), nil) + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-01", "2024-01-02"}) + assert.NoError(t, err) + }) + + t.Run("success - it should return the value", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Execute(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(mockDateScalar("value", []utils.ValueWithDate{{Date: "2024-01-01", Value: 150000}}), nil) + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_value", []any{"2024-01-01", "2024-01-02"}) + assert.NoError(t, err) + }) + + t.Run("validation - it should return an error if the method is unknown", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "unknown", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown method") + }) + + t.Run("validation - it should return expected 2 inputs when args are not 2", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-01"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected 2 arguments") + }) + + t.Run("validation - it should return expected string when date is not a string", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{1, "2024-01-02"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected string") + }) + + t.Run("validation - it should return invalid date_to when date_to is not a valid date", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-01", 1}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "expected string for date_to") + }) + + t.Run("validation - it should return invalid date when date is not a valid date", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"wrong_date", "2024-01-02"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid date") + }) + + t.Run("validation - it should return invalid date when date_to is not a valid date", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-01", "wrong_date"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid date") + }) + + t.Run("validation - it should return is before date when date_to is before date", func(t *testing.T) { + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-02", "2024-01-01"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is before date") + }) + + t.Run("error - it should return error when the engine returns an error", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Execute(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError) + _, err := instance.baseStream.Call(instance.scope, instance.app, "get_index", []any{"2024-01-01", "2024-01-02"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "error getting current value on db execute") + }) +} diff --git a/internal/extensions/compose_streams/compose_streams.go b/internal/extensions/compose_streams/compose_streams.go index 519386d0c..500561ec5 100644 --- a/internal/extensions/compose_streams/compose_streams.go +++ b/internal/extensions/compose_streams/compose_streams.go @@ -41,10 +41,7 @@ func InitializeStream(ctx *precompiles.DeploymentContext, service *common.Servic if err != nil { return nil, err } - DBID, err := utils.GetDBIDFromPath(ctx, dbIdOrPath) - if err != nil { - return nil, err - } + DBID := utils.GetDBIDFromPath(ctx, dbIdOrPath) totalWeight += weightInt weightMap[DBID] = weightInt } @@ -123,7 +120,6 @@ func (s *Stream) CalculateWeightedResultsWithFn(fn func(string) ([]utils.ValueWi // for each database, get the value and multiply by the weight for dbId, weight := range s.weightMap { results, err := fn(dbId) - //results, err := CallOnTargetDBID(scoper, method, dbId, date, dateTo) if err != nil { fmt.Println("error getting results from dbid", dbId, ":", err) return nil, err diff --git a/internal/extensions/compose_streams/compose_streams_test.go b/internal/extensions/compose_streams/compose_streams_test.go index d0e6b9ec4..8096fbab4 100644 --- a/internal/extensions/compose_streams/compose_streams_test.go +++ b/internal/extensions/compose_streams/compose_streams_test.go @@ -1,7 +1,13 @@ package compose_streams import ( + "errors" + "github.com/kwilteam/kwil-db/common" + "github.com/kwilteam/kwil-db/common/sql" + "github.com/kwilteam/kwil-db/extensions/precompiles" + "github.com/stretchr/testify/mock" "github.com/truflation/tsn-db/internal/utils" + "github.com/truflation/tsn-db/mocks" "reflect" "testing" @@ -83,6 +89,18 @@ func TestCalculateWeightedResultsWithFn(t *testing.T) { expected: []utils.ValueWithDate{{Date: "2024-01-01", Value: 1}, {Date: "2024-01-02", Value: 1}}, expectedError: nil, }, + { + name: "zero denominator", + weightMap: map[string]int64{ + "abc": 0, + "def": 0, + }, + fn: func(s string) ([]utils.ValueWithDate, error) { + return []utils.ValueWithDate{{Date: "2024-01-01", Value: 10}, {Date: "2024-01-02", Value: 20}}, nil + }, + expected: nil, + expectedError: errors.New("denominator cannot be zero"), + }, } for _, test := range tests { @@ -183,3 +201,157 @@ func TestFillForwardWithLatestFromCols(t *testing.T) { }) } } + +type composeStreamsTest struct { + mock.Mock + scoper *precompiles.ProcedureContext + app *common.App + stream *Stream +} + +func newComposeStreamsTest() composeStreamsTest { + return composeStreamsTest{ + Mock: mock.Mock{}, + scoper: &precompiles.ProcedureContext{}, + app: &common.App{}, + stream: &Stream{ + weightMap: map[string]int64{"dbId": 1}, + totalWeight: 1, + }, + } +} + +func TestInitializeStream(t *testing.T) { + //instance := newComposeStreamsTest() + t.Run("success - it should return Stream instance", func(t *testing.T) { + metadata := map[string]string{"key_id": "dbId", "key_weight": "1"} + _, err := InitializeStream(nil, nil, metadata) + assert.NoError(t, err, "InitializeStream returned an error") + }) + + t.Run("validation - missing weightStr for stream", func(t *testing.T) { + metadata := map[string]string{"key_id": "dbId"} + _, err := InitializeStream(nil, nil, metadata) + assert.EqualError(t, err, "missing weightStr for stream dbId") + }) + + t.Run("error - it should return error when weightStr is not a number", func(t *testing.T) { + metadata := map[string]string{"key_id": "dbId", "key_weight": "not_a_number"} + _, err := InitializeStream(nil, nil, metadata) + assert.Error(t, err, "InitializeStream did not return an error") + }) +} + +func TestCallOnTargetDBID(t *testing.T) { + instance := newComposeStreamsTest() + mockEngine := mocks.NewEngine(t) + instance.app.Engine = mockEngine + expectedResultSet := &sql.ResultSet{ + Columns: []string{"date", "value"}, + Rows: [][]interface{}{{"2023-12-31", int64(1)}}, + } + + t.Run("success - it should return nil when method is get_index", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(expectedResultSet, nil) + _, err := CallOnTargetDBID(instance.scoper, instance.app, "get_index", "targetDBID", "2023-11-01", "2023-12-31") + assert.NoError(t, err, "stream.Call returned an error") + }) + + t.Run("success - it should return nil when method is get_value", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(expectedResultSet, nil) + _, err := CallOnTargetDBID(instance.scoper, instance.app, "get_value", "targetDBID", "2023-11-01", "2023-12-31") + assert.NoError(t, err, "stream.Call returned an error") + }) + + t.Run("error - it should return error when app.Engine.Procedure returns an error", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError) + _, err := CallOnTargetDBID(instance.scoper, instance.app, "get_value", "targetDBID", "2023-11-01", "2023-12-31") + assert.Error(t, err, "stream.Call did not return an error") + assert.Contains(t, err.Error(), assert.AnError.Error()) + }) + + t.Run("validation - it should return stream returned nil error when app.Engine.Procedure returns nil", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil) + _, err := CallOnTargetDBID(instance.scoper, instance.app, "get_value", "targetDBID", "2023-11-01", "2023-12-31") + assert.Error(t, err, "stream.Call did not return an error") + assert.Contains(t, err.Error(), "stream returned nil") + }) + + t.Run("validation - it should return error getting scalar", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(&sql.ResultSet{}, nil) + _, err := CallOnTargetDBID(instance.scoper, instance.app, "get_value", "targetDBID", "wrongDate", "2023-12-31") + assert.Error(t, err, "stream.Call did not return an error") + assert.Contains(t, err.Error(), "error getting scalar") + }) +} + +func TestStream_Call(t *testing.T) { + instance := newComposeStreamsTest() + mockEngine := mocks.NewEngine(t) + instance.app.Engine = mockEngine + + t.Run("success - it should return nil when method is get_index", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + expectedResultSet := &sql.ResultSet{ + Columns: []string{"date", "value"}, + Rows: [][]interface{}{{"2023-12-30", int64(1)}, {"2023-12-31", int64(2)}}, + } + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(expectedResultSet, nil) + _, err := instance.stream.Call(instance.scoper, instance.app, "get_index", []interface{}{"2023-11-01", "2023-12-31"}) + assert.NoError(t, err, "stream.Call returned an error") + }) + + t.Run("success - it should return nil when method is get_value", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + expectedResultSet := &sql.ResultSet{ + Columns: []string{"date", "value"}, + Rows: [][]interface{}{{"2023-12-31", int64(1)}}, + } + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(expectedResultSet, nil) + _, err := instance.stream.Call(instance.scoper, instance.app, "get_value", []interface{}{"2023-11-01", "2023-12-31"}) + assert.NoError(t, err, "stream.Call returned an error") + }) + + t.Run("error - it should return error when Engine.Procedure returns error", func(t *testing.T) { + mockEngine.ExpectedCalls = nil + mockEngine.EXPECT().Procedure(mock.Anything, mock.Anything, mock.Anything).Return(nil, assert.AnError) + _, err := instance.stream.Call(instance.scoper, instance.app, "get_index", []interface{}{"2023-11-01", "2023-12-31"}) + assert.Error(t, err, "stream.Call did not return an error") + assert.Contains(t, err.Error(), assert.AnError.Error()) + }) + + t.Run("validation - it should return unknown method error", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "unknown", nil) + assert.Contains(t, err.Error(), "unknown method") + }) + + t.Run("validation - it should return error when inputs length is less than 2", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "get_index", []interface{}{}) + assert.Contains(t, err.Error(), "expected 2 inputs") + }) + + t.Run("validation - it should return error when inputs[0] is not string", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "get_index", []interface{}{1, "2023-12-31"}) + assert.Contains(t, err.Error(), "expected string") + }) + + t.Run("validation - it should return error when inputs[1] is not string", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "get_index", []interface{}{"2023-11-01", 1}) + assert.Contains(t, err.Error(), "expected string") + }) + + t.Run("validation - it should return error when inputs[0] is not valid date", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "get_index", []interface{}{"2023-11-01", "not_a_date"}) + assert.Contains(t, err.Error(), "invalid date") + }) + + t.Run("validation - it should return error when inputs[1] is not valid date", func(t *testing.T) { + _, err := instance.stream.Call(nil, nil, "get_index", []interface{}{"not_a_date", "2023-12-31"}) + assert.Contains(t, err.Error(), "invalid date") + }) +} diff --git a/internal/extensions/mathutil/mathutil_test.go b/internal/extensions/mathutil/mathutil_test.go index 798a4a773..ec23e5782 100644 --- a/internal/extensions/mathutil/mathutil_test.go +++ b/internal/extensions/mathutil/mathutil_test.go @@ -1,6 +1,9 @@ package mathutil -import "testing" +import ( + "github.com/stretchr/testify/assert" + "testing" +) func Test_Fraction(t *testing.T) { type testcase struct { @@ -47,18 +50,79 @@ func Test_Fraction(t *testing.T) { number: 9223372036854775807, want: 9223372036854775807, }, + { + name: "zero denominator", + numerator: 1, + denominator: 0, + number: 1, + want: 0, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := fraction(tt.number, tt.numerator, tt.denominator) if err != nil { + if err.Error() == "denominator cannot be zero" { + return + } t.Errorf("fraction() error = %v", err) return } + if got[0] != tt.want { t.Errorf("fraction() = %v, want %v", got, tt.want) } }) } } + +func TestInitializeMathUtil(t *testing.T) { + t.Run("success - it should return mathUtilExt instance", func(t *testing.T) { + _, err := InitializeMathUtil(nil, nil, nil) + assert.NoError(t, err, "InitializeMathUtil returned an error") + }) + + t.Run("error - it should return error when metadata is not empty", func(t *testing.T) { + _, err := InitializeMathUtil(nil, nil, map[string]string{"key": "value"}) + assert.EqualError(t, err, "mathutil does not take any configs") + }) +} + +func TestMathUtilExt_Call(t *testing.T) { + t.Run("success - it should return nil when method is fraction", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "fraction", []any{int64(1), int64(2), int64(2)}) + assert.NoError(t, err, "mathUtilExt.Call returned an error") + }) + + t.Run("validation - it should return error when method is unknown", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "unknown", nil) + assert.Contains(t, err.Error(), "unknown method") + }) + + t.Run("validation - it should return error when inputs length is less than 3", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "fraction", []any{}) + assert.Contains(t, err.Error(), "expected 3 inputs") + }) + + t.Run("validation - it should return error when inputs[0] is not int64", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "fraction", []any{"string", int64(2), int64(2)}) + assert.Contains(t, err.Error(), "expected int64 for arg 1") + }) + + t.Run("validation - it should return error when inputs[1] is not int64", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "fraction", []any{int64(1), "string", int64(2)}) + assert.Contains(t, err.Error(), "expected int64 for arg 2") + }) + + t.Run("validation - it should return error when inputs[2] is not int64", func(t *testing.T) { + instance := &mathUtilExt{} + _, err := instance.Call(nil, nil, "fraction", []any{int64(1), int64(2), "string"}) + assert.Contains(t, err.Error(), "expected int64 for arg 3") + }) +} diff --git a/internal/extensions/stream/stream_test.go b/internal/extensions/stream/stream_test.go index 50bc15eaa..b0e2a481f 100644 --- a/internal/extensions/stream/stream_test.go +++ b/internal/extensions/stream/stream_test.go @@ -17,7 +17,7 @@ type streamTest struct { app *common.App } -func newStreamTest(t *testing.T) streamTest { +func newStreamTest() streamTest { return streamTest{ stream: &stream.Stream{}, scoper: &precompiles.ProcedureContext{}, @@ -39,7 +39,7 @@ func TestInitializeStream(t *testing.T) { } func TestStream_Call(t *testing.T) { - instance := newStreamTest(t) + instance := newStreamTest() mockEngine := mocks.NewEngine(t) instance.app.Engine = mockEngine diff --git a/internal/extensions/whitelist/whitelist_test.go b/internal/extensions/whitelist/whitelist_test.go index acb92ab11..ccfcb478c 100644 --- a/internal/extensions/whitelist/whitelist_test.go +++ b/internal/extensions/whitelist/whitelist_test.go @@ -4,6 +4,7 @@ import ( "encoding/hex" "github.com/kwilteam/kwil-db/common" "github.com/kwilteam/kwil-db/extensions/precompiles" + "github.com/stretchr/testify/assert" "reflect" "sort" "testing" @@ -116,6 +117,7 @@ func TestInitializeExtension(t *testing.T) { } invalidAddress := "notgood" + invalidAddress2 := "000000000000000000000000000000000000000001" var ctx = &precompiles.DeploymentContext{Schema: &common.Schema{Owner: byteOwner}} tests := []struct { @@ -154,6 +156,18 @@ func TestInitializeExtension(t *testing.T) { false, []string{ownerAddress, validAddress, validAddress2}, }, + { + "Address not start with 0x", + map[string]string{"whitelist_wallets": invalidAddress2}, + true, + nil, + }, + { + "wrong metadata", + map[string]string{"wrong": "wrong"}, + true, + nil, + }, } for _, tt := range tests { @@ -177,3 +191,19 @@ func TestInitializeExtension(t *testing.T) { }) } } + +func TestWhitelistExt_Call(t *testing.T) { + t.Run("success - it should return nil when method is check", func(t *testing.T) { + instance := &WhitelistExt{ + whitelistedWallets: []string{"wallet1", "wallet2"}, + } + _, err := instance.Call(nil, nil, "check", []interface{}{"wallet1"}) + assert.NoError(t, err, "WhitelistExt.Call returned an error") + }) + + t.Run("validation - it should return error when method is unknown", func(t *testing.T) { + instance := &WhitelistExt{} + _, err := instance.Call(nil, nil, "unknown", nil) + assert.Contains(t, err.Error(), "unknown method") + }) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index ac372a9eb..c05178603 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -16,10 +16,10 @@ import ( // - xac760c4d5332844f0da28c01adb53c6c369be0a2c4bf530a0f3366bd (DBID) // - / // - / (will use the wallet address from the scoper) -func GetDBIDFromPath(ctx *precompiles.DeploymentContext, pathOrDBID string) (string, error) { +func GetDBIDFromPath(ctx *precompiles.DeploymentContext, pathOrDBID string) string { // if the path does not contain a "/", we assume it is a DBID if !strings.Contains(pathOrDBID, "/") { - return pathOrDBID, nil + return pathOrDBID } var walletAddress []byte @@ -41,7 +41,7 @@ func GetDBIDFromPath(ctx *precompiles.DeploymentContext, pathOrDBID string) (str DBID := utils.GenerateDBID(dbName, walletAddress) - return DBID, nil + return DBID } func Fraction(number int64, numerator int64, denominator int64) (int64, error) { @@ -76,7 +76,7 @@ type ValueWithDate struct { // Else, it will return an error. func GetScalarWithDate(res *sql.ResultSet) ([]ValueWithDate, error) { if len(res.Columns) != 2 { - return nil, fmt.Errorf("stream expected one column, got %d", len(res.Columns)) + return nil, fmt.Errorf("stream expected two column, got %d", len(res.Columns)) } if len(res.Rows) == 0 { return []ValueWithDate{}, nil diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index f784b2cde..4e57e84a7 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -41,13 +41,7 @@ func TestGetDBIDFromPath(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - dbID, err := GetDBIDFromPath(tt.ctx, tt.pathOrDBID) - if err != nil && !tt.expectedError { - t.Fatalf("unexpected error: %v", err) - } - if tt.expectedError && err == nil { - t.Fatal("expected an error but got nil") - } + dbID := GetDBIDFromPath(tt.ctx, tt.pathOrDBID) if dbID != tt.expectedDBID { t.Errorf("DBID mismatch - want: %v, got: %v", tt.expectedDBID, dbID) }