Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions extensions/tn_local/create_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package tn_local

import (
"context"
"fmt"
"testing"

"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
jsonrpc "github.com/trufnetwork/kwil-db/core/rpc/json"
kwilsql "github.com/trufnetwork/kwil-db/node/types/sql"
"github.com/trufnetwork/node/tests/utils"
)

func TestCreateStream_NilRequest(t *testing.T) {
ext := newTestExtension(&utils.MockDB{})

resp, rpcErr := ext.CreateStream(context.Background(), nil)
require.Nil(t, resp)
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, "missing request")
}

func TestCreateStream_Success(t *testing.T) {
var capturedStmt string
var capturedArgs []any
mockDB := &utils.MockDB{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
capturedStmt = stmt
capturedArgs = args
return &kwilsql.ResultSet{}, nil
},
}
ext := newTestExtension(mockDB)

resp, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "primitive",
})

require.Nil(t, rpcErr, "expected no error")
require.NotNil(t, resp)
require.Contains(t, capturedStmt, "INSERT INTO "+SchemaName+".streams")
require.Len(t, capturedArgs, 4, "INSERT should have 4 parameters")
// data_provider should be lowercased (matching consensus behavior)
require.Equal(t, "0xec36224a679218ae28fcece8d3c68595b87dd832", capturedArgs[0])
require.Equal(t, "st00000000000000000000000000test", capturedArgs[1])
require.Equal(t, "primitive", capturedArgs[2])
// created_at should be a non-zero unix timestamp
createdAt, ok := capturedArgs[3].(int64)
require.True(t, ok, "created_at should be int64")
require.NotZero(t, createdAt, "created_at should be non-zero")
}

func TestCreateStream_ComposedType(t *testing.T) {
mockDB := &utils.MockDB{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
return &kwilsql.ResultSet{}, nil
},
}
ext := newTestExtension(mockDB)

resp, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "composed",
})

require.Nil(t, rpcErr)
require.NotNil(t, resp)
}

func TestCreateStream_InvalidStreamID(t *testing.T) {
ext := newTestExtension(&utils.MockDB{})

tests := []struct {
name string
streamID string
wantMsg string
}{
{"too short", "st00", "must be exactly 32 characters"},
{"too long", "st000000000000000000000000000test1", "must be exactly 32 characters"},
{"wrong prefix", "xx00000000000000000000000000test", "must start with 'st'"},
{"empty", "", "must be exactly 32 characters"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: tt.streamID,
StreamType: "primitive",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, tt.wantMsg)
})
}
}

func TestCreateStream_InvalidStreamType(t *testing.T) {
ext := newTestExtension(&utils.MockDB{})

_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "invalid",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, "must be 'primitive' or 'composed'")
}

func TestCreateStream_InvalidDataProvider(t *testing.T) {
ext := newTestExtension(&utils.MockDB{})

tests := []struct {
name string
dataProvider string
}{
{"no 0x prefix", "EC36224A679218Ae28FCeCe8d3c68595B87Dd832"},
{"too short", "0xEC36224A679218Ae28"},
{"too long", "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832FF"},
{"invalid chars", "0xGG36224A679218Ae28FCeCe8d3c68595B87Dd832"},
{"empty", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: tt.dataProvider,
StreamID: "st00000000000000000000000000test",
StreamType: "primitive",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, "data_provider must be a valid Ethereum address")
})
}
}

func TestCreateStream_DuplicateStream(t *testing.T) {
mockDB := &utils.MockDB{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
return nil, fmt.Errorf("duplicate key value violates unique constraint")
},
}
ext := newTestExtension(mockDB)

_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "primitive",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, "stream already exists")
}

func TestCreateStream_DuplicateStream_PgError(t *testing.T) {
mockDB := &utils.MockDB{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
return nil, &pgconn.PgError{Code: pgUniqueViolation, Message: "unique_violation"}
},
}
ext := newTestExtension(mockDB)

_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "primitive",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInvalidParams), rpcErr.Code)
require.Contains(t, rpcErr.Message, "stream already exists")
}

func TestCreateStream_DBError(t *testing.T) {
mockDB := &utils.MockDB{
ExecuteFn: func(ctx context.Context, stmt string, args ...any) (*kwilsql.ResultSet, error) {
return nil, fmt.Errorf("connection refused")
},
}
ext := newTestExtension(mockDB)

_, rpcErr := ext.CreateStream(context.Background(), &CreateStreamRequest{
DataProvider: "0xEC36224A679218Ae28FCeCe8d3c68595B87Dd832",
StreamID: "st00000000000000000000000000test",
StreamType: "primitive",
})
require.NotNil(t, rpcErr)
require.Equal(t, jsonrpc.ErrorCode(jsonrpc.ErrorInternal), rpcErr.Code)
require.Contains(t, rpcErr.Message, "failed to create stream")
}
47 changes: 47 additions & 0 deletions extensions/tn_local/db_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,53 @@ func (ext *Extension) dbCreateStream(ctx context.Context, dataProvider, streamID
return err
}

// dbLookupStreamRef looks up a stream by data_provider and stream_id.
// Returns (id, stream_type, nil) if found, or (0, "", nil) if not found.
func (ext *Extension) dbLookupStreamRef(ctx context.Context, dataProvider, streamID string) (int64, string, error) {
rs, err := ext.db.Execute(ctx, fmt.Sprintf(
`SELECT id, stream_type FROM %s.streams WHERE data_provider = $1 AND stream_id = $2`, SchemaName),
dataProvider, streamID)
if err != nil {
return 0, "", err
}
if len(rs.Rows) == 0 {
return 0, "", nil
}
id, ok := rs.Rows[0][0].(int64)
if !ok {
return 0, "", fmt.Errorf("unexpected id type: %T", rs.Rows[0][0])
}
streamType, ok := rs.Rows[0][1].(string)
if !ok {
return 0, "", fmt.Errorf("unexpected stream_type type: %T", rs.Rows[0][1])
}
return id, streamType, nil
}

