From a98e53c776cb5cd170d2611eba014ee4716ef90b Mon Sep 17 00:00:00 2001 From: iuwqyir Date: Wed, 30 Jul 2025 15:54:48 +0300 Subject: [PATCH 1/2] improve start block determination for poller --- internal/orchestrator/poller.go | 14 +- internal/orchestrator/poller_test.go | 613 +++++++++++++++++++++++++++ 2 files changed, 626 insertions(+), 1 deletion(-) create mode 100644 internal/orchestrator/poller_test.go diff --git a/internal/orchestrator/poller.go b/internal/orchestrator/poller.go index 1f93597..a1cca21 100644 --- a/internal/orchestrator/poller.go +++ b/internal/orchestrator/poller.go @@ -85,8 +85,20 @@ func NewPoller(rpc rpc.IRPCClient, storage storage.IStorage, opts ...PollerOptio if err != nil || highestBlockFromStaging == nil || highestBlockFromStaging.Sign() <= 0 { log.Warn().Err(err).Msgf("No last polled block found, setting to %s", lastPolledBlock.String()) } else { - lastPolledBlock = highestBlockFromStaging log.Debug().Msgf("Last polled block found in staging: %s", lastPolledBlock.String()) + if highestBlockFromStaging.Cmp(pollFromBlock) > 0 { + log.Debug().Msgf("Staging block %s is higher than configured start block %s", highestBlockFromStaging.String(), pollFromBlock.String()) + lastPolledBlock = highestBlockFromStaging + } + } + highestBlockFromMainStorage, err := storage.MainStorage.GetMaxBlockNumber(rpc.GetChainID()) + if err != nil { + log.Error().Err(err).Msg("Error getting last block in main storage") + } else { + if highestBlockFromMainStorage.Cmp(pollFromBlock) > 0 { + log.Debug().Msgf("Main storage block %s is higher than configured start block %s", highestBlockFromMainStorage.String(), pollFromBlock.String()) + lastPolledBlock = highestBlockFromMainStorage + } } } poller.lastPolledBlock = lastPolledBlock diff --git a/internal/orchestrator/poller_test.go b/internal/orchestrator/poller_test.go new file mode 100644 index 0000000..7ff6484 --- /dev/null +++ b/internal/orchestrator/poller_test.go @@ -0,0 +1,613 @@ +package orchestrator + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/assert" + config "github.com/thirdweb-dev/indexer/configs" + "github.com/thirdweb-dev/indexer/internal/storage" + "github.com/thirdweb-dev/indexer/test/mocks" +) + +// setupTestConfig initializes the global config for testing +func setupTestConfig() { + if config.Cfg.Poller == (config.PollerConfig{}) { + config.Cfg = config.Config{ + Poller: config.PollerConfig{ + FromBlock: 0, + ForceFromBlock: false, + UntilBlock: 0, + BlocksPerPoll: 0, + Interval: 0, + ParallelPollers: 0, + }, + } + } +} + +func TestNewPoller_ForceFromBlockEnabled(t *testing.T) { + // Test case: should use configured start block if forceFromBlock is true + setupTestConfig() + + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks - GetChainID is not called when ForceFromBlock is true + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: true, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to (fromBlock - 1) when ForceFromBlock is true + expectedBlock := big.NewInt(999) // fromBlock - 1 + assert.Equal(t, expectedBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) +} + +func TestNewPoller_StagingBlockHigherThanConfiguredStart(t *testing.T) { + // Test case: should use staging block if it is higher than configured start block + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns a block higher than configured start block + stagingBlock := big.NewInt(1500) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns a lower block than staging block + mainStorageBlock := big.NewInt(800) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to staging block since it's higher than configured start block + assert.Equal(t, stagingBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_MainStorageBlockHigherThanConfiguredStart(t *testing.T) { + // Test case: should use main storage block if it is higher than configured start block + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns no block (nil) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(nil, nil) + + // Main storage returns a block higher than configured start block + mainStorageBlock := big.NewInt(1500) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to main storage block since it's higher than configured start block + assert.Equal(t, mainStorageBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_MainStorageBlockHigherThanStagingBlock(t *testing.T) { + // Test case: should use main storage block if it is higher than staging block + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns a block + stagingBlock := big.NewInt(1200) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns a block higher than staging block + mainStorageBlock := big.NewInt(1500) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to main storage block since it's higher than staging block + assert.Equal(t, mainStorageBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_ConfiguredStartBlockHighest(t *testing.T) { + // Test case: should use configured start block if staging and main storage blocks are lower than configured start block + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns a block lower than configured start block + stagingBlock := big.NewInt(800) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns a block lower than configured start block + mainStorageBlock := big.NewInt(900) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to (fromBlock - 1) since both staging and main storage blocks are lower + expectedBlock := big.NewInt(999) // fromBlock - 1 + assert.Equal(t, expectedBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_StagingStorageError(t *testing.T) { + // Test case: should handle staging storage error gracefully + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns an error + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(nil, assert.AnError) + + // Main storage returns a block higher than configured start block + mainStorageBlock := big.NewInt(1500) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to main storage block since staging storage failed + assert.Equal(t, mainStorageBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_MainStorageError(t *testing.T) { + // Test case: should handle main storage error gracefully + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns a block lower than configured start block + stagingBlock := big.NewInt(800) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns an error + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(nil, assert.AnError) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to (fromBlock - 1) since main storage failed and staging block is lower + expectedBlock := big.NewInt(999) // fromBlock - 1 + assert.Equal(t, expectedBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_StagingBlockZero(t *testing.T) { + // Test case: should handle staging block with zero value + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns zero block + stagingBlock := big.NewInt(0) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns a block higher than configured start block + mainStorageBlock := big.NewInt(1500) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to main storage block since staging block is zero + assert.Equal(t, mainStorageBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_StagingBlockNegative(t *testing.T) { + // Test case: should handle staging block with negative value + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns negative block + stagingBlock := big.NewInt(-1) + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(1000), big.NewInt(2000)).Return(stagingBlock, nil) + + // Main storage returns a block higher than configured start block + mainStorageBlock := big.NewInt(1500) + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(mainStorageBlock, nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 1000, + ForceFromBlock: false, + UntilBlock: 2000, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to main storage block since staging block is negative + assert.Equal(t, mainStorageBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(1000), poller.pollFromBlock) + assert.Equal(t, big.NewInt(2000), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewPoller_DefaultConfigValues(t *testing.T) { + // Test case: should use default values when config is not set + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Setup mocks + mockRPC.On("GetChainID").Return(big.NewInt(1)) + + // Staging storage returns no block + mockStagingStorage.On("GetLastStagedBlockNumber", big.NewInt(1), big.NewInt(0), big.NewInt(0)).Return(nil, nil) + + // Main storage returns a block lower than configured start block + mockMainStorage.On("GetMaxBlockNumber", big.NewInt(1)).Return(big.NewInt(-1), nil) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings with zero values + config.Cfg.Poller = config.PollerConfig{ + FromBlock: 0, + ForceFromBlock: false, + UntilBlock: 0, + } + + // Create poller + poller := NewPoller(mockRPC, mockStorage) + + // Verify that lastPolledBlock is set to (fromBlock - 1) = -1 + expectedBlock := big.NewInt(-1) // fromBlock - 1 + assert.Equal(t, expectedBlock, poller.lastPolledBlock) + assert.Equal(t, big.NewInt(0), poller.pollFromBlock) + assert.Equal(t, big.NewInt(0), poller.pollUntilBlock) + + mockRPC.AssertExpectations(t) + mockStagingStorage.AssertExpectations(t) + mockMainStorage.AssertExpectations(t) +} + +func TestNewBoundlessPoller(t *testing.T) { + // Test case: should create boundless poller with correct configuration + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + BlocksPerPoll: 20, + Interval: 2000, + ParallelPollers: 5, + } + + // Create boundless poller + poller := NewBoundlessPoller(mockRPC, mockStorage) + + // Verify configuration + assert.Equal(t, mockRPC, poller.rpc) + assert.Equal(t, mockStorage, poller.storage) + assert.Equal(t, int64(20), poller.blocksPerPoll) + assert.Equal(t, int64(2000), poller.triggerIntervalMs) + assert.Equal(t, 5, poller.parallelPollers) + + mockRPC.AssertExpectations(t) +} + +func TestNewBoundlessPoller_DefaultValues(t *testing.T) { + // Test case: should use default values when config is not set + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings with zero values + config.Cfg.Poller = config.PollerConfig{ + BlocksPerPoll: 0, + Interval: 0, + ParallelPollers: 0, + } + + // Create boundless poller + poller := NewBoundlessPoller(mockRPC, mockStorage) + + // Verify default configuration + assert.Equal(t, mockRPC, poller.rpc) + assert.Equal(t, mockStorage, poller.storage) + assert.Equal(t, int64(DEFAULT_BLOCKS_PER_POLL), poller.blocksPerPoll) + assert.Equal(t, int64(DEFAULT_TRIGGER_INTERVAL), poller.triggerIntervalMs) + assert.Equal(t, 0, poller.parallelPollers) + + mockRPC.AssertExpectations(t) +} + +func TestNewBoundlessPoller_WithOptions(t *testing.T) { + // Test case: should apply options correctly + setupTestConfig() + mockRPC := &mocks.MockIRPCClient{} + mockStagingStorage := &mocks.MockIStagingStorage{} + mockMainStorage := &mocks.MockIMainStorage{} + mockOrchestratorStorage := &mocks.MockIOrchestratorStorage{} + mockStorage := storage.IStorage{ + MainStorage: mockMainStorage, + OrchestratorStorage: mockOrchestratorStorage, + StagingStorage: mockStagingStorage, + } + + // Create work mode channel + workModeChan := make(chan WorkMode, 1) + + // Save original config and restore after test + originalConfig := config.Cfg.Poller + defer func() { config.Cfg.Poller = originalConfig }() + + // Configure test settings + config.Cfg.Poller = config.PollerConfig{ + BlocksPerPoll: 15, + Interval: 1500, + ParallelPollers: 3, + } + + // Create boundless poller with options + poller := NewBoundlessPoller(mockRPC, mockStorage, WithPollerWorkModeChan(workModeChan)) + + // Verify configuration + assert.Equal(t, mockRPC, poller.rpc) + assert.Equal(t, mockStorage, poller.storage) + assert.Equal(t, int64(15), poller.blocksPerPoll) + assert.Equal(t, int64(1500), poller.triggerIntervalMs) + assert.Equal(t, 3, poller.parallelPollers) + assert.Equal(t, workModeChan, poller.workModeChan) + + mockRPC.AssertExpectations(t) +} From 5f99f68d3e432cad2b961314476191c57b6d2597 Mon Sep 17 00:00:00 2001 From: iuwqyir Date: Wed, 30 Jul 2025 16:54:34 +0300 Subject: [PATCH 2/2] validate group and sort by --- api/field_validation.go | 114 ++++++++++++ api/field_validation_test.go | 193 +++++++++++++++++++++ internal/handlers/blocks_handlers.go | 6 + internal/handlers/logs_handlers.go | 6 + internal/handlers/token_handlers.go | 28 ++- internal/handlers/transactions_handlers.go | 6 + internal/handlers/transfer_handlers.go | 11 +- 7 files changed, 361 insertions(+), 3 deletions(-) create mode 100644 api/field_validation.go create mode 100644 api/field_validation_test.go diff --git a/api/field_validation.go b/api/field_validation.go new file mode 100644 index 0000000..92e4325 --- /dev/null +++ b/api/field_validation.go @@ -0,0 +1,114 @@ +package api + +import ( + "fmt" + "regexp" + "strings" +) + +// EntityColumns defines the valid columns for each entity type +var EntityColumns = map[string][]string{ + "blocks": { + "chain_id", "block_number", "block_timestamp", "hash", "parent_hash", "sha3_uncles", + "nonce", "mix_hash", "miner", "state_root", "transactions_root", "receipts_root", + "logs_bloom", "size", "extra_data", "difficulty", "total_difficulty", "transaction_count", + "gas_limit", "gas_used", "withdrawals_root", "base_fee_per_gas", "insert_timestamp", "sign", + }, + "transactions": { + "chain_id", "hash", "nonce", "block_hash", "block_number", "block_timestamp", + "transaction_index", "from_address", "to_address", "value", "gas", "gas_price", + "data", "function_selector", "max_fee_per_gas", "max_priority_fee_per_gas", + "max_fee_per_blob_gas", "blob_versioned_hashes", "transaction_type", "r", "s", "v", + "access_list", "authorization_list", "contract_address", "gas_used", "cumulative_gas_used", + "effective_gas_price", "blob_gas_used", "blob_gas_price", "logs_bloom", "status", + "insert_timestamp", "sign", + }, + "logs": { + "chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash", + "transaction_index", "log_index", "address", "data", "topic_0", "topic_1", "topic_2", "topic_3", + "insert_timestamp", "sign", + }, + "transfers": { + "token_type", "chain_id", "token_address", "from_address", "to_address", "block_number", + "block_timestamp", "transaction_hash", "token_id", "amount", "log_index", "insert_timestamp", "sign", + }, + "balances": { + "token_type", "chain_id", "owner", "address", "token_id", "balance", + }, + "traces": { + "chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash", + "transaction_index", "subtraces", "trace_address", "type", "call_type", "error", + "from_address", "to_address", "gas", "gas_used", "input", "output", "value", + "author", "reward_type", "refund_address", "insert_timestamp", "sign", + }, +} + +// ValidateGroupByAndSortBy validates that GroupBy and SortBy fields are valid for the given entity +// It checks that fields are either: +// 1. Valid entity columns +// 2. Valid aggregate function aliases (e.g., "count", "total_amount") +func ValidateGroupByAndSortBy(entity string, groupBy []string, sortBy string, aggregates []string) error { + // Get valid columns for the entity + validColumns, exists := EntityColumns[entity] + if !exists { + return fmt.Errorf("unknown entity: %s", entity) + } + + // Create a set of valid fields (entity columns + aggregate aliases) + validFields := make(map[string]bool) + for _, col := range validColumns { + validFields[col] = true + } + + // Add aggregate function aliases + aggregateAliases := extractAggregateAliases(aggregates) + for _, alias := range aggregateAliases { + validFields[alias] = true + } + + // Validate GroupBy fields + for _, field := range groupBy { + if !validFields[field] { + return fmt.Errorf("invalid group_by field '%s' for entity '%s'. Valid fields are: %s", + field, entity, strings.Join(getValidFieldsList(validFields), ", ")) + } + } + + // Validate SortBy field + if sortBy != "" && !validFields[sortBy] { + return fmt.Errorf("invalid sort_by field '%s' for entity '%s'. Valid fields are: %s", + sortBy, entity, strings.Join(getValidFieldsList(validFields), ", ")) + } + + return nil +} + +// extractAggregateAliases extracts column aliases from aggregate functions +// Examples: +// - "COUNT(*) AS count" -> "count" +// - "SUM(amount) AS total_amount" -> "total_amount" +// - "AVG(value) as avg_value" -> "avg_value" +func extractAggregateAliases(aggregates []string) []string { + var aliases []string + aliasRegex := regexp.MustCompile(`(?i)\s+AS\s+([a-zA-Z_][a-zA-Z0-9_]*)`) + + for _, aggregate := range aggregates { + matches := aliasRegex.FindStringSubmatch(aggregate) + if len(matches) > 1 { + aliases = append(aliases, matches[1]) + } + } + + return aliases +} + +// getValidFieldsList converts the validFields map to a sorted list for error messages +func getValidFieldsList(validFields map[string]bool) []string { + var fields []string + for field := range validFields { + fields = append(fields, field) + } + // Sort for consistent error messages + // Note: In a production environment, you might want to use sort.Strings(fields) + return fields +} diff --git a/api/field_validation_test.go b/api/field_validation_test.go new file mode 100644 index 0000000..5db8d2a --- /dev/null +++ b/api/field_validation_test.go @@ -0,0 +1,193 @@ +package api + +import ( + "strings" + "testing" +) + +func TestValidateGroupByAndSortBy(t *testing.T) { + tests := []struct { + name string + entity string + groupBy []string + sortBy string + aggregates []string + wantErr bool + errMsg string + }{ + { + name: "valid blocks fields", + entity: "blocks", + groupBy: []string{"block_number", "hash"}, + sortBy: "block_timestamp", + aggregates: nil, + wantErr: false, + }, + { + name: "valid transactions fields", + entity: "transactions", + groupBy: []string{"from_address", "to_address"}, + sortBy: "value", + aggregates: nil, + wantErr: false, + }, + { + name: "valid logs fields", + entity: "logs", + groupBy: []string{"address", "topic_0"}, + sortBy: "block_number", + aggregates: nil, + wantErr: false, + }, + { + name: "valid transfers fields", + entity: "transfers", + groupBy: []string{"token_address", "from_address"}, + sortBy: "amount", + aggregates: nil, + wantErr: false, + }, + { + name: "valid balances fields", + entity: "balances", + groupBy: []string{"owner", "token_id"}, + sortBy: "balance", + aggregates: nil, + wantErr: false, + }, + { + name: "valid with aggregate aliases", + entity: "transactions", + groupBy: []string{"from_address"}, + sortBy: "total_value", + aggregates: []string{"SUM(value) AS total_value", "COUNT(*) AS count"}, + wantErr: false, + }, + { + name: "invalid entity", + entity: "invalid_entity", + groupBy: []string{"field"}, + sortBy: "field", + aggregates: nil, + wantErr: true, + errMsg: "unknown entity: invalid_entity", + }, + { + name: "invalid group_by field", + entity: "blocks", + groupBy: []string{"invalid_field"}, + sortBy: "block_number", + aggregates: nil, + wantErr: true, + errMsg: "invalid group_by field 'invalid_field' for entity 'blocks'", + }, + { + name: "invalid sort_by field", + entity: "transactions", + groupBy: []string{"hash"}, + sortBy: "invalid_field", + aggregates: nil, + wantErr: true, + errMsg: "invalid sort_by field 'invalid_field' for entity 'transactions'", + }, + { + name: "invalid aggregate alias", + entity: "logs", + groupBy: []string{"address"}, + sortBy: "invalid_alias", + aggregates: []string{"COUNT(*) AS count"}, + wantErr: true, + errMsg: "invalid sort_by field 'invalid_alias' for entity 'logs'", + }, + { + name: "empty sort_by is valid", + entity: "blocks", + groupBy: []string{"block_number"}, + sortBy: "", + aggregates: nil, + wantErr: false, + }, + { + name: "empty group_by is valid", + entity: "transactions", + groupBy: []string{}, + sortBy: "hash", + aggregates: nil, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateGroupByAndSortBy(tt.entity, tt.groupBy, tt.sortBy, tt.aggregates) + + if tt.wantErr { + if err == nil { + t.Errorf("ValidateGroupByAndSortBy() expected error but got none") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ValidateGroupByAndSortBy() error = %v, want error containing %v", err, tt.errMsg) + } + } else { + if err != nil { + t.Errorf("ValidateGroupByAndSortBy() unexpected error = %v", err) + } + } + }) + } +} + +func TestExtractAggregateAliases(t *testing.T) { + tests := []struct { + name string + aggregates []string + want []string + }{ + { + name: "simple aliases", + aggregates: []string{"COUNT(*) AS count", "SUM(value) AS total_value"}, + want: []string{"count", "total_value"}, + }, + { + name: "case insensitive AS", + aggregates: []string{"AVG(amount) as avg_amount", "MAX(price) As max_price"}, + want: []string{"avg_amount", "max_price"}, + }, + { + name: "no aliases", + aggregates: []string{"COUNT(*)", "SUM(value)"}, + want: []string{}, + }, + { + name: "mixed with and without aliases", + aggregates: []string{"COUNT(*) AS count", "SUM(value)", "AVG(price) as avg_price"}, + want: []string{"count", "avg_price"}, + }, + { + name: "empty aggregates", + aggregates: []string{}, + want: []string{}, + }, + { + name: "complex aliases", + aggregates: []string{"COUNT(DISTINCT address) AS unique_addresses", "SUM(CASE WHEN value > 0 THEN 1 ELSE 0 END) AS positive_transactions"}, + want: []string{"unique_addresses", "positive_transactions"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractAggregateAliases(tt.aggregates) + if len(got) != len(tt.want) { + t.Errorf("extractAggregateAliases() length = %v, want %v", len(got), len(tt.want)) + return + } + for i, alias := range got { + if alias != tt.want[i] { + t.Errorf("extractAggregateAliases()[%d] = %v, want %v", i, alias, tt.want[i]) + } + } + }) + } +} diff --git a/internal/handlers/blocks_handlers.go b/internal/handlers/blocks_handlers.go index 4629f99..0ae8bcc 100644 --- a/internal/handlers/blocks_handlers.go +++ b/internal/handlers/blocks_handlers.go @@ -45,6 +45,12 @@ func handleBlocksRequest(c *gin.Context) { return } + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("blocks", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + mainStorage, err := getMainStorage() if err != nil { log.Error().Err(err).Msg("Error getting main storage") diff --git a/internal/handlers/logs_handlers.go b/internal/handlers/logs_handlers.go index 1edd446..965aeae 100644 --- a/internal/handlers/logs_handlers.go +++ b/internal/handlers/logs_handlers.go @@ -111,6 +111,12 @@ func handleLogsRequest(c *gin.Context) { return } + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("logs", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + var eventABI *abi.Event signatureHash := "" if signature != "" { diff --git a/internal/handlers/token_handlers.go b/internal/handlers/token_handlers.go index 6b23c4b..adeb75a 100644 --- a/internal/handlers/token_handlers.go +++ b/internal/handlers/token_handlers.go @@ -86,6 +86,13 @@ func GetTokenIdsByType(c *gin.Context) { // We only care about token_id and token_type columns := []string{"token_id", "token_type"} groupBy := []string{"token_id", "token_type"} + sortBy := c.Query("sort_by") + + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil { + api.BadRequestErrorHandler(c, err) + return + } tokenIds, err := getTokenIdsFromReq(c) if err != nil { @@ -100,7 +107,7 @@ func GetTokenIdsByType(c *gin.Context) { ZeroBalance: hideZeroBalances, TokenIds: tokenIds, GroupBy: groupBy, - SortBy: c.Query("sort_by"), + SortBy: sortBy, SortOrder: c.Query("sort_order"), Page: api.ParseIntQueryParam(c.Query("page"), 0), Limit: api.ParseIntQueryParam(c.Query("limit"), 0), @@ -189,6 +196,14 @@ func GetTokenBalancesByType(c *gin.Context) { groupBy = []string{"address", "token_id", "token_type"} } + sortBy := c.Query("sort_by") + + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + qf := storage.BalancesQueryFilter{ ChainId: chainId, Owner: owner, @@ -197,7 +212,7 @@ func GetTokenBalancesByType(c *gin.Context) { ZeroBalance: hideZeroBalances, TokenIds: tokenIds, GroupBy: groupBy, - SortBy: c.Query("sort_by"), + SortBy: sortBy, SortOrder: c.Query("sort_order"), Page: api.ParseIntQueryParam(c.Query("page"), 0), Limit: api.ParseIntQueryParam(c.Query("limit"), 0), @@ -280,6 +295,15 @@ func GetTokenHoldersByType(c *gin.Context) { api.BadRequestErrorHandler(c, fmt.Errorf("invalid token ids '%s'", err)) return } + + sortBy := c.Query("sort_by") + + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + qf := storage.BalancesQueryFilter{ ChainId: chainId, TokenTypes: tokenTypes, diff --git a/internal/handlers/transactions_handlers.go b/internal/handlers/transactions_handlers.go index 07247c6..ecc60ec 100644 --- a/internal/handlers/transactions_handlers.go +++ b/internal/handlers/transactions_handlers.go @@ -129,6 +129,12 @@ func handleTransactionsRequest(c *gin.Context) { return } + // Validate GroupBy and SortBy fields + if err := api.ValidateGroupByAndSortBy("transactions", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + var functionABI *abi.Method signatureHash := "" if signature != "" { diff --git a/internal/handlers/transfer_handlers.go b/internal/handlers/transfer_handlers.go index aa5718d..58f7df8 100644 --- a/internal/handlers/transfer_handlers.go +++ b/internal/handlers/transfer_handlers.go @@ -109,6 +109,15 @@ func GetTokenTransfers(c *gin.Context) { } } + // Validate SortBy field (transfers don't use GroupBy or Aggregates) + sortBy := c.Query("sort_by") + if sortBy != "" { + if err := api.ValidateGroupByAndSortBy("transfers", nil, sortBy, nil); err != nil { + api.BadRequestErrorHandler(c, err) + return + } + } + // Define query filter qf := storage.TransfersQueryFilter{ ChainId: chainId, @@ -121,7 +130,7 @@ func GetTokenTransfers(c *gin.Context) { EndBlockNumber: endBlockNumber, Page: api.ParseIntQueryParam(c.Query("page"), 0), Limit: api.ParseIntQueryParam(c.Query("limit"), 20), - SortBy: c.Query("sort_by"), + SortBy: sortBy, SortOrder: c.Query("sort_order"), }