Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions api/field_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package api

import (
"fmt"
"regexp"
"strings"
)

// EntityColumns defines the valid columns for each entity type
var EntityColumns = map[string][]string{
"blocks": {
"chain_id", "block_number", "block_timestamp", "hash", "parent_hash", "sha3_uncles",
"nonce", "mix_hash", "miner", "state_root", "transactions_root", "receipts_root",
"logs_bloom", "size", "extra_data", "difficulty", "total_difficulty", "transaction_count",
"gas_limit", "gas_used", "withdrawals_root", "base_fee_per_gas", "insert_timestamp", "sign",
},
"transactions": {
"chain_id", "hash", "nonce", "block_hash", "block_number", "block_timestamp",
"transaction_index", "from_address", "to_address", "value", "gas", "gas_price",
"data", "function_selector", "max_fee_per_gas", "max_priority_fee_per_gas",
"max_fee_per_blob_gas", "blob_versioned_hashes", "transaction_type", "r", "s", "v",
"access_list", "authorization_list", "contract_address", "gas_used", "cumulative_gas_used",
"effective_gas_price", "blob_gas_used", "blob_gas_price", "logs_bloom", "status",
"insert_timestamp", "sign",
},
"logs": {
"chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash",
"transaction_index", "log_index", "address", "data", "topic_0", "topic_1", "topic_2", "topic_3",
"insert_timestamp", "sign",
},
"transfers": {
"token_type", "chain_id", "token_address", "from_address", "to_address", "block_number",
"block_timestamp", "transaction_hash", "token_id", "amount", "log_index", "insert_timestamp", "sign",
},
"balances": {
"token_type", "chain_id", "owner", "address", "token_id", "balance",
},
"traces": {
"chain_id", "block_number", "block_hash", "block_timestamp", "transaction_hash",
"transaction_index", "subtraces", "trace_address", "type", "call_type", "error",
"from_address", "to_address", "gas", "gas_used", "input", "output", "value",
"author", "reward_type", "refund_address", "insert_timestamp", "sign",
},
}

// ValidateGroupByAndSortBy validates that GroupBy and SortBy fields are valid for the given entity
// It checks that fields are either:
// 1. Valid entity columns
// 2. Valid aggregate function aliases (e.g., "count", "total_amount")
func ValidateGroupByAndSortBy(entity string, groupBy []string, sortBy string, aggregates []string) error {
// Get valid columns for the entity
validColumns, exists := EntityColumns[entity]
if !exists {
return fmt.Errorf("unknown entity: %s", entity)
}

// Create a set of valid fields (entity columns + aggregate aliases)
validFields := make(map[string]bool)
for _, col := range validColumns {
validFields[col] = true
}

// Add aggregate function aliases
aggregateAliases := extractAggregateAliases(aggregates)
for _, alias := range aggregateAliases {
validFields[alias] = true
}

// Validate GroupBy fields
for _, field := range groupBy {
if !validFields[field] {
return fmt.Errorf("invalid group_by field '%s' for entity '%s'. Valid fields are: %s",
field, entity, strings.Join(getValidFieldsList(validFields), ", "))
}
}

// Validate SortBy field
if sortBy != "" && !validFields[sortBy] {
return fmt.Errorf("invalid sort_by field '%s' for entity '%s'. Valid fields are: %s",
sortBy, entity, strings.Join(getValidFieldsList(validFields), ", "))
}

return nil
}

// extractAggregateAliases extracts column aliases from aggregate functions
// Examples:
// - "COUNT(*) AS count" -> "count"
// - "SUM(amount) AS total_amount" -> "total_amount"
// - "AVG(value) as avg_value" -> "avg_value"
func extractAggregateAliases(aggregates []string) []string {
var aliases []string
aliasRegex := regexp.MustCompile(`(?i)\s+AS\s+([a-zA-Z_][a-zA-Z0-9_]*)`)

for _, aggregate := range aggregates {
matches := aliasRegex.FindStringSubmatch(aggregate)
if len(matches) > 1 {
aliases = append(aliases, matches[1])
}
}

return aliases
}