// dbInsertRecords batch-inserts resolved records into ext_tn_local.primitive_events
// within a transaction. Mirrors the consensus INSERT in 003-primitive-insertion.sql.
func (ext *Extension) dbInsertRecords(ctx context.Context, streamRefs []int64, eventTimes []int64, values []string) error {
createdAt := time.Now().Unix()

tx, err := ext.db.BeginTx(ctx)
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func() { _ = tx.Rollback(ctx) }()

for i := range streamRefs {
_, err := tx.Execute(ctx, fmt.Sprintf(
`INSERT INTO %s.primitive_events (stream_ref, event_time, value, created_at)
VALUES ($1, $2, $3, $4)`, SchemaName),
streamRefs[i], eventTimes[i], values[i], createdAt)
if err != nil {
return err
}
}

return tx.Commit(ctx)
}

// SetupSchema creates the ext_tn_local schema and all tables within a single transaction.
func (l *LocalDB) SetupSchema(ctx context.Context) error {
l.logger.Info("setting up local storage schema")
Expand Down
89 changes: 87 additions & 2 deletions extensions/tn_local/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"errors"
"fmt"
"math"
"regexp"
"strconv"
"strings"

"github.com/jackc/pgx/v5/pgconn"
Expand Down Expand Up @@ -88,9 +90,92 @@ func (ext *Extension) CreateStream(ctx context.Context, req *CreateStreamRequest
return &CreateStreamResponse{}, nil
}

// InsertRecords inserts records into a local primitive stream. (Task 4)
// InsertRecords inserts records into local primitive streams.
// Mirrors the consensus insert_records action (003-primitive-insertion.sql):
// - Parallel arrays: data_provider[], stream_id[], event_time[], value[]
// - Zero values are silently filtered (WHERE value != 0)
// - Multiple rows per (stream_ref, event_time) allowed (created_at versioning)
// - Returns empty response (consensus returns nothing)
func (ext *Extension) InsertRecords(ctx context.Context, req *InsertRecordsRequest) (*InsertRecordsResponse, *jsonrpc.Error) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "not implemented", nil)
if req == nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, "missing request", nil)
}

n := len(req.DataProvider)
if n == 0 {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, "records must not be empty", nil)
}
if n != len(req.StreamID) || n != len(req.EventTime) || n != len(req.Value) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, "array lengths mismatch", nil)
}

// Normalize data_providers to lowercase (consensus uses LOWER() in 001-common-actions.sql).
for i := range req.DataProvider {
req.DataProvider[i] = strings.ToLower(req.DataProvider[i])
}

// Validate all inputs upfront.
for i := 0; i < n; i++ {
if err := validateDataProvider(req.DataProvider[i]); err != nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("record %d: %v", i, err), nil)
}
if err := validateStreamID(req.StreamID[i]); err != nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("record %d: %v", i, err), nil)
}
f, parseErr := strconv.ParseFloat(req.Value[i], 64)
if parseErr != nil {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("invalid record value at index %d: %v", i, parseErr), nil)
}
if math.IsNaN(f) || math.IsInf(f, 0) {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("invalid record value at index %d: must be a finite number", i), nil)
}
}

// Resolve stream refs for unique (data_provider, stream_id) pairs.
type streamKey struct{ dp, sid string }
streamRefMap := make(map[streamKey]int64)
for i := 0; i < n; i++ {
key := streamKey{req.DataProvider[i], req.StreamID[i]}
if _, ok := streamRefMap[key]; ok {
continue
}
ref, stype, err := ext.dbLookupStreamRef(ctx, key.dp, key.sid)
if err != nil {
ext.logger.Error("failed to look up stream", "error", err)
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "failed to look up stream", nil)
}
if ref == 0 {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("stream not found: %s/%s", key.dp, key.sid), nil)
}
if stype != "primitive" {
return nil, jsonrpc.NewError(jsonrpc.ErrorInvalidParams, fmt.Sprintf("stream %s/%s is not a primitive stream", key.dp, key.sid), nil)
}
streamRefMap[key] = ref
}

// Build resolved records, filtering zero values (mirrors consensus WHERE value != 0).
streamRefs := make([]int64, 0, n)
eventTimes := make([]int64, 0, n)
values := make([]string, 0, n)
for i := 0; i < n; i++ {
f, _ := strconv.ParseFloat(req.Value[i], 64)
if f == 0 {
continue
}
key := streamKey{req.DataProvider[i], req.StreamID[i]}
streamRefs = append(streamRefs, streamRefMap[key])
eventTimes = append(eventTimes, req.EventTime[i])
values = append(values, req.Value[i])
}

if len(streamRefs) > 0 {
if err := ext.dbInsertRecords(ctx, streamRefs, eventTimes, values); err != nil {
ext.logger.Error("failed to insert records", "error", err)
return nil, jsonrpc.NewError(jsonrpc.ErrorInternal, "failed to insert records", nil)
}
}

return &InsertRecordsResponse{}, nil
}

// InsertTaxonomy adds a taxonomy entry to a local composed stream. (Task 5)
Expand Down
Loading
Loading