From 5d12f07ec38c470b57c00125699cac33b562c8d3 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 16:17:14 +0530 Subject: [PATCH 01/28] refactor: api package --- universalClient/api/handlers.go | 5 ++- universalClient/api/handlers_test.go | 1 + universalClient/api/server.go | 49 ++++++++++++-------------- universalClient/api/server_test.go | 51 +++++----------------------- 4 files changed, 35 insertions(+), 71 deletions(-) diff --git a/universalClient/api/handlers.go b/universalClient/api/handlers.go index 7a594e62..4f2cc0fa 100644 --- a/universalClient/api/handlers.go +++ b/universalClient/api/handlers.go @@ -4,6 +4,9 @@ import "net/http" // handleHealth handles GET /health func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte("OK")) + if _, err := w.Write([]byte("OK")); err != nil { + s.logger.Error().Err(err).Msg("Failed to write health response") + } } diff --git a/universalClient/api/handlers_test.go b/universalClient/api/handlers_test.go index 197c0cb0..63ee4603 100644 --- a/universalClient/api/handlers_test.go +++ b/universalClient/api/handlers_test.go @@ -23,5 +23,6 @@ func TestHandleHealth(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "OK", w.Body.String()) + assert.Equal(t, "text/plain", w.Header().Get("Content-Type")) }) } diff --git a/universalClient/api/server.go b/universalClient/api/server.go index a98a496c..7abe2723 100644 --- a/universalClient/api/server.go +++ b/universalClient/api/server.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "net" "net/http" @@ -11,8 +12,9 @@ import ( // Server provides HTTP endpoints type Server struct { - logger zerolog.Logger - server *http.Server + logger zerolog.Logger + server *http.Server + listener net.Listener } // NewServer creates a new Server instance @@ -37,24 +39,14 @@ func (s *Server) Start() error { return fmt.Errorf("query server is nil") } - // Channel to signal server startup result - startupChan := make(chan error, 1) + ln, err := net.Listen("tcp", s.server.Addr) + if err != nil { + return fmt.Errorf("failed to bind to address %s: %w", s.server.Addr, err) + } + s.listener = ln - // Start server in goroutine go func() { - // Create a test listener to verify the port is available - ln, err := net.Listen("tcp", s.server.Addr) - if err != nil { - startupChan <- fmt.Errorf("failed to bind to address %s: %w", s.server.Addr, err) - return - } - ln.Close() - - // Signal successful startup check - startupChan <- nil - - // Now start the actual server - err = s.server.ListenAndServe() + err := s.server.Serve(ln) switch err { case nil: s.logger.Info().Msg("Query server stopped normally") @@ -65,22 +57,23 @@ func (s *Server) Start() error { } }() - // Wait for startup result with timeout - select { - case err := <-startupChan: - if err != nil { - return err - } - return nil - case <-time.After(5 * time.Second): - return fmt.Errorf("server startup timeout") + return nil +} + +// Addr returns the listener address, useful when started on port 0 +func (s *Server) Addr() string { + if s.listener != nil { + return s.listener.Addr().String() } + return "" } // Stop gracefully shuts down the HTTP server func (s *Server) Stop() error { if s.server != nil { - return s.server.Close() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.server.Shutdown(ctx) } return nil } diff --git a/universalClient/api/server_test.go b/universalClient/api/server_test.go index 65c44b0b..09eeef1b 100644 --- a/universalClient/api/server_test.go +++ b/universalClient/api/server_test.go @@ -1,10 +1,9 @@ package api import ( + "fmt" "net/http" - "net/http/httptest" "testing" - "time" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" @@ -34,19 +33,13 @@ func TestServerStartStop(t *testing.T) { logger := zerolog.New(zerolog.NewTestWriter(t)) t.Run("Start and stop server", func(t *testing.T) { - // Use a random port to avoid conflicts server := NewServer(logger, 0) - // Start server err := server.Start() require.NoError(t, err) + defer server.Stop() - // Give server time to start - time.Sleep(200 * time.Millisecond) - - // Stop server - err = server.Stop() - assert.NoError(t, err) + assert.NotEmpty(t, server.Addr()) }) t.Run("Start with nil server", func(t *testing.T) { @@ -73,43 +66,17 @@ func TestServerIntegration(t *testing.T) { logger := zerolog.New(zerolog.NewTestWriter(t)) t.Run("Server lifecycle with HTTP client", func(t *testing.T) { - // Create server on a specific port - server := NewServer(logger, 18080) + server := NewServer(logger, 0) - // Start server err := server.Start() require.NoError(t, err) defer server.Stop() - // Wait for server to be ready - time.Sleep(200 * time.Millisecond) - - // Test health endpoint - resp, err := http.Get("http://localhost:18080/health") - if err == nil { - defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) - } - }) -} - -// Test handler functions directly using httptest -func TestHealthHandler(t *testing.T) { - logger := zerolog.New(zerolog.NewTestWriter(t)) - server := &Server{ - logger: logger, - } - - handler := http.HandlerFunc(server.handleHealth) - - t.Run("Health check returns OK", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/health", nil) - w := httptest.NewRecorder() - - handler(w, req) + resp, err := http.Get(fmt.Sprintf("http://%s/health", server.Addr())) + require.NoError(t, err) + defer resp.Body.Close() - assert.Equal(t, http.StatusOK, w.Code) - // handleHealth just returns "OK" as plain text - assert.Equal(t, "OK", w.Body.String()) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/plain", resp.Header.Get("Content-Type")) }) } From 93ae02fbd7a3a505552a906b4caf3c81b9294a43 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 16:20:51 +0530 Subject: [PATCH 02/28] handle worng level passing --- universalClient/logger/logger.go | 7 ++++++- universalClient/logger/logger_test.go | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/universalClient/logger/logger.go b/universalClient/logger/logger.go index ca97defe..827b0f81 100644 --- a/universalClient/logger/logger.go +++ b/universalClient/logger/logger.go @@ -11,6 +11,11 @@ import ( // New creates a new zerolog logger with the specified configuration. // Supports console/json format, level filtering, and optional sampling. func New(logLevel int, logFormat string, logSampler bool) zerolog.Logger { + level := zerolog.Level(logLevel) + if level < zerolog.TraceLevel || level > zerolog.Disabled { + level = zerolog.InfoLevel + } + var writer io.Writer = os.Stdout if logFormat != "json" { writer = zerolog.ConsoleWriter{ @@ -20,7 +25,7 @@ func New(logLevel int, logFormat string, logSampler bool) zerolog.Logger { } logger := zerolog.New(writer). - Level(zerolog.Level(logLevel)). + Level(level). With(). Timestamp(). Logger() diff --git a/universalClient/logger/logger_test.go b/universalClient/logger/logger_test.go index 6b1fdae5..7f1480fc 100644 --- a/universalClient/logger/logger_test.go +++ b/universalClient/logger/logger_test.go @@ -53,6 +53,29 @@ func TestNewVariants(t *testing.T) { require.Contains(t, logOutput, "env=test") }) + t.Run("invalid log level falls back to info", func(t *testing.T) { + r, w, _ := os.Pipe() + defer r.Close() + + stdout := os.Stdout + os.Stdout = w + defer func() { os.Stdout = stdout }() + + logger := New(99, "json", false) + + // Debug should be filtered out at info level + logger.Debug().Msg("should_not_appear") + logger.Info().Msg("should_appear") + + _ = w.Close() + buf := make([]byte, 1024) + n, _ := r.Read(buf) + + logOutput := string(buf[:n]) + require.NotContains(t, logOutput, "should_not_appear") + require.Contains(t, logOutput, "should_appear") + }) + t.Run("sampler reduces output frequency", func(t *testing.T) { r, w, _ := os.Pipe() defer r.Close() From 6896c0b3ce0499e9c90df8d9dd5a7c17d82fdfce Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 16:23:55 +0530 Subject: [PATCH 03/28] refactor: remove deprecated package & dead code --- universalClient/db/db.go | 57 +++++++++++++---------------------- universalClient/db/db_test.go | 44 +++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/universalClient/db/db.go b/universalClient/db/db.go index aa598fee..a2e5d83b 100644 --- a/universalClient/db/db.go +++ b/universalClient/db/db.go @@ -6,9 +6,9 @@ package db import ( "fmt" "os" + "path/filepath" "strings" - "github.com/pkg/errors" "github.com/pushchain/push-chain-node/universalClient/store" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -23,19 +23,18 @@ const ( dbDirPermissions = 0o750 ) -var ( - // gormConfig disables logging output for cleaner usage in validator processes. - gormConfig = &gorm.Config{ +func newGormConfig() *gorm.Config { + return &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), } +} - // schemaModels lists the structs to be auto-migrated into the database. - schemaModels = []any{ +func schemaModels() []any { + return []any{ &store.State{}, &store.Event{}, - // Add additional models here as needed. } -) +} // DB wraps a GORM client and provides simplified DB lifecycle management. type DB struct { @@ -47,7 +46,7 @@ type DB struct { func OpenFileDB(dir, filename string, migrateSchema bool) (*DB, error) { dsn, err := prepareFilePath(dir, filename) if err != nil { - return nil, errors.Wrap(err, "failed to prepare database path") + return nil, fmt.Errorf("failed to prepare database path: %w", err) } return openSQLite(dsn, migrateSchema) } @@ -67,21 +66,21 @@ func openSQLite(dsn string, migrateSchema bool) (*DB, error) { dsn += "?_journal_mode=WAL&_busy_timeout=5000&cache=shared&mode=rwc" } - db, err := gorm.Open(sqlite.Open(dsn), gormConfig) + db, err := gorm.Open(sqlite.Open(dsn), newGormConfig()) if err != nil { - return nil, errors.Wrap(err, "failed to open SQLite database") + return nil, fmt.Errorf("failed to open SQLite database: %w", err) } if migrateSchema { - if err := db.AutoMigrate(schemaModels...); err != nil { - return nil, errors.Wrap(err, "failed to auto-migrate database schema") + if err := db.AutoMigrate(schemaModels()...); err != nil { + return nil, fmt.Errorf("failed to auto-migrate database schema: %w", err) } } // Configure connection pool for better concurrent access sqlDB, err := db.DB() if err != nil { - return nil, errors.Wrap(err, "failed to get underlying sql.DB") + return nil, fmt.Errorf("failed to get underlying sql.DB: %w", err) } // Configure connection pool based on database type @@ -98,11 +97,9 @@ func openSQLite(dsn string, migrateSchema bool) (*DB, error) { // Set maximum lifetime of a connection sqlDB.SetConnMaxLifetime(0) // Connections don't expire - // Apply SQLite performance optimizations + // Apply SQLite performance optimizations for file-based databases if err := optimizeSQLiteSettings(db, dsn); err != nil { - // Log warning but don't fail startup - these are performance optimizations - // The database will still work with defaults - fmt.Printf("Warning: Failed to apply SQLite optimizations: %v\n", err) + return nil, fmt.Errorf("failed to apply SQLite optimizations: %w", err) } return &DB{client: db}, nil @@ -125,7 +122,7 @@ func optimizeSQLiteSettings(db *gorm.DB, dsn string) error { for _, pragma := range pragmas { if err := db.Exec(pragma).Error; err != nil { - return errors.Wrapf(err, "failed to execute %s", pragma) + return fmt.Errorf("failed to execute %s: %w", pragma, err) } } @@ -137,41 +134,29 @@ func (d *DB) Client() *gorm.DB { return d.client } -// SetupDBForTesting sets the internal GORM client for testing purposes. -// This should only be used in test files. -func (d *DB) SetupDBForTesting(client *gorm.DB) { - d.client = client -} - // Close safely closes the underlying database connection. func (d *DB) Close() error { sqlDB, err := d.client.DB() if err != nil { - return errors.Wrap(err, "failed to retrieve native sql.DB") + return fmt.Errorf("failed to retrieve native sql.DB: %w", err) } if err := sqlDB.Close(); err != nil { - return errors.Wrap(err, "failed to close database connection") + return fmt.Errorf("failed to close database connection: %w", err) } return nil } // prepareFilePath ensures the target directory exists and returns the full database file path. -// If the directory contains the in-memory DSN string, it is returned as-is. func prepareFilePath(dir, filename string) (string, error) { - if strings.Contains(dir, InMemorySQLiteDSN) { - return dir, nil - } - - // Ensure the directory exists if _, err := os.Stat(dir); os.IsNotExist(err) { if err := os.MkdirAll(dir, dbDirPermissions); err != nil { - return "", errors.Wrapf(err, "failed to create directory: %s", dir) + return "", fmt.Errorf("failed to create directory %s: %w", dir, err) } } else if err != nil { - return "", errors.Wrap(err, "error checking directory") + return "", fmt.Errorf("error checking directory: %w", err) } - return fmt.Sprintf("%s/%s", dir, filename), nil + return filepath.Join(dir, filename), nil } diff --git a/universalClient/db/db_test.go b/universalClient/db/db_test.go index 3e1f46ba..6e0561f6 100644 --- a/universalClient/db/db_test.go +++ b/universalClient/db/db_test.go @@ -41,10 +41,18 @@ func TestDB_OpenModes(t *testing.T) { runSampleInsertSelectTest(t, db) assert.NoError(t, db.Close()) + }) + + t.Run("file-based DB creates directory", func(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested", "dir") + dbName := "test.db" + + db, err := OpenFileDB(dir, dbName, true) + require.NoError(t, err) + require.NotNil(t, db) - t.Run("close twice", func(t *testing.T) { - assert.NoError(t, db.Close()) - }) + assert.FileExists(t, filepath.Join(dir, dbName)) + assert.NoError(t, db.Close()) }) t.Run("invalid path fails", func(t *testing.T) { @@ -54,6 +62,36 @@ func TestDB_OpenModes(t *testing.T) { }) } +func TestDB_PragmaOptimizations(t *testing.T) { + dir := t.TempDir() + db, err := OpenFileDB(dir, "pragma_test.db", true) + require.NoError(t, err) + defer db.Close() + + // Verify WAL mode is active + var journalMode string + err = db.Client().Raw("PRAGMA journal_mode").Scan(&journalMode).Error + require.NoError(t, err) + assert.Equal(t, "wal", journalMode) + + // Verify synchronous is NORMAL (1) + var syncMode int + err = db.Client().Raw("PRAGMA synchronous").Scan(&syncMode).Error + require.NoError(t, err) + assert.Equal(t, 1, syncMode) + + // Verify foreign keys are enabled + var fkEnabled int + err = db.Client().Raw("PRAGMA foreign_keys").Scan(&fkEnabled).Error + require.NoError(t, err) + assert.Equal(t, 1, fkEnabled) +} + +func TestDB_SchemaModels(t *testing.T) { + models := schemaModels() + assert.Len(t, models, 2) +} + func runSampleInsertSelectTest(t *testing.T, db *DB) { // Given a sample row entry := store.State{ From 0d62f80b4b477e79671ecdef95c630095363eef4 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 16:36:40 +0530 Subject: [PATCH 04/28] removed unused fn, remove multiple initialization of authz accounts --- universalClient/pushcore/pushCore.go | 251 ++++------------------ universalClient/pushcore/pushCore_test.go | 93 ++++++-- 2 files changed, 114 insertions(+), 230 deletions(-) diff --git a/universalClient/pushcore/pushCore.go b/universalClient/pushcore/pushCore.go index 6721cef9..8554442c 100644 --- a/universalClient/pushcore/pushCore.go +++ b/universalClient/pushcore/pushCore.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "math/big" - "net/url" "strings" "sync/atomic" @@ -37,6 +36,8 @@ type Client struct { uexecutorClients []uexecutortypes.QueryClient // Executor query clients (for gas price queries) cmtClients []cmtservice.ServiceClient // CometBFT service clients txClients []tx.ServiceClient // Transaction service clients + authzClients []authz.QueryClient // AuthZ query clients + authClients []authtypes.QueryClient // Auth query clients conns []*grpc.ClientConn // Owned gRPC connections (for cleanup) rr uint32 // Round-robin counter for endpoint selection } @@ -51,14 +52,6 @@ type TxResult struct { // New creates a new Client by dialing the provided gRPC URLs. // It attempts to connect to all endpoints and skips any that fail to dial. // At least one endpoint must succeed, otherwise an error is returned. -// -// Parameters: -// - urls: List of gRPC endpoint URLs (schemes are automatically detected) -// - logger: Logger instance for client operations -// -// Returns: -// - *Client: A configured client instance, or nil on error -// - error: Error if all endpoints fail to connect func New(urls []string, logger zerolog.Logger) (*Client, error) { if len(urls) == 0 { return nil, errors.New("pushcore: at least one gRPC URL is required") @@ -69,8 +62,7 @@ func New(urls []string, logger zerolog.Logger) (*Client, error) { } for i, u := range urls { - // Use the local utility function - conn, err := CreateGRPCConnection(u) + conn, err := createGRPCConnection(u) if err != nil { c.logger.Warn().Str("url", u).Int("index", i).Err(err).Msg("dial failed; skipping endpoint") continue @@ -82,10 +74,11 @@ func New(urls []string, logger zerolog.Logger) (*Client, error) { c.uexecutorClients = append(c.uexecutorClients, uexecutortypes.NewQueryClient(conn)) c.cmtClients = append(c.cmtClients, cmtservice.NewServiceClient(conn)) c.txClients = append(c.txClients, tx.NewServiceClient(conn)) + c.authzClients = append(c.authzClients, authz.NewQueryClient(conn)) + c.authClients = append(c.authClients, authtypes.NewQueryClient(conn)) } if len(c.eps) == 0 { - // nothing usable _ = c.Close() return nil, fmt.Errorf("pushcore: all dials failed (%d urls)", len(urls)) } @@ -109,22 +102,13 @@ func (c *Client) Close() error { c.uexecutorClients = nil c.cmtClients = nil c.txClients = nil + c.authzClients = nil + c.authClients = nil return firstErr } // retryWithRoundRobin executes a function across multiple endpoints in round-robin order. // It tries each endpoint until one succeeds or all fail. -// -// Parameters: -// - numClients: Number of client endpoints available -// - rrCounter: Pointer to round-robin counter (atomic) -// - operation: Function to execute for each attempt, receives the endpoint index -// - operationName: Name of the operation (for logging and error messages) -// - logger: Logger for debug messages -// -// Returns: -// - T: Result from the operation if successful -// - error: Error if all endpoints fail func retryWithRoundRobin[T any]( numClients int, rrCounter *uint32, @@ -160,14 +144,6 @@ func retryWithRoundRobin[T any]( } // GetAllChainConfigs retrieves all chain configurations from Push Chain. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// -// Returns: -// - []*uregistrytypes.ChainConfig: List of chain configurations -// - error: Error if all endpoints fail func (c *Client) GetAllChainConfigs(ctx context.Context) ([]*uregistrytypes.ChainConfig, error) { return retryWithRoundRobin( len(c.eps), @@ -185,14 +161,6 @@ func (c *Client) GetAllChainConfigs(ctx context.Context) ([]*uregistrytypes.Chai } // GetLatestBlock retrieves the latest block from Push Chain. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// -// Returns: -// - uint64: Latest block height -// - error: Error if all endpoints fail func (c *Client) GetLatestBlock(ctx context.Context) (uint64, error) { return retryWithRoundRobin( len(c.cmtClients), @@ -213,14 +181,6 @@ func (c *Client) GetLatestBlock(ctx context.Context) (uint64, error) { } // GetAllUniversalValidators retrieves all universal validators from Push Chain. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// -// Returns: -// - []*uvalidatortypes.UniversalValidator: List of universal validators -// - error: Error if all endpoints fail func (c *Client) GetAllUniversalValidators(ctx context.Context) ([]*uvalidatortypes.UniversalValidator, error) { return retryWithRoundRobin( len(c.uvalidatorClients), @@ -238,14 +198,6 @@ func (c *Client) GetAllUniversalValidators(ctx context.Context) ([]*uvalidatorty } // GetCurrentKey retrieves the current TSS key from Push Chain. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// -// Returns: -// - *utsstypes.TssKey: TSS key -// - error: Error if all endpoints fail or no key exists func (c *Client) GetCurrentKey(ctx context.Context) (*utsstypes.TssKey, error) { return retryWithRoundRobin( len(c.utssClients), @@ -267,19 +219,7 @@ func (c *Client) GetCurrentKey(ctx context.Context) (*utsstypes.TssKey, error) { // GetTxsByEvents queries transactions matching the given event query. // The query should follow Cosmos SDK event query format, e.g., "tss_process_initiated.process_id EXISTS". -// -// Parameters: -// - ctx: Context for the request -// - eventQuery: Cosmos SDK event query string -// - minHeight: Minimum block height to search (0 means no minimum) -// - maxHeight: Maximum block height to search (0 means no maximum) -// - limit: Maximum number of results to return (0 defaults to 100) -// -// Returns: -// - []*TxResult: List of matching transaction results -// - error: Error if all endpoints fail func (c *Client) GetTxsByEvents(ctx context.Context, eventQuery string, minHeight, maxHeight uint64, limit uint64) ([]*TxResult, error) { - // Build the query events (same for all attempts) events := []string{eventQuery} if minHeight > 0 { events = append(events, fmt.Sprintf("tx.height>=%d", minHeight)) @@ -288,13 +228,11 @@ func (c *Client) GetTxsByEvents(ctx context.Context, eventQuery string, minHeigh events = append(events, fmt.Sprintf("tx.height<=%d", maxHeight)) } - // Set pagination limit pageLimit := limit if pageLimit == 0 { - pageLimit = 100 // default limit + pageLimit = 100 } - // Join events with AND to create query string (SDK v0.50+ uses Query field) queryString := strings.Join(events, " AND ") return retryWithRoundRobin( @@ -314,13 +252,17 @@ func (c *Client) GetTxsByEvents(ctx context.Context, eventQuery string, minHeigh return nil, err } + if len(resp.Txs) != len(resp.TxResponses) { + return nil, fmt.Errorf("pushcore: mismatched Txs (%d) and TxResponses (%d) lengths", len(resp.Txs), len(resp.TxResponses)) + } + results := make([]*TxResult, 0, len(resp.TxResponses)) - for _, txResp := range resp.TxResponses { + for i, txResp := range resp.TxResponses { results = append(results, &TxResult{ TxHash: txResp.TxHash, Height: txResp.Height, TxResponse: &tx.GetTxResponse{ - Tx: resp.Txs[len(results)], + Tx: resp.Txs[i], TxResponse: txResp, }, }) @@ -333,15 +275,6 @@ func (c *Client) GetTxsByEvents(ctx context.Context, eventQuery string, minHeigh } // GetGasPrice retrieves the median gas price for a specific chain from the on-chain oracle. -// The gas price is voted on by universal validators and stored on-chain. -// -// Parameters: -// - ctx: Context for the request -// - chainID: Chain identifier in CAIP-2 format (e.g., "eip155:84532" for Base Sepolia) -// -// Returns: -// - *big.Int: Median gas price in the chain's native unit (Wei for EVM chains, lamports for Solana) -// - error: Error if all endpoints fail or chainID is invalid func (c *Client) GetGasPrice(ctx context.Context, chainID string) (*big.Int, error) { if chainID == "" { return nil, errors.New("pushcore: chainID is required") @@ -361,14 +294,12 @@ func (c *Client) GetGasPrice(ctx context.Context, chainID string) (*big.Int, err return nil, errors.New("pushcore: GasPrice response is nil") } - // Get the median price using MedianIndex if len(resp.GasPrice.Prices) == 0 { return nil, fmt.Errorf("pushcore: no gas prices available for chain %s", chainID) } medianIdx := resp.GasPrice.MedianIndex if medianIdx >= uint64(len(resp.GasPrice.Prices)) { - // Fallback to first price if median index is out of bounds medianIdx = 0 } @@ -381,27 +312,12 @@ func (c *Client) GetGasPrice(ctx context.Context, chainID string) (*big.Int, err } // GetGranteeGrants queries AuthZ grants for a grantee using round-robin logic. -// This function only queries and returns raw grant data; it does not perform validation or processing. -// -// Parameters: -// - ctx: Context for the request -// - granteeAddr: Address of the grantee to query grants for -// -// Returns: -// - *authz.QueryGranteeGrantsResponse: Raw grant response from the chain -// - error: Error if all endpoints fail func (c *Client) GetGranteeGrants(ctx context.Context, granteeAddr string) (*authz.QueryGranteeGrantsResponse, error) { - // Create authz clients from existing connections - authzClients := make([]authz.QueryClient, len(c.conns)) - for i, conn := range c.conns { - authzClients[i] = authz.NewQueryClient(conn) - } - return retryWithRoundRobin( - len(authzClients), + len(c.authzClients), &c.rr, func(idx int) (*authz.QueryGranteeGrantsResponse, error) { - return authzClients[idx].GranteeGrants(ctx, &authz.QueryGranteeGrantsRequest{ + return c.authzClients[idx].GranteeGrants(ctx, &authz.QueryGranteeGrantsRequest{ Grantee: granteeAddr, }) }, @@ -411,27 +327,12 @@ func (c *Client) GetGranteeGrants(ctx context.Context, granteeAddr string) (*aut } // GetAccount retrieves account information for a given address. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// - address: Bech32 address of the account -// -// Returns: -// - *authtypes.QueryAccountResponse: Account response -// - error: Error if all endpoints fail func (c *Client) GetAccount(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { - // Create auth clients from existing connections - authClients := make([]authtypes.QueryClient, len(c.conns)) - for i, conn := range c.conns { - authClients[i] = authtypes.NewQueryClient(conn) - } - return retryWithRoundRobin( - len(authClients), + len(c.authClients), &c.rr, func(idx int) (*authtypes.QueryAccountResponse, error) { - return authClients[idx].Account(ctx, &authtypes.QueryAccountRequest{ + return c.authClients[idx].Account(ctx, &authtypes.QueryAccountRequest{ Address: address, }) }, @@ -440,26 +341,30 @@ func (c *Client) GetAccount(ctx context.Context, address string) (*authtypes.Que ) } -// CreateGRPCConnection creates a gRPC connection with appropriate transport security. -// It automatically detects whether to use TLS based on the URL scheme. -// -// The function handles: -// - https:// URLs: Uses TLS with default credentials -// - http:// or no scheme: Uses insecure connection -// - Automatically adds default port 9090 if no port is specified -// -// Parameters: -// - endpoint: gRPC endpoint URL (scheme is optional, port defaults to 9090) -// -// Returns: -// - *grpc.ClientConn: gRPC client connection -// - error: Error if connection fails -func CreateGRPCConnection(endpoint string) (*grpc.ClientConn, error) { +// BroadcastTx broadcasts a signed transaction to the chain. +func (c *Client) BroadcastTx(ctx context.Context, txBytes []byte) (*tx.BroadcastTxResponse, error) { + return retryWithRoundRobin( + len(c.txClients), + &c.rr, + func(idx int) (*tx.BroadcastTxResponse, error) { + return c.txClients[idx].BroadcastTx(ctx, &tx.BroadcastTxRequest{ + TxBytes: txBytes, + Mode: tx.BroadcastMode_BROADCAST_MODE_SYNC, + }) + }, + "BroadcastTx", + c.logger, + ) +} + +// createGRPCConnection creates a gRPC connection with appropriate transport security. +// It automatically detects whether to use TLS based on the URL scheme +// and adds default port 9090 if no port is specified. +func createGRPCConnection(endpoint string) (*grpc.ClientConn, error) { if endpoint == "" { return nil, fmt.Errorf("empty endpoint provided") } - // Determine if we should use TLS and process the endpoint processedEndpoint := endpoint useTLS := false @@ -475,16 +380,13 @@ func CreateGRPCConnection(endpoint string) (*grpc.ClientConn, error) { if !strings.Contains(processedEndpoint, ":") { processedEndpoint = processedEndpoint + ":9090" } else { - // Check if the port is valid (i.e., after the last colon is a number) lastColon := strings.LastIndex(processedEndpoint, ":") afterColon := processedEndpoint[lastColon+1:] if afterColon == "" || strings.Contains(afterColon, "/") { - // No valid port, add default processedEndpoint = strings.TrimSuffix(processedEndpoint, ":") + ":9090" } } - // Configure connection options var opts []grpc.DialOption if useTLS { opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(nil))) @@ -492,7 +394,6 @@ func CreateGRPCConnection(endpoint string) (*grpc.ClientConn, error) { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } - // Create the connection conn, err := grpc.NewClient(processedEndpoint, opts...) if err != nil { return nil, fmt.Errorf("failed to create gRPC connection to %s: %w", processedEndpoint, err) @@ -500,79 +401,3 @@ func CreateGRPCConnection(endpoint string) (*grpc.ClientConn, error) { return conn, nil } - -// BroadcastTx broadcasts a signed transaction to the chain. -// It tries each endpoint in round-robin order until one succeeds. -// -// Parameters: -// - ctx: Context for the request -// - txBytes: Signed transaction bytes -// -// Returns: -// - *tx.BroadcastTxResponse: Broadcast response containing tx hash and result -// - error: Error if all endpoints fail -func (c *Client) BroadcastTx(ctx context.Context, txBytes []byte) (*tx.BroadcastTxResponse, error) { - return retryWithRoundRobin( - len(c.txClients), - &c.rr, - func(idx int) (*tx.BroadcastTxResponse, error) { - return c.txClients[idx].BroadcastTx(ctx, &tx.BroadcastTxRequest{ - TxBytes: txBytes, - Mode: tx.BroadcastMode_BROADCAST_MODE_SYNC, - }) - }, - "BroadcastTx", - c.logger, - ) -} - -// ExtractHostnameFromURL extracts the hostname from a URL string. -// It handles various URL formats including full URLs with scheme, URLs without scheme, and plain hostnames. -// -// Parameters: -// - grpcURL: URL string in any format (with or without scheme/port) -// -// Returns: -// - string: Hostname without port or scheme -// - error: Error if hostname cannot be extracted -func ExtractHostnameFromURL(grpcURL string) (string, error) { - if grpcURL == "" { - return "", fmt.Errorf("empty URL provided") - } - - // Try to parse as a standard URL - parsedURL, err := url.Parse(grpcURL) - if err == nil && parsedURL.Hostname() != "" { - // Successfully parsed and has a hostname - return parsedURL.Hostname(), nil - } - - // Fallback: Handle cases where url.Parse fails or returns empty hostname - // This handles plain hostnames like "example.com" or "example.com:9090" - hostname := grpcURL - - // Remove common schemes if present - if strings.HasPrefix(hostname, "https://") { - hostname = strings.TrimPrefix(hostname, "https://") - } else if strings.HasPrefix(hostname, "http://") { - hostname = strings.TrimPrefix(hostname, "http://") - } - - // Remove port if present (but check that there's something before the colon) - if colonIndex := strings.Index(hostname, ":"); colonIndex >= 0 { - if colonIndex == 0 { - // URL starts with ":", no hostname - return "", fmt.Errorf("could not extract hostname from URL: %s", grpcURL) - } - hostname = hostname[:colonIndex] - } - - // Remove any trailing slashes - hostname = strings.TrimSuffix(hostname, "/") - - if hostname == "" { - return "", fmt.Errorf("could not extract hostname from URL: %s", grpcURL) - } - - return hostname, nil -} diff --git a/universalClient/pushcore/pushCore_test.go b/universalClient/pushcore/pushCore_test.go index abff6232..015a7d7a 100644 --- a/universalClient/pushcore/pushCore_test.go +++ b/universalClient/pushcore/pushCore_test.go @@ -9,6 +9,7 @@ import ( sdktypes "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/tx" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + "github.com/cosmos/cosmos-sdk/x/authz" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" uregistrytypes "github.com/pushchain/push-chain-node/x/uregistry/types" utsstypes "github.com/pushchain/push-chain-node/x/utss/types" @@ -89,6 +90,9 @@ func TestNew(t *testing.T) { } else { require.NotNil(t, client) assert.NotNil(t, client.logger) + // Verify authz and auth clients are initialized + assert.Equal(t, len(client.conns), len(client.authzClients)) + assert.Equal(t, len(client.conns), len(client.authClients)) _ = client.Close() } } @@ -571,8 +575,8 @@ func TestClient_GetGranteeGrants(t *testing.T) { t.Run("no endpoints configured", func(t *testing.T) { client := &Client{ - logger: logger, - conns: []*grpc.ClientConn{}, + logger: logger, + authzClients: []authz.QueryClient{}, } grants, err := client.GetGranteeGrants(context.Background(), "cosmos1abc...") @@ -582,17 +586,41 @@ func TestClient_GetGranteeGrants(t *testing.T) { }) t.Run("successful query with mock", func(t *testing.T) { - // Note: This test requires actual gRPC connections, so we'll test the error case - // For a full mock test, we'd need to set up a gRPC server + mockClient := &mockAuthzQueryClient{ + granteeGrantsResp: &authz.QueryGranteeGrantsResponse{ + Grants: []*authz.GrantAuthorization{ + {Granter: "push1granter"}, + }, + }, + } + client := &Client{ - logger: logger, - conns: []*grpc.ClientConn{}, + logger: logger, + authzClients: []authz.QueryClient{mockClient}, } grants, err := client.GetGranteeGrants(context.Background(), "cosmos1abc...") - require.Error(t, err) - assert.Contains(t, err.Error(), "no endpoints configured") - assert.Nil(t, grants) + require.NoError(t, err) + require.Len(t, grants.Grants, 1) + assert.Equal(t, "push1granter", grants.Grants[0].Granter) + }) + + t.Run("round robin failover", func(t *testing.T) { + failingClient := &mockAuthzQueryClient{err: assert.AnError} + successClient := &mockAuthzQueryClient{ + granteeGrantsResp: &authz.QueryGranteeGrantsResponse{ + Grants: []*authz.GrantAuthorization{}, + }, + } + + client := &Client{ + logger: logger, + authzClients: []authz.QueryClient{failingClient, successClient}, + } + + grants, err := client.GetGranteeGrants(context.Background(), "cosmos1abc...") + require.NoError(t, err) + assert.NotNil(t, grants) }) } @@ -602,8 +630,8 @@ func TestClient_GetAccount(t *testing.T) { t.Run("no endpoints configured", func(t *testing.T) { client := &Client{ - logger: logger, - conns: []*grpc.ClientConn{}, + logger: logger, + authClients: []authtypes.QueryClient{}, } account, err := client.GetAccount(ctx, "cosmos1abc123") @@ -612,15 +640,33 @@ func TestClient_GetAccount(t *testing.T) { assert.Nil(t, account) }) - t.Run("empty address", func(t *testing.T) { + t.Run("successful query with mock", func(t *testing.T) { + mockClient := &mockAuthAccountQueryClient{ + accountResp: &authtypes.QueryAccountResponse{}, + } + + client := &Client{ + logger: logger, + authClients: []authtypes.QueryClient{mockClient}, + } + + account, err := client.GetAccount(ctx, "cosmos1abc123") + require.NoError(t, err) + assert.NotNil(t, account) + }) + + t.Run("all endpoints fail", func(t *testing.T) { client := &Client{ logger: logger, - conns: []*grpc.ClientConn{}, + authClients: []authtypes.QueryClient{ + &mockAuthAccountQueryClient{err: assert.AnError}, + &mockAuthAccountQueryClient{err: assert.AnError}, + }, } - account, err := client.GetAccount(ctx, "") + account, err := client.GetAccount(ctx, "cosmos1abc123") require.Error(t, err) - assert.Contains(t, err.Error(), "no endpoints configured") + assert.Contains(t, err.Error(), "failed on all 2 endpoints") assert.Nil(t, account) }) } @@ -745,13 +791,26 @@ func (m *mockUExecutorQueryClient) AllGasPrices(ctx context.Context, req *uexecu return nil, nil } -type mockAuthQueryClient struct { +type mockAuthzQueryClient struct { + authz.QueryClient + granteeGrantsResp *authz.QueryGranteeGrantsResponse + err error +} + +func (m *mockAuthzQueryClient) GranteeGrants(ctx context.Context, req *authz.QueryGranteeGrantsRequest, opts ...grpc.CallOption) (*authz.QueryGranteeGrantsResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.granteeGrantsResp, nil +} + +type mockAuthAccountQueryClient struct { authtypes.QueryClient accountResp *authtypes.QueryAccountResponse err error } -func (m *mockAuthQueryClient) Account(ctx context.Context, req *authtypes.QueryAccountRequest, opts ...grpc.CallOption) (*authtypes.QueryAccountResponse, error) { +func (m *mockAuthAccountQueryClient) Account(ctx context.Context, req *authtypes.QueryAccountRequest, opts ...grpc.CallOption) (*authtypes.QueryAccountResponse, error) { if m.err != nil { return nil, m.err } From 95f882f05f8d138491864077447433568fb50855 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:12:24 +0530 Subject: [PATCH 05/28] refactor: keys package --- universalClient/pushsigner/keys/interfaces.go | 32 --- universalClient/pushsigner/keys/keyring.go | 184 ------------ .../pushsigner/keys/keyring_test.go | 199 ------------- universalClient/pushsigner/keys/keys.go | 98 +++++-- universalClient/pushsigner/keys/keys_test.go | 268 ++++++++---------- 5 files changed, 189 insertions(+), 592 deletions(-) delete mode 100644 universalClient/pushsigner/keys/interfaces.go delete mode 100644 universalClient/pushsigner/keys/keyring.go delete mode 100644 universalClient/pushsigner/keys/keyring_test.go diff --git a/universalClient/pushsigner/keys/interfaces.go b/universalClient/pushsigner/keys/interfaces.go deleted file mode 100644 index eafad58d..00000000 --- a/universalClient/pushsigner/keys/interfaces.go +++ /dev/null @@ -1,32 +0,0 @@ -package keysv2 - -import ( - "github.com/cosmos/cosmos-sdk/crypto/keyring" - sdk "github.com/cosmos/cosmos-sdk/types" -) - -// KeyringBackend represents the type of keyring backend to use -type KeyringBackend string - -const ( - // KeyringBackendTest is the test Cosmos keyring backend (unencrypted) - KeyringBackendTest KeyringBackend = "test" - - // KeyringBackendFile is the file Cosmos keyring backend (encrypted) - KeyringBackendFile KeyringBackend = "file" -) - -// UniversalValidatorKeys defines the interface for key management in Universal Validator -type UniversalValidatorKeys interface { - // GetAddress returns the hot key address - GetAddress() (sdk.AccAddress, error) - - // GetKeyName returns the name of the hot key in the keyring - GetKeyName() string - - // GetKeyring returns the underlying keyring for signing operations. - // It validates that the key exists before returning the keyring. - // For file backend, decryption happens automatically when signing via tx.Sign(). - // This allows signing without exposing the private key. - GetKeyring() (keyring.Keyring, error) -} diff --git a/universalClient/pushsigner/keys/keyring.go b/universalClient/pushsigner/keys/keyring.go deleted file mode 100644 index e37faaf0..00000000 --- a/universalClient/pushsigner/keys/keyring.go +++ /dev/null @@ -1,184 +0,0 @@ -package keysv2 - -import ( - "fmt" - "io" - "os" - "strings" - - "github.com/cosmos/cosmos-sdk/codec" - codectypes "github.com/cosmos/cosmos-sdk/codec/types" - cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" - "github.com/cosmos/cosmos-sdk/crypto/keyring" - "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" - "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - sdk "github.com/cosmos/cosmos-sdk/types" - evmcrypto "github.com/cosmos/evm/crypto/ethsecp256k1" - evmhd "github.com/cosmos/evm/crypto/hd" - cosmosevmkeyring "github.com/cosmos/evm/crypto/keyring" - "github.com/rs/zerolog/log" - - "github.com/pushchain/push-chain-node/universalClient/config" -) - -// KeyringConfig holds configuration for keyring initialization -type KeyringConfig struct { - HomeDir string - KeyringBackend KeyringBackend - HotkeyName string - HotkeyPassword string -} - -// GetKeyringKeybase creates and returns keyring and key info -func GetKeyringKeybase(cfg KeyringConfig) (keyring.Keyring, string, error) { - logger := log.Logger.With().Str("module", "GetKeyringKeybase").Logger() - - if len(cfg.HotkeyName) == 0 { - return nil, "", fmt.Errorf("hotkey name is empty") - } - - if len(cfg.HomeDir) == 0 { - return nil, "", fmt.Errorf("home directory is empty") - } - - // Prepare password reader for file backend - var reader io.Reader = strings.NewReader("") - if cfg.KeyringBackend == KeyringBackendFile { - if cfg.HotkeyPassword == "" { - return nil, "", fmt.Errorf("password is required for file backend") - } - // Keyring expects password twice, each followed by newline - passwordInput := fmt.Sprintf("%s\n%s\n", cfg.HotkeyPassword, cfg.HotkeyPassword) - reader = strings.NewReader(passwordInput) - } - - kb, err := CreateKeyring(cfg.HomeDir, reader, cfg.KeyringBackend) - if err != nil { - return nil, "", fmt.Errorf("failed to get keybase: %w", err) - } - - // Temporarily disable stdin to avoid prompts - oldStdIn := os.Stdin - defer func() { - os.Stdin = oldStdIn - }() - os.Stdin = nil - - logger.Debug(). - Msgf("Checking for Hotkey: %s \nFolder: %s\nBackend: %s", - cfg.HotkeyName, cfg.HomeDir, kb.Backend()) - - rc, err := kb.Key(cfg.HotkeyName) - if err != nil { - return nil, "", fmt.Errorf("key not present in backend %s with name (%s): %w", - kb.Backend(), cfg.HotkeyName, err) - } - - // Get public key in bech32 format - pubkeyBech32, err := getPubkeyBech32FromRecord(rc) - if err != nil { - return nil, "", fmt.Errorf("failed to get pubkey from record: %w", err) - } - - return kb, pubkeyBech32, nil -} - -// CreateNewKey creates a new key in the keyring and returns the record and mnemonic. -// If mnemonic is provided, it imports the key; otherwise, it generates a new one. -// The returned mnemonic will be empty if importing from an existing mnemonic. -func CreateNewKey(kr keyring.Keyring, name string, mnemonic string, passphrase string) (*keyring.Record, string, error) { - if mnemonic != "" { - // Import from mnemonic using EVM algorithm - record, err := kr.NewAccount(name, mnemonic, passphrase, sdk.FullFundraiserPath, evmhd.EthSecp256k1) - return record, mnemonic, err - } - - // Generate new key with mnemonic using EVM algorithm - record, generatedMnemonic, err := kr.NewMnemonic(name, keyring.English, sdk.FullFundraiserPath, passphrase, evmhd.EthSecp256k1) - if err != nil { - return nil, "", fmt.Errorf("failed to generate new key with mnemonic: %w", err) - } - - return record, generatedMnemonic, nil -} - -// CreateInterfaceRegistryWithEVMSupport creates an interface registry with EVM-compatible key types -func CreateInterfaceRegistryWithEVMSupport() codectypes.InterfaceRegistry { - registry := codectypes.NewInterfaceRegistry() - cryptocodec.RegisterInterfaces(registry) - - // Register all key types (both public and private) - registry.RegisterImplementations((*cryptotypes.PubKey)(nil), - &secp256k1.PubKey{}, - &ed25519.PubKey{}, - &evmcrypto.PubKey{}, - ) - registry.RegisterImplementations((*cryptotypes.PrivKey)(nil), - &secp256k1.PrivKey{}, - &ed25519.PrivKey{}, - &evmcrypto.PrivKey{}, - ) - - return registry -} - -// CreateKeyring creates a keyring with EVM compatibility -func CreateKeyring(homeDir string, reader io.Reader, keyringBackend KeyringBackend) (keyring.Keyring, error) { - if len(homeDir) == 0 { - return nil, fmt.Errorf("home directory is empty") - } - - // Create codec with EVM-compatible key types directly - registry := CreateInterfaceRegistryWithEVMSupport() - cdc := codec.NewProtoCodec(registry) - - // Determine backend type - var backend string - switch keyringBackend { - case KeyringBackendFile: - backend = "file" - case KeyringBackendTest: - backend = "test" - default: - backend = "test" // Default to test backend - } - - // Create keyring with appropriate backend and EVM compatibility - return keyring.New(sdk.KeyringServiceName(), backend, homeDir, reader, cdc, cosmosevmkeyring.Option()) -} - -// CreateKeyringFromConfig creates a keyring with EVM compatibility from config backend type -func CreateKeyringFromConfig(homeDir string, reader io.Reader, configBackend config.KeyringBackend) (keyring.Keyring, error) { - // Convert config types to keys types - var keysBackend KeyringBackend - switch configBackend { - case config.KeyringBackendFile: - keysBackend = KeyringBackendFile - case config.KeyringBackendTest: - keysBackend = KeyringBackendTest - default: - keysBackend = KeyringBackendTest - } - - return CreateKeyring(homeDir, reader, keysBackend) -} - -// getPubkeyBech32FromRecord extracts bech32 public key from key record -func getPubkeyBech32FromRecord(record *keyring.Record) (string, error) { - pubkey, err := record.GetPubKey() - if err != nil { - return "", fmt.Errorf("failed to get public key: %w", err) - } - - // Return hex representation of the public key with prefix - return fmt.Sprintf("pushpub%x", pubkey.Bytes()), nil -} - -// ValidateKeyExists checks if a key exists in the keyring -func ValidateKeyExists(kr keyring.Keyring, keyName string) error { - if _, err := kr.Key(keyName); err != nil { - return fmt.Errorf("key %s not found: %w", keyName, err) - } - return nil -} diff --git a/universalClient/pushsigner/keys/keyring_test.go b/universalClient/pushsigner/keys/keyring_test.go deleted file mode 100644 index 8cc607f9..00000000 --- a/universalClient/pushsigner/keys/keyring_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package keysv2 - -import ( - "os" - "testing" - - "github.com/cosmos/cosmos-sdk/crypto/keyring" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// KeyringTestSuite tests keyring operations -type KeyringTestSuite struct { - suite.Suite - tempDir string - config KeyringConfig - kb keyring.Keyring -} - -func (suite *KeyringTestSuite) SetupTest() { - // Initialize SDK config safely - check if already sealed - sdkConfig := sdk.GetConfig() - func() { - defer func() { - // Config already sealed, that's fine - ignore panic - _ = recover() - }() - sdkConfig.SetBech32PrefixForAccount("push", "pushpub") - sdkConfig.SetBech32PrefixForValidator("pushvaloper", "pushvaloperpub") - sdkConfig.SetBech32PrefixForConsensusNode("pushvalcons", "pushvalconspub") - sdkConfig.Seal() - }() - - // Create temporary directory - var err error - suite.tempDir, err = os.MkdirTemp("", "keyring-test") - require.NoError(suite.T(), err) - - // Create keyring config - suite.config = KeyringConfig{ - HomeDir: suite.tempDir, - KeyringBackend: KeyringBackendTest, - HotkeyName: "test-key", - HotkeyPassword: "", - } - - // Create keyring with EVM compatibility using our standard CreateKeyring function - suite.kb, err = CreateKeyring(suite.tempDir, nil, KeyringBackendTest) - require.NoError(suite.T(), err, "keyring creation should succeed") - require.NotNil(suite.T(), suite.kb, "keyring should be initialized") -} - -func (suite *KeyringTestSuite) TearDownTest() { - if suite.tempDir != "" { - _ = os.RemoveAll(suite.tempDir) - } -} - -// TestGetKeyringKeybase tests keyring creation -func (suite *KeyringTestSuite) TestGetKeyringKeybase() { - kb, record, err := GetKeyringKeybase(suite.config) - - // Should fail because the key doesn't exist yet - assert.Error(suite.T(), err) - assert.Nil(suite.T(), kb) - assert.Equal(suite.T(), "", record) - assert.Contains(suite.T(), err.Error(), "not found") -} - -// TestGetKeyringKeybaseWithExistingKey tests keyring with existing key -func (suite *KeyringTestSuite) TestGetKeyringKeybaseWithExistingKey() { - // First create a key in the test keyring - _, _, err := CreateNewKey(suite.kb, "test-key", "", "") - require.NoError(suite.T(), err) - - kb, record, err := GetKeyringKeybase(suite.config) - - // Should succeed now with proper keyring setup - assert.NoError(suite.T(), err) - assert.NotNil(suite.T(), kb) - assert.NotEqual(suite.T(), "", record) -} - -// TestCreateNewKey tests key creation -func (suite *KeyringTestSuite) TestCreateNewKey() { - record, _, err := CreateNewKey(suite.kb, "new-test-key", "", "") - - require.NoError(suite.T(), err) - assert.NotNil(suite.T(), record) - assert.Equal(suite.T(), "new-test-key", record.Name) - - // Verify key was created - retrievedRecord, err := suite.kb.Key("new-test-key") - require.NoError(suite.T(), err) - assert.Equal(suite.T(), record.Name, retrievedRecord.Name) -} - -// TestCreateNewKeyWithMnemonic tests key creation with mnemonic -func (suite *KeyringTestSuite) TestCreateNewKeyWithMnemonic() { - mnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" - - record, returnedMnemonic, err := CreateNewKey(suite.kb, "mnemonic-key", mnemonic, "") - - require.NoError(suite.T(), err) - assert.NotNil(suite.T(), record) - assert.Equal(suite.T(), "mnemonic-key", record.Name) - assert.Equal(suite.T(), mnemonic, returnedMnemonic) // Should return the provided mnemonic - - // Verify key was created - retrievedRecord, err := suite.kb.Key("mnemonic-key") - require.NoError(suite.T(), err) - assert.Equal(suite.T(), record.Name, retrievedRecord.Name) -} - -// TestCreateNewKeyWithInvalidMnemonic tests key creation with invalid mnemonic -func (suite *KeyringTestSuite) TestCreateNewKeyWithInvalidMnemonic() { - invalidMnemonic := "invalid mnemonic words" - - _, _, err := CreateNewKey(suite.kb, "invalid-key", invalidMnemonic, "") - - assert.Error(suite.T(), err) - assert.Contains(suite.T(), err.Error(), "Invalid mnenomic") -} - -// TestGetKeybase tests keybase creation with different backends -func (suite *KeyringTestSuite) TestGetKeybase() { - // Test with test backend - kb, err := CreateKeyring(suite.tempDir, nil, KeyringBackendTest) - - require.NoError(suite.T(), err) - assert.NotNil(suite.T(), kb) - assert.Equal(suite.T(), keyring.BackendTest, kb.Backend()) -} - -// TestGetKeybaseWithFileBackend tests keybase with file backend -func (suite *KeyringTestSuite) TestGetKeybaseWithFileBackend() { - // Create a mock input reader for password (though it won't be called for test) - kb, err := CreateKeyring(suite.tempDir, nil, KeyringBackendFile) - - require.NoError(suite.T(), err) - assert.NotNil(suite.T(), kb) - assert.Equal(suite.T(), keyring.BackendFile, kb.Backend()) -} - -// TestValidateKeyExists tests key existence validation -func (suite *KeyringTestSuite) TestValidateKeyExists() { - // Create a key first - _, _, err := CreateNewKey(suite.kb, "validation-test", "", "") - require.NoError(suite.T(), err) - - // Test existing key - err = ValidateKeyExists(suite.kb, "validation-test") - assert.NoError(suite.T(), err) - - // Test non-existent key - err = ValidateKeyExists(suite.kb, "non-existent") - assert.Error(suite.T(), err) - assert.Contains(suite.T(), err.Error(), "not found") -} - -// TestGetPubkeyBech32FromRecord tests public key extraction -func (suite *KeyringTestSuite) TestGetPubkeyBech32FromRecord() { - // Create a key - record, _, err := CreateNewKey(suite.kb, "pubkey-test", "", "") - require.NoError(suite.T(), err) - - // Get public key - pubkeyBech32, err := getPubkeyBech32FromRecord(record) - - require.NoError(suite.T(), err) - assert.NotEmpty(suite.T(), pubkeyBech32) - assert.Contains(suite.T(), pubkeyBech32, "pushpub") -} - -// TestKeyringConfigValidation tests keyring config validation -func (suite *KeyringTestSuite) TestKeyringConfigValidation() { - // Test valid config - validConfig := KeyringConfig{ - HomeDir: suite.tempDir, - KeyringBackend: KeyringBackendTest, - HotkeyName: "test-key", - HotkeyPassword: "", - } - - // Create a key for this config to work - _, _, err := CreateNewKey(suite.kb, validConfig.HotkeyName, "", "") - require.NoError(suite.T(), err) - - // Test the validation indirectly through GetKeyringKeybase - _, _, err = GetKeyringKeybase(validConfig) - assert.NoError(suite.T(), err) // Should succeed with proper keyring setup -} - -// Run the test suite -func TestKeyring(t *testing.T) { - suite.Run(t, new(KeyringTestSuite)) -} diff --git a/universalClient/pushsigner/keys/keys.go b/universalClient/pushsigner/keys/keys.go index 276021b2..9f460c10 100644 --- a/universalClient/pushsigner/keys/keys.go +++ b/universalClient/pushsigner/keys/keys.go @@ -1,35 +1,40 @@ -package keysv2 +// Package keys provides keyring management for the Push Universal Validator. +package keys import ( "fmt" + "io" + "github.com/cosmos/cosmos-sdk/codec" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" + cryptocodec "github.com/cosmos/cosmos-sdk/crypto/codec" "github.com/cosmos/cosmos-sdk/crypto/keyring" + "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + "github.com/cosmos/cosmos-sdk/crypto/keys/secp256k1" + cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" sdk "github.com/cosmos/cosmos-sdk/types" -) + evmcrypto "github.com/cosmos/evm/crypto/ethsecp256k1" + evmhd "github.com/cosmos/evm/crypto/hd" + cosmosevmkeyring "github.com/cosmos/evm/crypto/keyring" -var _ UniversalValidatorKeys = &Keys{} + "github.com/pushchain/push-chain-node/universalClient/config" +) -// Keys manages all the keys used by Universal Validator +// Keys wraps a Cosmos SDK keyring and a specific key name within it. type Keys struct { - keyName string // Hot key name in keyring - keyring keyring.Keyring // Cosmos SDK keyring - hotkeyPassword string // Password for file backend + keyName string + keyring keyring.Keyring } -// NewKeys creates a new instance of Keys -func NewKeys( - kr keyring.Keyring, - keyName string, - hotkeyPassword string, -) *Keys { +// NewKeys creates a new Keys instance. +func NewKeys(kr keyring.Keyring, keyName string) *Keys { return &Keys{ - keyName: keyName, - keyring: kr, - hotkeyPassword: hotkeyPassword, + keyName: keyName, + keyring: kr, } } -// GetAddress returns the hot key address +// GetAddress returns the address of the key. func (k *Keys) GetAddress() (sdk.AccAddress, error) { info, err := k.keyring.Key(k.keyName) if err != nil { @@ -44,18 +49,67 @@ func (k *Keys) GetAddress() (sdk.AccAddress, error) { return addr, nil } -// GetKeyName returns the name of the hot key in the keyring +// GetKeyName returns the name of the key in the keyring. func (k *Keys) GetKeyName() string { return k.keyName } -// GetKeyring returns the underlying keyring for signing operations. -// It validates that the key exists in the keyring before returning it. -// For file backend, the keyring handles decryption automatically when signing. +// GetKeyring validates the key exists and returns the underlying keyring for signing. func (k *Keys) GetKeyring() (keyring.Keyring, error) { - // Validate that the key exists in the keyring if _, err := k.keyring.Key(k.keyName); err != nil { return nil, fmt.Errorf("key %s not found in keyring: %w", k.keyName, err) } return k.keyring, nil } + +// CreateKeyring creates an EVM-compatible keyring. +func CreateKeyring(homeDir string, reader io.Reader, backend config.KeyringBackend) (keyring.Keyring, error) { + if homeDir == "" { + return nil, fmt.Errorf("home directory is empty") + } + + registry := NewInterfaceRegistryWithEVMSupport() + cdc := codec.NewProtoCodec(registry) + + backendStr := "test" + if backend == config.KeyringBackendFile { + backendStr = "file" + } + + return keyring.New(sdk.KeyringServiceName(), backendStr, homeDir, reader, cdc, cosmosevmkeyring.Option()) +} + +// CreateNewKey creates a new key in the keyring and returns the record and mnemonic. +// If mnemonic is provided, it imports the key; otherwise generates a new one. +func CreateNewKey(kr keyring.Keyring, name string, mnemonic string, passphrase string) (*keyring.Record, string, error) { + if mnemonic != "" { + record, err := kr.NewAccount(name, mnemonic, passphrase, sdk.FullFundraiserPath, evmhd.EthSecp256k1) + return record, mnemonic, err + } + + record, generatedMnemonic, err := kr.NewMnemonic(name, keyring.English, sdk.FullFundraiserPath, passphrase, evmhd.EthSecp256k1) + if err != nil { + return nil, "", fmt.Errorf("failed to generate new key with mnemonic: %w", err) + } + + return record, generatedMnemonic, nil +} + +// NewInterfaceRegistryWithEVMSupport creates an interface registry with EVM-compatible key types. +func NewInterfaceRegistryWithEVMSupport() codectypes.InterfaceRegistry { + registry := codectypes.NewInterfaceRegistry() + cryptocodec.RegisterInterfaces(registry) + + registry.RegisterImplementations((*cryptotypes.PubKey)(nil), + &secp256k1.PubKey{}, + &ed25519.PubKey{}, + &evmcrypto.PubKey{}, + ) + registry.RegisterImplementations((*cryptotypes.PrivKey)(nil), + &secp256k1.PrivKey{}, + &ed25519.PrivKey{}, + &evmcrypto.PrivKey{}, + ) + + return registry +} diff --git a/universalClient/pushsigner/keys/keys_test.go b/universalClient/pushsigner/keys/keys_test.go index a6c165e9..0ca3cc37 100644 --- a/universalClient/pushsigner/keys/keys_test.go +++ b/universalClient/pushsigner/keys/keys_test.go @@ -1,4 +1,4 @@ -package keysv2 +package keys import ( "os" @@ -8,209 +8,167 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pushchain/push-chain-node/universalClient/config" ) func TestMain(m *testing.M) { - // Initialize SDK config for tests - config := sdk.GetConfig() - config.SetBech32PrefixForAccount("push", "pushpub") - config.SetBech32PrefixForValidator("pushvaloper", "pushvaloperpub") - config.SetBech32PrefixForConsensusNode("pushvalcons", "pushvalconspub") - config.Seal() + sdkConfig := sdk.GetConfig() + func() { + defer func() { _ = recover() }() + sdkConfig.SetBech32PrefixForAccount("push", "pushpub") + sdkConfig.SetBech32PrefixForValidator("pushvaloper", "pushvaloperpub") + sdkConfig.SetBech32PrefixForConsensusNode("pushvalcons", "pushvalconspub") + sdkConfig.Seal() + }() os.Exit(m.Run()) } func TestNewKeys(t *testing.T) { - // Create temporary directory for test keyring - tempDir, err := os.MkdirTemp("", "test-keyring") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() - - // Create test keyring - kb, err := CreateKeyring(tempDir, nil, KeyringBackendTest) + tempDir := t.TempDir() + kb, err := CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - // Create basic Keys instance - keys := NewKeys(kb, "test-hotkey", "") - - require.NotNil(t, keys) - require.Equal(t, "test-hotkey", keys.keyName) - require.NotNil(t, keys.keyring) + k := NewKeys(kb, "test-hotkey") - // Test methods that should work without requiring actual key - assert.NotNil(t, keys.keyring) - // Password is not exposed - signing uses keyring directly - assert.Equal(t, "test-hotkey", keys.GetKeyName()) + require.NotNil(t, k) + assert.Equal(t, "test-hotkey", k.GetKeyName()) } func TestKeyringBackends(t *testing.T) { - tests := []struct { - name string - backend KeyringBackend - wantErr bool - }{ - { - name: "test backend", - backend: KeyringBackendTest, - wantErr: false, - }, - { - name: "file backend", - backend: KeyringBackendFile, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tempDir, err := os.MkdirTemp("", "keyring-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() - - kb, err := CreateKeyring(tempDir, nil, tt.backend) - if tt.wantErr { - require.Error(t, err) - require.Nil(t, kb) - } else { - require.NoError(t, err) - require.NotNil(t, kb) - } - }) - } + t.Run("test backend", func(t *testing.T) { + kb, err := CreateKeyring(t.TempDir(), nil, config.KeyringBackendTest) + require.NoError(t, err) + assert.Equal(t, "test", kb.Backend()) + }) + + t.Run("file backend", func(t *testing.T) { + kb, err := CreateKeyring(t.TempDir(), nil, config.KeyringBackendFile) + require.NoError(t, err) + assert.Equal(t, "file", kb.Backend()) + }) + + t.Run("empty home dir", func(t *testing.T) { + _, err := CreateKeyring("", nil, config.KeyringBackendTest) + require.Error(t, err) + assert.Contains(t, err.Error(), "home directory is empty") + }) } -// TestPasswordFailureScenarios tests various password failure scenarios -func TestPasswordFailureScenarios(t *testing.T) { - tempDir, err := os.MkdirTemp("", "test-keyring") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() +func TestKeysWithFileBackend(t *testing.T) { + tempDir := t.TempDir() - // Test with file backend requiring password - // For file backend, we need a password reader passwordReader := strings.NewReader("testpass\ntestpass\n") - kb, err := CreateKeyring(tempDir, passwordReader, KeyringBackendFile) + kb, err := CreateKeyring(tempDir, passwordReader, config.KeyringBackendFile) require.NoError(t, err) - // Create a key first with password _, _, err = CreateNewKey(kb, "test-key", "", "testpass") require.NoError(t, err) - keys := NewKeys(kb, "test-key", "") + k := NewKeys(kb, "test-key") - // Test GetKeyring returns the keyring and validates key exists - kr, err := keys.GetKeyring() + kr, err := k.GetKeyring() require.NoError(t, err) - assert.NotNil(t, kr) - // Verify it's the same backend type assert.Equal(t, kb.Backend(), kr.Backend()) +} - // Test with test backend - kbTest, err := CreateKeyring(tempDir, nil, KeyringBackendTest) +func TestCreateNewKey(t *testing.T) { + tempDir := t.TempDir() + kb, err := CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - keysTest := NewKeys(kbTest, "test-key", "") - // Password is not exposed - signing uses keyring directly - // The keyring handles password internally when needed - assert.NotNil(t, keysTest) + t.Run("generate new key", func(t *testing.T) { + record, mnemonic, err := CreateNewKey(kb, "new-key", "", "") + require.NoError(t, err) + assert.NotNil(t, record) + assert.Equal(t, "new-key", record.Name) + assert.NotEmpty(t, mnemonic) + }) + + t.Run("import from mnemonic", func(t *testing.T) { + mnemonic := "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about" + record, returnedMnemonic, err := CreateNewKey(kb, "mnemonic-key", mnemonic, "") + require.NoError(t, err) + assert.Equal(t, "mnemonic-key", record.Name) + assert.Equal(t, mnemonic, returnedMnemonic) + }) + + t.Run("invalid mnemonic", func(t *testing.T) { + _, _, err := CreateNewKey(kb, "bad-key", "invalid mnemonic words", "") + require.Error(t, err) + }) } -// TestKeyringBackendSwitching tests switching between keyring backends -func TestKeyringBackendSwitching(t *testing.T) { - tempDir, err := os.MkdirTemp("", "test-keyring") +func TestGetAddress(t *testing.T) { + tempDir := t.TempDir() + kb, err := CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() - - tests := []struct { - name string - backend1 KeyringBackend - backend2 KeyringBackend - }{ - { - name: "test to file", - backend1: KeyringBackendTest, - backend2: KeyringBackendFile, - }, - { - name: "file to test", - backend1: KeyringBackendFile, - backend2: KeyringBackendTest, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create keyring with first backend - kb1, err := CreateKeyring(tempDir+"1", nil, tt.backend1) - require.NoError(t, err) - - // Create keyring with second backend - kb2, err := CreateKeyring(tempDir+"2", nil, tt.backend2) - require.NoError(t, err) - - // Both should be valid - assert.NotNil(t, kb1) - assert.NotNil(t, kb2) - assert.Equal(t, string(tt.backend1), kb1.Backend()) - assert.Equal(t, string(tt.backend2), kb2.Backend()) - }) - } + record, _, err := CreateNewKey(kb, "addr-test", "", "") + require.NoError(t, err) + + t.Run("valid key", func(t *testing.T) { + k := NewKeys(kb, record.Name) + addr, err := k.GetAddress() + require.NoError(t, err) + assert.NotEmpty(t, addr) + }) + + t.Run("non-existent key", func(t *testing.T) { + k := NewKeys(kb, "non-existent") + _, err := k.GetAddress() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get key") + }) } -// TestConcurrentKeyAccess tests concurrent access to keys -func TestConcurrentKeyAccess(t *testing.T) { - tempDir, err := os.MkdirTemp("", "test-keyring") +func TestGetKeyring(t *testing.T) { + tempDir := t.TempDir() + kb, err := CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() - // Create test keyring and key - kb, err := CreateKeyring(tempDir, nil, KeyringBackendTest) + record, _, err := CreateNewKey(kb, "kr-test", "", "") require.NoError(t, err) - keyName := "concurrent-test-key" - _, _, err = CreateNewKey(kb, keyName, "", "") + t.Run("valid key", func(t *testing.T) { + k := NewKeys(kb, record.Name) + kr, err := k.GetKeyring() + require.NoError(t, err) + assert.NotNil(t, kr) + }) + + t.Run("non-existent key", func(t *testing.T) { + k := NewKeys(kb, "non-existent") + kr, err := k.GetKeyring() + require.Error(t, err) + assert.Nil(t, kr) + assert.Contains(t, err.Error(), "not found in keyring") + }) +} + +func TestConcurrentKeyAccess(t *testing.T) { + tempDir := t.TempDir() + kb, err := CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - keys := NewKeys(kb, keyName, "") + _, _, err = CreateNewKey(kb, "concurrent-key", "", "") + require.NoError(t, err) + + k := NewKeys(kb, "concurrent-key") - // Test concurrent GetAddress calls const numGoroutines = 10 results := make(chan error, numGoroutines) - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { + _ = i go func() { - _, err := keys.GetAddress() + _, err := k.GetAddress() results <- err }() } - // Collect results - for i := 0; i < numGoroutines; i++ { - err := <-results - assert.NoError(t, err) + for range numGoroutines { + assert.NoError(t, <-results) } } - -// TestErrorConditions tests various error conditions -func TestErrorConditions(t *testing.T) { - tempDir, err := os.MkdirTemp("", "test-keyring") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tempDir) }() - - // Create test keyring - kb, err := CreateKeyring(tempDir, nil, KeyringBackendTest) - require.NoError(t, err) - - keys := NewKeys(kb, "non-existent-key", "") - - // Test GetAddress with non-existent key - _, err = keys.GetAddress() - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get key") - - // Test GetKeyring validates key exists and returns error for non-existent key - kr, err := keys.GetKeyring() - assert.Error(t, err) - assert.Nil(t, kr) - assert.Contains(t, err.Error(), "not found in keyring") -} From 46a3e21ce96a6fc4d0079067ae2d0e78d23d2a47 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:13:21 +0530 Subject: [PATCH 06/28] fix: grant verifier --- universalClient/constant/constant.go | 10 -------- universalClient/pushsigner/grant_verifier.go | 23 ++++++++++++------- .../pushsigner/grant_verifier_test.go | 12 ++++------ 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/universalClient/constant/constant.go b/universalClient/constant/constant.go index ffd74d98..3a84c2cb 100644 --- a/universalClient/constant/constant.go +++ b/universalClient/constant/constant.go @@ -21,13 +21,3 @@ const ( ) var DefaultNodeHome = os.ExpandEnv("$HOME/") + NodeDir - -// RequiredMsgGrants contains all the required message type URLs -// that must be granted via AuthZ for the Universal Validator to function. -// These messages are executed on behalf of the core validator by the grantee (hotkey of the Universal Validator). -var RequiredMsgGrants = []string{ - "/uexecutor.v1.MsgVoteInbound", - "/uexecutor.v1.MsgVoteChainMeta", - "/uexecutor.v1.MsgVoteOutbound", - "/utss.v1.MsgVoteTssKeyProcess", -} diff --git a/universalClient/pushsigner/grant_verifier.go b/universalClient/pushsigner/grant_verifier.go index 6b1e0c2d..e410b5a8 100644 --- a/universalClient/pushsigner/grant_verifier.go +++ b/universalClient/pushsigner/grant_verifier.go @@ -14,12 +14,19 @@ import ( "github.com/cosmos/cosmos-sdk/x/authz" "github.com/pushchain/push-chain-node/universalClient/config" - "github.com/pushchain/push-chain-node/universalClient/constant" - "github.com/pushchain/push-chain-node/universalClient/pushcore" - keysv2 "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" + "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" uetypes "github.com/pushchain/push-chain-node/x/uexecutor/types" ) +// requiredMsgGrants contains all the AuthZ message type URLs +// that must be granted for the Universal Validator to function. +var requiredMsgGrants = []string{ + "/uexecutor.v1.MsgVoteInbound", + "/uexecutor.v1.MsgVoteChainMeta", + "/uexecutor.v1.MsgVoteOutbound", + "/utss.v1.MsgVoteTssKeyProcess", +} + // GrantInfo represents information about a single AuthZ grant. type grantInfo struct { Granter string @@ -37,8 +44,8 @@ type validationResult struct { } // ValidateKeysAndGrants validates hotkey and AuthZ grants against the specified granter. -func validateKeysAndGrants(keyringBackend config.KeyringBackend, keyringPassword string, nodeHome string, pushCore *pushcore.Client, granter string) (*validationResult, error) { - interfaceRegistry := keysv2.CreateInterfaceRegistryWithEVMSupport() +func validateKeysAndGrants(keyringBackend config.KeyringBackend, keyringPassword string, nodeHome string, pushCore chainClient, granter string) (*validationResult, error) { + interfaceRegistry := keys.NewInterfaceRegistryWithEVMSupport() authz.RegisterInterfaces(interfaceRegistry) uetypes.RegisterInterfaces(interfaceRegistry) cdc := codec.NewProtoCodec(interfaceRegistry) @@ -54,7 +61,7 @@ func validateKeysAndGrants(keyringBackend config.KeyringBackend, keyringPassword reader = strings.NewReader(passwordInput) } - kr, err := keysv2.CreateKeyringFromConfig(nodeHome, reader, keyringBackend) + kr, err := keys.CreateKeyring(nodeHome, reader, keyringBackend) if err != nil { return nil, fmt.Errorf("failed to create keyring: %w", err) } @@ -116,14 +123,14 @@ func verifyGrants(grants []grantInfo, granter string) ([]string, error) { } // Check if this grant is for a required message type - if slices.Contains(constant.RequiredMsgGrants, grant.MessageType) { + if slices.Contains(requiredMsgGrants, grant.MessageType) { authorized[grant.MessageType] = true } } // Verify all required grants are present var missing []string - for _, req := range constant.RequiredMsgGrants { + for _, req := range requiredMsgGrants { if !authorized[req] { missing = append(missing, req) } diff --git a/universalClient/pushsigner/grant_verifier_test.go b/universalClient/pushsigner/grant_verifier_test.go index aeb465b9..1ee4a955 100644 --- a/universalClient/pushsigner/grant_verifier_test.go +++ b/universalClient/pushsigner/grant_verifier_test.go @@ -9,8 +9,6 @@ import ( "github.com/cosmos/cosmos-sdk/x/authz" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/pushchain/push-chain-node/universalClient/constant" ) func TestVerifyGrants(t *testing.T) { @@ -28,10 +26,10 @@ func TestVerifyGrants(t *testing.T) { msgs, err := verifyGrants(grants, granter) require.NoError(t, err) - assert.Len(t, msgs, len(constant.RequiredMsgGrants)) + assert.Len(t, msgs, len(requiredMsgGrants)) // Verify all required messages are returned - for _, req := range constant.RequiredMsgGrants { + for _, req := range requiredMsgGrants { assert.Contains(t, msgs, req) } }) @@ -46,7 +44,7 @@ func TestVerifyGrants(t *testing.T) { msgs, err := verifyGrants(grants, granter) require.NoError(t, err) - assert.Len(t, msgs, len(constant.RequiredMsgGrants)) + assert.Len(t, msgs, len(requiredMsgGrants)) }) t.Run("missing required grant", func(t *testing.T) { @@ -111,7 +109,7 @@ func TestVerifyGrants(t *testing.T) { msgs, err := verifyGrants(grants, granter) require.NoError(t, err) - assert.Len(t, msgs, len(constant.RequiredMsgGrants)) + assert.Len(t, msgs, len(requiredMsgGrants)) }) t.Run("extra non-required grants are ignored", func(t *testing.T) { @@ -125,7 +123,7 @@ func TestVerifyGrants(t *testing.T) { msgs, err := verifyGrants(grants, granter) require.NoError(t, err) - assert.Len(t, msgs, len(constant.RequiredMsgGrants)) + assert.Len(t, msgs, len(requiredMsgGrants)) assert.NotContains(t, msgs, "/some.other.v1.MsgNotRequired") }) } From 0e32740ce0d2e77b8ce31a75a980dd006f529a6d Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:13:57 +0530 Subject: [PATCH 07/28] fix: push signer --- universalClient/pushsigner/pushsigner.go | 61 +-- universalClient/pushsigner/pushsigner_test.go | 512 ++++++++++++++---- 2 files changed, 441 insertions(+), 132 deletions(-) diff --git a/universalClient/pushsigner/pushsigner.go b/universalClient/pushsigner/pushsigner.go index 75f16c32..cca89fea 100644 --- a/universalClient/pushsigner/pushsigner.go +++ b/universalClient/pushsigner/pushsigner.go @@ -11,6 +11,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" cosmoskeyring "github.com/cosmos/cosmos-sdk/crypto/keyring" sdk "github.com/cosmos/cosmos-sdk/types" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" "github.com/cosmos/cosmos-sdk/types/tx/signing" authtx "github.com/cosmos/cosmos-sdk/x/auth/tx" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -22,15 +23,23 @@ import ( "github.com/pushchain/push-chain-node/universalClient/config" "github.com/pushchain/push-chain-node/universalClient/pushcore" - keysv2 "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" + "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" ) +// chainClient defines the subset of pushcore.Client methods used by Signer. +// Defined here (consumer-side) so tests can provide mock implementations. +type chainClient interface { + BroadcastTx(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) + GetAccount(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) + GetGranteeGrants(ctx context.Context, granteeAddr string) (*cosmosauthz.QueryGranteeGrantsResponse, error) +} + // Signer provides the main public API for signing and voting operations. type Signer struct { - keys keysv2.UniversalValidatorKeys + keys *keys.Keys clientCtx client.Context - pushCore *pushcore.Client + pushCore chainClient granter string log zerolog.Logger sequenceMutex sync.Mutex // Mutex to synchronize transaction signing @@ -60,10 +69,9 @@ func New( return nil, fmt.Errorf("failed to parse key address: %w", err) } - universalKeys := keysv2.NewKeys( + universalKeys := keys.NewKeys( validationResult.Keyring, validationResult.KeyName, - "", ) derivedAddr, err := universalKeys.GetAddress() @@ -79,10 +87,7 @@ func New( return nil, fmt.Errorf("failed to validate keyring: %w", err) } - clientCtx, err := createClientContext(validationResult.Keyring, chainID) - if err != nil { - return nil, fmt.Errorf("failed to create client context: %w", err) - } + clientCtx := createClientContext(validationResult.Keyring, chainID) log.Info(). Str("key_name", validationResult.KeyName). @@ -172,15 +177,12 @@ func (s *Signer) signAndBroadcastAuthZTx( Uint64("current_sequence", s.lastSequence). Int("attempt", attempt). Msg("Sequence mismatch detected, forcing refresh and retrying") - // Force refresh sequence on next attempt - s.lastSequence = 0 // This will force a refresh from chain - continue // Retry + s.lastSequence = 0 + continue } - // For other errors or final attempt, increment and return error - s.lastSequence++ - s.log.Debug(). - Uint64("new_sequence", s.lastSequence). - Msg("Incremented sequence after broadcast error") + // Network/transport errors: sequence was NOT consumed, don't increment. + // Force refresh from chain on next use to reconcile. + s.lastSequence = 0 return nil, fmt.Errorf("failed to broadcast transaction: %w", err) } @@ -192,25 +194,19 @@ func (s *Signer) signAndBroadcastAuthZTx( // If chain responded with error code, handle sequence-mismatch specially if txResp != nil && txResp.Code != 0 { - // Retry immediately for account sequence mismatch responses if strings.Contains(strings.ToLower(txResp.RawLog), "account sequence mismatch") && attempt < maxAttempts { s.log.Warn(). Uint64("current_sequence", s.lastSequence). Int("attempt", attempt). Str("raw_log", txResp.RawLog). Msg("Sequence mismatch in response, refreshing and retrying") - // Force refresh from chain on next attempt s.lastSequence = 0 continue } - // Conservatively increment sequence since the sequence may have been consumed + // Chain accepted the tx into mempool but it failed — sequence was consumed s.lastSequence++ - s.log.Debug(). - Uint64("new_sequence", s.lastSequence). - Msg("Incremented sequence after on-chain error response") - // Log and return error s.log.Error(). Str("tx_hash", txResp.TxHash). Uint32("code", txResp.Code). @@ -220,18 +216,13 @@ func (s *Signer) signAndBroadcastAuthZTx( return txResp, fmt.Errorf("transaction failed with code %d: %s", txResp.Code, txResp.RawLog) } - // Success: increment sequence once and return + // Success: sequence was consumed s.lastSequence++ - s.log.Debug(). - Uint64("new_sequence", s.lastSequence). - Str("tx_hash", txResp.TxHash). - Msg("Incremented sequence after successful broadcast") - s.log.Debug(). Str("tx_hash", txResp.TxHash). Int64("gas_used", txResp.GasUsed). Uint64("sequence_used", s.lastSequence-1). - Msg("Transaction broadcasted and executed successfully") + Msg("Transaction broadcasted successfully") return txResp, nil } @@ -404,8 +395,8 @@ func (s *Signer) getAccountInfo(ctx context.Context) (client.Account, error) { return account, nil } -func createClientContext(kr cosmoskeyring.Keyring, chainID string) (client.Context, error) { - interfaceRegistry := keysv2.CreateInterfaceRegistryWithEVMSupport() +func createClientContext(kr cosmoskeyring.Keyring, chainID string) client.Context { + interfaceRegistry := keys.NewInterfaceRegistryWithEVMSupport() cosmosauthz.RegisterInterfaces(interfaceRegistry) authtypes.RegisterInterfaces(interfaceRegistry) banktypes.RegisterInterfaces(interfaceRegistry) @@ -416,12 +407,10 @@ func createClientContext(kr cosmoskeyring.Keyring, chainID string) (client.Conte cdc := codec.NewProtoCodec(interfaceRegistry) txConfig := authtx.NewTxConfig(cdc, []signing.SignMode{signing.SignMode_SIGN_MODE_DIRECT}) - clientCtx := client.Context{}. + return client.Context{}. WithCodec(cdc). WithInterfaceRegistry(interfaceRegistry). WithChainID(chainID). WithKeyring(kr). WithTxConfig(txConfig) - - return clientCtx, nil } diff --git a/universalClient/pushsigner/pushsigner_test.go b/universalClient/pushsigner/pushsigner_test.go index 5c59eb76..4e1958b6 100644 --- a/universalClient/pushsigner/pushsigner_test.go +++ b/universalClient/pushsigner/pushsigner_test.go @@ -1,26 +1,30 @@ package pushsigner import ( + "context" + "fmt" "os" "testing" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" + sdktx "github.com/cosmos/cosmos-sdk/types/tx" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + cosmosauthz "github.com/cosmos/cosmos-sdk/x/authz" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/pushchain/push-chain-node/universalClient/config" "github.com/pushchain/push-chain-node/universalClient/pushcore" - keysv2 "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" - uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" + "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" ) func TestMain(m *testing.M) { - // Initialize SDK config for tests sdkConfig := sdk.GetConfig() func() { defer func() { - _ = recover() // Ignore panic if already sealed + _ = recover() }() sdkConfig.SetBech32PrefixForAccount("push", "pushpub") sdkConfig.SetBech32PrefixForValidator("pushvaloper", "pushvaloperpub") @@ -31,13 +35,87 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -// createMockPushCoreClient creates a minimal pushcore.Client for testing. -// Since pushcore.Client is a concrete struct, we create an empty one -// and tests will need to handle the actual gRPC calls appropriately. +// --- mock chainClient --- + +type mockChainClient struct { + broadcastTxFn func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) + getAccountFn func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) + getGranteeGrantFn func(ctx context.Context, addr string) (*cosmosauthz.QueryGranteeGrantsResponse, error) +} + +func (m *mockChainClient) BroadcastTx(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + if m.broadcastTxFn != nil { + return m.broadcastTxFn(ctx, txBytes) + } + return nil, fmt.Errorf("BroadcastTx not mocked") +} + +func (m *mockChainClient) GetAccount(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + if m.getAccountFn != nil { + return m.getAccountFn(ctx, address) + } + return nil, fmt.Errorf("GetAccount not mocked") +} + +func (m *mockChainClient) GetGranteeGrants(ctx context.Context, addr string) (*cosmosauthz.QueryGranteeGrantsResponse, error) { + if m.getGranteeGrantFn != nil { + return m.getGranteeGrantFn(ctx, addr) + } + return nil, fmt.Errorf("GetGranteeGrants not mocked") +} + +// --- helpers --- + func createMockPushCoreClient() *pushcore.Client { return &pushcore.Client{} } +// createTestSigner creates a Signer with a real keyring and mock chainClient for testing. +func createTestSigner(t *testing.T, mock *mockChainClient) *Signer { + t.Helper() + + tempDir, err := os.MkdirTemp("", "signer-test") + require.NoError(t, err) + t.Cleanup(func() { os.RemoveAll(tempDir) }) + + kr, err := keys.CreateKeyring(tempDir, nil, config.KeyringBackendTest) + require.NoError(t, err) + + record, _, err := keys.CreateNewKey(kr, "test-key", "", "") + require.NoError(t, err) + + k := keys.NewKeys(kr, record.Name) + clientCtx := createClientContext(kr, "test-chain") + + return &Signer{ + keys: k, + clientCtx: clientCtx, + pushCore: mock, + granter: "push1granter", + log: zerolog.New(zerolog.NewTestWriter(t)), + } +} + +// makeAccountResponse creates a mock QueryAccountResponse with the given sequence and account number. +func makeAccountResponse(t *testing.T, address sdk.AccAddress, seq, accNum uint64) *authtypes.QueryAccountResponse { + t.Helper() + + baseAccount := &authtypes.BaseAccount{ + Address: address.String(), + AccountNumber: accNum, + Sequence: seq, + } + + anyAccount, err := codectypes.NewAnyWithValue(baseAccount) + require.NoError(t, err) + + return &authtypes.QueryAccountResponse{ + Account: anyAccount, + } +} + +// --- New() tests --- + func TestNew(t *testing.T) { logger := zerolog.Nop() @@ -46,28 +124,18 @@ func TestNew(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(tempDir) - cfg := &config.Config{ - KeyringBackend: config.KeyringBackendTest, - KeyringPassword: "", - } - mockCore := createMockPushCoreClient() - signer, err := New(logger, cfg.KeyringBackend, cfg.KeyringPassword, "", mockCore, "test-chain", "cosmos1granter") + signer, err := New(logger, config.KeyringBackendTest, "", "", mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) assert.Contains(t, err.Error(), "PushSigner validation failed") }) t.Run("validation failure - keyring creation fails", func(t *testing.T) { - cfg := &config.Config{ - KeyringBackend: config.KeyringBackendFile, - KeyringPassword: "", // Missing password for file backend - } - mockCore := createMockPushCoreClient() - signer, err := New(logger, cfg.KeyringBackend, cfg.KeyringPassword, "", mockCore, "test-chain", "cosmos1granter") + signer, err := New(logger, config.KeyringBackendFile, "", "", mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) assert.Contains(t, err.Error(), "keyring_password is required for file backend") @@ -78,126 +146,378 @@ func TestNew(t *testing.T) { require.NoError(t, err) defer os.RemoveAll(tempDir) - // Create keyring and add a key - kr, err := keysv2.CreateKeyring(tempDir, nil, keysv2.KeyringBackendTest) + kr, err := keys.CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - _, _, err = keysv2.CreateNewKey(kr, "test-key", "", "") + _, _, err = keys.CreateNewKey(kr, "test-key", "", "") require.NoError(t, err) - cfg := &config.Config{ - KeyringBackend: config.KeyringBackendTest, - KeyringPassword: "", - } - mockCore := createMockPushCoreClient() - // This will fail because GetGranteeGrants will fail (no real gRPC connection) - signer, err := New(logger, cfg.KeyringBackend, cfg.KeyringPassword, tempDir, mockCore, "test-chain", "cosmos1granter") + signer, err := New(logger, config.KeyringBackendTest, "", tempDir, mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) - // Error will be from GetGranteeGrants failing }) } +// --- Keys tests --- + func TestSigner_GetKeyring(t *testing.T) { tempDir, err := os.MkdirTemp("", "test-signer") require.NoError(t, err) defer os.RemoveAll(tempDir) - kr, err := keysv2.CreateKeyring(tempDir, nil, keysv2.KeyringBackendTest) + kr, err := keys.CreateKeyring(tempDir, nil, config.KeyringBackendTest) require.NoError(t, err) - record, _, err := keysv2.CreateNewKey(kr, "test-key", "", "") + record, _, err := keys.CreateNewKey(kr, "test-key", "", "") require.NoError(t, err) - keys := keysv2.NewKeys(kr, record.Name, "") + k := keys.NewKeys(kr, record.Name) t.Run("valid key", func(t *testing.T) { - keyring, err := keys.GetKeyring() + keyring, err := k.GetKeyring() require.NoError(t, err) assert.NotNil(t, keyring) }) t.Run("invalid key", func(t *testing.T) { - invalidKeys := keysv2.NewKeys(kr, "non-existent-key", "") - keyring, err := invalidKeys.GetKeyring() + invalidK := keys.NewKeys(kr, "non-existent-key") + keyring, err := invalidK.GetKeyring() require.Error(t, err) assert.Nil(t, keyring) assert.Contains(t, err.Error(), "not found in keyring") }) } -// TestSigner_VoteInbound tests the VoteInbound method signature. -// Full integration tests would require a complete setup with real keyring, pushcore client, etc. -func TestSigner_VoteInbound(t *testing.T) { - // This test verifies the method exists and has the correct signature. - // Full testing requires integration test setup with real dependencies. - t.Run("method exists", func(t *testing.T) { - // Verify the method signature by checking it compiles - var signer *Signer - var inbound *uexecutortypes.Inbound - _ = signer - _ = inbound - // Method signature: VoteInbound(ctx context.Context, inbound *uexecutortypes.Inbound) (string, error) - assert.True(t, true) +// --- AuthZ wrapping tests --- + +func TestWrapWithAuthZ(t *testing.T) { + mock := &mockChainClient{} + signer := createTestSigner(t, mock) + + t.Run("empty messages returns error", func(t *testing.T) { + msgs, err := signer.wrapWithAuthZ(nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no messages to wrap") + assert.Nil(t, msgs) + }) + + t.Run("wraps messages with MsgExec", func(t *testing.T) { + innerMsg := &cosmosauthz.MsgExec{} + msgs, err := signer.wrapWithAuthZ([]sdk.Msg{innerMsg}) + require.NoError(t, err) + require.Len(t, msgs, 1) + + exec, ok := msgs[0].(*cosmosauthz.MsgExec) + require.True(t, ok, "expected MsgExec wrapper") + assert.Len(t, exec.Msgs, 1) }) } -// TestSigner_VoteChainMeta tests the VoteChainMeta method signature. -func TestSigner_VoteChainMeta(t *testing.T) { - t.Run("method exists", func(t *testing.T) { - // Method signature: VoteChainMeta(ctx context.Context, chainID string, price uint64, chainHeight uint64) (string, error) - assert.True(t, true) +// --- TxBuilder tests --- + +func TestCreateTxBuilder(t *testing.T) { + mock := &mockChainClient{} + signer := createTestSigner(t, mock) + + t.Run("creates tx builder with params", func(t *testing.T) { + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + txBuilder, err := signer.createTxBuilder( + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test memo", + 200000, + fee, + ) + require.NoError(t, err) + require.NotNil(t, txBuilder) + + builtTx := txBuilder.GetTx() + assert.Equal(t, "test memo", builtTx.GetMemo()) + assert.Equal(t, uint64(200000), builtTx.GetGas()) + assert.Equal(t, fee, builtTx.GetFee()) }) } -// TestSigner_VoteOutbound tests the VoteOutbound method signature. -func TestSigner_VoteOutbound(t *testing.T) { - t.Run("method exists with correct signature", func(t *testing.T) { - // Method signature: VoteOutbound(ctx context.Context, txID string, utxID string, observation *uexecutortypes.OutboundObservation) (string, error) - // Verify the method signature by checking it compiles with the correct parameters - var signer *Signer - var txID string = "tx-123" - var utxID string = "utx-456" - var observation *uexecutortypes.OutboundObservation - _ = signer - _ = txID - _ = utxID - _ = observation - assert.True(t, true) - }) - - t.Run("observation struct has required fields", func(t *testing.T) { - observation := &uexecutortypes.OutboundObservation{ - Success: true, - BlockHeight: 12345, - TxHash: "0xabc123", - ErrorMsg: "", +// --- Account info tests --- + +func TestGetAccountInfo(t *testing.T) { + t.Run("returns account from chain", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 42, 7), nil + }, } - assert.True(t, observation.Success) - assert.Equal(t, uint64(12345), observation.BlockHeight) - assert.Equal(t, "0xabc123", observation.TxHash) - assert.Equal(t, "", observation.ErrorMsg) - }) - - t.Run("observation for failed transaction", func(t *testing.T) { - observation := &uexecutortypes.OutboundObservation{ - Success: false, - BlockHeight: 0, - TxHash: "", - ErrorMsg: "transaction failed: insufficient funds", + + signer := createTestSigner(t, mock) + + account, err := signer.getAccountInfo(context.Background()) + require.NoError(t, err) + assert.Equal(t, uint64(42), account.GetSequence()) + assert.Equal(t, uint64(7), account.GetAccountNumber()) + }) + + t.Run("returns error on chain failure", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + return nil, fmt.Errorf("node unavailable") + }, } - assert.False(t, observation.Success) - assert.Equal(t, uint64(0), observation.BlockHeight) - assert.Equal(t, "transaction failed: insufficient funds", observation.ErrorMsg) + + signer := createTestSigner(t, mock) + + account, err := signer.getAccountInfo(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "node unavailable") + assert.Nil(t, account) }) } -// TestSigner_VoteTssKeyProcess tests the VoteTssKeyProcess method signature. -func TestSigner_VoteTssKeyProcess(t *testing.T) { - t.Run("method exists", func(t *testing.T) { - // Method signature: VoteTssKeyProcess(ctx context.Context, tssPubKey string, keyID string, processID uint64) (string, error) - assert.True(t, true) +// --- Sign and broadcast tests --- + +func TestSignAndBroadcast_SequenceManagement(t *testing.T) { + t.Run("successful broadcast increments sequence", func(t *testing.T) { + broadcastCalls := 0 + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 5, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + broadcastCalls++ + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "ABC123"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + assert.Equal(t, uint64(0), signer.lastSequence) + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + resp, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, "ABC123", resp.TxHash) + assert.Equal(t, uint64(6), signer.lastSequence) + assert.Equal(t, 1, broadcastCalls) + }) + + t.Run("network error resets sequence to 0", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 10, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return nil, fmt.Errorf("connection refused") + }, + } + + signer := createTestSigner(t, mock) + signer.lastSequence = 10 + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + _, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "connection refused") + assert.Equal(t, uint64(0), signer.lastSequence) + }) + + t.Run("sequence mismatch error retries with refresh", func(t *testing.T) { + attempt := 0 + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + if attempt > 0 { + return makeAccountResponse(t, addr, 7, 1), nil + } + return makeAccountResponse(t, addr, 5, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + attempt++ + if attempt == 1 { + return nil, fmt.Errorf("account sequence mismatch: expected 7, got 5") + } + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "RETRY_OK"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + resp, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + assert.Equal(t, "RETRY_OK", resp.TxHash) + assert.Equal(t, 2, attempt) + assert.Equal(t, uint64(8), signer.lastSequence) + }) + + t.Run("on-chain error increments sequence", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 3, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 5, TxHash: "FAILED_TX", RawLog: "insufficient funds"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + resp, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.Error(t, err) + assert.Contains(t, err.Error(), "insufficient funds") + assert.NotNil(t, resp) + assert.Equal(t, uint64(4), signer.lastSequence) + }) + + t.Run("on-chain sequence mismatch retries", func(t *testing.T) { + attempt := 0 + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + if attempt > 0 { + return makeAccountResponse(t, addr, 9, 1), nil + } + return makeAccountResponse(t, addr, 5, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + attempt++ + if attempt == 1 { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 32, RawLog: "account sequence mismatch: expected 9, got 5"}, + }, nil + } + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "OK_AFTER_RETRY"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + resp, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + assert.Equal(t, "OK_AFTER_RETRY", resp.TxHash) + assert.Equal(t, 2, attempt) + }) +} + +// --- Sequence reconciliation tests --- + +func TestSequenceReconciliation(t *testing.T) { + t.Run("local=0 adopts chain sequence", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 15, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "OK"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + signer.lastSequence = 0 + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + _, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + assert.Equal(t, uint64(16), signer.lastSequence) + }) + + t.Run("local behind chain adopts chain", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 20, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "OK"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + signer.lastSequence = 10 + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + _, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + assert.Equal(t, uint64(21), signer.lastSequence) + }) + + t.Run("local ahead of chain keeps local", func(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 5, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "OK"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + signer.lastSequence = 8 + + fee := sdk.NewCoins(sdk.NewInt64Coin("upc", 1000)) + _, err := signer.signAndBroadcastAuthZTx( + context.Background(), + []sdk.Msg{&cosmosauthz.MsgExec{}}, + "test", 200000, fee, + ) + + require.NoError(t, err) + assert.Equal(t, uint64(9), signer.lastSequence) }) } From 3caba4eddff9cb3a0bcad2baa1e952a6ed4c66b4 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:14:09 +0530 Subject: [PATCH 08/28] remove unecessary tests --- universalClient/pushsigner/vote_test.go | 248 ------------------------ 1 file changed, 248 deletions(-) diff --git a/universalClient/pushsigner/vote_test.go b/universalClient/pushsigner/vote_test.go index 48981ab4..3a302880 100644 --- a/universalClient/pushsigner/vote_test.go +++ b/universalClient/pushsigner/vote_test.go @@ -7,9 +7,6 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" - utsstypes "github.com/pushchain/push-chain-node/x/utss/types" ) func TestVoteConstants(t *testing.T) { @@ -21,254 +18,9 @@ func TestVoteConstants(t *testing.T) { coins, err := sdk.ParseCoinsNormalized(defaultFeeAmount) require.NoError(t, err) assert.False(t, coins.IsZero()) - assert.Equal(t, "500000000000000upc", defaultFeeAmount) }) t.Run("default vote timeout", func(t *testing.T) { assert.Equal(t, 30*time.Second, defaultVoteTimeout) }) } - -func TestMsgVoteInboundConstruction(t *testing.T) { - t.Run("construct valid MsgVoteInbound", func(t *testing.T) { - granter := "push1granter123" - inbound := &uexecutortypes.Inbound{ - TxHash: "0x123abc", - SourceChain: "eip155:1", - Sender: "0xsender", - Recipient: "push1receiver", - Amount: "1000000", - } - - msg := &uexecutortypes.MsgVoteInbound{ - Signer: granter, - Inbound: inbound, - } - - assert.Equal(t, granter, msg.Signer) - assert.Equal(t, inbound, msg.Inbound) - assert.Equal(t, "0x123abc", msg.Inbound.TxHash) - }) - - t.Run("MsgVoteInbound with nil inbound", func(t *testing.T) { - msg := &uexecutortypes.MsgVoteInbound{ - Signer: "push1granter123", - Inbound: nil, - } - assert.Nil(t, msg.Inbound) - }) -} - -func TestMsgVoteChainMetaConstruction(t *testing.T) { - t.Run("construct valid MsgVoteChainMeta", func(t *testing.T) { - granter := "push1granter123" - chainID := "eip155:1" - price := uint64(20000000000) - chainHeight := uint64(18500000) - - msg := &uexecutortypes.MsgVoteChainMeta{ - Signer: granter, - ObservedChainId: chainID, - Price: price, - ChainHeight: chainHeight, - } - - assert.Equal(t, granter, msg.Signer) - assert.Equal(t, chainID, msg.ObservedChainId) - assert.Equal(t, price, msg.Price) - assert.Equal(t, chainHeight, msg.ChainHeight) - }) - - t.Run("MsgVoteChainMeta with zero values", func(t *testing.T) { - msg := &uexecutortypes.MsgVoteChainMeta{ - Signer: "push1granter123", - ObservedChainId: "eip155:1", - Price: 0, - ChainHeight: 0, - } - assert.Equal(t, uint64(0), msg.Price) - assert.Equal(t, uint64(0), msg.ChainHeight) - }) -} - -func TestMsgVoteOutboundConstruction(t *testing.T) { - t.Run("construct valid MsgVoteOutbound for successful tx", func(t *testing.T) { - granter := "push1granter123" - txID := "tx-123" - utxID := "utx-456" - observation := &uexecutortypes.OutboundObservation{ - Success: true, - BlockHeight: 18500000, - TxHash: "0xabc123def456", - ErrorMsg: "", - } - - msg := &uexecutortypes.MsgVoteOutbound{ - Signer: granter, - TxId: txID, - UtxId: utxID, - ObservedTx: observation, - } - - assert.Equal(t, granter, msg.Signer) - assert.Equal(t, txID, msg.TxId) - assert.Equal(t, utxID, msg.UtxId) - assert.True(t, msg.ObservedTx.Success) - assert.Equal(t, "0xabc123def456", msg.ObservedTx.TxHash) - }) - - t.Run("construct valid MsgVoteOutbound for failed tx", func(t *testing.T) { - observation := &uexecutortypes.OutboundObservation{ - Success: false, - BlockHeight: 0, - TxHash: "", - ErrorMsg: "execution reverted: insufficient balance", - } - - msg := &uexecutortypes.MsgVoteOutbound{ - Signer: "push1granter123", - TxId: "tx-789", - UtxId: "utx-101", - ObservedTx: observation, - } - - assert.False(t, msg.ObservedTx.Success) - assert.Empty(t, msg.ObservedTx.TxHash) - assert.Contains(t, msg.ObservedTx.ErrorMsg, "insufficient balance") - }) - - t.Run("MsgVoteOutbound with nil observation", func(t *testing.T) { - msg := &uexecutortypes.MsgVoteOutbound{ - Signer: "push1granter123", - TxId: "tx-123", - UtxId: "utx-456", - ObservedTx: nil, - } - assert.Nil(t, msg.ObservedTx) - }) -} - -func TestMsgVoteTssKeyProcessConstruction(t *testing.T) { - t.Run("construct valid MsgVoteTssKeyProcess", func(t *testing.T) { - granter := "push1granter123" - tssPubKey := "tsspub1abc123" - keyID := "key-001" - processID := uint64(42) - - msg := &utsstypes.MsgVoteTssKeyProcess{ - Signer: granter, - TssPubkey: tssPubKey, - KeyId: keyID, - ProcessId: processID, - } - - assert.Equal(t, granter, msg.Signer) - assert.Equal(t, tssPubKey, msg.TssPubkey) - assert.Equal(t, keyID, msg.KeyId) - assert.Equal(t, processID, msg.ProcessId) - }) - - t.Run("MsgVoteTssKeyProcess with empty strings", func(t *testing.T) { - msg := &utsstypes.MsgVoteTssKeyProcess{ - Signer: "", - TssPubkey: "", - KeyId: "", - ProcessId: 0, - } - assert.Empty(t, msg.Signer) - assert.Empty(t, msg.TssPubkey) - assert.Empty(t, msg.KeyId) - }) -} - -func TestOutboundObservation(t *testing.T) { - t.Run("successful observation fields", func(t *testing.T) { - obs := &uexecutortypes.OutboundObservation{ - Success: true, - BlockHeight: 12345678, - TxHash: "0x1234567890abcdef", - ErrorMsg: "", - } - - assert.True(t, obs.Success) - assert.Equal(t, uint64(12345678), obs.BlockHeight) - assert.Equal(t, "0x1234567890abcdef", obs.TxHash) - assert.Empty(t, obs.ErrorMsg) - }) - - t.Run("failed observation fields", func(t *testing.T) { - obs := &uexecutortypes.OutboundObservation{ - Success: false, - BlockHeight: 0, - TxHash: "", - ErrorMsg: "transaction failed: nonce too low", - } - - assert.False(t, obs.Success) - assert.Equal(t, uint64(0), obs.BlockHeight) - assert.Empty(t, obs.TxHash) - assert.NotEmpty(t, obs.ErrorMsg) - }) -} - -func TestInbound(t *testing.T) { - t.Run("inbound struct fields", func(t *testing.T) { - inbound := &uexecutortypes.Inbound{ - TxHash: "0xabc123", - SourceChain: "eip155:97", - Sender: "0x1234567890123456789012345678901234567890", - Recipient: "push1receiver123", - Amount: "1000000000000000000", - } - - assert.Equal(t, "0xabc123", inbound.TxHash) - assert.Equal(t, "eip155:97", inbound.SourceChain) - assert.NotEmpty(t, inbound.Sender) - assert.NotEmpty(t, inbound.Recipient) - assert.NotEmpty(t, inbound.Amount) - }) - - t.Run("inbound with zero amount", func(t *testing.T) { - inbound := &uexecutortypes.Inbound{ - TxHash: "0xdef456", - SourceChain: "eip155:1", - Sender: "0xsender", - Recipient: "push1receiver", - Amount: "0", - } - - assert.Equal(t, "0", inbound.Amount) - }) -} - -func TestVoteMemoFormats(t *testing.T) { - t.Run("inbound vote memo format", func(t *testing.T) { - inbound := &uexecutortypes.Inbound{ - TxHash: "0x123abc456def", - } - expectedMemo := "Vote inbound: 0x123abc456def" - actualMemo := "Vote inbound: " + inbound.TxHash - assert.Equal(t, expectedMemo, actualMemo) - }) - - t.Run("chain meta vote memo format", func(t *testing.T) { - chainID := "eip155:1" - expectedMemo := "Vote chain meta: eip155:1 @ price=25000000000 height=18500000" - actualMemo := "Vote chain meta: " + chainID + " @ price=25000000000 height=18500000" - assert.Equal(t, expectedMemo, actualMemo) - }) - - t.Run("outbound vote memo format", func(t *testing.T) { - txID := "tx-12345" - expectedMemo := "Vote outbound: tx-12345" - actualMemo := "Vote outbound: " + txID - assert.Equal(t, expectedMemo, actualMemo) - }) - - t.Run("tss key vote memo format", func(t *testing.T) { - keyID := "key-001" - expectedMemo := "Vote TSS key: key-001" - actualMemo := "Vote TSS key: " + keyID - assert.Equal(t, expectedMemo, actualMemo) - }) -} From 1f471e563a3de33a673e958fd27aa5269ab495ba Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:40:25 +0530 Subject: [PATCH 09/28] refactor: config package --- universalClient/config/config.go | 202 +++++++------ universalClient/config/config_test.go | 401 +++++++++++--------------- universalClient/config/types.go | 120 +++----- 3 files changed, 316 insertions(+), 407 deletions(-) diff --git a/universalClient/config/config.go b/universalClient/config/config.go index 009eab14..04625695 100644 --- a/universalClient/config/config.go +++ b/universalClient/config/config.go @@ -1,3 +1,16 @@ +// Package config provides configuration loading, validation, and persistence +// for the Push Universal Validator. +// +// Directory layout: +// +// / (default: ~/.puniversal) +// ├── config/ +// │ └── pushuv_config.json +// ├── databases/ +// │ ├── eip155_1.db +// │ └── eip155_97.db +// └── relayer/ +// └── .json package config import ( @@ -6,155 +19,134 @@ import ( "fmt" "os" "path/filepath" +) - "github.com/pushchain/push-chain-node/universalClient/constant" +const ( + NodeDir = ".puniversal" + ConfigSubdir = "config" + ConfigFileName = "pushuv_config.json" + DatabasesSubdir = "databases" + RelayerSubdir = "relayer" ) +// DefaultNodeHome returns the default node home directory (~/.puniversal). +func DefaultNodeHome() string { + return os.ExpandEnv("$HOME/") + NodeDir +} + //go:embed default_config.json var defaultConfigJSON []byte -// LoadDefaultConfig loads the default configuration from the embedded JSON +// LoadDefaultConfig loads the embedded default configuration. func LoadDefaultConfig() (Config, error) { var cfg Config if err := json.Unmarshal(defaultConfigJSON, &cfg); err != nil { return Config{}, fmt.Errorf("failed to unmarshal default config: %w", err) } - - // Validate the config (default config validates against itself) - if err := validateConfig(&cfg, nil); err != nil { + if err := validate(&cfg); err != nil { return Config{}, fmt.Errorf("invalid default config: %w", err) } - return cfg, nil } -func validateConfig(cfg *Config, defaultCfg *Config) error { - // Validate log level - if cfg.LogLevel < 0 || cfg.LogLevel > 5 { - return fmt.Errorf("log level must be between 0 and 5") - } - - // Validate log format - if cfg.LogFormat != "json" && cfg.LogFormat != "console" { - return fmt.Errorf("log format must be 'json' or 'console'") - } - - // Set defaults for registry config from default config - if cfg.ConfigRefreshIntervalSeconds == 0 && defaultCfg != nil { - cfg.ConfigRefreshIntervalSeconds = defaultCfg.ConfigRefreshIntervalSeconds - } - if cfg.MaxRetries == 0 && defaultCfg != nil { - cfg.MaxRetries = defaultCfg.MaxRetries - } - - // Set defaults for registry config from default config - if len(cfg.PushChainGRPCURLs) == 0 && defaultCfg != nil { - cfg.PushChainGRPCURLs = defaultCfg.PushChainGRPCURLs - } - - // Set defaults for query server from default config - if cfg.QueryServerPort == 0 && defaultCfg != nil { - cfg.QueryServerPort = defaultCfg.QueryServerPort - } - - // Set defaults and validate hot key management config - // Don't override if already set, just validate - if cfg.KeyringBackend != "" { - // Validate keyring backend - if cfg.KeyringBackend != KeyringBackendFile && cfg.KeyringBackend != KeyringBackendTest { - // Try to fix common case issues - if cfg.KeyringBackend == "test" || cfg.KeyringBackend == KeyringBackend("test") { - cfg.KeyringBackend = KeyringBackendTest - } else if cfg.KeyringBackend == "file" || cfg.KeyringBackend == KeyringBackend("file") { - cfg.KeyringBackend = KeyringBackendFile - } else { - return fmt.Errorf("keyring backend must be 'file' or 'test', got: %s", cfg.KeyringBackend) - } - } - } else if defaultCfg != nil { - cfg.KeyringBackend = defaultCfg.KeyringBackend - } - - // Initialize ChainConfigs if empty - if len(cfg.ChainConfigs) == 0 && defaultCfg != nil { - cfg.ChainConfigs = defaultCfg.ChainConfigs - } - - // Set NodeHome default - if cfg.NodeHome == "" { - cfg.NodeHome = constant.DefaultNodeHome +// Load reads the config from /config/pushuv_config.json, +// applies defaults for any missing fields, and validates. +func Load(basePath string) (Config, error) { + path := filepath.Join(basePath, ConfigSubdir, ConfigFileName) + data, err := os.ReadFile(filepath.Clean(path)) + if err != nil { + return Config{}, fmt.Errorf("failed to read config file: %w", err) } - // Set TSS defaults - if cfg.TSSP2PListen == "" { - cfg.TSSP2PListen = "/ip4/0.0.0.0/tcp/39000" + var cfg Config + if err := json.Unmarshal(data, &cfg); err != nil { + return Config{}, fmt.Errorf("failed to unmarshal config: %w", err) } - // Validate TSS config (TSS is always enabled) - // Skip TSS validation when defaultCfg is nil (validating default config itself) - if defaultCfg != nil { - // Set TSS defaults from default config if available - if cfg.TSSP2PPrivateKeyHex == "" && defaultCfg.TSSP2PPrivateKeyHex != "" { - cfg.TSSP2PPrivateKeyHex = defaultCfg.TSSP2PPrivateKeyHex - } - if cfg.TSSPassword == "" && defaultCfg.TSSPassword != "" { - cfg.TSSPassword = defaultCfg.TSSPassword - } + defaults, _ := LoadDefaultConfig() + applyDefaults(&cfg, &defaults) - // Validate required TSS fields - if cfg.TSSP2PPrivateKeyHex == "" { - return fmt.Errorf("tss_p2p_private_key_hex is required for TSS") - } - if cfg.TSSPassword == "" { - return fmt.Errorf("tss_password is required for TSS") - } + if err := validate(&cfg); err != nil { + return Config{}, fmt.Errorf("invalid config: %w", err) } - return nil + return cfg, nil } -// Save writes the given config to /config/pushuv_config.json. +// Save validates the config and writes it to /config/pushuv_config.json. func Save(cfg *Config, basePath string) error { - // Load default config for validation - defaultCfg, _ := LoadDefaultConfig() - if err := validateConfig(cfg, &defaultCfg); err != nil { + defaults, _ := LoadDefaultConfig() + applyDefaults(cfg, &defaults) + + if err := validate(cfg); err != nil { return fmt.Errorf("invalid config: %w", err) } - configDir := filepath.Join(basePath, constant.ConfigSubdir) - if err := os.MkdirAll(configDir, 0o750); err != nil { + dir := filepath.Join(basePath, ConfigSubdir) + if err := os.MkdirAll(dir, 0o750); err != nil { return fmt.Errorf("failed to create config directory: %w", err) } - configFile := filepath.Join(configDir, constant.ConfigFileName) data, err := json.MarshalIndent(cfg, "", " ") if err != nil { return fmt.Errorf("failed to marshal config: %w", err) } - if err := os.WriteFile(configFile, data, 0o600); err != nil { + if err := os.WriteFile(filepath.Join(dir, ConfigFileName), data, 0o600); err != nil { return fmt.Errorf("failed to write config file: %w", err) } return nil } -// Load reads and returns the config from /config/pushuv_config.json. -func Load(basePath string) (Config, error) { - configFile := filepath.Join(basePath, constant.ConfigSubdir, constant.ConfigFileName) - data, err := os.ReadFile(filepath.Clean(configFile)) - if err != nil { - return Config{}, fmt.Errorf("failed to read config file: %w", err) +// applyDefaults fills zero-valued fields in cfg from defaults. +func applyDefaults(cfg *Config, defaults *Config) { + if defaults == nil { + return } - var cfg Config - if err := json.Unmarshal(data, &cfg); err != nil { - return Config{}, fmt.Errorf("failed to unmarshal config: %w", err) + if cfg.NodeHome == "" { + cfg.NodeHome = DefaultNodeHome() } + if cfg.ConfigRefreshIntervalSeconds == 0 { + cfg.ConfigRefreshIntervalSeconds = defaults.ConfigRefreshIntervalSeconds + } + if cfg.MaxRetries == 0 { + cfg.MaxRetries = defaults.MaxRetries + } + if len(cfg.PushChainGRPCURLs) == 0 { + cfg.PushChainGRPCURLs = defaults.PushChainGRPCURLs + } + if cfg.QueryServerPort == 0 { + cfg.QueryServerPort = defaults.QueryServerPort + } + if cfg.KeyringBackend == "" { + cfg.KeyringBackend = defaults.KeyringBackend + } + if len(cfg.ChainConfigs) == 0 { + cfg.ChainConfigs = defaults.ChainConfigs + } + if cfg.TSSP2PListen == "" { + cfg.TSSP2PListen = "/ip4/0.0.0.0/tcp/39000" + } + if cfg.TSSP2PPrivateKeyHex == "" { + cfg.TSSP2PPrivateKeyHex = defaults.TSSP2PPrivateKeyHex + } + if cfg.TSSPassword == "" { + cfg.TSSPassword = defaults.TSSPassword + } +} - // Don't validate for now - let the config file values pass through - // if err := validateConfig(&cfg); err != nil { - // return Config{}, fmt.Errorf("invalid config: %w", err) - // } - - return cfg, nil +// validate checks that all config values are within acceptable ranges. +// It does NOT apply defaults — call applyDefaults first if needed. +func validate(cfg *Config) error { + if cfg.LogLevel < 0 || cfg.LogLevel > 5 { + return fmt.Errorf("log level must be between 0 and 5") + } + if cfg.LogFormat != "json" && cfg.LogFormat != "console" { + return fmt.Errorf("log format must be 'json' or 'console'") + } + if cfg.KeyringBackend != "" && cfg.KeyringBackend != KeyringBackendFile && cfg.KeyringBackend != KeyringBackendTest { + return fmt.Errorf("keyring backend must be 'file' or 'test', got: %s", cfg.KeyringBackend) + } + return nil } diff --git a/universalClient/config/config_test.go b/universalClient/config/config_test.go index f111b7e1..16f77c1f 100644 --- a/universalClient/config/config_test.go +++ b/universalClient/config/config_test.go @@ -6,211 +6,151 @@ import ( "path/filepath" "testing" - "github.com/pushchain/push-chain-node/universalClient/constant" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Test constants for TSS config (required fields) const ( testTSSPrivateKeyHex = "0101010101010101010101010101010101010101010101010101010101010101" testTSSPassword = "testpassword" ) -func TestConfigValidation(t *testing.T) { - // Test default settings - cfg := &Config{ - LogLevel: 1, - LogFormat: "console", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - } - - // Load default config for validation - defaultCfg, _ := LoadDefaultConfig() - - // This should set defaults and validate - err := validateConfig(cfg, &defaultCfg) - assert.NoError(t, err) +func applyDefaultsAndValidate(t *testing.T, cfg *Config) error { + t.Helper() + defaults, err := LoadDefaultConfig() + require.NoError(t, err) + applyDefaults(cfg, &defaults) + return validate(cfg) +} - // Check that defaults were set +func TestLoadDefaultConfig(t *testing.T) { + cfg, err := LoadDefaultConfig() + require.NoError(t, err) assert.NotZero(t, cfg.ConfigRefreshIntervalSeconds) assert.NotZero(t, cfg.MaxRetries) assert.NotZero(t, cfg.QueryServerPort) - assert.Equal(t, KeyringBackendTest, cfg.KeyringBackend) // Defaults to test when empty assert.NotEmpty(t, cfg.PushChainGRPCURLs) + assert.Equal(t, "console", cfg.LogFormat) } -func TestValidConfigScenarios(t *testing.T) { - tests := []struct { - name string - config Config - validate func(t *testing.T, cfg *Config) - }{ - { - name: "Valid config with all fields", - config: Config{ - LogLevel: 2, - LogFormat: "json", - ConfigRefreshIntervalSeconds: 30, - MaxRetries: 5, - PushChainGRPCURLs: []string{"localhost:9090"}, - QueryServerPort: 8080, - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, 2, cfg.LogLevel) - assert.Equal(t, "json", cfg.LogFormat) - }, - }, - { - name: "Valid config with console log format", - config: Config{ - LogLevel: 1, - LogFormat: "console", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, 1, cfg.LogLevel) - assert.Equal(t, "console", cfg.LogFormat) - }, - }, - { - name: "Config with defaults applied", - config: Config{ - LogLevel: 2, - LogFormat: "json", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - validate: func(t *testing.T, cfg *Config) { - // These should match the default config values - assert.Equal(t, 60, cfg.ConfigRefreshIntervalSeconds) // Default is 60 - assert.Equal(t, 3, cfg.MaxRetries) - assert.Equal(t, []string{"localhost:9090"}, cfg.PushChainGRPCURLs) - assert.Equal(t, 8080, cfg.QueryServerPort) - }, - }, - { - name: "Empty PushChainGRPCURLs gets default", - config: Config{ - LogLevel: 2, - LogFormat: "json", - PushChainGRPCURLs: []string{}, - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, []string{"localhost:9090"}, cfg.PushChainGRPCURLs) - }, - }, - { - name: "Zero QueryServerPort gets default", - config: Config{ - LogLevel: 2, - LogFormat: "json", - QueryServerPort: 0, - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - validate: func(t *testing.T, cfg *Config) { - assert.Equal(t, 8080, cfg.QueryServerPort) - }, - }, - } +func TestApplyDefaults(t *testing.T) { + t.Run("fills zero-valued fields", func(t *testing.T) { + cfg := &Config{ + LogLevel: 2, + LogFormat: "json", + TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, + TSSPassword: testTSSPassword, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - defaultCfg, _ := LoadDefaultConfig() - cfg := tt.config - err := validateConfig(&cfg, &defaultCfg) - assert.NoError(t, err) - if tt.validate != nil { - tt.validate(t, &cfg) - } - }) - } + err := applyDefaultsAndValidate(t, cfg) + require.NoError(t, err) + + defaults, _ := LoadDefaultConfig() + assert.Equal(t, defaults.ConfigRefreshIntervalSeconds, cfg.ConfigRefreshIntervalSeconds) + assert.Equal(t, defaults.MaxRetries, cfg.MaxRetries) + assert.Equal(t, defaults.PushChainGRPCURLs, cfg.PushChainGRPCURLs) + assert.Equal(t, defaults.QueryServerPort, cfg.QueryServerPort) + assert.Equal(t, defaults.KeyringBackend, cfg.KeyringBackend) + assert.NotEmpty(t, cfg.NodeHome) + assert.NotEmpty(t, cfg.TSSP2PListen) + }) + + t.Run("preserves explicit values", func(t *testing.T) { + cfg := &Config{ + LogLevel: 2, + LogFormat: "json", + ConfigRefreshIntervalSeconds: 30, + MaxRetries: 5, + PushChainGRPCURLs: []string{"custom:9090"}, + QueryServerPort: 9999, + KeyringBackend: KeyringBackendFile, + TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, + TSSPassword: testTSSPassword, + } + + err := applyDefaultsAndValidate(t, cfg) + require.NoError(t, err) + + assert.Equal(t, 30, cfg.ConfigRefreshIntervalSeconds) + assert.Equal(t, 5, cfg.MaxRetries) + assert.Equal(t, []string{"custom:9090"}, cfg.PushChainGRPCURLs) + assert.Equal(t, 9999, cfg.QueryServerPort) + assert.Equal(t, KeyringBackendFile, cfg.KeyringBackend) + }) + + t.Run("empty slice gets default", func(t *testing.T) { + cfg := &Config{ + LogLevel: 2, + LogFormat: "json", + PushChainGRPCURLs: []string{}, + TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, + TSSPassword: testTSSPassword, + } + + err := applyDefaultsAndValidate(t, cfg) + require.NoError(t, err) + + defaults, _ := LoadDefaultConfig() + assert.Equal(t, defaults.PushChainGRPCURLs, cfg.PushChainGRPCURLs) + }) } -func TestInvalidConfigValidation(t *testing.T) { +func TestValidate(t *testing.T) { tests := []struct { name string config Config errMsg string }{ { - name: "invalid log format", - config: Config{ - LogLevel: 1, - LogFormat: "invalid", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - errMsg: "log format must be 'json' or 'console'", + name: "valid minimal config", + config: Config{LogLevel: 1, LogFormat: "console"}, }, { - name: "invalid keyring backend", - config: Config{ - LogLevel: 1, - LogFormat: "console", - KeyringBackend: "invalid", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, - errMsg: "keyring backend must be 'file' or 'test'", + name: "valid json format", + config: Config{LogLevel: 0, LogFormat: "json"}, }, { - name: "Invalid log level (too high)", - config: Config{ - LogLevel: 6, - LogFormat: "json", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, + name: "valid file backend", + config: Config{LogLevel: 1, LogFormat: "console", KeyringBackend: KeyringBackendFile}, + }, + { + name: "log level too high", + config: Config{LogLevel: 6, LogFormat: "json"}, errMsg: "log level must be between 0 and 5", }, { - name: "Invalid log level (negative)", - config: Config{ - LogLevel: -1, - LogFormat: "json", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, + name: "log level negative", + config: Config{LogLevel: -1, LogFormat: "json"}, errMsg: "log level must be between 0 and 5", }, { - name: "Invalid log format xml", - config: Config{ - LogLevel: 2, - LogFormat: "xml", - TSSP2PPrivateKeyHex: testTSSPrivateKeyHex, - TSSPassword: testTSSPassword, - }, + name: "invalid log format", + config: Config{LogLevel: 1, LogFormat: "xml"}, errMsg: "log format must be 'json' or 'console'", }, + { + name: "invalid keyring backend", + config: Config{LogLevel: 1, LogFormat: "console", KeyringBackend: "invalid"}, + errMsg: "keyring backend must be 'file' or 'test'", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - defaultCfg, _ := LoadDefaultConfig() - cfg := tt.config - err := validateConfig(&cfg, &defaultCfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), tt.errMsg) + err := validate(&tt.config) + if tt.errMsg != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } }) } } func TestSaveAndLoad(t *testing.T) { - // Create a temporary directory for testing - tempDir, err := os.MkdirTemp("", "config_test") - require.NoError(t, err) - defer os.RemoveAll(tempDir) - - t.Run("Save and load valid config", func(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + dir := t.TempDir() cfg := &Config{ LogLevel: 3, LogFormat: "json", @@ -222,64 +162,47 @@ func TestSaveAndLoad(t *testing.T) { TSSPassword: testTSSPassword, } - // Save config - err := Save(cfg, tempDir) + err := Save(cfg, dir) require.NoError(t, err) - // Verify file exists - configPath := filepath.Join(tempDir, constant.ConfigSubdir, constant.ConfigFileName) - _, err = os.Stat(configPath) - assert.NoError(t, err) + assert.FileExists(t, filepath.Join(dir, ConfigSubdir, ConfigFileName)) - // Load config - loadedCfg, err := Load(tempDir) + loaded, err := Load(dir) require.NoError(t, err) - // Verify loaded config matches saved config - assert.Equal(t, cfg.LogLevel, loadedCfg.LogLevel) - assert.Equal(t, cfg.LogFormat, loadedCfg.LogFormat) - assert.Equal(t, cfg.ConfigRefreshIntervalSeconds, loadedCfg.ConfigRefreshIntervalSeconds) - assert.Equal(t, cfg.MaxRetries, loadedCfg.MaxRetries) - assert.Equal(t, cfg.PushChainGRPCURLs, loadedCfg.PushChainGRPCURLs) - assert.Equal(t, cfg.QueryServerPort, loadedCfg.QueryServerPort) + assert.Equal(t, cfg.LogLevel, loaded.LogLevel) + assert.Equal(t, cfg.LogFormat, loaded.LogFormat) + assert.Equal(t, cfg.ConfigRefreshIntervalSeconds, loaded.ConfigRefreshIntervalSeconds) + assert.Equal(t, cfg.MaxRetries, loaded.MaxRetries) + assert.Equal(t, cfg.PushChainGRPCURLs, loaded.PushChainGRPCURLs) + assert.Equal(t, cfg.QueryServerPort, loaded.QueryServerPort) }) - t.Run("Save invalid config", func(t *testing.T) { - cfg := &Config{ - LogLevel: -1, // Invalid - LogFormat: "json", - } - - err := Save(cfg, tempDir) - assert.Error(t, err) + t.Run("save invalid config fails", func(t *testing.T) { + err := Save(&Config{LogLevel: -1, LogFormat: "json"}, t.TempDir()) + require.Error(t, err) assert.Contains(t, err.Error(), "invalid config") }) - t.Run("Load from non-existent file", func(t *testing.T) { - nonExistentDir := filepath.Join(tempDir, "non_existent") - _, err := Load(nonExistentDir) - assert.Error(t, err) + t.Run("load non-existent file fails", func(t *testing.T) { + _, err := Load(filepath.Join(t.TempDir(), "nope")) + require.Error(t, err) assert.Contains(t, err.Error(), "failed to read config file") }) - t.Run("Load invalid JSON", func(t *testing.T) { - // Create config directory - configDir := filepath.Join(tempDir, "invalid", constant.ConfigSubdir) - err := os.MkdirAll(configDir, 0o750) - require.NoError(t, err) - - // Write invalid JSON - configPath := filepath.Join(configDir, constant.ConfigFileName) - err = os.WriteFile(configPath, []byte("{invalid json}"), 0o600) - require.NoError(t, err) + t.Run("load invalid JSON fails", func(t *testing.T) { + dir := t.TempDir() + configDir := filepath.Join(dir, ConfigSubdir) + require.NoError(t, os.MkdirAll(configDir, 0o750)) + require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFileName), []byte("{bad}"), 0o600)) - _, err = Load(filepath.Join(tempDir, "invalid")) - assert.Error(t, err) + _, err := Load(dir) + require.Error(t, err) assert.Contains(t, err.Error(), "failed to unmarshal config") }) - t.Run("Save with directory creation", func(t *testing.T) { - newDir := filepath.Join(tempDir, "new_dir") + t.Run("save creates directory", func(t *testing.T) { + dir := filepath.Join(t.TempDir(), "nested") cfg := &Config{ LogLevel: 2, LogFormat: "json", @@ -287,42 +210,62 @@ func TestSaveAndLoad(t *testing.T) { TSSPassword: testTSSPassword, } - err := Save(cfg, newDir) - require.NoError(t, err) - - // Verify directory was created - configDir := filepath.Join(newDir, constant.ConfigSubdir) - _, err = os.Stat(configDir) - assert.NoError(t, err) + require.NoError(t, Save(cfg, dir)) + assert.DirExists(t, filepath.Join(dir, ConfigSubdir)) }) } -func TestConfigJSONMarshaling(t *testing.T) { - t.Run("Marshal and unmarshal config", func(t *testing.T) { +func TestConfigJSONRoundTrip(t *testing.T) { + cfg := &Config{ + LogLevel: 2, + LogFormat: "console", + ConfigRefreshIntervalSeconds: 15, + MaxRetries: 3, + PushChainGRPCURLs: []string{"host1:9090", "host2:9090"}, + QueryServerPort: 8080, + } + + data, err := json.MarshalIndent(cfg, "", " ") + require.NoError(t, err) + + var loaded Config + require.NoError(t, json.Unmarshal(data, &loaded)) + + assert.Equal(t, cfg.LogLevel, loaded.LogLevel) + assert.Equal(t, cfg.LogFormat, loaded.LogFormat) + assert.Equal(t, cfg.PushChainGRPCURLs, loaded.PushChainGRPCURLs) +} + +func TestGetChainCleanupSettings(t *testing.T) { + cleanup := 1800 + retention := 43200 + + t.Run("returns settings", func(t *testing.T) { cfg := &Config{ - LogLevel: 2, - LogFormat: "console", - ConfigRefreshIntervalSeconds: 15, - MaxRetries: 3, - PushChainGRPCURLs: []string{"host1:9090", "host2:9090"}, - QueryServerPort: 8080, + ChainConfigs: map[string]ChainSpecificConfig{ + "eip155:1": {CleanupIntervalSeconds: &cleanup, RetentionPeriodSeconds: &retention}, + }, } - - // Marshal to JSON - data, err := json.MarshalIndent(cfg, "", " ") + c, r, err := cfg.GetChainCleanupSettings("eip155:1") require.NoError(t, err) + assert.Equal(t, 1800, c) + assert.Equal(t, 43200, r) + }) - // Unmarshal back - var unmarshaledCfg Config - err = json.Unmarshal(data, &unmarshaledCfg) - require.NoError(t, err) + t.Run("missing chain", func(t *testing.T) { + cfg := &Config{ChainConfigs: map[string]ChainSpecificConfig{}} + _, _, err := cfg.GetChainCleanupSettings("eip155:1") + require.Error(t, err) + }) - // Compare - assert.Equal(t, cfg.LogLevel, unmarshaledCfg.LogLevel) - assert.Equal(t, cfg.LogFormat, unmarshaledCfg.LogFormat) - assert.Equal(t, cfg.ConfigRefreshIntervalSeconds, unmarshaledCfg.ConfigRefreshIntervalSeconds) - assert.Equal(t, cfg.MaxRetries, unmarshaledCfg.MaxRetries) - assert.Equal(t, cfg.PushChainGRPCURLs, unmarshaledCfg.PushChainGRPCURLs) - assert.Equal(t, cfg.QueryServerPort, unmarshaledCfg.QueryServerPort) + t.Run("missing cleanup interval", func(t *testing.T) { + cfg := &Config{ + ChainConfigs: map[string]ChainSpecificConfig{ + "eip155:1": {RetentionPeriodSeconds: &retention}, + }, + } + _, _, err := cfg.GetChainCleanupSettings("eip155:1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cleanup_interval_seconds") }) } diff --git a/universalClient/config/types.go b/universalClient/config/types.go index 92c094de..0ba9bd83 100644 --- a/universalClient/config/types.go +++ b/universalClient/config/types.go @@ -2,103 +2,77 @@ package config import "fmt" -// KeyringBackend represents the type of keyring backend to use +// KeyringBackend represents the type of keyring backend to use. type KeyringBackend string const ( - // KeyringBackendTest is the test Cosmos keyring backend (unencrypted) KeyringBackendTest KeyringBackend = "test" - - // KeyringBackendFile is the file Cosmos keyring backend (encrypted) KeyringBackendFile KeyringBackend = "file" ) +// Config holds all configuration for the Universal Validator. type Config struct { - // Log Config - LogLevel int `json:"log_level"` // e.g., 0 = debug, 1 = info, etc. - LogFormat string `json:"log_format"` // "json" or "console" - LogSampler bool `json:"log_sampler"` // if true, samples logs (e.g., 1 in 5) - - // Node Config - NodeHome string `json:"node_home"` // Node home directory (default: ~/.puniversal) - - // Push Chain configuration - PushChainID string `json:"push_chain_id"` // Push Chain chain ID (default: localchain_9000-1) - PushChainGRPCURLs []string `json:"push_chain_grpc_urls"` // Push Chain gRPC endpoints (default: ["localhost:9090"]) - PushValoperAddress string `json:"push_valoper_address"` // Push Chain validator operator address (pushvaloper1...) - ConfigRefreshIntervalSeconds int `json:"config_refresh_interval_seconds"` // How often to refresh configs in seconds (default: 60) - MaxRetries int `json:"max_retries"` // Max retry attempts for registry queries (default: 3) - - // Query Server Config - QueryServerPort int `json:"query_server_port"` // Port for HTTP query server (default: 8080) - - // Keyring configuration - KeyringBackend KeyringBackend `json:"keyring_backend"` // Keyring backend type (file/test) - KeyringPassword string `json:"keyring_password"` // Password for file backend keyring encryption - - // Unified per-chain configuration - ChainConfigs map[string]ChainSpecificConfig `json:"chain_configs"` // Map of chain ID to all chain-specific settings - - // TSS Node configuration - TSSP2PPrivateKeyHex string `json:"tss_p2p_private_key_hex"` // Ed25519 private key in hex for libp2p identity - TSSP2PListen string `json:"tss_p2p_listen"` // libp2p listen address (default: /ip4/0.0.0.0/tcp/39000) - TSSPassword string `json:"tss_password"` // Encryption password for keyshares - TSSHomeDir string `json:"tss_home_dir"` // Keyshare storage directory (default: ~/.puniversal/tss) + // Logging + LogLevel int `json:"log_level"` + LogFormat string `json:"log_format"` + LogSampler bool `json:"log_sampler"` + + // Node + NodeHome string `json:"node_home"` + + // Push Chain + PushChainID string `json:"push_chain_id"` + PushChainGRPCURLs []string `json:"push_chain_grpc_urls"` + PushValoperAddress string `json:"push_valoper_address"` + ConfigRefreshIntervalSeconds int `json:"config_refresh_interval_seconds"` + MaxRetries int `json:"max_retries"` + + // Query Server + QueryServerPort int `json:"query_server_port"` + + // Keyring + KeyringBackend KeyringBackend `json:"keyring_backend"` + KeyringPassword string `json:"keyring_password"` + + // Per-chain settings (keyed by CAIP-2 chain ID) + ChainConfigs map[string]ChainSpecificConfig `json:"chain_configs"` + + // TSS + TSSP2PPrivateKeyHex string `json:"tss_p2p_private_key_hex"` + TSSP2PListen string `json:"tss_p2p_listen"` + TSSPassword string `json:"tss_password"` + TSSHomeDir string `json:"tss_home_dir"` } -// ChainSpecificConfig holds all chain-specific configuration in one place +// ChainSpecificConfig holds per-chain configuration. type ChainSpecificConfig struct { - // RPC Configuration - RPCURLs []string `json:"rpc_urls,omitempty"` // RPC endpoints for this chain - - // Transaction Cleanup Configuration - CleanupIntervalSeconds *int `json:"cleanup_interval_seconds,omitempty"` // How often to run cleanup for this chain (required) - RetentionPeriodSeconds *int `json:"retention_period_seconds,omitempty"` // How long to keep confirmed transactions for this chain (required) - - // Event Monitoring Configuration - EventPollingIntervalSeconds *int `json:"event_polling_interval_seconds,omitempty"` // How often to poll for new events for this chain (required) - - // Event Start Cursor - // If set to a non-negative value, gateway event watchers start from this - // block/slot for this chain. If set to -1 or not present, start from the - // latest block/slot (or from DB resume point when available). - EventStartFrom *int64 `json:"event_start_from,omitempty"` - - // Gas Oracle Configuration - GasPriceIntervalSeconds *int `json:"gas_price_interval_seconds,omitempty"` // How often to fetch and vote on gas price (default: 30 seconds) - - // Future chain-specific settings can be added here + RPCURLs []string `json:"rpc_urls,omitempty"` + CleanupIntervalSeconds *int `json:"cleanup_interval_seconds,omitempty"` + RetentionPeriodSeconds *int `json:"retention_period_seconds,omitempty"` + EventPollingIntervalSeconds *int `json:"event_polling_interval_seconds,omitempty"` + EventStartFrom *int64 `json:"event_start_from,omitempty"` + GasPriceIntervalSeconds *int `json:"gas_price_interval_seconds,omitempty"` } -// GetChainCleanupSettings returns cleanup settings for a specific chain -// Returns chain-specific settings (required per chain) +// GetChainCleanupSettings returns cleanup settings for a specific chain. func (c *Config) GetChainCleanupSettings(chainID string) (cleanupInterval, retentionPeriod int, err error) { - if c.ChainConfigs == nil { - return 0, 0, fmt.Errorf("no chain configs found") - } - - config, ok := c.ChainConfigs[chainID] + cc, ok := c.ChainConfigs[chainID] if !ok { return 0, 0, fmt.Errorf("no config found for chain %s", chainID) } - - if config.CleanupIntervalSeconds == nil { + if cc.CleanupIntervalSeconds == nil { return 0, 0, fmt.Errorf("cleanup_interval_seconds is required for chain %s", chainID) } - if config.RetentionPeriodSeconds == nil { + if cc.RetentionPeriodSeconds == nil { return 0, 0, fmt.Errorf("retention_period_seconds is required for chain %s", chainID) } - - return *config.CleanupIntervalSeconds, *config.RetentionPeriodSeconds, nil + return *cc.CleanupIntervalSeconds, *cc.RetentionPeriodSeconds, nil } -// GetChainConfig returns the complete configuration for a specific chain +// GetChainConfig returns the configuration for a specific chain, or an empty config if not found. func (c *Config) GetChainConfig(chainID string) *ChainSpecificConfig { - if c.ChainConfigs != nil { - if config, ok := c.ChainConfigs[chainID]; ok { - return &config - } + if cc, ok := c.ChainConfigs[chainID]; ok { + return &cc } - // Return empty config if not found return &ChainSpecificConfig{} } From 2706bf972adc02584f0acd97e8ee256db2f22202 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:40:41 +0530 Subject: [PATCH 10/28] fix: use config --- cmd/puniversald/commands.go | 13 ++++++------- cmd/puniversald/root.go | 4 ++-- universalClient/chains/chains.go | 3 +-- universalClient/chains/svm/tx_builder.go | 4 ++-- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/cmd/puniversald/commands.go b/cmd/puniversald/commands.go index 4403e921..3bff8197 100644 --- a/cmd/puniversald/commands.go +++ b/cmd/puniversald/commands.go @@ -12,8 +12,7 @@ import ( "github.com/libp2p/go-libp2p/core/peer" sdkversion "github.com/cosmos/cosmos-sdk/version" - "github.com/pushchain/push-chain-node/universalClient/config" - "github.com/pushchain/push-chain-node/universalClient/constant" + uvconfig "github.com/pushchain/push-chain-node/universalClient/config" "github.com/pushchain/push-chain-node/universalClient/core" "github.com/spf13/cobra" @@ -24,7 +23,7 @@ func InitRootCmd(rootCmd *cobra.Command) { rootCmd.AddCommand(versionCmd()) rootCmd.AddCommand(startCmd()) rootCmd.AddCommand(initCmd()) - rootCmd.AddCommand(cosmosevmcmd.KeyCommands(constant.DefaultNodeHome, true)) + rootCmd.AddCommand(cosmosevmcmd.KeyCommands(uvconfig.DefaultNodeHome(), true)) rootCmd.AddCommand(tssPeerIDCmd()) } @@ -54,17 +53,17 @@ This command creates a default configuration file at: You can edit this file to customize your universal validator settings.`, RunE: func(cmd *cobra.Command, args []string) error { // Load default config - defaultCfg, err := config.LoadDefaultConfig() + defaultCfg, err := uvconfig.LoadDefaultConfig() if err != nil { return fmt.Errorf("failed to load default config: %w", err) } // Save to config directory - if err := config.Save(&defaultCfg, constant.DefaultNodeHome); err != nil { + if err := uvconfig.Save(&defaultCfg, uvconfig.DefaultNodeHome()); err != nil { return fmt.Errorf("failed to save config: %w", err) } - configPath := fmt.Sprintf("%s/%s/%s", constant.DefaultNodeHome, constant.ConfigSubdir, constant.ConfigFileName) + configPath := fmt.Sprintf("%s/%s/%s", uvconfig.DefaultNodeHome(), uvconfig.ConfigSubdir, uvconfig.ConfigFileName) fmt.Printf("✅ Configuration file initialized at: %s\n", configPath) fmt.Println("You can now edit this file to customize your settings.") return nil @@ -79,7 +78,7 @@ func startCmd() *cobra.Command { Short: "Start the universal message handler", RunE: func(cmd *cobra.Command, args []string) error { // --- Step 1: Load config --- - loadedCfg, err := config.Load(constant.DefaultNodeHome) + loadedCfg, err := uvconfig.Load(uvconfig.DefaultNodeHome()) if err != nil { return fmt.Errorf("failed to load config: %w", err) } diff --git a/cmd/puniversald/root.go b/cmd/puniversald/root.go index 6f82fba7..90857345 100644 --- a/cmd/puniversald/root.go +++ b/cmd/puniversald/root.go @@ -12,7 +12,7 @@ import ( cosmosevmkeyring "github.com/cosmos/evm/crypto/keyring" "github.com/pushchain/push-chain-node/app" "github.com/pushchain/push-chain-node/app/params" - "github.com/pushchain/push-chain-node/universalClient/constant" + uvconfig "github.com/pushchain/push-chain-node/universalClient/config" "github.com/spf13/cobra" wasmkeeper "github.com/CosmWasm/wasmd/x/wasm/keeper" "cosmossdk.io/log" @@ -39,7 +39,7 @@ func NewRootCmd() *cobra.Command { WithLegacyAmino(encodingConfig.Amino). WithInput(os.Stdin). WithAccountRetriever(authtypes.AccountRetriever{}). - WithHomeDir(constant.DefaultNodeHome). + WithHomeDir(uvconfig.DefaultNodeHome()). WithBroadcastMode(flags.FlagBroadcastMode). WithKeyringOptions(cosmosevmkeyring.Option()). WithLedgerHasProtobuf(true). diff --git a/universalClient/chains/chains.go b/universalClient/chains/chains.go index 8947de8a..61197681 100644 --- a/universalClient/chains/chains.go +++ b/universalClient/chains/chains.go @@ -13,7 +13,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/chains/push" "github.com/pushchain/push-chain-node/universalClient/chains/svm" "github.com/pushchain/push-chain-node/universalClient/config" - "github.com/pushchain/push-chain-node/universalClient/constant" "github.com/pushchain/push-chain-node/universalClient/db" "github.com/pushchain/push-chain-node/universalClient/pushcore" "github.com/pushchain/push-chain-node/universalClient/pushsigner" @@ -415,7 +414,7 @@ func (c *Chains) getChainDB(chainID string) (*db.DB, error) { dbFilename := sanitizedChainID + ".db" // Derive database base directory from NodeHome - baseDir := filepath.Join(c.config.NodeHome, constant.DatabasesSubdir) + baseDir := filepath.Join(c.config.NodeHome, config.DatabasesSubdir) database, err := db.OpenFileDB(baseDir, dbFilename, true) if err != nil { diff --git a/universalClient/chains/svm/tx_builder.go b/universalClient/chains/svm/tx_builder.go index 25a2ec13..4b97305a 100644 --- a/universalClient/chains/svm/tx_builder.go +++ b/universalClient/chains/svm/tx_builder.go @@ -75,7 +75,7 @@ import ( "github.com/rs/zerolog" "github.com/pushchain/push-chain-node/universalClient/chains/common" - "github.com/pushchain/push-chain-node/universalClient/constant" + "github.com/pushchain/push-chain-node/universalClient/config" uetypes "github.com/pushchain/push-chain-node/x/uexecutor/types" ) @@ -1039,7 +1039,7 @@ func (tb *TxBuilder) loadRelayerKeypair() (solana.PrivateKey, error) { return nil, fmt.Errorf("empty namespace in chain ID: %s", tb.chainID) } - keyPath := filepath.Join(tb.nodeHome, constant.RelayerSubdir, namespace+".json") + keyPath := filepath.Join(tb.nodeHome, config.RelayerSubdir, namespace+".json") keyData, err := os.ReadFile(keyPath) if err != nil { From 9c229a58d101891bdaa6253ada14a83279f84573 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:40:53 +0530 Subject: [PATCH 11/28] fix: use config --- universalClient/core/client.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/universalClient/core/client.go b/universalClient/core/client.go index c0fbfa4b..c314938c 100644 --- a/universalClient/core/client.go +++ b/universalClient/core/client.go @@ -9,7 +9,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/api" "github.com/pushchain/push-chain-node/universalClient/chains" "github.com/pushchain/push-chain-node/universalClient/config" - "github.com/pushchain/push-chain-node/universalClient/constant" "github.com/pushchain/push-chain-node/universalClient/db" "github.com/pushchain/push-chain-node/universalClient/logger" "github.com/pushchain/push-chain-node/universalClient/pushcore" @@ -88,7 +87,7 @@ func NewUniversalClient(ctx context.Context, cfg *config.Config) (*UniversalClie sanitizedChainID := cfg.PushChainID // Replace colons and other special chars with underscores for filename dbFilename := sanitizedChainID + ".db" - baseDir := filepath.Join(cfg.NodeHome, constant.DatabasesSubdir) + baseDir := filepath.Join(cfg.NodeHome, config.DatabasesSubdir) pushDB, err := db.OpenFileDB(baseDir, dbFilename, true) if err != nil { return nil, fmt.Errorf("failed to create Push database: %w", err) From 4811111f0dffd835e767bc2a217528a36b76dd13 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:41:02 +0530 Subject: [PATCH 12/28] remove constant --- universalClient/constant/constant.go | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 universalClient/constant/constant.go diff --git a/universalClient/constant/constant.go b/universalClient/constant/constant.go deleted file mode 100644 index 3a84c2cb..00000000 --- a/universalClient/constant/constant.go +++ /dev/null @@ -1,23 +0,0 @@ -package constant - -import "os" - -// / (e.g., /home/universal/.puniversal) -// └── config/ -// └── pushuv_config.json -// └── databases/ -// └── eip155_1.db -// └── eip155_97.db - -const ( - NodeDir = ".puniversal" - - ConfigSubdir = "config" - ConfigFileName = "pushuv_config.json" - - DatabasesSubdir = "databases" - - RelayerSubdir = "relayer" -) - -var DefaultNodeHome = os.ExpandEnv("$HOME/") + NodeDir From 6e8675d16cfd8d9fc2c58239d3d863b887e2d3a0 Mon Sep 17 00:00:00 2001 From: aman035 Date: Wed, 18 Mar 2026 17:56:16 +0530 Subject: [PATCH 13/28] refactor: client --- universalClient/core/client.go | 193 ++++++++++++++-------------- universalClient/core/client_test.go | 63 ++++----- 2 files changed, 126 insertions(+), 130 deletions(-) diff --git a/universalClient/core/client.go b/universalClient/core/client.go index c314938c..44313231 100644 --- a/universalClient/core/client.go +++ b/universalClient/core/client.go @@ -1,9 +1,13 @@ +// Package core provides the top-level orchestrator for the Push Universal Validator. +// It wires together all subsystems (pushcore, pushsigner, chains, tss, api) +// and manages their lifecycle. package core import ( "context" "fmt" "path/filepath" + "strings" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/pushchain/push-chain-node/universalClient/api" @@ -17,6 +21,7 @@ import ( "github.com/rs/zerolog" ) +// UniversalClient is the top-level orchestrator that owns all subsystems. type UniversalClient struct { ctx context.Context log zerolog.Logger @@ -28,34 +33,25 @@ type UniversalClient struct { tssNode *tss.Node } +// NewUniversalClient creates and initializes all subsystems. +// It validates config, connects to Push Chain, sets up signing, chain watchers, and TSS. func NewUniversalClient(ctx context.Context, cfg *config.Config) (*UniversalClient, error) { if cfg == nil { - return nil, fmt.Errorf("Config is nil") + return nil, fmt.Errorf("config is nil") } - // Initialize logger log := logger.New(cfg.LogLevel, cfg.LogFormat, cfg.LogSampler) - // Initialize pushCore client pushCore, err := pushcore.New(cfg.PushChainGRPCURLs, log) if err != nil { return nil, fmt.Errorf("failed to create pushcore client: %w", err) } - // Convert valoper address to account address for grant validation - var granterAddr string - if cfg.PushValoperAddress != "" { - valAddr, err := sdk.ValAddressFromBech32(cfg.PushValoperAddress) - if err != nil { - return nil, fmt.Errorf("failed to parse valoper address %s: %w", cfg.PushValoperAddress, err) - } - // Convert validator address to account address (they share the same bytes) - accAddr := sdk.AccAddress(valAddr) - granterAddr = accAddr.String() + granterAddr, err := valoperToAccountAddr(cfg.PushValoperAddress) + if err != nil { + return nil, err } - // Initialize pushSigner (includes key and AuthZ validation) - // Grant validation will check grants against the granter address derived from valoper pushSigner, err := pushsigner.New( log, cfg.KeyringBackend, @@ -69,58 +65,15 @@ func NewUniversalClient(ctx context.Context, cfg *config.Config) (*UniversalClie return nil, fmt.Errorf("failed to create pushsigner: %w", err) } - // Initialize chains manager (fetches chain configs periodically and manages chain clients) - chainsManager := chains.NewChains( - pushCore, - pushSigner, - cfg, - log, - ) - - // Initialize TSS node - var tssNode *tss.Node - if cfg.PushValoperAddress != "" && cfg.TSSP2PPrivateKeyHex != "" { - log.Info().Msg("🔑 Initializing TSS node...") - - // Create push chain database - // Use the same approach as chains manager - sanitizedChainID := cfg.PushChainID - // Replace colons and other special chars with underscores for filename - dbFilename := sanitizedChainID + ".db" - baseDir := filepath.Join(cfg.NodeHome, config.DatabasesSubdir) - pushDB, err := db.OpenFileDB(baseDir, dbFilename, true) - if err != nil { - return nil, fmt.Errorf("failed to create Push database: %w", err) - } + chainsManager := chains.NewChains(pushCore, pushSigner, cfg, log) - tssCfg := tss.Config{ - ValidatorAddress: cfg.PushValoperAddress, - P2PPrivateKeyHex: cfg.TSSP2PPrivateKeyHex, - LibP2PListen: cfg.TSSP2PListen, - HomeDir: cfg.NodeHome, - Password: cfg.TSSPassword, - Database: pushDB, - PushCore: pushCore, - Logger: log, - Chains: chainsManager, - PushSigner: pushSigner, - } - - tssNode, err = tss.NewNode(ctx, tssCfg) - if err != nil { - return nil, fmt.Errorf("failed to create TSS node: %w", err) - } - - log.Info(). - Str("valoper", cfg.PushValoperAddress). - Str("p2p_listen", cfg.TSSP2PListen). - Msg("✅ TSS node initialized") + tssNode, err := initTSS(ctx, cfg, pushCore, chainsManager, pushSigner, log) + if err != nil { + return nil, err } - // Initialize query server queryServer := api.NewServer(log, cfg.QueryServerPort) - // Create and return UniversalClient with all components initialized return &UniversalClient{ ctx: ctx, log: log, @@ -133,72 +86,122 @@ func NewUniversalClient(ctx context.Context, cfg *config.Config) (*UniversalClie }, nil } +// Start launches all subsystems, blocks until ctx is cancelled, then shuts down. func (uc *UniversalClient) Start() error { - uc.log.Info().Msg("🚀 Starting universal client...") + uc.log.Info().Msg("Starting universal client...") - // Start chains manager (fetches chain configs periodically and manages chain clients) - if uc.chains != nil { - if err := uc.chains.Start(uc.ctx); err != nil { - uc.log.Error().Err(err).Msg("failed to start chains manager") - } else { - uc.log.Info().Msg("✅ Chains manager started") - } + if err := uc.chains.Start(uc.ctx); err != nil { + return fmt.Errorf("failed to start chains manager: %w", err) } - // Start the TSS node if enabled if uc.tssNode != nil { if err := uc.tssNode.Start(uc.ctx); err != nil { - uc.log.Error().Err(err).Msg("failed to start TSS node") - // Don't fail startup, TSS can recover - } else { - uc.log.Info(). - Str("peer_id", uc.tssNode.PeerID()). - Strs("listen_addrs", uc.tssNode.ListenAddrs()). - Msg("✅ TSS node started") + return fmt.Errorf("failed to start TSS node: %w", err) } + uc.log.Info(). + Str("peer_id", uc.tssNode.PeerID()). + Strs("listen_addrs", uc.tssNode.ListenAddrs()). + Msg("TSS node started") } - // Start the query server - if uc.queryServer != nil { - uc.log.Info().Int("port", uc.config.QueryServerPort).Msg("Starting query server...") - if err := uc.queryServer.Start(); err != nil { - return fmt.Errorf("failed to start query server: %w", err) - } - } else { - uc.log.Warn().Msg("Query server is nil, skipping start") + if err := uc.queryServer.Start(); err != nil { + return fmt.Errorf("failed to start query server: %w", err) } - uc.log.Info().Msg("✅ Initialization complete. Entering main loop...") + uc.log.Info().Msg("Initialization complete. Entering main loop...") <-uc.ctx.Done() - uc.log.Info().Msg("🛑 Shutting down universal client...") + uc.shutdown() + return nil +} + +// shutdown stops all subsystems in reverse startup order. +func (uc *UniversalClient) shutdown() { + uc.log.Info().Msg("Shutting down universal client...") - // Stop query server if err := uc.queryServer.Stop(); err != nil { uc.log.Error().Err(err).Msg("error stopping query server") } - // Stop TSS node if uc.tssNode != nil { if err := uc.tssNode.Stop(); err != nil { uc.log.Error().Err(err).Msg("error stopping TSS node") - } else { - uc.log.Info().Msg("✅ TSS node stopped") } } - // Stop chains manager (stops all chains and closes databases) if uc.chains != nil { uc.chains.Stop() } - // Close pushcore client if uc.pushCore != nil { if err := uc.pushCore.Close(); err != nil { uc.log.Error().Err(err).Msg("error closing pushcore client") } } +} - return nil +// valoperToAccountAddr converts a validator operator address to its account address. +func valoperToAccountAddr(valoper string) (string, error) { + if valoper == "" { + return "", fmt.Errorf("push_valoper_address is required") + } + valAddr, err := sdk.ValAddressFromBech32(valoper) + if err != nil { + return "", fmt.Errorf("failed to parse valoper address %s: %w", valoper, err) + } + return sdk.AccAddress(valAddr).String(), nil +} + +// initTSS creates and returns a TSS node if the config has the required fields. +// Returns nil node (not an error) if TSS config is absent. +func initTSS( + ctx context.Context, + cfg *config.Config, + pushCore *pushcore.Client, + chainsManager *chains.Chains, + pushSigner *pushsigner.Signer, + log zerolog.Logger, +) (*tss.Node, error) { + if cfg.PushValoperAddress == "" || cfg.TSSP2PPrivateKeyHex == "" { + return nil, nil + } + + log.Info().Msg("Initializing TSS node...") + + // Sanitize chain ID for use as a database filename (e.g. "push_42101-1" → "push_42101-1.db") + dbFilename := sanitizeForFilename(cfg.PushChainID) + ".db" + baseDir := filepath.Join(cfg.NodeHome, config.DatabasesSubdir) + pushDB, err := db.OpenFileDB(baseDir, dbFilename, true) + if err != nil { + return nil, fmt.Errorf("failed to create push database: %w", err) + } + + node, err := tss.NewNode(ctx, tss.Config{ + ValidatorAddress: cfg.PushValoperAddress, + P2PPrivateKeyHex: cfg.TSSP2PPrivateKeyHex, + LibP2PListen: cfg.TSSP2PListen, + HomeDir: cfg.NodeHome, + Password: cfg.TSSPassword, + Database: pushDB, + PushCore: pushCore, + Logger: log, + Chains: chainsManager, + PushSigner: pushSigner, + }) + if err != nil { + return nil, fmt.Errorf("failed to create TSS node: %w", err) + } + + log.Info(). + Str("valoper", cfg.PushValoperAddress). + Str("p2p_listen", cfg.TSSP2PListen). + Msg("TSS node initialized") + + return node, nil +} + +// sanitizeForFilename replaces characters that are problematic in filenames. +func sanitizeForFilename(s string) string { + return strings.ReplaceAll(s, ":", "_") } diff --git a/universalClient/core/client_test.go b/universalClient/core/client_test.go index bf6a2698..c0710d57 100644 --- a/universalClient/core/client_test.go +++ b/universalClient/core/client_test.go @@ -10,49 +10,40 @@ import ( ) func TestNewUniversalClient(t *testing.T) { - t.Run("fails with nil config", func(t *testing.T) { - ctx := context.Background() - - client, err := NewUniversalClient(ctx, nil) + t.Run("nil config", func(t *testing.T) { + client, err := NewUniversalClient(context.Background(), nil) require.Error(t, err) assert.Nil(t, client) - assert.Contains(t, err.Error(), "Config is nil") + assert.Contains(t, err.Error(), "config is nil") }) - t.Run("fails with empty PushChainGRPCURLs", func(t *testing.T) { - ctx := context.Background() - + t.Run("empty valoper address", func(t *testing.T) { cfg := &config.Config{ - PushChainGRPCURLs: []string{}, + PushChainGRPCURLs: []string{"localhost:9090"}, LogLevel: 1, LogFormat: "console", } - client, err := NewUniversalClient(ctx, cfg) + client, err := NewUniversalClient(context.Background(), cfg) require.Error(t, err) assert.Nil(t, client) - assert.Contains(t, err.Error(), "failed to create pushcore client") - assert.Contains(t, err.Error(), "at least one gRPC URL is required") + assert.Contains(t, err.Error(), "push_valoper_address is required") }) - t.Run("fails with nil PushChainGRPCURLs", func(t *testing.T) { - ctx := context.Background() - + t.Run("empty gRPC URLs", func(t *testing.T) { cfg := &config.Config{ - PushChainGRPCURLs: nil, + PushChainGRPCURLs: []string{}, LogLevel: 1, LogFormat: "console", } - client, err := NewUniversalClient(ctx, cfg) + client, err := NewUniversalClient(context.Background(), cfg) require.Error(t, err) assert.Nil(t, client) - assert.Contains(t, err.Error(), "failed to create pushcore client") + assert.Contains(t, err.Error(), "at least one gRPC URL is required") }) - t.Run("fails with invalid valoper address", func(t *testing.T) { - ctx := context.Background() - + t.Run("invalid valoper address", func(t *testing.T) { cfg := &config.Config{ PushChainGRPCURLs: []string{"localhost:9090"}, PushValoperAddress: "invalid-valoper-address", @@ -60,26 +51,28 @@ func TestNewUniversalClient(t *testing.T) { LogFormat: "console", } - client, err := NewUniversalClient(ctx, cfg) + client, err := NewUniversalClient(context.Background(), cfg) require.Error(t, err) assert.Nil(t, client) assert.Contains(t, err.Error(), "failed to parse valoper address") }) } -func TestUniversalClientStruct(t *testing.T) { - t.Run("struct has expected fields", func(t *testing.T) { - // Verify the UniversalClient struct has all expected fields - uc := &UniversalClient{} - assert.Nil(t, uc.ctx) - assert.Nil(t, uc.config) - assert.Nil(t, uc.queryServer) - assert.Nil(t, uc.pushCore) - assert.Nil(t, uc.pushSigner) - assert.Nil(t, uc.chains) - assert.Nil(t, uc.tssNode) +func TestValoperToAccountAddr(t *testing.T) { + t.Run("empty valoper returns error", func(t *testing.T) { + _, err := valoperToAccountAddr("") + require.Error(t, err) + assert.Contains(t, err.Error(), "push_valoper_address is required") + }) + + t.Run("invalid valoper returns error", func(t *testing.T) { + _, err := valoperToAccountAddr("garbage") + require.Error(t, err) }) } -// Note: Factory tests removed as OutboundTxBuilderFactory has been replaced -// with direct chain client access via Chains.GetClient() and ChainClient.GetTxBuilder() +func TestSanitizeForFilename(t *testing.T) { + assert.Equal(t, "eip155_1", sanitizeForFilename("eip155:1")) + assert.Equal(t, "push_42101-1", sanitizeForFilename("push_42101-1")) + assert.Equal(t, "solana_EtWTRABZaYq6", sanitizeForFilename("solana:EtWTRABZaYq6")) +} From 39391cbe07ab34c94f145e4930b249713fee5c72 Mon Sep 17 00:00:00 2001 From: aman035 Date: Thu, 19 Mar 2026 15:55:51 +0530 Subject: [PATCH 14/28] fix: add types to store --- universalClient/store/models.go | 37 +++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/universalClient/store/models.go b/universalClient/store/models.go index 54353f39..61bdc59a 100644 --- a/universalClient/store/models.go +++ b/universalClient/store/models.go @@ -1,15 +1,40 @@ -// Package store contains GORM-backed SQLite models used by the Universal Validator. +// Package store contains data models and enum constants for the Universal Validator. +// All event status, type, and confirmation type constants are defined here +// as the single source of truth — import from here, not from individual packages. package store import ( "gorm.io/gorm" ) -// Database Structure: -// -// {CHAIN_CAIP2_FORMAT}.db (e.g., "eip155:1.db") -// ├── states -// └── events +// Event status values. +const ( + StatusPending = "PENDING" // Observed on external chain, awaiting confirmations + StatusConfirmed = "CONFIRMED" // Confirmed (ready for processing or voting) + StatusInProgress = "IN_PROGRESS" // TSS signing in progress + StatusSigned = "SIGNED" // TSS signing done, tx not yet broadcast + StatusBroadcasted = "BROADCASTED" // Transaction sent to external chain + StatusCompleted = "COMPLETED" // Successfully completed + StatusReverted = "REVERTED" // Failed (expiry, receipt failed, or vote failed) + StatusReorged = "REORGED" // Removed due to chain reorganization +) + +// Event type values. +const ( + EventTypeKeygen = "KEYGEN" + EventTypeKeyrefresh = "KEYREFRESH" + EventTypeQuorumChange = "QUORUM_CHANGE" + EventTypeSign = "SIGN" + EventTypeInbound = "INBOUND" + EventTypeOutbound = "OUTBOUND" +) + +// Confirmation type values. +const ( + ConfirmationStandard = "STANDARD" // Standard finality (multiple block confirmations) + ConfirmationFast = "FAST" // Fast finality (fewer confirmations) + ConfirmationInstant = "INSTANT" // Instant finality (Push Chain) +) // State tracks synchronization state for a chain. // There is exactly one State record per chain database, storing the last processed block height. From c599fecf18f866c7f20b2ee49d62749b647a3275 Mon Sep 17 00:00:00 2001 From: aman035 Date: Thu, 19 Mar 2026 15:59:16 +0530 Subject: [PATCH 15/28] fix: use types from store --- .../chains/common/chain_store_test.go | 4 +- .../chains/common/event_processor.go | 7 +- .../chains/common/event_processor_test.go | 24 +-- universalClient/chains/common/types.go | 10 -- universalClient/chains/evm/event_confirmer.go | 3 +- .../chains/evm/event_confirmer_test.go | 9 +- universalClient/chains/evm/event_parser.go | 4 +- .../chains/evm/event_parser_test.go | 11 +- universalClient/chains/push/event_parser.go | 9 +- .../chains/push/event_parser_test.go | 24 +-- .../chains/svm/event_confirmer_test.go | 9 +- universalClient/chains/svm/event_parser.go | 4 +- .../chains/svm/event_parser_test.go | 7 +- .../tss/coordinator/coordinator.go | 16 +- .../tss/coordinator/coordinator_test.go | 24 +-- universalClient/tss/coordinator/types.go | 10 -- universalClient/tss/eventstore/store.go | 34 ++--- universalClient/tss/eventstore/store_test.go | 139 +++++++++--------- universalClient/tss/expirysweeper/sweeper.go | 8 +- .../tss/expirysweeper/sweeper_test.go | 4 +- .../tss/sessionmanager/sessionmanager.go | 29 ++-- .../tss/sessionmanager/sessionmanager_test.go | 8 +- .../tss/txbroadcaster/broadcaster.go | 12 +- .../tss/txbroadcaster/broadcaster_test.go | 36 ++--- universalClient/tss/txresolver/evm.go | 3 +- universalClient/tss/txresolver/resolver.go | 2 +- .../tss/txresolver/resolver_test.go | 10 +- universalClient/tss/txresolver/svm.go | 3 +- 28 files changed, 213 insertions(+), 250 deletions(-) diff --git a/universalClient/chains/common/chain_store_test.go b/universalClient/chains/common/chain_store_test.go index e22c2c2a..4d0372de 100644 --- a/universalClient/chains/common/chain_store_test.go +++ b/universalClient/chains/common/chain_store_test.go @@ -5,6 +5,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + storemodels "github.com/pushchain/push-chain-node/universalClient/store" ) func TestNewChainStore(t *testing.T) { @@ -46,7 +48,7 @@ func TestChainStoreNilDatabase(t *testing.T) { }) t.Run("UpdateEventStatus returns error for nil database", func(t *testing.T) { - rowsAffected, err := store.UpdateEventStatus("event-1", "PENDING", "CONFIRMED") + rowsAffected, err := store.UpdateEventStatus("event-1", storemodels.StatusPending, storemodels.StatusConfirmed) require.Error(t, err) assert.Equal(t, int64(0), rowsAffected) assert.Contains(t, err.Error(), "database is nil") diff --git a/universalClient/chains/common/event_processor.go b/universalClient/chains/common/event_processor.go index 2b399436..b3c4618f 100644 --- a/universalClient/chains/common/event_processor.go +++ b/universalClient/chains/common/event_processor.go @@ -119,7 +119,7 @@ func (ep *EventProcessor) processConfirmedEvents(ctx context.Context) error { } for _, event := range events { - if event.Type == EventTypeInbound { + if event.Type == store.EventTypeInbound { if !ep.inboundEnabled { ep.logger.Warn().Str("event_id", event.EventID).Msg("inbound disabled, skipping inbound event processing") continue @@ -131,7 +131,7 @@ func (ep *EventProcessor) processConfirmedEvents(ctx context.Context) error { Msg("failed to vote on inbound event") continue } - } else if event.Type == EventTypeOutbound { + } else if event.Type == store.EventTypeOutbound { if !ep.outboundEnabled { ep.logger.Warn().Str("event_id", event.EventID).Msg("outbound disabled, skipping outbound event processing") continue @@ -299,7 +299,8 @@ func (ep *EventProcessor) constructInbound(event *store.Event) (*uexecutortypes. } // Set recipient for transactions that involve funds - if txType == uexecutortypes.TxType_FUNDS || txType == uexecutortypes.TxType_GAS { + if txType == uexecutortypes.TxType_FUNDS || txType == uexecutortypes.TxType_GAS || + (txType == uexecutortypes.TxType_FUNDS_AND_PAYLOAD && eventData.FromCEA) { inboundMsg.Recipient = eventData.Recipient } diff --git a/universalClient/chains/common/event_processor_test.go b/universalClient/chains/common/event_processor_test.go index 1deff96c..3dfe3391 100644 --- a/universalClient/chains/common/event_processor_test.go +++ b/universalClient/chains/common/event_processor_test.go @@ -384,14 +384,14 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { return []store.Event{ { EventID: "0xaaa:0", - Status: "CONFIRMED", - Type: EventTypeInbound, + Status: store.StatusConfirmed, + Type: store.EventTypeInbound, EventData: inboundEventData, }, { EventID: "0xbbb:0", - Status: "CONFIRMED", - Type: EventTypeOutbound, + Status: store.StatusConfirmed, + Type: store.EventTypeOutbound, EventData: outboundEventData, }, } @@ -408,7 +408,7 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { // Inbound event should still be CONFIRMED (skipped, not processed) var inboundEvt store.Event database.Client().Where("event_id = ?", "0xaaa:0").First(&inboundEvt) - assert.Equal(t, "CONFIRMED", inboundEvt.Status) + assert.Equal(t, store.StatusConfirmed, inboundEvt.Status) }) t.Run("outbound disabled skips outbound events, leaves them CONFIRMED", func(t *testing.T) { @@ -421,7 +421,7 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { // Outbound event should still be CONFIRMED (skipped, not processed) var outboundEvt store.Event database.Client().Where("event_id = ?", "0xbbb:0").First(&outboundEvt) - assert.Equal(t, "CONFIRMED", outboundEvt.Status) + assert.Equal(t, store.StatusConfirmed, outboundEvt.Status) }) t.Run("inbound enabled but outbound disabled skips only outbound", func(t *testing.T) { @@ -429,8 +429,8 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { database := setupDB(t, []store.Event{ { EventID: "0xbbb:0", - Status: "CONFIRMED", - Type: EventTypeOutbound, + Status: store.StatusConfirmed, + Type: store.EventTypeOutbound, EventData: outboundEventData, }, }) @@ -442,7 +442,7 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { // Outbound event should still be CONFIRMED (skipped due to outbound disabled) var outboundEvt store.Event database.Client().Where("event_id = ?", "0xbbb:0").First(&outboundEvt) - assert.Equal(t, "CONFIRMED", outboundEvt.Status) + assert.Equal(t, store.StatusConfirmed, outboundEvt.Status) }) t.Run("outbound enabled but inbound disabled skips only inbound", func(t *testing.T) { @@ -450,8 +450,8 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { database := setupDB(t, []store.Event{ { EventID: "0xaaa:0", - Status: "CONFIRMED", - Type: EventTypeInbound, + Status: store.StatusConfirmed, + Type: store.EventTypeInbound, EventData: inboundEventData, }, }) @@ -463,6 +463,6 @@ func TestProcessConfirmedEventsEnabledFlags(t *testing.T) { // Inbound event should still be CONFIRMED (skipped due to inbound disabled) var inboundEvt store.Event database.Client().Where("event_id = ?", "0xaaa:0").First(&inboundEvt) - assert.Equal(t, "CONFIRMED", inboundEvt.Status) + assert.Equal(t, store.StatusConfirmed, inboundEvt.Status) }) } diff --git a/universalClient/chains/common/types.go b/universalClient/chains/common/types.go index bf7c231c..bda327d1 100644 --- a/universalClient/chains/common/types.go +++ b/universalClient/chains/common/types.go @@ -87,13 +87,3 @@ type OutboundEvent struct { GasFeeUsed string `json:"gas_fee_used,omitempty"` // gas fee used in wei (decimal string) } -// Event type enum values for event classification. -// These constants define the types of events that can be processed. -const ( - EventTypeKeygen = "KEYGEN" - EventTypeKeyrefresh = "KEYREFRESH" - EventTypeQuorumChange = "QUORUM_CHANGE" - EventTypeSign = "SIGN" - EventTypeInbound = "INBOUND" - EventTypeOutbound = "OUTBOUND" -) diff --git a/universalClient/chains/evm/event_confirmer.go b/universalClient/chains/evm/event_confirmer.go index b38012df..581ac373 100644 --- a/universalClient/chains/evm/event_confirmer.go +++ b/universalClient/chains/evm/event_confirmer.go @@ -14,6 +14,7 @@ import ( chaincommon "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/db" + "github.com/pushchain/push-chain-node/universalClient/store" ) // EventConfirmer periodically checks pending events and marks them as CONFIRMED @@ -152,7 +153,7 @@ func (ec *EventConfirmer) processPendingEvents(ctx context.Context) error { var rowsAffected int64 // For outbound events, enrich with gas fee before confirming - if event.Type == chaincommon.EventTypeOutbound { + if event.Type == store.EventTypeOutbound { tx, _, txErr := ec.rpcClient.GetTransactionByHash(ctx, hash) if txErr != nil { ec.logger.Warn(). diff --git a/universalClient/chains/evm/event_confirmer_test.go b/universalClient/chains/evm/event_confirmer_test.go index baed46a3..77bc2838 100644 --- a/universalClient/chains/evm/event_confirmer_test.go +++ b/universalClient/chains/evm/event_confirmer_test.go @@ -3,6 +3,7 @@ package evm import ( "testing" + "github.com/pushchain/push-chain-node/universalClient/store" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -85,13 +86,13 @@ func TestEventConfirmerGetRequiredConfirmations(t *testing.T) { t.Run("FAST confirmation type", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "eip155:1", 5, 5, 12, logger) - confirmations := confirmer.getRequiredConfirmations("FAST") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationFast) assert.Equal(t, uint64(5), confirmations) }) t.Run("STANDARD confirmation type", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "eip155:1", 5, 5, 12, logger) - confirmations := confirmer.getRequiredConfirmations("STANDARD") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationStandard) assert.Equal(t, uint64(12), confirmations) }) @@ -109,13 +110,13 @@ func TestEventConfirmerGetRequiredConfirmations(t *testing.T) { t.Run("uses custom fast confirmations", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "eip155:1", 5, 3, 20, logger) - confirmations := confirmer.getRequiredConfirmations("FAST") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationFast) assert.Equal(t, uint64(3), confirmations) }) t.Run("uses custom standard confirmations", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "eip155:1", 5, 3, 20, logger) - confirmations := confirmer.getRequiredConfirmations("STANDARD") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationStandard) assert.Equal(t, uint64(20), confirmations) }) } diff --git a/universalClient/chains/evm/event_parser.go b/universalClient/chains/evm/event_parser.go index b5ae28f5..cf25f9fa 100644 --- a/universalClient/chains/evm/event_parser.go +++ b/universalClient/chains/evm/event_parser.go @@ -74,7 +74,7 @@ func parseSendFundsEvent(log *types.Log, chainID string, logger zerolog.Logger) event := &store.Event{ EventID: eventID, BlockHeight: log.BlockNumber, - Type: common.EventTypeInbound, // Gateway events from external chains are INBOUND + Type: store.EventTypeInbound, // Gateway events from external chains are INBOUND Status: "PENDING", ExpiryBlockHeight: 0, // 0 means no expiry } @@ -134,7 +134,7 @@ func parseOutboundObservationEvent(log *types.Log, chainID string, logger zerolo event := &store.Event{ EventID: eventID, BlockHeight: log.BlockNumber, - Type: common.EventTypeOutbound, // Outbound observation events + Type: store.EventTypeOutbound, // Outbound observation events Status: "PENDING", ConfirmationType: "STANDARD", // Use STANDARD confirmation for outbound events ExpiryBlockHeight: 0, // 0 means no expiry diff --git a/universalClient/chains/evm/event_parser_test.go b/universalClient/chains/evm/event_parser_test.go index 33023760..ee85a314 100644 --- a/universalClient/chains/evm/event_parser_test.go +++ b/universalClient/chains/evm/event_parser_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/store" uregistrytypes "github.com/pushchain/push-chain-node/x/uregistry/types" ) @@ -230,9 +229,9 @@ func TestParseOutboundObservationEvent(t *testing.T) { // TxHash.Hex() returns full 32-byte hex representation assert.Equal(t, "0x0000000000000000000000000000000000000000000000000000abc123def456:5", event.EventID) assert.Equal(t, uint64(98765), event.BlockHeight) - assert.Equal(t, common.EventTypeOutbound, event.Type) - assert.Equal(t, "PENDING", event.Status) - assert.Equal(t, "STANDARD", event.ConfirmationType) + assert.Equal(t, store.EventTypeOutbound, event.Type) + assert.Equal(t, store.StatusPending, event.Status) + assert.Equal(t, store.ConfirmationStandard, event.ConfirmationType) // Verify event data contains tx_id and universal_tx_id assert.NotNil(t, event.EventData) @@ -357,8 +356,8 @@ func TestParseGatewayEvent_OutboundObservation(t *testing.T) { event := ParseEvent(log, EventTypeFinalizeUniversalTx, config.Chain, logger) require.NotNil(t, event) - assert.Equal(t, common.EventTypeOutbound, event.Type) - assert.Equal(t, "STANDARD", event.ConfirmationType) + assert.Equal(t, store.EventTypeOutbound, event.Type) + assert.Equal(t, store.ConfirmationStandard, event.ConfirmationType) assert.Equal(t, uint64(77777), event.BlockHeight) var outboundData map[string]interface{} diff --git a/universalClient/chains/push/event_parser.go b/universalClient/chains/push/event_parser.go index 12bcb264..54598f9a 100644 --- a/universalClient/chains/push/event_parser.go +++ b/universalClient/chains/push/event_parser.go @@ -7,7 +7,6 @@ import ( "strconv" abci "github.com/cometbft/cometbft/abci/types" - "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/store" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" utsstypes "github.com/pushchain/push-chain-node/x/utss/types" @@ -197,7 +196,7 @@ func parseOutboundEvent(event abci.Event) (*store.Event, error) { return &store.Event{ EventID: txID, - Type: common.EventTypeSign, + Type: store.EventTypeSign, EventData: eventData, }, nil } @@ -229,11 +228,11 @@ func buildTSSEventData(processID uint64, participants []string) ([]byte, error) func convertProcessType(chainType string) string { switch chainType { case ChainProcessTypeKeygen: - return common.EventTypeKeygen + return store.EventTypeKeygen case ChainProcessTypeRefresh: - return common.EventTypeKeyrefresh + return store.EventTypeKeyrefresh case ChainProcessTypeQuorumChange: - return common.EventTypeQuorumChange + return store.EventTypeQuorumChange default: // Return as-is for unknown types to maintain forward compatibility return chainType diff --git a/universalClient/chains/push/event_parser_test.go b/universalClient/chains/push/event_parser_test.go index 8fcd7f45..a200e8fd 100644 --- a/universalClient/chains/push/event_parser_test.go +++ b/universalClient/chains/push/event_parser_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/pushchain/push-chain-node/universalClient/chains/common" + "github.com/pushchain/push-chain-node/universalClient/store" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" ) @@ -36,7 +36,7 @@ func TestParseEvent_TSSEvent(t *testing.T) { }, blockHeight: 500, wantEventID: "123", - wantType: common.EventTypeKeygen, + wantType: store.EventTypeKeygen, wantExpiry: 1000, wantErr: false, }, @@ -51,7 +51,7 @@ func TestParseEvent_TSSEvent(t *testing.T) { }, blockHeight: 600, wantEventID: "456", - wantType: common.EventTypeKeyrefresh, + wantType: store.EventTypeKeyrefresh, wantExpiry: 0, wantErr: false, }, @@ -67,7 +67,7 @@ func TestParseEvent_TSSEvent(t *testing.T) { }, blockHeight: 700, wantEventID: "789", - wantType: common.EventTypeQuorumChange, + wantType: store.EventTypeQuorumChange, wantExpiry: 2000, wantErr: false, }, @@ -157,8 +157,8 @@ func TestParseEvent_TSSEvent(t *testing.T) { assert.Equal(t, tt.wantType, result.Type) assert.Equal(t, tt.wantExpiry, result.ExpiryBlockHeight) assert.Equal(t, tt.blockHeight, result.BlockHeight) - assert.Equal(t, "CONFIRMED", result.Status) - assert.Equal(t, "INSTANT", result.ConfirmationType) + assert.Equal(t, store.StatusConfirmed, result.Status) + assert.Equal(t, store.ConfirmationInstant, result.ConfirmationType) }) } } @@ -242,11 +242,11 @@ func TestParseEvent_OutboundEvent(t *testing.T) { require.NotNil(t, result) assert.Equal(t, tt.wantEventID, result.EventID) - assert.Equal(t, common.EventTypeSign, result.Type) + assert.Equal(t, store.EventTypeSign, result.Type) assert.Equal(t, tt.wantExpiry, result.ExpiryBlockHeight) assert.Equal(t, tt.blockHeight, result.BlockHeight) - assert.Equal(t, "CONFIRMED", result.Status) - assert.Equal(t, "INSTANT", result.ConfirmationType) + assert.Equal(t, store.StatusConfirmed, result.Status) + assert.Equal(t, store.ConfirmationInstant, result.ConfirmationType) }) } } @@ -314,9 +314,9 @@ func TestConvertProcessType(t *testing.T) { input string expected string }{ - {ChainProcessTypeKeygen, common.EventTypeKeygen}, - {ChainProcessTypeRefresh, common.EventTypeKeyrefresh}, - {ChainProcessTypeQuorumChange, common.EventTypeQuorumChange}, + {ChainProcessTypeKeygen, store.EventTypeKeygen}, + {ChainProcessTypeRefresh, store.EventTypeKeyrefresh}, + {ChainProcessTypeQuorumChange, store.EventTypeQuorumChange}, {"UNKNOWN_TYPE", "UNKNOWN_TYPE"}, // Unknown types returned as-is {"", ""}, } diff --git a/universalClient/chains/svm/event_confirmer_test.go b/universalClient/chains/svm/event_confirmer_test.go index dafc14ed..005efbe9 100644 --- a/universalClient/chains/svm/event_confirmer_test.go +++ b/universalClient/chains/svm/event_confirmer_test.go @@ -3,6 +3,7 @@ package svm import ( "testing" + "github.com/pushchain/push-chain-node/universalClient/store" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -85,25 +86,25 @@ func TestEventConfirmerGetRequiredConfirmations(t *testing.T) { t.Run("FAST confirmation type with custom value", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "solana:mainnet", 5, 5, 12, logger) - confirmations := confirmer.getRequiredConfirmations("FAST") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationFast) assert.Equal(t, uint64(5), confirmations) }) t.Run("FAST confirmation type with zero uses default", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "solana:mainnet", 5, 0, 12, logger) - confirmations := confirmer.getRequiredConfirmations("FAST") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationFast) assert.Equal(t, uint64(5), confirmations) // Default is 5 }) t.Run("STANDARD confirmation type with custom value", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "solana:mainnet", 5, 5, 20, logger) - confirmations := confirmer.getRequiredConfirmations("STANDARD") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationStandard) assert.Equal(t, uint64(20), confirmations) }) t.Run("STANDARD confirmation type with zero uses default", func(t *testing.T) { confirmer := NewEventConfirmer(nil, nil, "solana:mainnet", 5, 5, 0, logger) - confirmations := confirmer.getRequiredConfirmations("STANDARD") + confirmations := confirmer.getRequiredConfirmations(store.ConfirmationStandard) assert.Equal(t, uint64(12), confirmations) // Default is 12 }) diff --git a/universalClient/chains/svm/event_parser.go b/universalClient/chains/svm/event_parser.go index c809607d..1ec4687b 100644 --- a/universalClient/chains/svm/event_parser.go +++ b/universalClient/chains/svm/event_parser.go @@ -90,7 +90,7 @@ func parseSendFundsEvent(log string, signature string, slot uint64, logIndex uin event := &store.Event{ EventID: eventID, BlockHeight: slot, - Type: common.EventTypeInbound, // Gateway events from external chains are INBOUND + Type: store.EventTypeInbound, // Gateway events from external chains are INBOUND Status: "PENDING", ExpiryBlockHeight: 0, // Will be set based on confirmation type if needed } @@ -176,7 +176,7 @@ func parseOutboundObservationEvent(log string, signature string, slot uint64, lo event := &store.Event{ EventID: eventID, BlockHeight: slot, - Type: common.EventTypeOutbound, // Outbound observation events + Type: store.EventTypeOutbound, // Outbound observation events Status: "PENDING", ConfirmationType: "STANDARD", // Use STANDARD confirmation for outbound events ExpiryBlockHeight: 0, // 0 means no expiry diff --git a/universalClient/chains/svm/event_parser_test.go b/universalClient/chains/svm/event_parser_test.go index 7a84656e..950a8be6 100644 --- a/universalClient/chains/svm/event_parser_test.go +++ b/universalClient/chains/svm/event_parser_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/store" ) @@ -63,9 +62,9 @@ func TestParseOutboundObservationEvent(t *testing.T) { validate: func(t *testing.T, event *store.Event) { assert.Contains(t, event.EventID, signature) assert.Equal(t, uint64(12345), event.BlockHeight) - assert.Equal(t, common.EventTypeOutbound, event.Type) - assert.Equal(t, "PENDING", event.Status) - assert.Equal(t, "STANDARD", event.ConfirmationType) + assert.Equal(t, store.EventTypeOutbound, event.Type) + assert.Equal(t, store.StatusPending, event.Status) + assert.Equal(t, store.ConfirmationStandard, event.ConfirmationType) // Verify event data contains tx_id, universal_tx_id, and gas_fee_used assert.NotNil(t, event.EventData) diff --git a/universalClient/tss/coordinator/coordinator.go b/universalClient/tss/coordinator/coordinator.go index fdd74b04..2af712bd 100644 --- a/universalClient/tss/coordinator/coordinator.go +++ b/universalClient/tss/coordinator/coordinator.go @@ -396,7 +396,7 @@ func (c *Coordinator) processConfirmedEvents(ctx context.Context) error { for _, event := range events { var assignedNonce *uint64 - if event.Type == string(ProtocolSign) { + if event.Type == store.EventTypeSign { chain := extractDestinationChain(event.EventData) if chain == "" { continue @@ -427,7 +427,7 @@ func (c *Coordinator) processConfirmedEvents(ctx context.Context) error { // A threshold subset suffices for signing and is more resilient when some nodes are offline. // For all other protocols (keygen, keyrefresh, quorum_change), all eligible must participate. var participants []*types.UniversalValidator - if event.Type == string(ProtocolSign) { + if event.Type == store.EventTypeSign { participants = getSignParticipants(allValidators) } else { participants = getEligibleForProtocol(event.Type, allValidators) @@ -487,12 +487,12 @@ func (c *Coordinator) processEventAsCoordinator(ctx context.Context, event store var unsignedTxReq *common.UnSignedOutboundTxReq var err error switch event.Type { - case string(ProtocolKeygen), string(ProtocolKeyrefresh): + case store.EventTypeKeygen, store.EventTypeKeyrefresh: // Keygen and keyrefresh use the same setup structure setupData, err = c.createKeygenSetup(threshold, partyIDs) - case string(ProtocolQuorumChange): + case store.EventTypeQuorumChange: setupData, err = c.createQcSetup(ctx, threshold, partyIDs, sortedParticipants) - case string(ProtocolSign): + case store.EventTypeSign: setupData, unsignedTxReq, err = c.createSignSetup(ctx, event.EventData, partyIDs, assignedNonce) default: err = errors.Errorf("unknown protocol type: %s", event.Type) @@ -842,13 +842,13 @@ func (c *Coordinator) createQcSetup(ctx context.Context, threshold int, partyIDs // For SIGN coordinator setup, use getSignParticipants instead (random threshold subset). func getEligibleForProtocol(protocolType string, allValidators []*types.UniversalValidator) []*types.UniversalValidator { switch protocolType { - case string(ProtocolKeygen), string(ProtocolQuorumChange): + case store.EventTypeKeygen, store.EventTypeQuorumChange: // Active + Pending Join return getQuorumChangeParticipants(allValidators) - case string(ProtocolKeyrefresh): + case store.EventTypeKeyrefresh: // Active + Pending Leave return getSignEligible(allValidators) - case string(ProtocolSign): + case store.EventTypeSign: // Active + Pending Leave return getSignEligible(allValidators) default: diff --git a/universalClient/tss/coordinator/coordinator_test.go b/universalClient/tss/coordinator/coordinator_test.go index 85f80e66..61b792f2 100644 --- a/universalClient/tss/coordinator/coordinator_test.go +++ b/universalClient/tss/coordinator/coordinator_test.go @@ -251,13 +251,13 @@ func TestCalculateThreshold(t *testing.T) { n int expected int }{ - {0, 1}, // edge: 0 or fewer → 1 + {0, 1}, // edge: 0 or fewer → 1 {1, 1}, - {3, 3}, // (2*3)/3+1 = 3 - {4, 3}, // (2*4)/3+1 = 3 - {5, 4}, // (2*5)/3+1 = 4 - {6, 5}, // (2*6)/3+1 = 5 - {9, 7}, // (2*9)/3+1 = 7 + {3, 3}, // (2*3)/3+1 = 3 + {4, 3}, // (2*4)/3+1 = 3 + {5, 4}, // (2*5)/3+1 = 4 + {6, 5}, // (2*6)/3+1 = 5 + {9, 7}, // (2*9)/3+1 = 7 } for _, tt := range tests { assert.Equal(t, tt.expected, CalculateThreshold(tt.n), "n=%d", tt.n) @@ -339,14 +339,14 @@ func TestGetInFlightSignCountPerChain(t *testing.T) { polyData := []byte(`{"destination_chain":"polygon"}`) // IN_PROGRESS and SIGNED both count as in-flight. - db.Create(&store.Event{EventID: "e1", Type: "SIGN", Status: eventstore.StatusInProgress, EventData: ethData}) - db.Create(&store.Event{EventID: "e2", Type: "SIGN", Status: eventstore.StatusInProgress, EventData: ethData}) - db.Create(&store.Event{EventID: "e3", Type: "SIGN", Status: eventstore.StatusSigned, EventData: polyData}) + db.Create(&store.Event{EventID: "e1", Type: "SIGN", Status: store.StatusInProgress, EventData: ethData}) + db.Create(&store.Event{EventID: "e2", Type: "SIGN", Status: store.StatusInProgress, EventData: ethData}) + db.Create(&store.Event{EventID: "e3", Type: "SIGN", Status: store.StatusSigned, EventData: polyData}) // These must NOT be counted. - db.Create(&store.Event{EventID: "e4", Type: "SIGN", Status: eventstore.StatusConfirmed, EventData: ethData}) // not yet in-flight - db.Create(&store.Event{EventID: "e5", Type: "SIGN", Status: eventstore.StatusBroadcasted, EventData: ethData}) // pending nonce RPC covers it - db.Create(&store.Event{EventID: "e6", Type: "KEYGEN", Status: eventstore.StatusInProgress}) // not a SIGN event + db.Create(&store.Event{EventID: "e4", Type: "SIGN", Status: store.StatusConfirmed, EventData: ethData}) // not yet in-flight + db.Create(&store.Event{EventID: "e5", Type: "SIGN", Status: store.StatusBroadcasted, EventData: ethData}) // pending nonce RPC covers it + db.Create(&store.Event{EventID: "e6", Type: "KEYGEN", Status: store.StatusInProgress}) // not a SIGN event perChain, err := coord.getInFlightSignCountPerChain() require.NoError(t, err) diff --git a/universalClient/tss/coordinator/types.go b/universalClient/tss/coordinator/types.go index 3882b8a8..3f87cd6e 100644 --- a/universalClient/tss/coordinator/types.go +++ b/universalClient/tss/coordinator/types.go @@ -11,16 +11,6 @@ import ( // data: The message bytes type SendFunc func(ctx context.Context, peerID string, data []byte) error -// ProtocolType enumerates the supported DKLS protocol flows. -type ProtocolType string - -const ( - ProtocolKeygen ProtocolType = "KEYGEN" - ProtocolKeyrefresh ProtocolType = "KEYREFRESH" - ProtocolQuorumChange ProtocolType = "QUORUM_CHANGE" - ProtocolSign ProtocolType = "SIGN" -) - // Message represents a simple message with type, eventId, payload, and participants. type Message struct { Type string `json:"type"` // "setup", "ack", "begin", "step" diff --git a/universalClient/tss/eventstore/store.go b/universalClient/tss/eventstore/store.go index 564d2696..13b816cc 100644 --- a/universalClient/tss/eventstore/store.go +++ b/universalClient/tss/eventstore/store.go @@ -8,20 +8,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/store" ) -// Event statuses for TSS operations. -// -// Lifecycle: CONFIRMED → IN_PROGRESS → SIGNED → BROADCASTED → COMPLETED -// -// ↘ REVERTED (on expiry, receipt failed, or key vote failed) -const ( - StatusConfirmed = "CONFIRMED" // Event confirmed on Push chain, ready for processing - StatusInProgress = "IN_PROGRESS" // TSS signing is in progress - StatusSigned = "SIGNED" // TSS signing done, tx not yet broadcast (sign events only) - StatusBroadcasted = "BROADCASTED" // Transaction sent to external chain (sign events only) - StatusReverted = "REVERTED" // Reverted (failure vote sent for sign events, or key vote failed) - StatusCompleted = "COMPLETED" // Successfully completed -) - // Store provides database access for TSS events. type Store struct { db *gorm.DB @@ -50,8 +36,8 @@ func (s *Store) GetEvent(eventID string) (*store.Event, error) { // // Example usage: // -// s.Update(id, map[string]any{"status": StatusInProgress}) -// s.Update(id, map[string]any{"status": StatusConfirmed, "block_height": newHeight}) +// s.Update(id, map[string]any{"status": store.StatusInProgress}) +// s.Update(id, map[string]any{"status": store.StatusConfirmed, "block_height": newHeight}) // s.Update(id, map[string]any{"broadcasted_tx_hash": txHash}) func (s *Store) Update(eventID string, fields map[string]any) error { result := s.db.Model(&store.Event{}). @@ -70,7 +56,7 @@ func (s *Store) Update(eventID string, fields map[string]any) error { // Used by the coordinator to cap how many new events to fetch. func (s *Store) CountInProgress() (int64, error) { var count int64 - if err := s.db.Model(&store.Event{}).Where("status = ?", StatusInProgress).Count(&count).Error; err != nil { + if err := s.db.Model(&store.Event{}).Where("status = ?", store.StatusInProgress).Count(&count).Error; err != nil { return 0, errors.Wrap(err, "failed to count IN_PROGRESS events") } return count, nil @@ -80,8 +66,8 @@ func (s *Store) CountInProgress() (int64, error) { // Called on node startup to recover from crashes mid-session. func (s *Store) ResetInProgressEventsToConfirmed() (int64, error) { result := s.db.Model(&store.Event{}). - Where("status = ?", StatusInProgress). - Update("status", StatusConfirmed) + Where("status = ?", store.StatusInProgress). + Update("status", store.StatusConfirmed) if result.Error != nil { return 0, errors.Wrap(result.Error, "failed to reset IN_PROGRESS events to CONFIRMED") } @@ -97,7 +83,7 @@ func (s *Store) GetNonExpiredConfirmedEvents(currentBlock, minBlockConfirmation } query := s.db.Where("status = ? AND block_height <= ? AND expiry_block_height > ?", - StatusConfirmed, minBlock, currentBlock). + store.StatusConfirmed, minBlock, currentBlock). Order("block_height ASC, created_at ASC") if limit > 0 { query = query.Limit(limit) @@ -117,7 +103,7 @@ func (s *Store) GetNonExpiredConfirmedEvents(currentBlock, minBlockConfirmation func (s *Store) GetInFlightSignEvents() ([]store.Event, error) { var events []store.Event if err := s.db.Where("type = ? AND status IN (?, ?)", - "SIGN", StatusInProgress, StatusSigned). + store.EventTypeSign, store.StatusInProgress, store.StatusSigned). Find(&events).Error; err != nil { return nil, errors.Wrap(err, "failed to query in-flight sign events") } @@ -130,7 +116,7 @@ func (s *Store) GetSignedSignEvents(limit int) ([]store.Event, error) { limit = 50 } var events []store.Event - if err := s.db.Where("type = ? AND status = ?", "SIGN", StatusSigned). + if err := s.db.Where("type = ? AND status = ?", store.EventTypeSign, store.StatusSigned). Order("block_height ASC, created_at ASC"). Limit(limit). Find(&events).Error; err != nil { @@ -145,7 +131,7 @@ func (s *Store) GetBroadcastedSignEvents(limit int) ([]store.Event, error) { limit = 50 } var events []store.Event - if err := s.db.Where("type = ? AND status = ? AND broadcasted_tx_hash != ?", "SIGN", StatusBroadcasted, ""). + if err := s.db.Where("type = ? AND status = ? AND broadcasted_tx_hash != ?", store.EventTypeSign, store.StatusBroadcasted, ""). Order("block_height ASC, created_at ASC"). Limit(limit). Find(&events).Error; err != nil { @@ -157,7 +143,7 @@ func (s *Store) GetBroadcastedSignEvents(limit int) ([]store.Event, error) { // GetExpiredConfirmedEvents returns CONFIRMED events past their expiry block. func (s *Store) GetExpiredConfirmedEvents(currentBlock uint64, limit int) ([]store.Event, error) { query := s.db.Where("status = ? AND expiry_block_height <= ?", - StatusConfirmed, currentBlock). + store.StatusConfirmed, currentBlock). Order("block_height ASC, created_at ASC") if limit > 0 { query = query.Limit(limit) diff --git a/universalClient/tss/eventstore/store_test.go b/universalClient/tss/eventstore/store_test.go index e15c1fdc..91754b6b 100644 --- a/universalClient/tss/eventstore/store_test.go +++ b/universalClient/tss/eventstore/store_test.go @@ -9,7 +9,6 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/store" ) @@ -36,7 +35,7 @@ func setupTestStore(t *testing.T) *Store { // createTestEvent creates a test TSS event in the database. func createTestEvent(t *testing.T, s *Store, eventID string, blockHeight uint64, status string, expiryHeight uint64) { - createTestEventWithType(t, s, eventID, blockHeight, status, expiryHeight, common.EventTypeKeygen) + createTestEventWithType(t, s, eventID, blockHeight, status, expiryHeight, store.EventTypeKeygen) } // createTestEventWithType creates a test event with a specific type. @@ -88,7 +87,7 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { s := setupTestStore(t) // Create event at block 95, current block is 100, min confirmation is 10 // Event is only 5 blocks old, needs 10 blocks confirmation - createTestEvent(t, s, "event-1", 95, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 95, store.StatusConfirmed, 200) events, err := s.GetNonExpiredConfirmedEvents(100, 10, 0) if err != nil { @@ -103,8 +102,8 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { s := setupTestStore(t) // Create event at block 80, current block is 100, min confirmation is 10 // Event is 20 blocks old, should be ready - createTestEvent(t, s, "event-1", 80, StatusConfirmed, 200) - createTestEvent(t, s, "event-2", 85, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 80, store.StatusConfirmed, 200) + createTestEvent(t, s, "event-2", 85, store.StatusConfirmed, 200) events, err := s.GetNonExpiredConfirmedEvents(100, 10, 0) if err != nil { @@ -123,10 +122,10 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { t.Run("filters non-pending events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "pending-1", 80, StatusConfirmed, 200) - createTestEvent(t, s, "in-progress-1", 80, StatusInProgress, 200) - createTestEvent(t, s, "success-1", 80, StatusCompleted, 200) - createTestEvent(t, s, "reverted-1", 80, StatusReverted, 200) + createTestEvent(t, s, "pending-1", 80, store.StatusConfirmed, 200) + createTestEvent(t, s, "in-progress-1", 80, store.StatusInProgress, 200) + createTestEvent(t, s, "success-1", 80, store.StatusCompleted, 200) + createTestEvent(t, s, "reverted-1", 80, store.StatusReverted, 200) events, err := s.GetNonExpiredConfirmedEvents(100, 10, 0) if err != nil { @@ -142,9 +141,9 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { t.Run("excludes expired events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "expired-1", 80, StatusConfirmed, 90) // expired (expiry 90 < current 100) - createTestEvent(t, s, "valid-1", 80, StatusConfirmed, 200) // not expired (expiry 200 > current 100) - createTestEvent(t, s, "valid-2", 80, StatusConfirmed, 101) // not expired (expiry 101 > current 100) + createTestEvent(t, s, "expired-1", 80, store.StatusConfirmed, 90) // expired (expiry 90 < current 100) + createTestEvent(t, s, "valid-1", 80, store.StatusConfirmed, 200) // not expired (expiry 200 > current 100) + createTestEvent(t, s, "valid-2", 80, store.StatusConfirmed, 101) // not expired (expiry 101 > current 100) events, err := s.GetNonExpiredConfirmedEvents(100, 10, 0) if err != nil { @@ -157,9 +156,9 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { t.Run("respects limit", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 80, StatusConfirmed, 200) - createTestEvent(t, s, "event-2", 85, StatusConfirmed, 200) - createTestEvent(t, s, "event-3", 88, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 80, store.StatusConfirmed, 200) + createTestEvent(t, s, "event-2", 85, store.StatusConfirmed, 200) + createTestEvent(t, s, "event-3", 88, store.StatusConfirmed, 200) events, err := s.GetNonExpiredConfirmedEvents(100, 10, 2) if err != nil { @@ -172,11 +171,11 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { t.Run("orders by block number and created_at", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 80, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 80, store.StatusConfirmed, 200) time.Sleep(10 * time.Millisecond) // Ensure different created_at - createTestEvent(t, s, "event-2", 80, StatusConfirmed, 200) + createTestEvent(t, s, "event-2", 80, store.StatusConfirmed, 200) time.Sleep(10 * time.Millisecond) - createTestEvent(t, s, "event-3", 75, StatusConfirmed, 200) // Earlier block + createTestEvent(t, s, "event-3", 75, store.StatusConfirmed, 200) // Earlier block events, err := s.GetNonExpiredConfirmedEvents(100, 10, 0) if err != nil { @@ -199,7 +198,7 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { t.Run("handles current block less than min confirmation", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 0, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 0, store.StatusConfirmed, 200) // Current block is 5, min confirmation is 10 events, err := s.GetNonExpiredConfirmedEvents(5, 10, 0) @@ -215,7 +214,7 @@ func TestGetNonExpiredConfirmedEvents(t *testing.T) { func TestGetEvent(t *testing.T) { t.Run("event exists", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 100, store.StatusConfirmed, 200) event, err := s.GetEvent("event-1") if err != nil { @@ -230,8 +229,8 @@ func TestGetEvent(t *testing.T) { if event.BlockHeight != 100 { t.Errorf("GetEvent() block height = %d, want 100", event.BlockHeight) } - if event.Status != StatusConfirmed { - t.Errorf("GetEvent() status = %s, want %s", event.Status, StatusConfirmed) + if event.Status != store.StatusConfirmed { + t.Errorf("GetEvent() status = %s, want %s", event.Status, store.StatusConfirmed) } }) @@ -251,9 +250,9 @@ func TestGetEvent(t *testing.T) { func TestUpdate(t *testing.T) { t.Run("update status", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 100, store.StatusConfirmed, 200) - err := s.Update("event-1", map[string]any{"status": StatusInProgress}) + err := s.Update("event-1", map[string]any{"status": store.StatusInProgress}) if err != nil { t.Fatalf("Update() error = %v, want nil", err) } @@ -262,17 +261,17 @@ func TestUpdate(t *testing.T) { if err != nil { t.Fatalf("GetEvent() error = %v, want nil", err) } - if event.Status != StatusInProgress { - t.Errorf("Update() status = %s, want %s", event.Status, StatusInProgress) + if event.Status != store.StatusInProgress { + t.Errorf("Update() status = %s, want %s", event.Status, store.StatusInProgress) } }) t.Run("update multiple fields", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusInProgress, 200) + createTestEvent(t, s, "event-1", 100, store.StatusInProgress, 200) err := s.Update("event-1", map[string]any{ - "status": StatusConfirmed, + "status": store.StatusConfirmed, "block_height": uint64(150), }) if err != nil { @@ -280,8 +279,8 @@ func TestUpdate(t *testing.T) { } event, _ := s.GetEvent("event-1") - if event.Status != StatusConfirmed { - t.Errorf("status = %s, want %s", event.Status, StatusConfirmed) + if event.Status != store.StatusConfirmed { + t.Errorf("status = %s, want %s", event.Status, store.StatusConfirmed) } if event.BlockHeight != 150 { t.Errorf("block_height = %d, want 150", event.BlockHeight) @@ -290,7 +289,7 @@ func TestUpdate(t *testing.T) { t.Run("update broadcasted tx hash", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusBroadcasted, 200) + createTestEvent(t, s, "event-1", 100, store.StatusBroadcasted, 200) err := s.Update("event-1", map[string]any{"broadcasted_tx_hash": "eip155:11155111:0xabc"}) if err != nil { @@ -306,7 +305,7 @@ func TestUpdate(t *testing.T) { t.Run("update non-existent event", func(t *testing.T) { s := setupTestStore(t) - err := s.Update("non-existent", map[string]any{"status": StatusCompleted}) + err := s.Update("non-existent", map[string]any{"status": store.StatusCompleted}) if err == nil { t.Fatal("Update() error = nil, want error") } @@ -314,24 +313,24 @@ func TestUpdate(t *testing.T) { t.Run("multiple sequential updates", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 100, store.StatusConfirmed, 200) // CONFIRMED -> IN_PROGRESS - if err := s.Update("event-1", map[string]any{"status": StatusInProgress}); err != nil { + if err := s.Update("event-1", map[string]any{"status": store.StatusInProgress}); err != nil { t.Fatalf("Update() error = %v", err) } event, _ := s.GetEvent("event-1") - if event.Status != StatusInProgress { - t.Errorf("status = %s, want %s", event.Status, StatusInProgress) + if event.Status != store.StatusInProgress { + t.Errorf("status = %s, want %s", event.Status, store.StatusInProgress) } // IN_PROGRESS -> COMPLETED - if err := s.Update("event-1", map[string]any{"status": StatusCompleted}); err != nil { + if err := s.Update("event-1", map[string]any{"status": store.StatusCompleted}); err != nil { t.Fatalf("Update() error = %v", err) } event, _ = s.GetEvent("event-1") - if event.Status != StatusCompleted { - t.Errorf("status = %s, want %s", event.Status, StatusCompleted) + if event.Status != store.StatusCompleted { + t.Errorf("status = %s, want %s", event.Status, store.StatusCompleted) } }) } @@ -339,8 +338,8 @@ func TestUpdate(t *testing.T) { func TestCountInProgress(t *testing.T) { t.Run("no in-progress events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusConfirmed, 200) - createTestEvent(t, s, "event-2", 100, StatusCompleted, 200) + createTestEvent(t, s, "event-1", 100, store.StatusConfirmed, 200) + createTestEvent(t, s, "event-2", 100, store.StatusCompleted, 200) count, err := s.CountInProgress() if err != nil { @@ -353,9 +352,9 @@ func TestCountInProgress(t *testing.T) { t.Run("some in-progress events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusInProgress, 200) - createTestEvent(t, s, "event-2", 100, StatusInProgress, 200) - createTestEvent(t, s, "event-3", 100, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 100, store.StatusInProgress, 200) + createTestEvent(t, s, "event-2", 100, store.StatusInProgress, 200) + createTestEvent(t, s, "event-3", 100, store.StatusConfirmed, 200) count, err := s.CountInProgress() if err != nil { @@ -370,9 +369,9 @@ func TestCountInProgress(t *testing.T) { func TestResetInProgressEventsToConfirmed(t *testing.T) { t.Run("resets in-progress events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "ip-1", 100, StatusInProgress, 200) - createTestEvent(t, s, "ip-2", 100, StatusInProgress, 200) - createTestEvent(t, s, "confirmed-1", 100, StatusConfirmed, 200) + createTestEvent(t, s, "ip-1", 100, store.StatusInProgress, 200) + createTestEvent(t, s, "ip-2", 100, store.StatusInProgress, 200) + createTestEvent(t, s, "confirmed-1", 100, store.StatusConfirmed, 200) count, err := s.ResetInProgressEventsToConfirmed() if err != nil { @@ -385,15 +384,15 @@ func TestResetInProgressEventsToConfirmed(t *testing.T) { // Verify all are now CONFIRMED for _, id := range []string{"ip-1", "ip-2", "confirmed-1"} { event, _ := s.GetEvent(id) - if event.Status != StatusConfirmed { - t.Errorf("event %s status = %s, want %s", id, event.Status, StatusConfirmed) + if event.Status != store.StatusConfirmed { + t.Errorf("event %s status = %s, want %s", id, event.Status, store.StatusConfirmed) } } }) t.Run("no in-progress events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 100, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 100, store.StatusConfirmed, 200) count, err := s.ResetInProgressEventsToConfirmed() if err != nil { @@ -406,9 +405,9 @@ func TestResetInProgressEventsToConfirmed(t *testing.T) { t.Run("does not affect other statuses", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "reverted-1", 100, StatusReverted, 200) - createTestEvent(t, s, "broadcasted-1", 100, StatusBroadcasted, 200) - createTestEvent(t, s, "ip-1", 100, StatusInProgress, 200) + createTestEvent(t, s, "reverted-1", 100, store.StatusReverted, 200) + createTestEvent(t, s, "broadcasted-1", 100, store.StatusBroadcasted, 200) + createTestEvent(t, s, "ip-1", 100, store.StatusInProgress, 200) count, _ := s.ResetInProgressEventsToConfirmed() if count != 1 { @@ -417,12 +416,12 @@ func TestResetInProgressEventsToConfirmed(t *testing.T) { // REVERTED and BROADCASTED should be unchanged reverted, _ := s.GetEvent("reverted-1") - if reverted.Status != StatusReverted { - t.Errorf("reverted event status = %s, want %s", reverted.Status, StatusReverted) + if reverted.Status != store.StatusReverted { + t.Errorf("reverted event status = %s, want %s", reverted.Status, store.StatusReverted) } broadcasted, _ := s.GetEvent("broadcasted-1") - if broadcasted.Status != StatusBroadcasted { - t.Errorf("broadcasted event status = %s, want %s", broadcasted.Status, StatusBroadcasted) + if broadcasted.Status != store.StatusBroadcasted { + t.Errorf("broadcasted event status = %s, want %s", broadcasted.Status, store.StatusBroadcasted) } }) } @@ -431,16 +430,16 @@ func TestGetExpiredConfirmedEvents(t *testing.T) { t.Run("returns only expired CONFIRMED events", func(t *testing.T) { s := setupTestStore(t) // Expired CONFIRMED (should be returned) - createTestEvent(t, s, "confirmed-expired", 50, StatusConfirmed, 90) + createTestEvent(t, s, "confirmed-expired", 50, store.StatusConfirmed, 90) // Expired non-CONFIRMED (should NOT be returned) - createTestEvent(t, s, "ip-expired", 50, StatusInProgress, 95) - createTestEvent(t, s, "signed-expired", 50, StatusSigned, 95) - createTestEvent(t, s, "broadcasted-expired", 50, StatusBroadcasted, 100) + createTestEvent(t, s, "ip-expired", 50, store.StatusInProgress, 95) + createTestEvent(t, s, "signed-expired", 50, store.StatusSigned, 95) + createTestEvent(t, s, "broadcasted-expired", 50, store.StatusBroadcasted, 100) // Not expired - createTestEvent(t, s, "confirmed-valid", 50, StatusConfirmed, 200) + createTestEvent(t, s, "confirmed-valid", 50, store.StatusConfirmed, 200) // Terminal statuses (should not be returned) - createTestEvent(t, s, "completed", 50, StatusCompleted, 90) - createTestEvent(t, s, "reverted", 50, StatusReverted, 90) + createTestEvent(t, s, "completed", 50, store.StatusCompleted, 90) + createTestEvent(t, s, "reverted", 50, store.StatusReverted, 90) events, err := s.GetExpiredConfirmedEvents(100, 100) if err != nil { @@ -456,7 +455,7 @@ func TestGetExpiredConfirmedEvents(t *testing.T) { t.Run("no expired events", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "event-1", 50, StatusConfirmed, 200) + createTestEvent(t, s, "event-1", 50, store.StatusConfirmed, 200) events, err := s.GetExpiredConfirmedEvents(100, 100) if err != nil { @@ -469,9 +468,9 @@ func TestGetExpiredConfirmedEvents(t *testing.T) { t.Run("respects limit", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "expired-1", 50, StatusConfirmed, 90) - createTestEvent(t, s, "expired-2", 60, StatusConfirmed, 95) - createTestEvent(t, s, "expired-3", 70, StatusConfirmed, 99) + createTestEvent(t, s, "expired-1", 50, store.StatusConfirmed, 90) + createTestEvent(t, s, "expired-2", 60, store.StatusConfirmed, 95) + createTestEvent(t, s, "expired-3", 70, store.StatusConfirmed, 99) events, err := s.GetExpiredConfirmedEvents(100, 2) if err != nil { @@ -484,8 +483,8 @@ func TestGetExpiredConfirmedEvents(t *testing.T) { t.Run("orders by block height", func(t *testing.T) { s := setupTestStore(t) - createTestEvent(t, s, "expired-high", 70, StatusConfirmed, 90) - createTestEvent(t, s, "expired-low", 50, StatusConfirmed, 90) + createTestEvent(t, s, "expired-high", 70, store.StatusConfirmed, 90) + createTestEvent(t, s, "expired-low", 50, store.StatusConfirmed, 90) events, err := s.GetExpiredConfirmedEvents(100, 100) if err != nil { diff --git a/universalClient/tss/expirysweeper/sweeper.go b/universalClient/tss/expirysweeper/sweeper.go index 7d724fc1..66b10d2b 100644 --- a/universalClient/tss/expirysweeper/sweeper.go +++ b/universalClient/tss/expirysweeper/sweeper.go @@ -19,8 +19,6 @@ import ( const ( defaultCheckInterval = 30 * time.Second sweepBatchSize = 100 - - statusSign = "SIGN" ) // Config holds configuration for the expiry sweeper. @@ -96,13 +94,13 @@ func (s *Sweeper) sweep(ctx context.Context) { swept := 0 for _, event := range events { - if event.Type == statusSign { + if event.Type == store.EventTypeSign { if err := s.voteFailureAndMarkReverted(ctx, &event, "event expired before TSS could start"); err != nil { s.logger.Error().Err(err).Str("event_id", event.EventID).Msg("failed to sweep expired SIGN event") continue } } else { - if err := s.eventStore.Update(event.EventID, map[string]any{"status": eventstore.StatusReverted}); err != nil { + if err := s.eventStore.Update(event.EventID, map[string]any{"status": store.StatusReverted}); err != nil { s.logger.Error().Err(err).Str("event_id", event.EventID).Msg("failed to revert expired key event") continue } @@ -124,7 +122,7 @@ func (s *Sweeper) voteFailureAndMarkReverted(ctx context.Context, event *store.E return errors.Wrapf(err, "failed to parse outbound event data for event %s", event.EventID) } - fields := map[string]any{"status": eventstore.StatusReverted} + fields := map[string]any{"status": store.StatusReverted} if s.pushSigner == nil { s.logger.Warn().Str("event_id", event.EventID).Msg("pushSigner not configured, skipping failure vote") diff --git a/universalClient/tss/expirysweeper/sweeper_test.go b/universalClient/tss/expirysweeper/sweeper_test.go index a17ba1d3..7bf30a4f 100644 --- a/universalClient/tss/expirysweeper/sweeper_test.go +++ b/universalClient/tss/expirysweeper/sweeper_test.go @@ -54,10 +54,10 @@ func runSweepBatch(t *testing.T, s *Sweeper, currentBlock uint64) { require.NoError(t, err) for _, event := range events { ev := event - if ev.Type == statusSign { + if ev.Type == store.EventTypeSign { require.NoError(t, s.voteFailureAndMarkReverted(ctx, &ev, "event expired before TSS could start")) } else { - require.NoError(t, s.eventStore.Update(ev.EventID, map[string]any{"status": eventstore.StatusReverted})) + require.NoError(t, s.eventStore.Update(ev.EventID, map[string]any{"status": store.StatusReverted})) } } } diff --git a/universalClient/tss/sessionmanager/sessionmanager.go b/universalClient/tss/sessionmanager/sessionmanager.go index 52fe59da..f1938bef 100644 --- a/universalClient/tss/sessionmanager/sessionmanager.go +++ b/universalClient/tss/sessionmanager/sessionmanager.go @@ -144,7 +144,7 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s } // 3. Validate event is CONFIRMED and not expired - if event.Status != eventstore.StatusConfirmed { + if event.Status != store.StatusConfirmed { return errors.Errorf("event %s is not in confirmed status (got %s)", msg.EventID, event.Status) } currentBlock, err := sm.coordinator.GetLatestBlockNum(ctx) @@ -170,7 +170,7 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s } // 6. For SIGN events, verify the signing hash independently - if event.Type == string(coordinator.ProtocolSign) { + if event.Type == store.EventTypeSign { if err := sm.verifySigningRequest(ctx, event, msg.UnSignedOutboundTxReq); err != nil { return errors.Wrap(err, "signing request verification failed") } @@ -195,7 +195,7 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s sm.mu.Unlock() // 9. Update event status to IN_PROGRESS - if err := sm.eventStore.Update(msg.EventID, map[string]any{"status": eventstore.StatusInProgress}); err != nil { + if err := sm.eventStore.Update(msg.EventID, map[string]any{"status": store.StatusInProgress}); err != nil { sm.logger.Warn().Err(err).Str("event_id", msg.EventID).Msg("failed to update event status") } @@ -403,7 +403,7 @@ func (sm *SessionManager) handleSessionFinished(ctx context.Context, eventID str } // SIGN sessions: broadcast the signed tx, then done (status managed by handleSigningComplete) - if state.protocolType == string(coordinator.ProtocolSign) { + if state.protocolType == store.EventTypeSign { return sm.handleSignFinished(ctx, eventID, result, state.signingReq) } @@ -467,7 +467,7 @@ func (sm *SessionManager) handleKeyFinished(ctx context.Context, eventID, protoc voteTxHash, err = sm.pushSigner.VoteTssKeyProcess(ctx, pubKeyHex, storageID, processID) if err != nil { // Vote failed after TSS signing — mark REVERTED directly (no RevertHandler needed for key events) - if updateErr := sm.eventStore.Update(eventID, map[string]any{"status": eventstore.StatusReverted}); updateErr != nil { + if updateErr := sm.eventStore.Update(eventID, map[string]any{"status": store.StatusReverted}); updateErr != nil { sm.logger.Error().Err(updateErr).Str("event_id", eventID).Msg("failed to mark event as REVERTED") } return errors.Wrapf(err, "TSS vote failed for event %s — marked REVERTED", eventID) @@ -476,7 +476,7 @@ func (sm *SessionManager) handleKeyFinished(ctx context.Context, eventID, protoc sm.logger.Info().Str("vote_tx_hash", voteTxHash).Str("event_id", eventID).Msg("TSS vote succeeded") } - if err := sm.eventStore.Update(eventID, map[string]any{"status": eventstore.StatusCompleted, "vote_tx_hash": voteTxHash}); err != nil { + if err := sm.eventStore.Update(eventID, map[string]any{"status": store.StatusCompleted, "vote_tx_hash": voteTxHash}); err != nil { return errors.Wrapf(err, "failed to update event status to completed") } @@ -489,7 +489,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, threshold := coordinator.CalculateThreshold(len(msg.Participants)) switch event.Type { - case string(coordinator.ProtocolKeygen): + case store.EventTypeKeygen: return dkls.NewKeygenSession( msg.Payload, // setupData msg.EventID, // sessionID @@ -498,7 +498,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, threshold, ) - case string(coordinator.ProtocolKeyrefresh): + case store.EventTypeKeyrefresh: // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { @@ -520,7 +520,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, oldKeyshare, ) - case string(coordinator.ProtocolQuorumChange): + case store.EventTypeQuorumChange: // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { @@ -554,7 +554,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, oldKeyshare, ) - case string(coordinator.ProtocolSign): + case store.EventTypeSign: // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { @@ -614,7 +614,7 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto // Protocol-specific validation switch event.Type { - case string(coordinator.ProtocolKeygen), string(coordinator.ProtocolKeyrefresh), string(coordinator.ProtocolQuorumChange): + case store.EventTypeKeygen, store.EventTypeKeyrefresh, store.EventTypeQuorumChange: // For keygen, keyrefresh, and quorumchange: participants must match exactly with eligible participants if len(participants) != len(eligibleList) { return errors.Errorf("participants count %d does not match eligible count %d for %s", len(participants), len(eligibleList), event.Type) @@ -626,7 +626,7 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto } } - case string(coordinator.ProtocolSign): + case store.EventTypeSign: // For SIGN the coordinator picks a random threshold subset (>2/3 of eligible) rather than // all eligible validators. Accept any subset as long as it meets the threshold minimum; // all participants are already verified eligible by the eligibleSet check above. @@ -703,7 +703,7 @@ func (sm *SessionManager) checkExpiredSessions(ctx context.Context, blockDelay u // Update event: mark as confimed and set new block height (current + delay) newBlockHeight := currentBlock + blockDelay - if err := sm.eventStore.Update(eventID, map[string]any{"status": eventstore.StatusConfirmed, "block_height": newBlockHeight}); err != nil { + if err := sm.eventStore.Update(eventID, map[string]any{"status": store.StatusConfirmed, "block_height": newBlockHeight}); err != nil { sm.logger.Warn(). Err(err). Str("event_id", eventID). @@ -799,7 +799,6 @@ func (sm *SessionManager) verifySigningRequest(ctx context.Context, event *store return nil } - // getTSSAddress gets the TSS ECDSA address from the current TSS public key // The TSS address is always the same ECDSA address derived from the TSS public key func (sm *SessionManager) getTSSAddress(ctx context.Context) (string, error) { @@ -838,7 +837,7 @@ func (sm *SessionManager) handleSigningComplete(_ context.Context, eventID strin // Persist enriched event data + mark SIGNED; txBroadcaster will pick it up if err := sm.eventStore.Update(eventID, map[string]any{ "event_data": newEventData, - "status": eventstore.StatusSigned, + "status": store.StatusSigned, }); err != nil { return errors.Wrap(err, "failed to update event with signing data") } diff --git a/universalClient/tss/sessionmanager/sessionmanager_test.go b/universalClient/tss/sessionmanager/sessionmanager_test.go index f98f7cf6..c4afc5bf 100644 --- a/universalClient/tss/sessionmanager/sessionmanager_test.go +++ b/universalClient/tss/sessionmanager/sessionmanager_test.go @@ -145,7 +145,7 @@ func setupTestSessionManager(t *testing.T) (*SessionManager, *coordinator.Coordi nil, // chains - nil for testing sendFn, "validator1", - 3*time.Minute, // sessionExpiryTime + 3*time.Minute, // sessionExpiryTime 30*time.Second, // sessionExpiryCheckInterval 60, // sessionExpiryBlockDelay zerolog.Nop(), @@ -186,7 +186,7 @@ func TestHandleSetupMessage_Validation(t *testing.T) { EventID: "event1", BlockHeight: 100, Type: "KEYGEN", - Status: eventstore.StatusConfirmed, + Status: store.StatusConfirmed, } require.NoError(t, testDB.Create(&event).Error) @@ -365,7 +365,7 @@ func TestSessionManager_Integration(t *testing.T) { EventID: "keygen-event", BlockHeight: 100, Type: "KEYGEN", - Status: eventstore.StatusConfirmed, + Status: store.StatusConfirmed, } require.NoError(t, testDB.Create(&event).Error) @@ -406,7 +406,7 @@ func TestVerifySigningRequest_OutboundDisabled(t *testing.T) { event := &store.Event{ EventID: "sign-event-1", Type: "SIGN", - Status: eventstore.StatusConfirmed, + Status: store.StatusConfirmed, EventData: eventDataBytes, } diff --git a/universalClient/tss/txbroadcaster/broadcaster.go b/universalClient/tss/txbroadcaster/broadcaster.go index f8d9fa4d..24729f26 100644 --- a/universalClient/tss/txbroadcaster/broadcaster.go +++ b/universalClient/tss/txbroadcaster/broadcaster.go @@ -32,11 +32,11 @@ type SignedEventData struct { // Config holds configuration for the broadcaster. type Config struct { - EventStore *eventstore.Store - Chains *chains.Chains - CheckInterval time.Duration - Logger zerolog.Logger - GetTSSAddress func(ctx context.Context) (string, error) + EventStore *eventstore.Store + Chains *chains.Chains + CheckInterval time.Duration + Logger zerolog.Logger + GetTSSAddress func(ctx context.Context) (string, error) } // Broadcaster polls SIGNED events and broadcasts them to external chains. @@ -147,7 +147,7 @@ func (b *Broadcaster) markBroadcasted(event *store.Event, chainID, txHash string caipTxHash := chainID + ":" + txHash if err := b.eventStore.Update(event.EventID, map[string]any{ "broadcasted_tx_hash": caipTxHash, - "status": eventstore.StatusBroadcasted, + "status": store.StatusBroadcasted, }); err != nil { b.logger.Warn().Err(err).Str("event_id", event.EventID).Msg("failed to update event to BROADCASTED") return diff --git a/universalClient/tss/txbroadcaster/broadcaster_test.go b/universalClient/tss/txbroadcaster/broadcaster_test.go index 860a16f5..eecfbf77 100644 --- a/universalClient/tss/txbroadcaster/broadcaster_test.go +++ b/universalClient/tss/txbroadcaster/broadcaster_test.go @@ -66,10 +66,10 @@ func (m *mockTxBuilder) GetGasFeeUsed(ctx context.Context, txHash string) (strin type mockChainClient struct{ builder *mockTxBuilder } -func (m *mockChainClient) Start(context.Context) error { return nil } -func (m *mockChainClient) Stop() error { return nil } -func (m *mockChainClient) IsHealthy() bool { return true } -func (m *mockChainClient) GetTxBuilder() (common.OutboundTxBuilder, error) { return m.builder, nil } +func (m *mockChainClient) Start(context.Context) error { return nil } +func (m *mockChainClient) Stop() error { return nil } +func (m *mockChainClient) IsHealthy() bool { return true } +func (m *mockChainClient) GetTxBuilder() (common.OutboundTxBuilder, error) { return m.builder, nil } // --------------------------------------------------------------------------- // Helpers @@ -139,7 +139,7 @@ func insertSignedEvent(t *testing.T, db *gorm.DB, eventID, destChain string, non ExpiryBlockHeight: 99999, Type: "SIGN", ConfirmationType: "STANDARD", - Status: eventstore.StatusSigned, + Status: store.StatusSigned, EventData: makeSignedEventData(t, destChain, nonce), } require.NoError(t, db.Create(&event).Error) @@ -184,7 +184,7 @@ func TestEVM_BroadcastError_NonceConsumed_MarksBroadcasted(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) require.Equal(t, "eip155:1:0xabc", ev.BroadcastedTxHash) } @@ -203,7 +203,7 @@ func TestEVM_BroadcastSuccess_MarksBroadcasted(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) require.Equal(t, "eip155:1:0xabc123", ev.BroadcastedTxHash) builder.AssertNotCalled(t, "GetNextNonce", mock.Anything, mock.Anything, mock.Anything) } @@ -225,7 +225,7 @@ func TestEVM_BroadcastFails_NonceConsumedOnRecheck_MarksBroadcasted(t *testing.T b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) } func TestEVM_BroadcastFails_NonceNotConsumed_StaysSigned(t *testing.T) { @@ -244,7 +244,7 @@ func TestEVM_BroadcastFails_NonceNotConsumed_StaysSigned(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusSigned, ev.Status) // stays SIGNED + require.Equal(t, store.StatusSigned, ev.Status) // stays SIGNED builder.AssertNotCalled(t, "GetNextNonce", mock.Anything, mock.Anything, mock.Anything) } @@ -265,7 +265,7 @@ func TestEVM_BroadcastFails_WithTxHash_NonceNotConsumed_StaysSigned(t *testing.T b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusSigned, ev.Status) // stays SIGNED + require.Equal(t, store.StatusSigned, ev.Status) // stays SIGNED } func TestEVM_GetTSSAddressNil_UsesEmptyAddress(t *testing.T) { @@ -291,7 +291,7 @@ func TestEVM_GetTSSAddressNil_UsesEmptyAddress(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) builder.AssertCalled(t, "GetNextNonce", mock.Anything, "", true) } @@ -315,7 +315,7 @@ func TestSVM_BroadcastSuccess_MarksBroadcasted(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) require.Equal(t, "solana:mainnet:solTxSig123", ev.BroadcastedTxHash) } @@ -336,7 +336,7 @@ func TestSVM_BroadcastFails_PDAExists_MarksBroadcasted(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) require.Equal(t, "solana:mainnet:", ev.BroadcastedTxHash) // empty tx hash } @@ -357,7 +357,7 @@ func TestSVM_BroadcastFails_PDANotFound_MarksBroadcasted(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, ev.Status) + require.Equal(t, store.StatusBroadcasted, ev.Status) require.Equal(t, "solana:mainnet:", ev.BroadcastedTxHash) // empty tx hash } @@ -378,7 +378,7 @@ func TestSVM_BroadcastFails_PDACheckFails_StaysSigned(t *testing.T) { b.processSigned(context.Background()) ev := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusSigned, ev.Status) // stays SIGNED + require.Equal(t, store.StatusSigned, ev.Status) // stays SIGNED } // --------------------------------------------------------------------------- @@ -421,8 +421,8 @@ func TestProcessSigned_MultipleEvents(t *testing.T) { ev1 := getEvent(t, db, "ev-1") ev2 := getEvent(t, db, "ev-2") - require.Equal(t, eventstore.StatusBroadcasted, ev1.Status) - require.Equal(t, eventstore.StatusBroadcasted, ev2.Status) + require.Equal(t, store.StatusBroadcasted, ev1.Status) + require.Equal(t, store.StatusBroadcasted, ev2.Status) } // --------------------------------------------------------------------------- @@ -439,7 +439,7 @@ func TestMarkBroadcasted_FormatsCAIPTxHash(t *testing.T) { updated := getEvent(t, db, "ev-1") require.Equal(t, "eip155:1:0xdeadbeef", updated.BroadcastedTxHash) - require.Equal(t, eventstore.StatusBroadcasted, updated.Status) + require.Equal(t, store.StatusBroadcasted, updated.Status) } func TestMarkBroadcasted_EmptyTxHash(t *testing.T) { diff --git a/universalClient/tss/txresolver/evm.go b/universalClient/tss/txresolver/evm.go index 4aa09afd..77afc3a9 100644 --- a/universalClient/tss/txresolver/evm.go +++ b/universalClient/tss/txresolver/evm.go @@ -5,7 +5,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/store" - "github.com/pushchain/push-chain-node/universalClient/tss/eventstore" ) // resolveEVM checks the on-chain receipt and moves the event to COMPLETED or REVERTED. @@ -73,7 +72,7 @@ func (r *Resolver) resolveEVM(ctx context.Context, event *store.Event, chainID, } // status == 1 (success) - if err := r.eventStore.Update(event.EventID, map[string]any{"status": eventstore.StatusCompleted}); err != nil { + if err := r.eventStore.Update(event.EventID, map[string]any{"status": store.StatusCompleted}); err != nil { r.logger.Warn().Err(err).Str("event_id", event.EventID).Msg("failed to mark event COMPLETED") return } diff --git a/universalClient/tss/txresolver/resolver.go b/universalClient/tss/txresolver/resolver.go index d76b514f..1416225e 100644 --- a/universalClient/tss/txresolver/resolver.go +++ b/universalClient/tss/txresolver/resolver.go @@ -153,7 +153,7 @@ func (r *Resolver) voteFailureAndMarkReverted(ctx context.Context, event *store. r.logger.Warn().Err(err).Str("event_id", event.EventID).Msg("failed to vote failure") return err } - if err := r.eventStore.Update(event.EventID, map[string]any{"status": eventstore.StatusReverted, "vote_tx_hash": voteTxHash}); err != nil { + if err := r.eventStore.Update(event.EventID, map[string]any{"status": store.StatusReverted, "vote_tx_hash": voteTxHash}); err != nil { return errors.Wrapf(err, "failed to mark event %s as reverted", event.EventID) } r.logger.Info(). diff --git a/universalClient/tss/txresolver/resolver_test.go b/universalClient/tss/txresolver/resolver_test.go index af7214e4..ae621365 100644 --- a/universalClient/tss/txresolver/resolver_test.go +++ b/universalClient/tss/txresolver/resolver_test.go @@ -124,7 +124,7 @@ func insertBroadcastedEvent(t *testing.T, db *gorm.DB, eventID, destChain, broad ExpiryBlockHeight: 99999, Type: "SIGN", ConfirmationType: "STANDARD", - Status: eventstore.StatusBroadcasted, + Status: store.StatusBroadcasted, EventData: eventData, BroadcastedTxHash: broadcastedTxHash, } @@ -247,7 +247,7 @@ func TestSVM_PDAExists_MarksCompleted(t *testing.T) { resolver.resolveSVM(context.Background(), &ev, "solana:mainnet") updated := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusCompleted, updated.Status) + require.Equal(t, store.StatusCompleted, updated.Status) } func TestSVM_PDANotFound_VotesFailureAndReverts(t *testing.T) { @@ -271,7 +271,7 @@ func TestSVM_PDANotFound_VotesFailureAndReverts(t *testing.T) { // With no push signer, voteFailureAndMarkReverted returns nil early (logs warning). // The event stays BROADCASTED because the vote+revert is skipped. updated := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, updated.Status) + require.Equal(t, store.StatusBroadcasted, updated.Status) } func TestSVM_PDACheckFails_StaysBroadcasted(t *testing.T) { @@ -291,7 +291,7 @@ func TestSVM_PDACheckFails_StaysBroadcasted(t *testing.T) { resolver.resolveSVM(context.Background(), &ev, "solana:mainnet") updated := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, updated.Status) // stays BROADCASTED + require.Equal(t, store.StatusBroadcasted, updated.Status) // stays BROADCASTED } func TestSVM_InvalidEventData_Skips(t *testing.T) { @@ -308,7 +308,7 @@ func TestSVM_InvalidEventData_Skips(t *testing.T) { resolver.resolveSVM(context.Background(), &ev, "solana:mainnet") updated := getEvent(t, db, "ev-1") - require.Equal(t, eventstore.StatusBroadcasted, updated.Status) // stays BROADCASTED + require.Equal(t, store.StatusBroadcasted, updated.Status) // stays BROADCASTED builder.AssertNotCalled(t, "IsAlreadyExecuted", mock.Anything, mock.Anything) } diff --git a/universalClient/tss/txresolver/svm.go b/universalClient/tss/txresolver/svm.go index fd6f4b97..f0e9c0de 100644 --- a/universalClient/tss/txresolver/svm.go +++ b/universalClient/tss/txresolver/svm.go @@ -4,7 +4,6 @@ import ( "context" "github.com/pushchain/push-chain-node/universalClient/store" - "github.com/pushchain/push-chain-node/universalClient/tss/eventstore" ) // resolveSVM checks the on-chain ExecutedTx PDA and moves the event to COMPLETED or REVERTED. @@ -45,7 +44,7 @@ func (r *Resolver) resolveSVM(ctx context.Context, event *store.Event, chainID s } if executed { - if err := r.eventStore.Update(event.EventID, map[string]any{"status": eventstore.StatusCompleted}); err != nil { + if err := r.eventStore.Update(event.EventID, map[string]any{"status": store.StatusCompleted}); err != nil { r.logger.Warn().Err(err).Str("event_id", event.EventID).Msg("failed to mark SVM event COMPLETED") return } From b99be8de0fa5dcc7568b2785e86480c9275149b6 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 16:42:16 +0530 Subject: [PATCH 16/28] remove raw string usage --- universalClient/chains/common/chain_store.go | 50 +++++-------------- .../chains/common/event_processor.go | 4 +- universalClient/chains/evm/event_confirmer.go | 8 +-- universalClient/chains/evm/event_parser.go | 12 ++--- universalClient/chains/push/event_parser.go | 4 +- universalClient/chains/svm/event_confirmer.go | 7 +-- universalClient/chains/svm/event_parser.go | 12 ++--- 7 files changed, 37 insertions(+), 60 deletions(-) diff --git a/universalClient/chains/common/chain_store.go b/universalClient/chains/common/chain_store.go index 34dc1975..c708a236 100644 --- a/universalClient/chains/common/chain_store.go +++ b/universalClient/chains/common/chain_store.go @@ -21,58 +21,34 @@ func NewChainStore(database *db.DB) *ChainStore { } } -// GetChainHeight returns the last processed block height for the chain -// Creates a new entry with height 0 if it doesn't exist +// GetChainHeight returns the last processed block height for the chain. +// Creates a new entry with height 0 if one doesn't exist (atomic via FirstOrCreate). func (cs *ChainStore) GetChainHeight() (uint64, error) { if cs.database == nil { return 0, fmt.Errorf("database is nil") } var state store.State - result := cs.database.Client().First(&state) - - if result.Error != nil { - if result.Error == gorm.ErrRecordNotFound { - // Create new entry with height 0 - state = store.State{ - BlockHeight: 0, - } - if err := cs.database.Client().Create(&state).Error; err != nil { - return 0, fmt.Errorf("failed to create chain state: %w", err) - } - return 0, nil - } - return 0, fmt.Errorf("failed to get chain height: %w", result.Error) + if err := cs.database.Client().FirstOrCreate(&state, store.State{}).Error; err != nil { + return 0, fmt.Errorf("failed to get or create chain state: %w", err) } return state.BlockHeight, nil } -// UpdateChainHeight updates the last processed block height for the chain -// Creates a new entry if it doesn't exist +// UpdateChainHeight updates the last processed block height for the chain. +// Creates a new entry if one doesn't exist (atomic via FirstOrCreate). +// Only updates if the new height is greater than the current one. func (cs *ChainStore) UpdateChainHeight(blockHeight uint64) error { if cs.database == nil { return fmt.Errorf("database is nil") } var state store.State - result := cs.database.Client().First(&state) - - if result.Error != nil { - if result.Error == gorm.ErrRecordNotFound { - // Create new entry - state = store.State{ - BlockHeight: blockHeight, - } - if err := cs.database.Client().Create(&state).Error; err != nil { - return fmt.Errorf("failed to create chain state: %w", err) - } - return nil - } - return fmt.Errorf("failed to query chain state: %w", result.Error) + if err := cs.database.Client().FirstOrCreate(&state, store.State{}).Error; err != nil { + return fmt.Errorf("failed to get or create chain state: %w", err) } - // Update existing record only if new block is higher if blockHeight > state.BlockHeight { state.BlockHeight = blockHeight if err := cs.database.Client().Save(&state).Error; err != nil { @@ -91,7 +67,7 @@ func (cs *ChainStore) GetPendingEvents(limit int) ([]store.Event, error) { var events []store.Event if err := cs.database.Client(). - Where("status = ?", "PENDING"). + Where("status = ?", store.StatusPending). Order("created_at ASC"). Limit(limit). Find(&events).Error; err != nil { @@ -109,7 +85,7 @@ func (cs *ChainStore) GetConfirmedEvents(limit int) ([]store.Event, error) { var events []store.Event if err := cs.database.Client(). - Where("status = ?", "CONFIRMED"). + Where("status = ?", store.StatusConfirmed). Order("created_at ASC"). Limit(limit). Find(&events).Error; err != nil { @@ -197,14 +173,14 @@ func (cs *ChainStore) UpdateVoteTxHash(eventID string, voteTxHash string) error // DeleteTerminalEvents deletes events in terminal states (COMPLETED, REVERTED, EXPIRED) // that were updated before the given time -func (cs *ChainStore) DeleteTerminalEvents(updatedBefore interface{}) (int64, error) { +func (cs *ChainStore) DeleteTerminalEvents(updatedBefore any) (int64, error) { if cs.database == nil { return 0, fmt.Errorf("database is nil") } res := cs.database.Client(). Where("status IN ? AND updated_at < ?", - []string{"COMPLETED", "REORGED", "REVERTED"}, updatedBefore). + []string{store.StatusCompleted, store.StatusReorged, store.StatusReverted}, updatedBefore). Delete(&store.Event{}) if res.Error != nil { diff --git a/universalClient/chains/common/event_processor.go b/universalClient/chains/common/event_processor.go index b3c4618f..85808dc4 100644 --- a/universalClient/chains/common/event_processor.go +++ b/universalClient/chains/common/event_processor.go @@ -173,7 +173,7 @@ func (ep *EventProcessor) processOutboundEvent(ctx context.Context, event *store } // Atomically record vote hash and flip status in one DB write - rowsAffected, err := ep.chainStore.UpdateStatusAndVoteTxHash(event.EventID, "CONFIRMED", "COMPLETED", voteTxHash) + rowsAffected, err := ep.chainStore.UpdateStatusAndVoteTxHash(event.EventID, store.StatusConfirmed, store.StatusCompleted, voteTxHash) if err != nil { return fmt.Errorf("failed to update event status and vote_tx_hash: %w", err) } @@ -218,7 +218,7 @@ func (ep *EventProcessor) processInboundEvent(ctx context.Context, event *store. } // Atomically record vote hash and flip status in one DB write - rowsAffected, err := ep.chainStore.UpdateStatusAndVoteTxHash(event.EventID, "CONFIRMED", "COMPLETED", voteTxHash) + rowsAffected, err := ep.chainStore.UpdateStatusAndVoteTxHash(event.EventID, store.StatusConfirmed, store.StatusCompleted, voteTxHash) if err != nil { return fmt.Errorf("failed to update event status after successful vote: %w", err) } diff --git a/universalClient/chains/evm/event_confirmer.go b/universalClient/chains/evm/event_confirmer.go index 581ac373..9dbf8a89 100644 --- a/universalClient/chains/evm/event_confirmer.go +++ b/universalClient/chains/evm/event_confirmer.go @@ -187,9 +187,9 @@ func (ec *EventConfirmer) processPendingEvents(ctx context.Context) error { continue } - rowsAffected, err = ec.chainStore.UpdateStatusAndEventData(event.EventID, "PENDING", "CONFIRMED", updatedData) + rowsAffected, err = ec.chainStore.UpdateStatusAndEventData(event.EventID, store.StatusPending, store.StatusConfirmed, updatedData) } else { - rowsAffected, err = ec.chainStore.UpdateEventStatus(event.EventID, "PENDING", "CONFIRMED") + rowsAffected, err = ec.chainStore.UpdateEventStatus(event.EventID, store.StatusPending, store.StatusConfirmed) } if err != nil { @@ -237,12 +237,12 @@ func (ec *EventConfirmer) getTxHashFromEventID(eventID string) string { // getRequiredConfirmations returns the required number of confirmations based on confirmation type func (ec *EventConfirmer) getRequiredConfirmations(confirmationType string) uint64 { switch confirmationType { - case "FAST": + case store.ConfirmationFast: if ec.fastConfirmations >= 0 { return ec.fastConfirmations } return 5 - case "STANDARD": + case store.ConfirmationStandard: if ec.standardConfirmations >= 0 { return ec.standardConfirmations } diff --git a/universalClient/chains/evm/event_parser.go b/universalClient/chains/evm/event_parser.go index cf25f9fa..55d6ac39 100644 --- a/universalClient/chains/evm/event_parser.go +++ b/universalClient/chains/evm/event_parser.go @@ -75,7 +75,7 @@ func parseSendFundsEvent(log *types.Log, chainID string, logger zerolog.Logger) EventID: eventID, BlockHeight: log.BlockNumber, Type: store.EventTypeInbound, // Gateway events from external chains are INBOUND - Status: "PENDING", + Status: store.StatusPending, ExpiryBlockHeight: 0, // 0 means no expiry } @@ -135,9 +135,9 @@ func parseOutboundObservationEvent(log *types.Log, chainID string, logger zerolo EventID: eventID, BlockHeight: log.BlockNumber, Type: store.EventTypeOutbound, // Outbound observation events - Status: "PENDING", - ConfirmationType: "STANDARD", // Use STANDARD confirmation for outbound events - ExpiryBlockHeight: 0, // 0 means no expiry + Status: store.StatusPending, + ConfirmationType: store.ConfirmationStandard, // Use STANDARD confirmation for outbound events + ExpiryBlockHeight: 0, // 0 means no expiry EventData: eventData, } @@ -248,9 +248,9 @@ func finalizeEvent(event *store.Event, payload *common.UniversalTx, logger zerol } if payload.TxType == 0 || payload.TxType == 1 { - event.ConfirmationType = "FAST" + event.ConfirmationType = store.ConfirmationFast } else { - event.ConfirmationType = "STANDARD" + event.ConfirmationType = store.ConfirmationStandard } } diff --git a/universalClient/chains/push/event_parser.go b/universalClient/chains/push/event_parser.go index 54598f9a..a009864c 100644 --- a/universalClient/chains/push/event_parser.go +++ b/universalClient/chains/push/event_parser.go @@ -96,8 +96,8 @@ func ParseEvent(event abci.Event, blockHeight uint64) (*store.Event, error) { // Set common fields parsed.BlockHeight = blockHeight - parsed.ConfirmationType = "INSTANT" // push chain is a cosmos chain ie instant finality - parsed.Status = "CONFIRMED" // push chain is a cosmos chain ie instant finality + parsed.ConfirmationType = store.ConfirmationInstant // push chain is a cosmos chain ie instant finality + parsed.Status = store.StatusConfirmed // push chain is a cosmos chain ie instant finality // Set expiry for outbound events (block seen + 400) if event.Type == EventTypeOutboundCreated { diff --git a/universalClient/chains/svm/event_confirmer.go b/universalClient/chains/svm/event_confirmer.go index 3f37fe8b..825c4918 100644 --- a/universalClient/chains/svm/event_confirmer.go +++ b/universalClient/chains/svm/event_confirmer.go @@ -12,6 +12,7 @@ import ( chaincommon "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/db" + "github.com/pushchain/push-chain-node/universalClient/store" ) // EventConfirmer periodically checks pending events and marks them as CONFIRMED @@ -170,7 +171,7 @@ func (ec *EventConfirmer) processPendingEvents(ctx context.Context) error { if confirmations >= requiredConfirmations { // GasFeeUsed for outbound events is already set by the event parser from the on-chain event data - rowsAffected, err := ec.chainStore.UpdateEventStatus(event.EventID, "PENDING", "CONFIRMED") + rowsAffected, err := ec.chainStore.UpdateEventStatus(event.EventID, store.StatusPending, store.StatusConfirmed) if err != nil { ec.logger.Error(). Err(err). @@ -216,12 +217,12 @@ func (ec *EventConfirmer) getTxSignatureFromEventID(eventID string) string { // getRequiredConfirmations returns the required number of confirmations based on confirmation type func (ec *EventConfirmer) getRequiredConfirmations(confirmationType string) uint64 { switch confirmationType { - case "FAST": + case store.ConfirmationFast: if ec.fastConfirmations > 0 { return ec.fastConfirmations } return 5 - case "STANDARD": + case store.ConfirmationStandard: if ec.standardConfirmations > 0 { return ec.standardConfirmations } diff --git a/universalClient/chains/svm/event_parser.go b/universalClient/chains/svm/event_parser.go index 1ec4687b..4e50c863 100644 --- a/universalClient/chains/svm/event_parser.go +++ b/universalClient/chains/svm/event_parser.go @@ -91,7 +91,7 @@ func parseSendFundsEvent(log string, signature string, slot uint64, logIndex uin EventID: eventID, BlockHeight: slot, Type: store.EventTypeInbound, // Gateway events from external chains are INBOUND - Status: "PENDING", + Status: store.StatusPending, ExpiryBlockHeight: 0, // Will be set based on confirmation type if needed } @@ -177,9 +177,9 @@ func parseOutboundObservationEvent(log string, signature string, slot uint64, lo EventID: eventID, BlockHeight: slot, Type: store.EventTypeOutbound, // Outbound observation events - Status: "PENDING", - ConfirmationType: "STANDARD", // Use STANDARD confirmation for outbound events - ExpiryBlockHeight: 0, // 0 means no expiry + Status: store.StatusPending, + ConfirmationType: store.ConfirmationStandard, // Use STANDARD confirmation for outbound events + ExpiryBlockHeight: 0, // 0 means no expiry EventData: payloadData, } @@ -221,9 +221,9 @@ func parseUniversalTxEvent(event *store.Event, decoded []byte, logIndex uint, ch // if TxType is 0 or 1, use FAST else use STANDARD if payload.TxType == 0 || payload.TxType == 1 { - event.ConfirmationType = "FAST" + event.ConfirmationType = store.ConfirmationFast } else { - event.ConfirmationType = "STANDARD" + event.ConfirmationType = store.ConfirmationStandard } } From e21d71c02e74f9f6a4a220125c2ab4d08a3de6c1 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 16:42:35 +0530 Subject: [PATCH 17/28] fix: remove usage of deprecated package --- universalClient/chains/chains.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/universalClient/chains/chains.go b/universalClient/chains/chains.go index 61197681..66b6dd52 100644 --- a/universalClient/chains/chains.go +++ b/universalClient/chains/chains.go @@ -7,7 +7,6 @@ import ( "sync" "time" - pkgerrors "github.com/pkg/errors" "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/chains/evm" "github.com/pushchain/push-chain-node/universalClient/chains/push" @@ -418,7 +417,7 @@ func (c *Chains) getChainDB(chainID string) (*db.DB, error) { database, err := db.OpenFileDB(baseDir, dbFilename, true) if err != nil { - return nil, pkgerrors.Wrapf(err, "failed to create database for chain %s", chainID) + return nil, fmt.Errorf("failed to create database for chain %s: %w", chainID, err) } c.logger.Info(). From fc49f01b5c8322006ff183835e5a1ac38ce69391 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 16:55:02 +0530 Subject: [PATCH 18/28] deprecated package --- .../tss/coordinator/coordinator.go | 80 ++++++------- universalClient/tss/dkls/sign.go | 11 +- universalClient/tss/eventstore/store.go | 21 ++-- universalClient/tss/expirysweeper/sweeper.go | 8 +- .../tss/sessionmanager/sessionmanager.go | 108 +++++++++--------- universalClient/tss/tss.go | 15 ++- .../tss/txbroadcaster/broadcaster.go | 8 +- universalClient/tss/txresolver/resolver.go | 8 +- 8 files changed, 129 insertions(+), 130 deletions(-) diff --git a/universalClient/tss/coordinator/coordinator.go b/universalClient/tss/coordinator/coordinator.go index 2af712bd..1f522335 100644 --- a/universalClient/tss/coordinator/coordinator.go +++ b/universalClient/tss/coordinator/coordinator.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "encoding/json" + "fmt" "sort" "strings" "sync" @@ -11,7 +12,6 @@ import ( "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto/secp256k1" - "github.com/pkg/errors" "github.com/rs/zerolog" session "go-wrapper/go-dkls/sessions" @@ -126,7 +126,7 @@ func (c *Coordinator) GetPartyIDFromPeerID(ctx context.Context, peerID string) ( } } - return "", errors.Errorf("peerID %s not found in validators", peerID) + return "", fmt.Errorf("peerID %s not found in validators", peerID) } // GetPeerIDFromPartyID gets the peerID for a given partyID (validator address). @@ -153,7 +153,7 @@ func (c *Coordinator) GetPeerIDFromPartyID(ctx context.Context, partyID string) } } - return "", errors.Errorf("partyID %s not found in validators", partyID) + return "", fmt.Errorf("partyID %s not found in validators", partyID) } // GetMultiAddrsFromPeerID gets the multiaddrs for a given peerID. @@ -178,7 +178,7 @@ func (c *Coordinator) GetMultiAddrsFromPeerID(ctx context.Context, peerID string } } - return nil, errors.Errorf("peerID %s not found in validators", peerID) + return nil, fmt.Errorf("peerID %s not found in validators", peerID) } // GetLatestBlockNum gets the latest block number from pushCore. @@ -191,7 +191,7 @@ func (c *Coordinator) GetLatestBlockNum(ctx context.Context) (uint64, error) { func (c *Coordinator) IsPeerCoordinator(ctx context.Context, peerID string) (bool, error) { currentBlock, err := c.pushCore.GetLatestBlock(ctx) if err != nil { - return false, errors.Wrap(err, "failed to get latest block") + return false, fmt.Errorf("failed to get latest block: %w", err) } c.mu.RLock() @@ -235,22 +235,22 @@ func (c *Coordinator) GetCurrentTSSKey(ctx context.Context) (string, string, err func (c *Coordinator) GetTSSAddress(ctx context.Context) (string, error) { key, err := c.pushCore.GetCurrentKey(ctx) if err != nil { - return "", errors.Wrap(err, "failed to get current TSS key") + return "", fmt.Errorf("failed to get current TSS key: %w", err) } if key == nil || key.TssPubkey == "" { - return "", errors.New("no TSS key found") + return "", fmt.Errorf("no TSS key found") } pubkeyHex := strings.TrimPrefix(strings.TrimSpace(key.TssPubkey), "0x") pubkeyBytes, err := hex.DecodeString(pubkeyHex) if err != nil { - return "", errors.Wrap(err, "failed to decode TSS public key") + return "", fmt.Errorf("failed to decode TSS public key: %w", err) } if len(pubkeyBytes) != 33 { - return "", errors.Errorf("invalid TSS public key length: %d bytes (expected 33)", len(pubkeyBytes)) + return "", fmt.Errorf("invalid TSS public key length: %d bytes (expected 33)", len(pubkeyBytes)) } vkX, vkY := secp256k1.DecompressPubkey(pubkeyBytes) if vkX == nil || vkY == nil { - return "", errors.New("failed to decompress TSS public key") + return "", fmt.Errorf("failed to decompress TSS public key") } xBytes := vkX.FillBytes(make([]byte, 32)) yBytes := vkY.FillBytes(make([]byte, 32)) @@ -354,7 +354,7 @@ func (c *Coordinator) updateValidators(ctx context.Context) { func (c *Coordinator) processConfirmedEvents(ctx context.Context) error { currentBlock, err := c.pushCore.GetLatestBlock(ctx) if err != nil { - return errors.Wrap(err, "failed to get latest block") + return fmt.Errorf("failed to get latest block: %w", err) } // Use cached validators (updated at polling interval) @@ -377,12 +377,12 @@ func (c *Coordinator) processConfirmedEvents(ctx context.Context) error { events, err := c.eventStore.GetNonExpiredConfirmedEvents(currentBlock, 10, 0) if err != nil { - return errors.Wrap(err, "failed to get confirmed events") + return fmt.Errorf("failed to get confirmed events: %w", err) } inFlightPerChain, err := c.getInFlightSignCountPerChain() if err != nil { - return errors.Wrap(err, "failed to get in-flight sign count per chain") + return fmt.Errorf("failed to get in-flight sign count per chain: %w", err) } c.logger.Info(). @@ -495,11 +495,11 @@ func (c *Coordinator) processEventAsCoordinator(ctx context.Context, event store case store.EventTypeSign: setupData, unsignedTxReq, err = c.createSignSetup(ctx, event.EventData, partyIDs, assignedNonce) default: - err = errors.Errorf("unknown protocol type: %s", event.Type) + err = fmt.Errorf("unknown protocol type: %s", event.Type) } if err != nil { - return errors.Wrapf(err, "failed to create setup message for event %s", event.EventID) + return fmt.Errorf("failed to create setup message for event %s: %w", event.EventID, err) } // Create and send setup message to all participants @@ -512,7 +512,7 @@ func (c *Coordinator) processEventAsCoordinator(ctx context.Context, event store } setupMsgBytes, err := json.Marshal(setupMsg) if err != nil { - return errors.Wrapf(err, "failed to marshal setup message for event %s", event.EventID) + return fmt.Errorf("failed to marshal setup message for event %s: %w", event.EventID, err) } // Initialize ACK tracking for this event @@ -575,7 +575,7 @@ func (c *Coordinator) HandleACK(ctx context.Context, senderPeerID string, eventI // Verify sender is a participant senderPartyID, err := c.GetPartyIDFromPeerID(ctx, senderPeerID) if err != nil { - return errors.Wrapf(err, "failed to get partyID for sender peerID %s", senderPeerID) + return fmt.Errorf("failed to get partyID for sender peerID %s: %w", senderPeerID, err) } isParticipant := false @@ -586,7 +586,7 @@ func (c *Coordinator) HandleACK(ctx context.Context, senderPeerID string, eventI } } if !isParticipant { - return errors.Errorf("sender %s (partyID: %s) is not a participant for event %s", senderPeerID, senderPartyID, eventID) + return fmt.Errorf("sender %s (partyID: %s) is not a participant for event %s", senderPeerID, senderPartyID, eventID) } // Mark as ACKed @@ -617,7 +617,7 @@ func (c *Coordinator) HandleACK(ctx context.Context, senderPeerID string, eventI } beginMsgBytes, err := json.Marshal(beginMsg) if err != nil { - return errors.Wrap(err, "failed to marshal begin message") + return fmt.Errorf("failed to marshal begin message: %w", err) } // Send to all participants @@ -667,7 +667,7 @@ func (c *Coordinator) createKeygenSetup(threshold int, partyIDs []string) ([]byt setupData, err := session.DklsKeygenSetupMsgNew(threshold, nil, participantIDs) if err != nil { - return nil, errors.Wrap(err, "failed to create setup") + return nil, fmt.Errorf("failed to create setup: %w", err) } return setupData, nil } @@ -680,17 +680,17 @@ func (c *Coordinator) createSignSetup(ctx context.Context, eventData []byte, par // Get current TSS keyId from pushCore key, err := c.pushCore.GetCurrentKey(ctx) if err != nil { - return nil, nil, errors.Wrap(err, "failed to get current TSS keyId") + return nil, nil, fmt.Errorf("failed to get current TSS keyId: %w", err) } if key == nil { - return nil, nil, errors.New("no TSS key exists") + return nil, nil, fmt.Errorf("no TSS key exists") } keyIDStr := key.KeyId // Load keyshare to ensure it exists (validation) keyshareBytes, err := c.keyshareManager.Get(keyIDStr) if err != nil { - return nil, nil, errors.Wrapf(err, "failed to load keyshare for keyId %s", keyIDStr) + return nil, nil, fmt.Errorf("failed to load keyshare for keyId %s: %w", keyIDStr, err) } _ = keyshareBytes // Keyshare is loaded for validation, keyID is derived from string @@ -709,12 +709,12 @@ func (c *Coordinator) createSignSetup(ctx context.Context, eventData []byte, par // Build the transaction and get signing parameters (use coordinator-assigned nonce when provided) signingReq, err := c.buildSignTransaction(ctx, eventData, assignedNonce) if err != nil { - return nil, nil, errors.Wrap(err, "failed to build sign transaction") + return nil, nil, fmt.Errorf("failed to build sign transaction: %w", err) } setupData, err := session.DklsSignSetupMsgNew(keyIDBytes, nil, signingReq.SigningHash, participantIDs) if err != nil { - return nil, nil, errors.Wrap(err, "failed to create sign setup") + return nil, nil, fmt.Errorf("failed to create sign setup: %w", err) } return setupData, signingReq, nil @@ -723,45 +723,45 @@ func (c *Coordinator) createSignSetup(ctx context.Context, eventData []byte, par // buildSignTransaction builds the outbound transaction using the appropriate OutboundTxBuilder. func (c *Coordinator) buildSignTransaction(ctx context.Context, eventData []byte, assignedNonce *uint64) (*common.UnSignedOutboundTxReq, error) { if len(eventData) == 0 { - return nil, errors.New("event data is empty") + return nil, fmt.Errorf("event data is empty") } var data uexecutortypes.OutboundCreatedEvent if err := json.Unmarshal(eventData, &data); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal outbound event data") + return nil, fmt.Errorf("failed to unmarshal outbound event data: %w", err) } if data.TxID == "" { - return nil, errors.New("outbound event missing tx_id") + return nil, fmt.Errorf("outbound event missing tx_id") } if data.DestinationChain == "" { - return nil, errors.New("outbound event missing destination_chain") + return nil, fmt.Errorf("outbound event missing destination_chain") } if c.chains == nil { - return nil, errors.New("chains manager not configured") + return nil, fmt.Errorf("chains manager not configured") } // Get the client for the destination chain client, err := c.chains.GetClient(data.DestinationChain) if err != nil { - return nil, errors.Wrapf(err, "failed to get client for chain %s", data.DestinationChain) + return nil, fmt.Errorf("failed to get client for chain %s: %w", data.DestinationChain, err) } // Get the builder from the client builder, err := client.GetTxBuilder() if err != nil { - return nil, errors.Wrapf(err, "failed to get tx builder for chain %s", data.DestinationChain) + return nil, fmt.Errorf("failed to get tx builder for chain %s: %w", data.DestinationChain, err) } // Get the signing request (nonce is required for SIGN) if assignedNonce == nil { - return nil, errors.New("assigned nonce is required for sign transaction") + return nil, fmt.Errorf("assigned nonce is required for sign transaction") } signingReq, err := builder.GetOutboundSigningRequest(ctx, &data, *assignedNonce) if err != nil { - return nil, errors.Wrap(err, "failed to get outbound signing request") + return nil, fmt.Errorf("failed to get outbound signing request: %w", err) } return signingReq, nil @@ -777,23 +777,23 @@ func (c *Coordinator) createQcSetup(ctx context.Context, threshold int, partyIDs // Get current TSS keyId from pushCore key, err := c.pushCore.GetCurrentKey(ctx) if err != nil { - return nil, errors.Wrap(err, "failed to get current TSS keyId") + return nil, fmt.Errorf("failed to get current TSS keyId: %w", err) } if key == nil { - return nil, errors.New("no TSS key exists") + return nil, fmt.Errorf("no TSS key exists") } keyIDStr := key.KeyId // Load old keyshare to get the key we're changing oldKeyshareBytes, err := c.keyshareManager.Get(keyIDStr) if err != nil { - return nil, errors.Wrapf(err, "failed to load keyshare for keyId %s", keyIDStr) + return nil, fmt.Errorf("failed to load keyshare for keyId %s: %w", keyIDStr, err) } // Load keyshare handle from bytes oldKeyshareHandle, err := session.DklsKeyshareFromBytes(oldKeyshareBytes) if err != nil { - return nil, errors.Wrap(err, "failed to load keyshare handle") + return nil, fmt.Errorf("failed to load keyshare handle: %w", err) } defer session.DklsKeyshareFree(oldKeyshareHandle) @@ -831,7 +831,7 @@ func (c *Coordinator) createQcSetup(ctx context.Context, threshold int, partyIDs setupData, err := session.DklsQcSetupMsgNew(oldKeyshareHandle, threshold, partyIDs, oldParticipantIndices, newParticipantIndices) if err != nil { - return nil, errors.Wrap(err, "failed to create quorumchange setup") + return nil, fmt.Errorf("failed to create quorumchange setup: %w", err) } return setupData, nil } @@ -1039,7 +1039,7 @@ func (c *Coordinator) assignSignNonce( // useFinalized: when true, uses finalized nonce (stuck nonce recovery); otherwise uses pending. func (c *Coordinator) getNextNonceForChain(ctx context.Context, chain string, useFinalized bool) (uint64, error) { if c.chains == nil { - return 0, errors.New("chains manager not configured") + return 0, fmt.Errorf("chains manager not configured") } client, err := c.chains.GetClient(chain) if err != nil { diff --git a/universalClient/tss/dkls/sign.go b/universalClient/tss/dkls/sign.go index 8e098186..d706d6b4 100644 --- a/universalClient/tss/dkls/sign.go +++ b/universalClient/tss/dkls/sign.go @@ -6,7 +6,6 @@ import ( "math/big" "github.com/ethereum/go-ethereum/crypto/secp256k1" - "github.com/pkg/errors" session "go-wrapper/go-dkls/sessions" ) @@ -168,7 +167,7 @@ func (s *signSession) GetResult() (*Result, error) { return nil, fmt.Errorf("signature verification error: %w", verifyErr) } if !verified { - return nil, errors.New("signature verification failed") + return nil, fmt.Errorf("signature verification failed") } // Return participants list (copy to avoid mutation) @@ -189,13 +188,13 @@ func (s *signSession) GetResult() (*Result, error) { // messageHash: SHA256 hash of the message (32 bytes) func (s *signSession) verifySignature(publicKey, signature, messageHash []byte) (bool, error) { if len(publicKey) != 33 { - return false, errors.Errorf("public key must be 33 bytes (compressed), got %d bytes", len(publicKey)) + return false, fmt.Errorf("public key must be 33 bytes (compressed), got %d bytes", len(publicKey)) } if len(signature) != 64 && len(signature) != 65 { - return false, errors.Errorf("signature must be 64 or 65 bytes (r || s [|| recovery_id]), got %d bytes", len(signature)) + return false, fmt.Errorf("signature must be 64 or 65 bytes (r || s [|| recovery_id]), got %d bytes", len(signature)) } if len(messageHash) != 32 { - return false, errors.Errorf("message hash must be 32 bytes, got %d bytes", len(messageHash)) + return false, fmt.Errorf("message hash must be 32 bytes, got %d bytes", len(messageHash)) } // Use only first 64 bytes (r || s), ignore recovery ID if present @@ -207,7 +206,7 @@ func (s *signSession) verifySignature(publicKey, signature, messageHash []byte) // Decompress public key vkX, vkY := secp256k1.DecompressPubkey(publicKey) if vkX == nil || vkY == nil { - return false, errors.New("failed to decompress public key") + return false, fmt.Errorf("failed to decompress public key") } // Create ECDSA public key diff --git a/universalClient/tss/eventstore/store.go b/universalClient/tss/eventstore/store.go index 13b816cc..a8c96830 100644 --- a/universalClient/tss/eventstore/store.go +++ b/universalClient/tss/eventstore/store.go @@ -1,7 +1,8 @@ package eventstore import ( - "github.com/pkg/errors" + "fmt" + "github.com/rs/zerolog" "gorm.io/gorm" @@ -44,10 +45,10 @@ func (s *Store) Update(eventID string, fields map[string]any) error { Where("event_id = ?", eventID). Updates(fields) if result.Error != nil { - return errors.Wrapf(result.Error, "failed to update event %s", eventID) + return fmt.Errorf("failed to update event %s: %w", eventID, result.Error) } if result.RowsAffected == 0 { - return errors.Errorf("event %s not found", eventID) + return fmt.Errorf("event %s not found", eventID) } return nil } @@ -57,7 +58,7 @@ func (s *Store) Update(eventID string, fields map[string]any) error { func (s *Store) CountInProgress() (int64, error) { var count int64 if err := s.db.Model(&store.Event{}).Where("status = ?", store.StatusInProgress).Count(&count).Error; err != nil { - return 0, errors.Wrap(err, "failed to count IN_PROGRESS events") + return 0, fmt.Errorf("failed to count IN_PROGRESS events: %w", err) } return count, nil } @@ -69,7 +70,7 @@ func (s *Store) ResetInProgressEventsToConfirmed() (int64, error) { Where("status = ?", store.StatusInProgress). Update("status", store.StatusConfirmed) if result.Error != nil { - return 0, errors.Wrap(result.Error, "failed to reset IN_PROGRESS events to CONFIRMED") + return 0, fmt.Errorf("failed to reset IN_PROGRESS events to CONFIRMED: %w", result.Error) } return result.RowsAffected, nil } @@ -91,7 +92,7 @@ func (s *Store) GetNonExpiredConfirmedEvents(currentBlock, minBlockConfirmation var events []store.Event if err := query.Find(&events).Error; err != nil { - return nil, errors.Wrap(err, "failed to query confirmed events") + return nil, fmt.Errorf("failed to query confirmed events: %w", err) } return events, nil } @@ -105,7 +106,7 @@ func (s *Store) GetInFlightSignEvents() ([]store.Event, error) { if err := s.db.Where("type = ? AND status IN (?, ?)", store.EventTypeSign, store.StatusInProgress, store.StatusSigned). Find(&events).Error; err != nil { - return nil, errors.Wrap(err, "failed to query in-flight sign events") + return nil, fmt.Errorf("failed to query in-flight sign events: %w", err) } return events, nil } @@ -120,7 +121,7 @@ func (s *Store) GetSignedSignEvents(limit int) ([]store.Event, error) { Order("block_height ASC, created_at ASC"). Limit(limit). Find(&events).Error; err != nil { - return nil, errors.Wrap(err, "failed to query signed sign events") + return nil, fmt.Errorf("failed to query signed sign events: %w", err) } return events, nil } @@ -135,7 +136,7 @@ func (s *Store) GetBroadcastedSignEvents(limit int) ([]store.Event, error) { Order("block_height ASC, created_at ASC"). Limit(limit). Find(&events).Error; err != nil { - return nil, errors.Wrap(err, "failed to query broadcasted sign events") + return nil, fmt.Errorf("failed to query broadcasted sign events: %w", err) } return events, nil } @@ -151,7 +152,7 @@ func (s *Store) GetExpiredConfirmedEvents(currentBlock uint64, limit int) ([]sto var events []store.Event if err := query.Find(&events).Error; err != nil { - return nil, errors.Wrap(err, "failed to query expired confirmed events") + return nil, fmt.Errorf("failed to query expired confirmed events: %w", err) } return events, nil } diff --git a/universalClient/tss/expirysweeper/sweeper.go b/universalClient/tss/expirysweeper/sweeper.go index 66b10d2b..4d72c77c 100644 --- a/universalClient/tss/expirysweeper/sweeper.go +++ b/universalClient/tss/expirysweeper/sweeper.go @@ -3,9 +3,9 @@ package expirysweeper import ( "context" "encoding/json" + "fmt" "time" - "github.com/pkg/errors" "github.com/rs/zerolog" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" @@ -119,7 +119,7 @@ func (s *Sweeper) sweep(ctx context.Context) { func (s *Sweeper) voteFailureAndMarkReverted(ctx context.Context, event *store.Event, errorMsg string) error { var data uexecutortypes.OutboundCreatedEvent if err := json.Unmarshal(event.EventData, &data); err != nil { - return errors.Wrapf(err, "failed to parse outbound event data for event %s", event.EventID) + return fmt.Errorf("failed to parse outbound event data for event %s: %w", event.EventID, err) } fields := map[string]any{"status": store.StatusReverted} @@ -135,13 +135,13 @@ func (s *Sweeper) voteFailureAndMarkReverted(ctx context.Context, event *store.E } voteTxHash, err := s.pushSigner.VoteOutbound(ctx, data.TxID, data.UniversalTxId, observation) if err != nil { - return errors.Wrapf(err, "failed to vote failure for event %s", event.EventID) + return fmt.Errorf("failed to vote failure for event %s: %w", event.EventID, err) } fields["vote_tx_hash"] = voteTxHash } if err := s.eventStore.Update(event.EventID, fields); err != nil { - return errors.Wrapf(err, "failed to mark event %s as reverted", event.EventID) + return fmt.Errorf("failed to mark event %s as reverted: %w", event.EventID, err) } s.logger.Info(). Str("event_id", event.EventID). diff --git a/universalClient/tss/sessionmanager/sessionmanager.go b/universalClient/tss/sessionmanager/sessionmanager.go index f1938bef..7e8e4cdb 100644 --- a/universalClient/tss/sessionmanager/sessionmanager.go +++ b/universalClient/tss/sessionmanager/sessionmanager.go @@ -6,11 +6,11 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "fmt" "strconv" "sync" "time" - "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/pushchain/push-chain-node/universalClient/chains" @@ -103,7 +103,7 @@ func (sm *SessionManager) HandleIncomingMessage(ctx context.Context, peerID stri // Unmarshal message var msg coordinator.Message if err := json.Unmarshal(data, &msg); err != nil { - return errors.Wrap(err, "failed to unmarshal message") + return fmt.Errorf("failed to unmarshal message: %w", err) } sm.logger.Debug(). @@ -122,7 +122,7 @@ func (sm *SessionManager) HandleIncomingMessage(ctx context.Context, peerID stri case "step": return sm.handleStepMessage(ctx, peerID, &msg) default: - return errors.Errorf("unknown message type: %s", msg.Type) + return fmt.Errorf("unknown message type: %s", msg.Type) } } @@ -140,46 +140,46 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s // 2. Validate event exists in DB event, err := sm.eventStore.GetEvent(msg.EventID) if err != nil { - return errors.Wrapf(err, "event %s not found in database", msg.EventID) + return fmt.Errorf("event %s not found in database: %w", msg.EventID, err) } // 3. Validate event is CONFIRMED and not expired if event.Status != store.StatusConfirmed { - return errors.Errorf("event %s is not in confirmed status (got %s)", msg.EventID, event.Status) + return fmt.Errorf("event %s is not in confirmed status (got %s)", msg.EventID, event.Status) } currentBlock, err := sm.coordinator.GetLatestBlockNum(ctx) if err != nil { - return errors.Wrap(err, "failed to get current block for setup validation") + return fmt.Errorf("failed to get current block for setup validation: %w", err) } if event.ExpiryBlockHeight > 0 && event.ExpiryBlockHeight <= currentBlock { - return errors.Errorf("event %s has expired (expiry_block_height %d <= current_block %d)", msg.EventID, event.ExpiryBlockHeight, currentBlock) + return fmt.Errorf("event %s has expired (expiry_block_height %d <= current_block %d)", msg.EventID, event.ExpiryBlockHeight, currentBlock) } // 4. Validate sender is coordinator isCoord, err := sm.coordinator.IsPeerCoordinator(ctx, senderPeerID) if err != nil { - return errors.Wrap(err, "failed to check if sender is coordinator") + return fmt.Errorf("failed to check if sender is coordinator: %w", err) } if !isCoord { - return errors.Errorf("sender %s is not the coordinator", senderPeerID) + return fmt.Errorf("sender %s is not the coordinator", senderPeerID) } // 5. Validate participants list matches event protocol requirements if err := sm.validateParticipants(msg.Participants, event); err != nil { - return errors.Wrap(err, "participants validation failed") + return fmt.Errorf("participants validation failed: %w", err) } // 6. For SIGN events, verify the signing hash independently if event.Type == store.EventTypeSign { if err := sm.verifySigningRequest(ctx, event, msg.UnSignedOutboundTxReq); err != nil { - return errors.Wrap(err, "signing request verification failed") + return fmt.Errorf("signing request verification failed: %w", err) } } // 7. Create session based on protocol type session, err := sm.createSession(ctx, event, msg) if err != nil { - return errors.Wrapf(err, "failed to create session for event %s", msg.EventID) + return fmt.Errorf("failed to create session for event %s: %w", msg.EventID, err) } // 8. Store session state @@ -225,7 +225,7 @@ func (sm *SessionManager) handleStepMessage(ctx context.Context, senderPeerID st sm.mu.RUnlock() if !exists { - return errors.Errorf("session for event %s does not exist", msg.EventID) + return fmt.Errorf("session for event %s does not exist", msg.EventID) } session := state.session @@ -237,7 +237,7 @@ func (sm *SessionManager) handleStepMessage(ctx context.Context, senderPeerID st // Get sender's validator address from peerID senderPartyID, err := sm.coordinator.GetPartyIDFromPeerID(ctx, senderPeerID) if err != nil { - return errors.Wrapf(err, "failed to get partyID for sender peerID %s", senderPeerID) + return fmt.Errorf("failed to get partyID for sender peerID %s: %w", senderPeerID, err) } // Check if sender is in participants @@ -249,12 +249,12 @@ func (sm *SessionManager) handleStepMessage(ctx context.Context, senderPeerID st } } if !isParticipant { - return errors.Errorf("sender %s (partyID: %s) is not in session participants for event %s", senderPeerID, senderPartyID, msg.EventID) + return fmt.Errorf("sender %s (partyID: %s) is not in session participants for event %s", senderPeerID, senderPartyID, msg.EventID) } // 3. Route message to session if err := session.InputMessage(msg.Payload); err != nil { - return errors.Wrapf(err, "failed to input message to session %s", msg.EventID) + return fmt.Errorf("failed to input message to session %s: %w", msg.EventID, err) } // 4. Process step @@ -268,7 +268,7 @@ func (sm *SessionManager) processSessionStep(ctx context.Context, eventID string sm.mu.RUnlock() if !exists { - return errors.Errorf("session for event %s does not exist", eventID) + return fmt.Errorf("session for event %s does not exist", eventID) } session := state.session @@ -279,7 +279,7 @@ func (sm *SessionManager) processSessionStep(ctx context.Context, eventID string state.stepMu.Unlock() if err != nil { - return errors.Wrapf(err, "failed to step session %s", eventID) + return fmt.Errorf("failed to step session %s: %w", eventID, err) } // Send output messages @@ -350,12 +350,12 @@ func (sm *SessionManager) handleBeginMessage(ctx context.Context, senderPeerID s sm.mu.RUnlock() if !exists { - return errors.Errorf("session for event %s does not exist", msg.EventID) + return fmt.Errorf("session for event %s does not exist", msg.EventID) } // 2. Validate sender is the coordinator for this session if senderPeerID != state.coordinator { - return errors.Errorf("begin message must come from coordinator %s, but received from %s", state.coordinator, senderPeerID) + return fmt.Errorf("begin message must come from coordinator %s, but received from %s", state.coordinator, senderPeerID) } sm.logger.Info(). @@ -377,11 +377,11 @@ func (sm *SessionManager) sendACK(ctx context.Context, coordinatorPeerID string, } msgBytes, err := json.Marshal(ackMsg) if err != nil { - return errors.Wrap(err, "failed to marshal ACK message") + return fmt.Errorf("failed to marshal ACK message: %w", err) } if err := sm.send(ctx, coordinatorPeerID, msgBytes); err != nil { - return errors.Wrap(err, "failed to send ACK message") + return fmt.Errorf("failed to send ACK message: %w", err) } sm.logger.Debug(). @@ -399,7 +399,7 @@ func (sm *SessionManager) handleSessionFinished(ctx context.Context, eventID str result, err := state.session.GetResult() if err != nil { - return errors.Wrapf(err, "failed to get result for session %s", eventID) + return fmt.Errorf("failed to get result for session %s: %w", eventID, err) } // SIGN sessions: broadcast the signed tx, then done (status managed by handleSigningComplete) @@ -421,7 +421,7 @@ func (sm *SessionManager) handleSignFinished(ctx context.Context, eventID string event, err := sm.eventStore.GetEvent(eventID) if err != nil { - return errors.Wrapf(err, "failed to get event %s for broadcasting", eventID) + return fmt.Errorf("failed to get event %s for broadcasting: %w", eventID, err) } if err := sm.handleSigningComplete(ctx, eventID, event.EventData, result.Signature, signingReq); err != nil { @@ -442,7 +442,7 @@ func (sm *SessionManager) handleKeyFinished(ctx context.Context, eventID, protoc // Store keyshare if err := sm.keyshareManager.Store(result.Keyshare, storageID); err != nil { - return errors.Wrapf(err, "failed to store keyshare for event %s", eventID) + return fmt.Errorf("failed to store keyshare for event %s: %w", eventID, err) } keyshareHash := sha256.Sum256(result.Keyshare) @@ -461,7 +461,7 @@ func (sm *SessionManager) handleKeyFinished(ctx context.Context, eventID, protoc processID, err := strconv.ParseUint(eventID, 10, 64) if err != nil { - return errors.Wrapf(err, "failed to parse process id from %s", eventID) + return fmt.Errorf("failed to parse process id from %s: %w", eventID, err) } voteTxHash, err = sm.pushSigner.VoteTssKeyProcess(ctx, pubKeyHex, storageID, processID) @@ -470,14 +470,14 @@ func (sm *SessionManager) handleKeyFinished(ctx context.Context, eventID, protoc if updateErr := sm.eventStore.Update(eventID, map[string]any{"status": store.StatusReverted}); updateErr != nil { sm.logger.Error().Err(updateErr).Str("event_id", eventID).Msg("failed to mark event as REVERTED") } - return errors.Wrapf(err, "TSS vote failed for event %s — marked REVERTED", eventID) + return fmt.Errorf("TSS vote failed for event %s — marked REVERTED: %w", eventID, err) } sm.logger.Info().Str("vote_tx_hash", voteTxHash).Str("event_id", eventID).Msg("TSS vote succeeded") } if err := sm.eventStore.Update(eventID, map[string]any{"status": store.StatusCompleted, "vote_tx_hash": voteTxHash}); err != nil { - return errors.Wrapf(err, "failed to update event status to completed") + return fmt.Errorf("failed to update event status to completed: %w", err) } sm.logger.Info().Str("event_id", eventID).Msg("key session finished successfully") @@ -502,13 +502,13 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { - return nil, errors.Wrap(err, "failed to get current TSS keyId") + return nil, fmt.Errorf("failed to get current TSS keyId: %w", err) } // Load old keyshare oldKeyshare, err := sm.keyshareManager.Get(keyID) if err != nil { - return nil, errors.Wrapf(err, "failed to load keyshare for keyId %s", keyID) + return nil, fmt.Errorf("failed to load keyshare for keyId %s: %w", keyID, err) } return dkls.NewKeyrefreshSession( @@ -524,7 +524,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { - return nil, errors.Wrap(err, "failed to get current TSS keyId for quorumchange") + return nil, fmt.Errorf("failed to get current TSS keyId for quorumchange: %w", err) } // Load old keyshare - if not found, we're a new party (oldKeyshare will be nil) @@ -541,7 +541,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, oldKeyshare = nil } else { // Other error (decryption failed, etc.) - return error - return nil, errors.Wrapf(err, "failed to load keyshare for keyId %s", keyID) + return nil, fmt.Errorf("failed to load keyshare for keyId %s: %w", keyID, err) } } @@ -558,13 +558,13 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, // Get current keyID keyID, _, err := sm.coordinator.GetCurrentTSSKey(ctx) if err != nil { - return nil, errors.Wrap(err, "failed to get current TSS keyId") + return nil, fmt.Errorf("failed to get current TSS keyId: %w", err) } // Load keyshare keyshareBytes, err := sm.keyshareManager.Get(keyID) if err != nil { - return nil, errors.Wrapf(err, "failed to load keyshare for keyId %s", keyID) + return nil, fmt.Errorf("failed to load keyshare for keyId %s: %w", keyID, err) } return dkls.NewSignSession( @@ -578,7 +578,7 @@ func (sm *SessionManager) createSession(ctx context.Context, event *store.Event, ) default: - return nil, errors.Errorf("unknown protocol type: %s", event.Type) + return nil, fmt.Errorf("unknown protocol type: %s", event.Type) } } @@ -589,7 +589,7 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto // Get eligible validators for this protocol eligible := sm.coordinator.GetEligibleUV(string(event.Type)) if len(eligible) == 0 { - return errors.New("no eligible validators for protocol") + return fmt.Errorf("no eligible validators for protocol") } // Build set and list of eligible partyIDs @@ -607,7 +607,7 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto participantSet := make(map[string]bool) for _, partyID := range participants { if !eligibleSet[partyID] { - return errors.Errorf("participant %s is not eligible for protocol %s", partyID, event.Type) + return fmt.Errorf("participant %s is not eligible for protocol %s", partyID, event.Type) } participantSet[partyID] = true } @@ -617,12 +617,12 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto case store.EventTypeKeygen, store.EventTypeKeyrefresh, store.EventTypeQuorumChange: // For keygen, keyrefresh, and quorumchange: participants must match exactly with eligible participants if len(participants) != len(eligibleList) { - return errors.Errorf("participants count %d does not match eligible count %d for %s", len(participants), len(eligibleList), event.Type) + return fmt.Errorf("participants count %d does not match eligible count %d for %s", len(participants), len(eligibleList), event.Type) } // Check all eligible are in participants for _, eligibleID := range eligibleList { if !participantSet[eligibleID] { - return errors.Errorf("eligible participant %s is missing from participants list for %s", eligibleID, event.Type) + return fmt.Errorf("eligible participant %s is missing from participants list for %s", eligibleID, event.Type) } } @@ -632,12 +632,12 @@ func (sm *SessionManager) validateParticipants(participants []string, event *sto // all participants are already verified eligible by the eligibleSet check above. threshold := coordinator.CalculateThreshold(len(eligibleList)) if len(participants) < threshold { - return errors.Errorf("SIGN participants count %d is below required threshold %d (eligible: %d)", + return fmt.Errorf("SIGN participants count %d is below required threshold %d (eligible: %d)", len(participants), threshold, len(eligibleList)) } default: - return errors.Errorf("unknown protocol type: %s", event.Type) + return fmt.Errorf("unknown protocol type: %s", event.Type) } return nil @@ -721,27 +721,27 @@ func (sm *SessionManager) checkExpiredSessions(ctx context.Context, blockDelay u // verifySigningRequest validates the coordinator's signing request: hash verification (coordinator nonce is source of truth). func (sm *SessionManager) verifySigningRequest(ctx context.Context, event *store.Event, req *common.UnSignedOutboundTxReq) error { if req == nil { - return errors.New("unsigned transaction request is required for SIGN events") + return fmt.Errorf("unsigned transaction request is required for SIGN events") } if len(req.SigningHash) == 0 { - return errors.New("signing hash is missing in request") + return fmt.Errorf("signing hash is missing in request") } // Parse the event data to get outbound transaction details var outboundData uexecutortypes.OutboundCreatedEvent if err := json.Unmarshal(event.EventData, &outboundData); err != nil { - return errors.Wrap(err, "failed to parse outbound event data") + return fmt.Errorf("failed to parse outbound event data: %w", err) } chainID := outboundData.DestinationChain if chainID == "" { - return errors.New("destination chain is missing") + return fmt.Errorf("destination chain is missing") } // Reject signing if outbound is disabled for the destination chain if sm.chains != nil && !sm.chains.IsChainOutboundEnabled(chainID) { - return errors.Errorf("outbound disabled for chain %s, refusing to sign", chainID) + return fmt.Errorf("outbound disabled for chain %s, refusing to sign", chainID) } // Build with coordinator's nonce and compare hash @@ -770,14 +770,14 @@ func (sm *SessionManager) verifySigningRequest(ctx context.Context, event *store } else if finalizedNonce, nonceErr := builder.GetNextNonce(ctx, tssAddr, true /* useFinalized */); nonceErr != nil { sm.logger.Warn().Err(nonceErr).Str("chain", chainID).Msg("cannot get finalized nonce for check, skipping") } else if req.Nonce < finalizedNonce { - return errors.Errorf("coordinator assigned nonce %d is below chain finalized nonce %d for %s — nonce already used on chain", + return fmt.Errorf("coordinator assigned nonce %d is below chain finalized nonce %d for %s — nonce already used on chain", req.Nonce, finalizedNonce, chainID) } // Use coordinator's nonce so our computed hash matches signingReq, err := builder.GetOutboundSigningRequest(ctx, &outboundData, req.Nonce) if err != nil { - return errors.Wrap(err, "failed to get signing request for verification") + return fmt.Errorf("failed to get signing request for verification: %w", err) } // Compare hashes - must match exactly @@ -787,7 +787,7 @@ func (sm *SessionManager) verifySigningRequest(ctx context.Context, event *store Str("coordinator_hash", hex.EncodeToString(req.SigningHash)). Str("event_id", event.EventID). Msg("signing hash mismatch - rejecting signing request") - return errors.New("signing hash mismatch: our computed hash does not match coordinator's hash") + return fmt.Errorf("signing hash mismatch: our computed hash does not match coordinator's hash") } sm.logger.Debug(). @@ -803,7 +803,7 @@ func (sm *SessionManager) verifySigningRequest(ctx context.Context, event *store // The TSS address is always the same ECDSA address derived from the TSS public key func (sm *SessionManager) getTSSAddress(ctx context.Context) (string, error) { if sm.coordinator == nil { - return "", errors.New("coordinator not configured") + return "", fmt.Errorf("coordinator not configured") } return sm.coordinator.GetTSSAddress(ctx) } @@ -812,7 +812,7 @@ func (sm *SessionManager) getTSSAddress(ctx context.Context) (string, error) { // signingReq is the cached signing request from the coordinator setup message. func (sm *SessionManager) handleSigningComplete(_ context.Context, eventID string, eventData []byte, signature []byte, signingReq *common.UnSignedOutboundTxReq) error { if signingReq == nil { - return errors.New("signing request is nil - cannot persist signing data") + return fmt.Errorf("signing request is nil - cannot persist signing data") } // Build signing_data to persist alongside the original event data @@ -825,13 +825,13 @@ func (sm *SessionManager) handleSigningComplete(_ context.Context, eventID strin // Unmarshal original event data, add signing_data, re-marshal var raw map[string]any if err := json.Unmarshal(eventData, &raw); err != nil { - return errors.Wrap(err, "failed to parse event data for signing_data injection") + return fmt.Errorf("failed to parse event data for signing_data injection: %w", err) } raw["signing_data"] = signingData newEventData, err := json.Marshal(raw) if err != nil { - return errors.Wrap(err, "failed to marshal event data with signing_data") + return fmt.Errorf("failed to marshal event data with signing_data: %w", err) } // Persist enriched event data + mark SIGNED; txBroadcaster will pick it up @@ -839,7 +839,7 @@ func (sm *SessionManager) handleSigningComplete(_ context.Context, eventID strin "event_data": newEventData, "status": store.StatusSigned, }); err != nil { - return errors.Wrap(err, "failed to update event with signing data") + return fmt.Errorf("failed to update event with signing data: %w", err) } sm.logger.Info(). diff --git a/universalClient/tss/tss.go b/universalClient/tss/tss.go index 4629f3fe..be6911a3 100644 --- a/universalClient/tss/tss.go +++ b/universalClient/tss/tss.go @@ -12,7 +12,6 @@ import ( "time" "github.com/libp2p/go-libp2p/core/crypto" - "github.com/pkg/errors" "github.com/rs/zerolog" "github.com/pushchain/push-chain-node/universalClient/chains" @@ -275,7 +274,7 @@ func (n *Node) Start(ctx context.Context) error { n.mu.Lock() if n.running { n.mu.Unlock() - return errors.New("node is already running") + return fmt.Errorf("node is already running") } n.running = true n.mu.Unlock() @@ -419,7 +418,7 @@ func (n *Node) Stop() error { // If the peer is not registered, it will automatically register it from validators before sending. func (n *Node) Send(ctx context.Context, peerID string, data []byte) error { if n.network == nil { - return errors.New("network not initialized") + return fmt.Errorf("network not initialized") } // If sending to self, call onReceive directly @@ -436,20 +435,20 @@ func (n *Node) Send(ctx context.Context, peerID string, data []byte) error { // If not registered, register it using coordinator if !isRegistered { if n.coordinator == nil { - return errors.New("coordinator not initialized") + return fmt.Errorf("coordinator not initialized") } multiaddrs, err := n.coordinator.GetMultiAddrsFromPeerID(ctx, peerID) if err != nil { - return errors.Wrapf(err, "failed to get multiaddrs for peer %s", peerID) + return fmt.Errorf("failed to get multiaddrs for peer %s: %w", peerID, err) } if len(multiaddrs) == 0 { - return errors.Errorf("peer %s has no addresses", peerID) + return fmt.Errorf("peer %s has no addresses", peerID) } if err := n.network.EnsurePeer(peerID, multiaddrs); err != nil { - return errors.Wrapf(err, "failed to register peer %s", peerID) + return fmt.Errorf("failed to register peer %s: %w", peerID, err) } // Mark as registered @@ -502,7 +501,7 @@ func (n *Node) onReceive(peerID string, data []byte) { // This allows coordinator to track ACKs even when it's not a participant. func (n *Node) HandleACKMessage(ctx context.Context, senderPeerID string, msg *coordinator.Message) error { if n.coordinator == nil { - return errors.New("coordinator not initialized") + return fmt.Errorf("coordinator not initialized") } // Forward ACK to coordinator for tracking diff --git a/universalClient/tss/txbroadcaster/broadcaster.go b/universalClient/tss/txbroadcaster/broadcaster.go index 24729f26..b4e3faf0 100644 --- a/universalClient/tss/txbroadcaster/broadcaster.go +++ b/universalClient/tss/txbroadcaster/broadcaster.go @@ -4,9 +4,9 @@ import ( "context" "encoding/hex" "encoding/json" + "fmt" "time" - "github.com/pkg/errors" "github.com/rs/zerolog" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" @@ -134,10 +134,10 @@ func (b *Broadcaster) broadcastEvent(ctx context.Context, event *store.Event) { func parseSignedEventData(eventData []byte) (*SignedEventData, error) { var data SignedEventData if err := json.Unmarshal(eventData, &data); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal signed event data") + return nil, fmt.Errorf("failed to unmarshal signed event data: %w", err) } if data.SigningData == nil { - return nil, errors.New("signing_data missing from event data") + return nil, fmt.Errorf("signing_data missing from event data") } return &data, nil } @@ -160,7 +160,7 @@ func (b *Broadcaster) markBroadcasted(event *store.Event, chainID, txHash string func reconstructSigningReq(sd *SigningData) (*common.UnSignedOutboundTxReq, error) { signingHash, err := hex.DecodeString(sd.SigningHash) if err != nil { - return nil, errors.Wrap(err, "failed to decode signing hash") + return nil, fmt.Errorf("failed to decode signing hash: %w", err) } return &common.UnSignedOutboundTxReq{ diff --git a/universalClient/tss/txresolver/resolver.go b/universalClient/tss/txresolver/resolver.go index 1416225e..fccd1df1 100644 --- a/universalClient/tss/txresolver/resolver.go +++ b/universalClient/tss/txresolver/resolver.go @@ -3,10 +3,10 @@ package txresolver import ( "context" "encoding/json" + "fmt" "strings" "time" - "github.com/pkg/errors" "github.com/rs/zerolog" uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" @@ -154,7 +154,7 @@ func (r *Resolver) voteFailureAndMarkReverted(ctx context.Context, event *store. return err } if err := r.eventStore.Update(event.EventID, map[string]any{"status": store.StatusReverted, "vote_tx_hash": voteTxHash}); err != nil { - return errors.Wrapf(err, "failed to mark event %s as reverted", event.EventID) + return fmt.Errorf("failed to mark event %s as reverted: %w", event.EventID, err) } r.logger.Info(). Str("event_id", event.EventID).Str("tx_id", txID). @@ -165,7 +165,7 @@ func (r *Resolver) voteFailureAndMarkReverted(ctx context.Context, event *store. func extractOutboundIDs(event *store.Event) (txID, utxID string, err error) { var data uexecutortypes.OutboundCreatedEvent if err := json.Unmarshal(event.EventData, &data); err != nil { - return "", "", errors.Wrap(err, "failed to parse outbound event data") + return "", "", fmt.Errorf("failed to parse outbound event data: %w", err) } return data.TxID, data.UniversalTxId, nil } @@ -173,7 +173,7 @@ func extractOutboundIDs(event *store.Event) (txID, utxID string, err error) { func parseCAIPTxHash(caipTxHash string) (chainID, txHash string, err error) { lastColon := strings.LastIndex(caipTxHash, ":") if lastColon <= 0 || lastColon == len(caipTxHash)-1 { - return "", "", errors.Errorf("invalid CAIP tx hash format: %s", caipTxHash) + return "", "", fmt.Errorf("invalid CAIP tx hash format: %s", caipTxHash) } return caipTxHash[:lastColon], caipTxHash[lastColon+1:], nil } From 4407b812db66d7c941546297a1fbb00959e23490 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:11:59 +0530 Subject: [PATCH 19/28] fix: lock --- universalClient/tss/coordinator/coordinator.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/universalClient/tss/coordinator/coordinator.go b/universalClient/tss/coordinator/coordinator.go index 1f522335..ebf85d35 100644 --- a/universalClient/tss/coordinator/coordinator.go +++ b/universalClient/tss/coordinator/coordinator.go @@ -990,13 +990,10 @@ func (c *Coordinator) assignSignNonce( if inFlightPerChain[chain] > 0 { c.chainWaitMu.Lock() consecutiveWait := c.consecutiveWaitPerChain[chain] - c.chainWaitMu.Unlock() - if consecutiveWait < ConsecutiveWaitThreshold { - // Still within patience — skip chain, let in-flight events clear - c.chainWaitMu.Lock() c.consecutiveWaitPerChain[chain]++ c.chainWaitMu.Unlock() + skippedChains[chain] = true c.logger.Debug(). Str("chain", chain). @@ -1005,6 +1002,7 @@ func (c *Coordinator) assignSignNonce( Msg("skipping chain — waiting for in-flight to clear") return 0, false } + c.chainWaitMu.Unlock() // Patience exhausted — recover with finalized nonce. // Cap is intentionally bypassed: stuck events have stale nonces and will From 5c3de6f7e03da2e4209cab224662e25a630a372d Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:12:22 +0530 Subject: [PATCH 20/28] ctx fixes --- universalClient/core/client.go | 1 + universalClient/pushsigner/grant_verifier.go | 4 ++-- universalClient/pushsigner/pushsigner.go | 3 ++- universalClient/pushsigner/pushsigner_test.go | 6 +++--- universalClient/tss/tss.go | 4 +++- 5 files changed, 11 insertions(+), 7 deletions(-) diff --git a/universalClient/core/client.go b/universalClient/core/client.go index 44313231..916fc23f 100644 --- a/universalClient/core/client.go +++ b/universalClient/core/client.go @@ -53,6 +53,7 @@ func NewUniversalClient(ctx context.Context, cfg *config.Config) (*UniversalClie } pushSigner, err := pushsigner.New( + ctx, log, cfg.KeyringBackend, cfg.KeyringPassword, diff --git a/universalClient/pushsigner/grant_verifier.go b/universalClient/pushsigner/grant_verifier.go index e410b5a8..ace5eb60 100644 --- a/universalClient/pushsigner/grant_verifier.go +++ b/universalClient/pushsigner/grant_verifier.go @@ -44,7 +44,7 @@ type validationResult struct { } // ValidateKeysAndGrants validates hotkey and AuthZ grants against the specified granter. -func validateKeysAndGrants(keyringBackend config.KeyringBackend, keyringPassword string, nodeHome string, pushCore chainClient, granter string) (*validationResult, error) { +func validateKeysAndGrants(ctx context.Context, keyringBackend config.KeyringBackend, keyringPassword string, nodeHome string, pushCore chainClient, granter string) (*validationResult, error) { interfaceRegistry := keys.NewInterfaceRegistryWithEVMSupport() authz.RegisterInterfaces(interfaceRegistry) uetypes.RegisterInterfaces(interfaceRegistry) @@ -81,7 +81,7 @@ func validateKeysAndGrants(keyringBackend config.KeyringBackend, keyringPassword } keyAddrStr := keyAddr.String() - grantResp, err := pushCore.GetGranteeGrants(context.Background(), keyAddrStr) + grantResp, err := pushCore.GetGranteeGrants(ctx, keyAddrStr) if err != nil { return nil, fmt.Errorf("failed to query grants: %w", err) } diff --git a/universalClient/pushsigner/pushsigner.go b/universalClient/pushsigner/pushsigner.go index cca89fea..8cb83988 100644 --- a/universalClient/pushsigner/pushsigner.go +++ b/universalClient/pushsigner/pushsigner.go @@ -48,6 +48,7 @@ type Signer struct { // New creates a new Signer instance with validation. func New( + ctx context.Context, log zerolog.Logger, keyringBackend config.KeyringBackend, keyringPassword string, @@ -58,7 +59,7 @@ func New( ) (*Signer, error) { log.Info().Msg("Validating hotkey and AuthZ permissions...") - validationResult, err := validateKeysAndGrants(keyringBackend, keyringPassword, nodeHome, pushCore, granter) + validationResult, err := validateKeysAndGrants(ctx, keyringBackend, keyringPassword, nodeHome, pushCore, granter) if err != nil { log.Error().Err(err).Msg("PushSigner validation failed") return nil, fmt.Errorf("PushSigner validation failed: %w", err) diff --git a/universalClient/pushsigner/pushsigner_test.go b/universalClient/pushsigner/pushsigner_test.go index 4e1958b6..de55127f 100644 --- a/universalClient/pushsigner/pushsigner_test.go +++ b/universalClient/pushsigner/pushsigner_test.go @@ -126,7 +126,7 @@ func TestNew(t *testing.T) { mockCore := createMockPushCoreClient() - signer, err := New(logger, config.KeyringBackendTest, "", "", mockCore, "test-chain", "cosmos1granter") + signer, err := New(context.Background(), logger, config.KeyringBackendTest, "", "", mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) assert.Contains(t, err.Error(), "PushSigner validation failed") @@ -135,7 +135,7 @@ func TestNew(t *testing.T) { t.Run("validation failure - keyring creation fails", func(t *testing.T) { mockCore := createMockPushCoreClient() - signer, err := New(logger, config.KeyringBackendFile, "", "", mockCore, "test-chain", "cosmos1granter") + signer, err := New(context.Background(), logger, config.KeyringBackendFile, "", "", mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) assert.Contains(t, err.Error(), "keyring_password is required for file backend") @@ -154,7 +154,7 @@ func TestNew(t *testing.T) { mockCore := createMockPushCoreClient() - signer, err := New(logger, config.KeyringBackendTest, "", tempDir, mockCore, "test-chain", "cosmos1granter") + signer, err := New(context.Background(), logger, config.KeyringBackendTest, "", tempDir, mockCore, "test-chain", "cosmos1granter") require.Error(t, err) assert.Nil(t, signer) }) diff --git a/universalClient/tss/tss.go b/universalClient/tss/tss.go index be6911a3..35b4f0ee 100644 --- a/universalClient/tss/tss.go +++ b/universalClient/tss/tss.go @@ -122,6 +122,7 @@ type Node struct { pushSigner *pushsigner.Signer // Optional - nil if voting disabled // Internal state + ctx context.Context mu sync.RWMutex running bool stopCh chan struct{} @@ -277,6 +278,7 @@ func (n *Node) Start(ctx context.Context) error { return fmt.Errorf("node is already running") } n.running = true + n.ctx = ctx n.mu.Unlock() n.logger.Info().Msg("starting TSS node") @@ -469,7 +471,7 @@ func (n *Node) Send(ctx context.Context, peerID string, data []byte) error { // onReceive handles incoming messages from p2p network. // It passes raw data directly to sessionManager. func (n *Node) onReceive(peerID string, data []byte) { - ctx := context.Background() + ctx := n.ctx // Unmarshal to check message type var msg coordinator.Message From f05e2b41f719294e466e0cbd6f03d4db4f060fce Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:21:47 +0530 Subject: [PATCH 21/28] chore: server tc --- universalClient/api/server_test.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/universalClient/api/server_test.go b/universalClient/api/server_test.go index 09eeef1b..10609230 100644 --- a/universalClient/api/server_test.go +++ b/universalClient/api/server_test.go @@ -62,6 +62,31 @@ func TestServerStartStop(t *testing.T) { }) } +func TestAddrBeforeStart(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + server := NewServer(logger, 0) + + // Before Start, listener is nil — Addr returns empty + assert.Empty(t, server.Addr()) +} + +func TestStartBindFailure(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + + // Start first server on a port + s1 := NewServer(logger, 0) + require.NoError(t, s1.Start()) + defer s1.Stop() + + // Try to start second server on the same port — should fail + addr := s1.Addr() + s2 := NewServer(logger, 0) + s2.server.Addr = addr // force same address + err := s2.Start() + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to bind") +} + func TestServerIntegration(t *testing.T) { logger := zerolog.New(zerolog.NewTestWriter(t)) From 79a99d8cc5add3f5cb5b1df0fe038fbbc9955452 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:22:01 +0530 Subject: [PATCH 22/28] chore: db tc --- universalClient/db/db_test.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/universalClient/db/db_test.go b/universalClient/db/db_test.go index 6e0561f6..63ca0d22 100644 --- a/universalClient/db/db_test.go +++ b/universalClient/db/db_test.go @@ -87,6 +87,33 @@ func TestDB_PragmaOptimizations(t *testing.T) { assert.Equal(t, 1, fkEnabled) } +func TestDB_OpenWithoutMigration(t *testing.T) { + db, err := OpenInMemoryDB(false) + require.NoError(t, err) + require.NotNil(t, db) + assert.NoError(t, db.Close()) +} + +func TestDB_FileDBExistingDirectory(t *testing.T) { + dir := t.TempDir() + // Open twice — second time directory already exists + db1, err := OpenFileDB(dir, "test1.db", true) + require.NoError(t, err) + db1.Close() + + db2, err := OpenFileDB(dir, "test2.db", true) + require.NoError(t, err) + db2.Close() +} + +func TestDB_ClientReturnsGorm(t *testing.T) { + db, err := OpenInMemoryDB(true) + require.NoError(t, err) + defer db.Close() + + assert.NotNil(t, db.Client()) +} + func TestDB_SchemaModels(t *testing.T) { models := schemaModels() assert.Len(t, models, 2) From 4b342edb31a188e590e3fd70b14eb0c1c4eaf8ac Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:22:17 +0530 Subject: [PATCH 23/28] chore: config tc --- go.mod | 2 +- universalClient/config/config_test.go | 52 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index c16d9f95..a31b9e9c 100755 --- a/go.mod +++ b/go.mod @@ -76,7 +76,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/mr-tron/base58 v1.2.0 github.com/near/borsh-go v0.3.1 - github.com/pkg/errors v0.9.1 + github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.22.0 github.com/rs/zerolog v1.34.0 github.com/spf13/cast v1.9.2 diff --git a/universalClient/config/config_test.go b/universalClient/config/config_test.go index 16f77c1f..1efc379d 100644 --- a/universalClient/config/config_test.go +++ b/universalClient/config/config_test.go @@ -215,6 +215,58 @@ func TestSaveAndLoad(t *testing.T) { }) } +func TestGetChainConfig(t *testing.T) { + polling := 5 + cfg := &Config{ + ChainConfigs: map[string]ChainSpecificConfig{ + "eip155:1": {EventPollingIntervalSeconds: &polling}, + }, + } + + t.Run("existing chain", func(t *testing.T) { + cc := cfg.GetChainConfig("eip155:1") + require.NotNil(t, cc.EventPollingIntervalSeconds) + assert.Equal(t, 5, *cc.EventPollingIntervalSeconds) + }) + + t.Run("missing chain returns empty", func(t *testing.T) { + cc := cfg.GetChainConfig("eip155:999") + assert.Nil(t, cc.EventPollingIntervalSeconds) + }) +} + +func TestLoadAppliesDefaults(t *testing.T) { + dir := t.TempDir() + configDir := filepath.Join(dir, ConfigSubdir) + require.NoError(t, os.MkdirAll(configDir, 0o750)) + + // Write a minimal valid config with empty node_home + minimalCfg := `{"log_level": 1, "log_format": "console"}` + require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFileName), []byte(minimalCfg), 0o600)) + + loaded, err := Load(dir) + require.NoError(t, err) + + // Defaults should have been applied + assert.NotEmpty(t, loaded.NodeHome) + assert.NotEmpty(t, loaded.PushChainGRPCURLs) + assert.NotZero(t, loaded.QueryServerPort) +} + +func TestLoadInvalidConfigFails(t *testing.T) { + dir := t.TempDir() + configDir := filepath.Join(dir, ConfigSubdir) + require.NoError(t, os.MkdirAll(configDir, 0o750)) + + // Valid JSON but invalid log level + badCfg := `{"log_level": 99, "log_format": "console"}` + require.NoError(t, os.WriteFile(filepath.Join(configDir, ConfigFileName), []byte(badCfg), 0o600)) + + _, err := Load(dir) + require.Error(t, err) + assert.Contains(t, err.Error(), "log level must be between 0 and 5") +} + func TestConfigJSONRoundTrip(t *testing.T) { cfg := &Config{ LogLevel: 2, From b39b66eeb67d593c9b866ea5716e3599d57bf6d1 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:23:14 +0530 Subject: [PATCH 24/28] add: pushcore tc --- universalClient/pushcore/pushCore_test.go | 98 +++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/universalClient/pushcore/pushCore_test.go b/universalClient/pushcore/pushCore_test.go index 015a7d7a..3ffad3e2 100644 --- a/universalClient/pushcore/pushCore_test.go +++ b/universalClient/pushcore/pushCore_test.go @@ -671,6 +671,96 @@ func TestClient_GetAccount(t *testing.T) { }) } +func TestClient_BroadcastTx(t *testing.T) { + logger := zerolog.Nop() + ctx := context.Background() + + t.Run("no endpoints configured", func(t *testing.T) { + client := &Client{ + logger: logger, + txClients: []tx.ServiceClient{}, + } + + resp, err := client.BroadcastTx(ctx, []byte("txbytes")) + require.Error(t, err) + assert.Contains(t, err.Error(), "no endpoints configured") + assert.Nil(t, resp) + }) + + t.Run("successful broadcast", func(t *testing.T) { + mockClient := &mockTxServiceClient{ + broadcastResp: &tx.BroadcastTxResponse{ + TxResponse: &sdktypes.TxResponse{TxHash: "0xabc", Code: 0}, + }, + } + + client := &Client{ + logger: logger, + txClients: []tx.ServiceClient{mockClient}, + } + + resp, err := client.BroadcastTx(ctx, []byte("txbytes")) + require.NoError(t, err) + assert.Equal(t, "0xabc", resp.TxResponse.TxHash) + }) + + t.Run("failover on first endpoint failure", func(t *testing.T) { + failing := &mockTxServiceClient{err: assert.AnError} + success := &mockTxServiceClient{ + broadcastResp: &tx.BroadcastTxResponse{ + TxResponse: &sdktypes.TxResponse{TxHash: "0xdef"}, + }, + } + + client := &Client{ + logger: logger, + txClients: []tx.ServiceClient{failing, success}, + } + + resp, err := client.BroadcastTx(ctx, []byte("txbytes")) + require.NoError(t, err) + assert.Equal(t, "0xdef", resp.TxResponse.TxHash) + }) +} + +func TestClient_GetTxsByEvents_MismatchedLengths(t *testing.T) { + logger := zerolog.Nop() + mockClient := &mockTxServiceClient{ + getTxsEventResp: &tx.GetTxsEventResponse{ + Txs: []*tx.Tx{}, + TxResponses: []*sdktypes.TxResponse{{TxHash: "0x1"}}, + }, + } + + client := &Client{ + logger: logger, + txClients: []tx.ServiceClient{mockClient}, + } + + _, err := client.GetTxsByEvents(context.Background(), "test.event", 0, 0, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "mismatched Txs") +} + +func TestClient_GetGasPrice_NilResponse(t *testing.T) { + logger := zerolog.Nop() + mockClient := &mockUExecutorQueryClient{ + gasPriceResp: &uexecutortypes.QueryGasPriceResponse{ + GasPrice: nil, + }, + } + + client := &Client{ + logger: logger, + uexecutorClients: []uexecutortypes.QueryClient{mockClient}, + } + + price, err := client.GetGasPrice(context.Background(), "eip155:1") + require.Error(t, err) + assert.Contains(t, err.Error(), "GasPrice response is nil") + assert.Nil(t, price) +} + // Mock implementations type mockRegistryQueryClient struct { @@ -744,6 +834,7 @@ func (m *mockUTSSQueryClient) KeyById(ctx context.Context, req *utsstypes.QueryK type mockTxServiceClient struct { tx.ServiceClient getTxsEventResp *tx.GetTxsEventResponse + broadcastResp *tx.BroadcastTxResponse err error } @@ -754,6 +845,13 @@ func (m *mockTxServiceClient) GetTxsEvent(ctx context.Context, req *tx.GetTxsEve return m.getTxsEventResp, nil } +func (m *mockTxServiceClient) BroadcastTx(ctx context.Context, req *tx.BroadcastTxRequest, opts ...grpc.CallOption) (*tx.BroadcastTxResponse, error) { + if m.err != nil { + return nil, m.err + } + return m.broadcastResp, nil +} + func (m *mockTxServiceClient) GetTx(ctx context.Context, req *tx.GetTxRequest, opts ...grpc.CallOption) (*tx.GetTxResponse, error) { return nil, nil } From f2ddb5d9564c677e15b696eaba148988d8e9bb00 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:28:34 +0530 Subject: [PATCH 25/28] chore: pushSigner tc --- universalClient/pushsigner/pushsigner_test.go | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/universalClient/pushsigner/pushsigner_test.go b/universalClient/pushsigner/pushsigner_test.go index de55127f..03752bef 100644 --- a/universalClient/pushsigner/pushsigner_test.go +++ b/universalClient/pushsigner/pushsigner_test.go @@ -18,6 +18,7 @@ import ( "github.com/pushchain/push-chain-node/universalClient/config" "github.com/pushchain/push-chain-node/universalClient/pushcore" "github.com/pushchain/push-chain-node/universalClient/pushsigner/keys" + uexecutortypes "github.com/pushchain/push-chain-node/x/uexecutor/types" ) func TestMain(m *testing.M) { @@ -521,3 +522,100 @@ func TestSequenceReconciliation(t *testing.T) { assert.Equal(t, uint64(9), signer.lastSequence) }) } + +// --- Vote tests --- + +func successMock(t *testing.T) *mockChainClient { + return &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 1, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 0, TxHash: "VOTE_OK"}, + }, nil + }, + } +} + +func failMock(t *testing.T) *mockChainClient { + return &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 1, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return nil, fmt.Errorf("connection refused") + }, + } +} + +func TestVoteInbound(t *testing.T) { + t.Run("successful vote", func(t *testing.T) { + signer := createTestSigner(t, successMock(t)) + inbound := &uexecutortypes.Inbound{TxHash: "0xabc"} + + txHash, err := signer.VoteInbound(context.Background(), inbound) + require.NoError(t, err) + assert.Equal(t, "VOTE_OK", txHash) + }) + + t.Run("broadcast failure", func(t *testing.T) { + signer := createTestSigner(t, failMock(t)) + inbound := &uexecutortypes.Inbound{TxHash: "0xabc"} + + _, err := signer.VoteInbound(context.Background(), inbound) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to broadcast vote") + }) +} + +func TestVoteChainMeta(t *testing.T) { + signer := createTestSigner(t, successMock(t)) + + txHash, err := signer.VoteChainMeta(context.Background(), "eip155:1", 20000000000, 18500000) + require.NoError(t, err) + assert.Equal(t, "VOTE_OK", txHash) +} + +func TestVoteOutbound(t *testing.T) { + signer := createTestSigner(t, successMock(t)) + obs := &uexecutortypes.OutboundObservation{ + Success: true, + BlockHeight: 12345, + TxHash: "0xdef", + } + + txHash, err := signer.VoteOutbound(context.Background(), "tx-1", "utx-1", obs) + require.NoError(t, err) + assert.Equal(t, "VOTE_OK", txHash) +} + +func TestVoteTssKeyProcess(t *testing.T) { + signer := createTestSigner(t, successMock(t)) + + txHash, err := signer.VoteTssKeyProcess(context.Background(), "tsspub1abc", "key-001", 42) + require.NoError(t, err) + assert.Equal(t, "VOTE_OK", txHash) +} + +func TestVoteOnChainRejection(t *testing.T) { + mock := &mockChainClient{ + getAccountFn: func(ctx context.Context, address string) (*authtypes.QueryAccountResponse, error) { + addr, _ := sdk.AccAddressFromBech32(address) + return makeAccountResponse(t, addr, 1, 1), nil + }, + broadcastTxFn: func(ctx context.Context, txBytes []byte) (*sdktx.BroadcastTxResponse, error) { + return &sdktx.BroadcastTxResponse{ + TxResponse: &sdk.TxResponse{Code: 5, RawLog: "unauthorized"}, + }, nil + }, + } + + signer := createTestSigner(t, mock) + + _, err := signer.VoteInbound(context.Background(), &uexecutortypes.Inbound{TxHash: "0x1"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "transaction failed with code 5") +} From e9ea7920a63833cd3a8c28454c80d16af55535a9 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:28:49 +0530 Subject: [PATCH 26/28] add: chainstore tests --- .../chains/common/chain_store_test.go | 211 +++++++++++++++++- 1 file changed, 207 insertions(+), 4 deletions(-) diff --git a/universalClient/chains/common/chain_store_test.go b/universalClient/chains/common/chain_store_test.go index 4d0372de..5b3b6070 100644 --- a/universalClient/chains/common/chain_store_test.go +++ b/universalClient/chains/common/chain_store_test.go @@ -1,11 +1,13 @@ package common import ( + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/pushchain/push-chain-node/universalClient/db" storemodels "github.com/pushchain/push-chain-node/universalClient/store" ) @@ -68,9 +70,210 @@ func TestChainStoreNilDatabase(t *testing.T) { }) } -func TestChainStoreStruct(t *testing.T) { - t.Run("struct has database field", func(t *testing.T) { - store := &ChainStore{} - assert.Nil(t, store.database) +func newTestChainStore(t *testing.T) *ChainStore { + t.Helper() + testDB, err := db.OpenInMemoryDB(true) + require.NoError(t, err) + t.Cleanup(func() { testDB.Close() }) + return NewChainStore(testDB) +} + +func TestChainStore_GetChainHeight(t *testing.T) { + cs := newTestChainStore(t) + + t.Run("creates state on first call", func(t *testing.T) { + height, err := cs.GetChainHeight() + require.NoError(t, err) + assert.Equal(t, uint64(0), height) + }) + + t.Run("returns existing height", func(t *testing.T) { + require.NoError(t, cs.UpdateChainHeight(100)) + height, err := cs.GetChainHeight() + require.NoError(t, err) + assert.Equal(t, uint64(100), height) + }) +} + +func TestChainStore_UpdateChainHeight(t *testing.T) { + cs := newTestChainStore(t) + + t.Run("creates and updates", func(t *testing.T) { + require.NoError(t, cs.UpdateChainHeight(50)) + height, err := cs.GetChainHeight() + require.NoError(t, err) + assert.Equal(t, uint64(50), height) + }) + + t.Run("only updates if higher", func(t *testing.T) { + require.NoError(t, cs.UpdateChainHeight(100)) + require.NoError(t, cs.UpdateChainHeight(50)) // lower — ignored + height, err := cs.GetChainHeight() + require.NoError(t, err) + assert.Equal(t, uint64(100), height) + }) +} + +func TestChainStore_InsertAndQuery(t *testing.T) { + cs := newTestChainStore(t) + + event := &storemodels.Event{ + EventID: "evt-1", + BlockHeight: 10, + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusPending, + } + + t.Run("insert new event", func(t *testing.T) { + inserted, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + assert.True(t, inserted) + }) + + t.Run("duplicate insert returns false", func(t *testing.T) { + inserted, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + assert.False(t, inserted) + }) + + t.Run("get pending events", func(t *testing.T) { + events, err := cs.GetPendingEvents(10) + require.NoError(t, err) + require.Len(t, events, 1) + assert.Equal(t, "evt-1", events[0].EventID) + }) + + t.Run("get confirmed events returns empty", func(t *testing.T) { + events, err := cs.GetConfirmedEvents(10) + require.NoError(t, err) + assert.Empty(t, events) + }) +} + +func TestChainStore_UpdateEventStatus(t *testing.T) { + cs := newTestChainStore(t) + + event := &storemodels.Event{ + EventID: "evt-2", + BlockHeight: 20, + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusPending, + } + _, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + + t.Run("updates matching status", func(t *testing.T) { + rows, err := cs.UpdateEventStatus("evt-2", storemodels.StatusPending, storemodels.StatusConfirmed) + require.NoError(t, err) + assert.Equal(t, int64(1), rows) }) + + t.Run("no-op if status mismatch", func(t *testing.T) { + rows, err := cs.UpdateEventStatus("evt-2", storemodels.StatusPending, storemodels.StatusCompleted) + require.NoError(t, err) + assert.Equal(t, int64(0), rows) + }) + + t.Run("confirmed events visible", func(t *testing.T) { + events, err := cs.GetConfirmedEvents(10) + require.NoError(t, err) + require.Len(t, events, 1) + assert.Equal(t, "evt-2", events[0].EventID) + }) +} + +func TestChainStore_UpdateStatusAndVoteTxHash(t *testing.T) { + cs := newTestChainStore(t) + + event := &storemodels.Event{ + EventID: "evt-3", + BlockHeight: 30, + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusConfirmed, + } + _, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + + rows, err := cs.UpdateStatusAndVoteTxHash("evt-3", storemodels.StatusConfirmed, storemodels.StatusCompleted, "0xvote123") + require.NoError(t, err) + assert.Equal(t, int64(1), rows) +} + +func TestChainStore_UpdateStatusAndEventData(t *testing.T) { + cs := newTestChainStore(t) + + event := &storemodels.Event{ + EventID: "evt-4", + BlockHeight: 40, + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusPending, + EventData: []byte(`{"old":"data"}`), + } + _, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + + newData := []byte(`{"new":"data"}`) + rows, err := cs.UpdateStatusAndEventData("evt-4", storemodels.StatusPending, storemodels.StatusConfirmed, newData) + require.NoError(t, err) + assert.Equal(t, int64(1), rows) +} + +func TestChainStore_UpdateVoteTxHash(t *testing.T) { + cs := newTestChainStore(t) + + event := &storemodels.Event{ + EventID: "evt-5", + BlockHeight: 50, + Type: storemodels.EventTypeOutbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusConfirmed, + } + _, err := cs.InsertEventIfNotExists(event) + require.NoError(t, err) + + err = cs.UpdateVoteTxHash("evt-5", "0xvotehash") + require.NoError(t, err) +} + +func TestChainStore_DeleteTerminalEvents(t *testing.T) { + cs := newTestChainStore(t) + + // Insert events in terminal states + for i, status := range []string{storemodels.StatusCompleted, storemodels.StatusReverted, storemodels.StatusReorged} { + evt := &storemodels.Event{ + EventID: fmt.Sprintf("term-%d", i), + BlockHeight: uint64(i), + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: status, + } + _, err := cs.InsertEventIfNotExists(evt) + require.NoError(t, err) + } + + // Insert a non-terminal event + active := &storemodels.Event{ + EventID: "active-1", + BlockHeight: 100, + Type: storemodels.EventTypeInbound, + ConfirmationType: storemodels.ConfirmationStandard, + Status: storemodels.StatusPending, + } + _, err := cs.InsertEventIfNotExists(active) + require.NoError(t, err) + + // Delete terminal events updated before far future + deleted, err := cs.DeleteTerminalEvents("2099-01-01") + require.NoError(t, err) + assert.Equal(t, int64(3), deleted) + + // Active event still exists + events, err := cs.GetPendingEvents(10) + require.NoError(t, err) + assert.Len(t, events, 1) + assert.Equal(t, "active-1", events[0].EventID) } From 9146aa48430800492579e945b0bc1e05d394419c Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 20 Mar 2026 17:33:00 +0530 Subject: [PATCH 27/28] add: store tests --- universalClient/tss/eventstore/store_test.go | 110 +++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/universalClient/tss/eventstore/store_test.go b/universalClient/tss/eventstore/store_test.go index 91754b6b..73cf65ee 100644 --- a/universalClient/tss/eventstore/store_test.go +++ b/universalClient/tss/eventstore/store_test.go @@ -495,3 +495,113 @@ func TestGetExpiredConfirmedEvents(t *testing.T) { } }) } + +func TestGetInFlightSignEvents(t *testing.T) { + s := setupTestStore(t) + + createTestEventWithType(t, s, "sign-inprogress", 10, store.StatusInProgress, 200, store.EventTypeSign) + createTestEventWithType(t, s, "sign-signed", 11, store.StatusSigned, 200, store.EventTypeSign) + createTestEventWithType(t, s, "sign-broadcasted", 12, store.StatusBroadcasted, 200, store.EventTypeSign) + createTestEventWithType(t, s, "keygen-inprogress", 13, store.StatusInProgress, 200, store.EventTypeKeygen) + + events, err := s.GetInFlightSignEvents() + if err != nil { + t.Fatalf("GetInFlightSignEvents() error = %v", err) + } + // Should return IN_PROGRESS + SIGNED sign events, NOT broadcasted, NOT keygen + if len(events) != 2 { + t.Fatalf("GetInFlightSignEvents() returned %d events, want 2", len(events)) + } +} + +func TestGetSignedSignEvents(t *testing.T) { + s := setupTestStore(t) + + createTestEventWithType(t, s, "signed-1", 10, store.StatusSigned, 200, store.EventTypeSign) + createTestEventWithType(t, s, "signed-2", 11, store.StatusSigned, 200, store.EventTypeSign) + createTestEventWithType(t, s, "inprogress-1", 12, store.StatusInProgress, 200, store.EventTypeSign) + + t.Run("returns only signed events", func(t *testing.T) { + events, err := s.GetSignedSignEvents(10) + if err != nil { + t.Fatalf("GetSignedSignEvents() error = %v", err) + } + if len(events) != 2 { + t.Errorf("GetSignedSignEvents() returned %d events, want 2", len(events)) + } + }) + + t.Run("respects limit", func(t *testing.T) { + events, err := s.GetSignedSignEvents(1) + if err != nil { + t.Fatalf("GetSignedSignEvents() error = %v", err) + } + if len(events) != 1 { + t.Errorf("GetSignedSignEvents(1) returned %d events, want 1", len(events)) + } + }) + + t.Run("zero limit defaults to 50", func(t *testing.T) { + events, err := s.GetSignedSignEvents(0) + if err != nil { + t.Fatalf("GetSignedSignEvents(0) error = %v", err) + } + if len(events) != 2 { + t.Errorf("GetSignedSignEvents(0) returned %d events, want 2", len(events)) + } + }) +} + +func TestGetBroadcastedSignEvents(t *testing.T) { + s := setupTestStore(t) + + // Create broadcasted event with tx hash + evt := store.Event{ + EventID: "bc-1", + BlockHeight: 10, + ExpiryBlockHeight: 200, + Type: store.EventTypeSign, + Status: store.StatusBroadcasted, + BroadcastedTxHash: "0xhash123", + } + if err := s.db.Create(&evt).Error; err != nil { + t.Fatalf("failed to create event: %v", err) + } + + // Create broadcasted event without tx hash — should be excluded + evt2 := store.Event{ + EventID: "bc-2", + BlockHeight: 11, + ExpiryBlockHeight: 200, + Type: store.EventTypeSign, + Status: store.StatusBroadcasted, + BroadcastedTxHash: "", + } + if err := s.db.Create(&evt2).Error; err != nil { + t.Fatalf("failed to create event: %v", err) + } + + t.Run("returns only events with tx hash", func(t *testing.T) { + events, err := s.GetBroadcastedSignEvents(10) + if err != nil { + t.Fatalf("GetBroadcastedSignEvents() error = %v", err) + } + if len(events) != 1 { + t.Errorf("got %d events, want 1", len(events)) + } + if events[0].EventID != "bc-1" { + t.Errorf("got event %s, want bc-1", events[0].EventID) + } + }) + + t.Run("zero limit defaults to 50", func(t *testing.T) { + events, err := s.GetBroadcastedSignEvents(0) + if err != nil { + t.Fatalf("error = %v", err) + } + if len(events) != 1 { + t.Errorf("got %d events, want 1", len(events)) + } + }) +} + From e5bb85a6dbc6c7b7a591cc21a0f853f3b5fdbc5b Mon Sep 17 00:00:00 2001 From: aman035 Date: Mon, 23 Mar 2026 16:08:44 +0530 Subject: [PATCH 28/28] fix: take home param, remove tss command --- cmd/puniversald/commands.go | 116 +++++++++--------------------------- 1 file changed, 29 insertions(+), 87 deletions(-) diff --git a/cmd/puniversald/commands.go b/cmd/puniversald/commands.go index 3bff8197..2e54b791 100644 --- a/cmd/puniversald/commands.go +++ b/cmd/puniversald/commands.go @@ -2,29 +2,35 @@ package main import ( "context" - "crypto/ed25519" - "encoding/hex" "encoding/json" "fmt" - "strings" - - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" + "path/filepath" sdkversion "github.com/cosmos/cosmos-sdk/version" + cosmosevmcmd "github.com/cosmos/evm/client" uvconfig "github.com/pushchain/push-chain-node/universalClient/config" "github.com/pushchain/push-chain-node/universalClient/core" "github.com/spf13/cobra" - - cosmosevmcmd "github.com/cosmos/evm/client" ) +const flagHome = "home" + func InitRootCmd(rootCmd *cobra.Command) { + rootCmd.PersistentFlags().String(flagHome, uvconfig.DefaultNodeHome(), "node home directory") + rootCmd.AddCommand(versionCmd()) rootCmd.AddCommand(startCmd()) rootCmd.AddCommand(initCmd()) rootCmd.AddCommand(cosmosevmcmd.KeyCommands(uvconfig.DefaultNodeHome(), true)) - rootCmd.AddCommand(tssPeerIDCmd()) +} + +// getHome reads the --home flag, falling back to DefaultNodeHome. +func getHome(cmd *cobra.Command) string { + home, _ := cmd.Flags().GetString(flagHome) + if home == "" { + home = uvconfig.DefaultNodeHome() + } + return home } func versionCmd() *cobra.Command { @@ -32,8 +38,7 @@ func versionCmd() *cobra.Command { Use: "version", Short: "Print universal validator version info", Run: func(cmd *cobra.Command, args []string) { - fmt.Printf("Name: %s\n", sdkversion.Name) - fmt.Printf("App Name: %s\n", sdkversion.AppName) + fmt.Printf("Name: puniversald\n") fmt.Printf("Version: %s\n", sdkversion.Version) fmt.Printf("Commit: %s\n", sdkversion.Commit) fmt.Printf("Build Tags: %s\n", sdkversion.BuildTags) @@ -42,55 +47,50 @@ func versionCmd() *cobra.Command { } func initCmd() *cobra.Command { - cmd := &cobra.Command{ + return &cobra.Command{ Use: "init", Short: "Initialize configuration file", Long: `Initialize the configuration file with default values. -This command creates a default configuration file at: - ~/.puniversal/config/pushuv_config.json - -You can edit this file to customize your universal validator settings.`, +By default creates the config at ~/.puniversal/config/pushuv_config.json. +Use --home to specify a different directory.`, RunE: func(cmd *cobra.Command, args []string) error { - // Load default config + home := getHome(cmd) + defaultCfg, err := uvconfig.LoadDefaultConfig() if err != nil { return fmt.Errorf("failed to load default config: %w", err) } - // Save to config directory - if err := uvconfig.Save(&defaultCfg, uvconfig.DefaultNodeHome()); err != nil { + if err := uvconfig.Save(&defaultCfg, home); err != nil { return fmt.Errorf("failed to save config: %w", err) } - configPath := fmt.Sprintf("%s/%s/%s", uvconfig.DefaultNodeHome(), uvconfig.ConfigSubdir, uvconfig.ConfigFileName) - fmt.Printf("✅ Configuration file initialized at: %s\n", configPath) - fmt.Println("You can now edit this file to customize your settings.") + configPath := filepath.Join(home, uvconfig.ConfigSubdir, uvconfig.ConfigFileName) + fmt.Printf("Configuration file initialized at: %s\n", configPath) return nil }, } - return cmd } func startCmd() *cobra.Command { - cmd := &cobra.Command{ + return &cobra.Command{ Use: "start", - Short: "Start the universal message handler", + Short: "Start the universal validator", RunE: func(cmd *cobra.Command, args []string) error { - // --- Step 1: Load config --- - loadedCfg, err := uvconfig.Load(uvconfig.DefaultNodeHome()) + home := getHome(cmd) + + loadedCfg, err := uvconfig.Load(home) if err != nil { return fmt.Errorf("failed to load config: %w", err) } - // Print loaded config as JSON configJSON, err := json.MarshalIndent(loadedCfg, "", " ") if err != nil { return fmt.Errorf("failed to marshal config: %w", err) } fmt.Printf("\n=== Loaded Configuration ===\n%s\n===========================\n\n", string(configJSON)) - // --- Step 2: Start client --- ctx := context.Background() client, err := core.NewUniversalClient(ctx, &loadedCfg) if err != nil { @@ -99,63 +99,5 @@ func startCmd() *cobra.Command { return client.Start() }, } - return cmd } -// tssPeerIDCmd computes and prints the libp2p peer ID from a TSS private key hex string. -// This is used during devnet setup to derive the peer ID for universal validator registration. -func tssPeerIDCmd() *cobra.Command { - var privateKeyHex string - - cmd := &cobra.Command{ - Use: "tss-peer-id", - Short: "Compute libp2p peer ID from TSS private key hex", - Long: `Compute the libp2p peer ID from a 32-byte hex-encoded Ed25519 seed. - -This is used during devnet setup to derive the peer ID that matches -what the TSS node will use, for universal validator registration. - -Example: - puniversald tss-peer-id --private-key 0101010101010101010101010101010101010101010101010101010101010101`, - RunE: func(cmd *cobra.Command, args []string) error { - privateKeyHex = strings.TrimSpace(privateKeyHex) - - // Decode hex to bytes - keyBytes, err := hex.DecodeString(privateKeyHex) - if err != nil { - return fmt.Errorf("invalid hex: %w", err) - } - if len(keyBytes) != 32 { - return fmt.Errorf("expected 32 bytes, got %d", len(keyBytes)) - } - - // Create Ed25519 key from seed - privKey := ed25519.NewKeyFromSeed(keyBytes) - pubKey := privKey.Public().(ed25519.PublicKey) - - // Convert to libp2p format (64 bytes: 32 priv seed + 32 pub) - libp2pKeyBytes := make([]byte, 64) - copy(libp2pKeyBytes[:32], privKey[:32]) - copy(libp2pKeyBytes[32:], pubKey) - - libp2pPrivKey, err := crypto.UnmarshalEd25519PrivateKey(libp2pKeyBytes) - if err != nil { - return fmt.Errorf("failed to unmarshal Ed25519 key: %w", err) - } - - // Get peer ID from public key - peerID, err := peer.IDFromPrivateKey(libp2pPrivKey) - if err != nil { - return fmt.Errorf("failed to derive peer ID: %w", err) - } - - fmt.Println(peerID.String()) - return nil - }, - } - - cmd.Flags().StringVar(&privateKeyHex, "private-key", "", "Hex-encoded 32-byte Ed25519 seed") - cmd.MarkFlagRequired("private-key") - - return cmd -}