// getValidFieldsList converts the validFields map to a sorted list for error messages
func getValidFieldsList(validFields map[string]bool) []string {
var fields []string
for field := range validFields {
fields = append(fields, field)
}
// Sort for consistent error messages
// Note: In a production environment, you might want to use sort.Strings(fields)
return fields
}
193 changes: 193 additions & 0 deletions api/field_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package api

import (
"strings"
"testing"
)

func TestValidateGroupByAndSortBy(t *testing.T) {
tests := []struct {
name string
entity string
groupBy []string
sortBy string
aggregates []string
wantErr bool
errMsg string
}{
{
name: "valid blocks fields",
entity: "blocks",
groupBy: []string{"block_number", "hash"},
sortBy: "block_timestamp",
aggregates: nil,
wantErr: false,
},
{
name: "valid transactions fields",
entity: "transactions",
groupBy: []string{"from_address", "to_address"},
sortBy: "value",
aggregates: nil,
wantErr: false,
},
{
name: "valid logs fields",
entity: "logs",
groupBy: []string{"address", "topic_0"},
sortBy: "block_number",
aggregates: nil,
wantErr: false,
},
{
name: "valid transfers fields",
entity: "transfers",
groupBy: []string{"token_address", "from_address"},
sortBy: "amount",
aggregates: nil,
wantErr: false,
},
{
name: "valid balances fields",
entity: "balances",
groupBy: []string{"owner", "token_id"},
sortBy: "balance",
aggregates: nil,
wantErr: false,
},
{
name: "valid with aggregate aliases",
entity: "transactions",
groupBy: []string{"from_address"},
sortBy: "total_value",
aggregates: []string{"SUM(value) AS total_value", "COUNT(*) AS count"},
wantErr: false,
},
{
name: "invalid entity",
entity: "invalid_entity",
groupBy: []string{"field"},
sortBy: "field",
aggregates: nil,
wantErr: true,
errMsg: "unknown entity: invalid_entity",
},
{
name: "invalid group_by field",
entity: "blocks",
groupBy: []string{"invalid_field"},
sortBy: "block_number",
aggregates: nil,
wantErr: true,
errMsg: "invalid group_by field 'invalid_field' for entity 'blocks'",
},
{
name: "invalid sort_by field",
entity: "transactions",
groupBy: []string{"hash"},
sortBy: "invalid_field",
aggregates: nil,
wantErr: true,
errMsg: "invalid sort_by field 'invalid_field' for entity 'transactions'",
},
{
name: "invalid aggregate alias",
entity: "logs",
groupBy: []string{"address"},
sortBy: "invalid_alias",
aggregates: []string{"COUNT(*) AS count"},
wantErr: true,
errMsg: "invalid sort_by field 'invalid_alias' for entity 'logs'",
},
{
name: "empty sort_by is valid",
entity: "blocks",
groupBy: []string{"block_number"},
sortBy: "",
aggregates: nil,
wantErr: false,
},
{
name: "empty group_by is valid",
entity: "transactions",
groupBy: []string{},
sortBy: "hash",
aggregates: nil,
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateGroupByAndSortBy(tt.entity, tt.groupBy, tt.sortBy, tt.aggregates)

if tt.wantErr {
if err == nil {
t.Errorf("ValidateGroupByAndSortBy() expected error but got none")
return
}
if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) {
t.Errorf("ValidateGroupByAndSortBy() error = %v, want error containing %v", err, tt.errMsg)
}
} else {
if err != nil {
t.Errorf("ValidateGroupByAndSortBy() unexpected error = %v", err)
}
}
})
}
}

func TestExtractAggregateAliases(t *testing.T) {
tests := []struct {
name string
aggregates []string
want []string
}{
{
name: "simple aliases",
aggregates: []string{"COUNT(*) AS count", "SUM(value) AS total_value"},
want: []string{"count", "total_value"},
},
{
name: "case insensitive AS",
aggregates: []string{"AVG(amount) as avg_amount", "MAX(price) As max_price"},
want: []string{"avg_amount", "max_price"},
},
{
name: "no aliases",
aggregates: []string{"COUNT(*)", "SUM(value)"},
want: []string{},
},
{
name: "mixed with and without aliases",
aggregates: []string{"COUNT(*) AS count", "SUM(value)", "AVG(price) as avg_price"},
want: []string{"count", "avg_price"},
},
{
name: "empty aggregates",
aggregates: []string{},
want: []string{},
},
{
name: "complex aliases",
aggregates: []string{"COUNT(DISTINCT address) AS unique_addresses", "SUM(CASE WHEN value > 0 THEN 1 ELSE 0 END) AS positive_transactions"},
want: []string{"unique_addresses", "positive_transactions"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractAggregateAliases(tt.aggregates)
if len(got) != len(tt.want) {
t.Errorf("extractAggregateAliases() length = %v, want %v", len(got), len(tt.want))
return
}
for i, alias := range got {
if alias != tt.want[i] {
t.Errorf("extractAggregateAliases()[%d] = %v, want %v", i, alias, tt.want[i])
}
}
})
}
}
6 changes: 6 additions & 0 deletions internal/handlers/blocks_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ func handleBlocksRequest(c *gin.Context) {
return
}

// Validate GroupBy and SortBy fields
if err := api.ValidateGroupByAndSortBy("blocks", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil {
api.BadRequestErrorHandler(c, err)
return
}

mainStorage, err := getMainStorage()
if err != nil {
log.Error().Err(err).Msg("Error getting main storage")
Expand Down
6 changes: 6 additions & 0 deletions internal/handlers/logs_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ func handleLogsRequest(c *gin.Context) {
return
}

// Validate GroupBy and SortBy fields
if err := api.ValidateGroupByAndSortBy("logs", queryParams.GroupBy, queryParams.SortBy, queryParams.Aggregates); err != nil {
api.BadRequestErrorHandler(c, err)
return
}

var eventABI *abi.Event
signatureHash := ""
if signature != "" {
Expand Down
28 changes: 26 additions & 2 deletions internal/handlers/token_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ func GetTokenIdsByType(c *gin.Context) {
// We only care about token_id and token_type
columns := []string{"token_id", "token_type"}
groupBy := []string{"token_id", "token_type"}
sortBy := c.Query("sort_by")

// Validate GroupBy and SortBy fields
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
api.BadRequestErrorHandler(c, err)
return
}

tokenIds, err := getTokenIdsFromReq(c)
if err != nil {
Expand All @@ -100,7 +107,7 @@ func GetTokenIdsByType(c *gin.Context) {
ZeroBalance: hideZeroBalances,
TokenIds: tokenIds,
GroupBy: groupBy,
SortBy: c.Query("sort_by"),
SortBy: sortBy,
SortOrder: c.Query("sort_order"),
Page: api.ParseIntQueryParam(c.Query("page"), 0),
Limit: api.ParseIntQueryParam(c.Query("limit"), 0),
Expand Down Expand Up @@ -189,6 +196,14 @@ func GetTokenBalancesByType(c *gin.Context) {
groupBy = []string{"address", "token_id", "token_type"}
}

sortBy := c.Query("sort_by")

// Validate GroupBy and SortBy fields
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
api.BadRequestErrorHandler(c, err)
return
}

qf := storage.BalancesQueryFilter{
ChainId: chainId,
Owner: owner,
Expand All @@ -197,7 +212,7 @@ func GetTokenBalancesByType(c *gin.Context) {
ZeroBalance: hideZeroBalances,
TokenIds: tokenIds,
GroupBy: groupBy,
SortBy: c.Query("sort_by"),
SortBy: sortBy,
SortOrder: c.Query("sort_order"),
Page: api.ParseIntQueryParam(c.Query("page"), 0),
Limit: api.ParseIntQueryParam(c.Query("limit"), 0),
Expand Down Expand Up @@ -280,6 +295,15 @@ func GetTokenHoldersByType(c *gin.Context) {
api.BadRequestErrorHandler(c, fmt.Errorf("invalid token ids '%s'", err))
return
}

sortBy := c.Query("sort_by")

// Validate GroupBy and SortBy fields
if err := api.ValidateGroupByAndSortBy("balances", groupBy, sortBy, nil); err != nil {
api.BadRequestErrorHandler(c, err)
return
}

qf := storage.BalancesQueryFilter{
ChainId: chainId,
TokenTypes: tokenTypes,
Expand Down
Loading