Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions cmd/account/link_key/link_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ type initiateLinkingResponse struct {
FunctionArgs []string `json:"functionArgs"`
}

func Exec(ctx *runtime.Context, in Inputs) error {
func Exec(parentCtx context.Context, ctx *runtime.Context, in Inputs) error {
h := newHandler(ctx, nil)

if err := h.ValidateInputs(in); err != nil {
return err
}
return h.Execute(in)
return h.Execute(parentCtx, in)
}

func New(runtimeContext *runtime.Context) *cobra.Command {
Expand All @@ -83,7 +83,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command {
return err
}

return h.Execute(inputs)
return h.Execute(cmd.Context(), inputs)
},
}
settings.AddTxnTypeFlags(cmd)
Expand All @@ -101,6 +101,7 @@ type handler struct {
stdin io.Reader
environmentSet *environments.EnvironmentSet
wrc *client.WorkflowRegistryV2Client
execCtx context.Context

validated bool

Expand All @@ -109,28 +110,28 @@ type handler struct {
}

func newHandler(ctx *runtime.Context, stdin io.Reader) *handler {
h := handler{
return &handler{
settings: ctx.Settings,
credentials: ctx.Credentials,
clientFactory: ctx.ClientFactory,
log: ctx.Logger,
environmentSet: ctx.EnvironmentSet,
stdin: stdin,
wg: sync.WaitGroup{},
wrcErr: nil,
}
}

func (h *handler) initWorkflowRegistryClient() error {
h.wg.Add(1)
go func() {
defer h.wg.Done()
wrc, err := h.clientFactory.NewWorkflowRegistryV2Client()
wrc, err := h.clientFactory.NewWorkflowRegistryV2Client(h.execCtx)
if err != nil {
h.wrcErr = fmt.Errorf("failed to create workflow registry client: %w", err)
return
}
h.wrc = wrc
}()

return &h
return nil
}

func (h *handler) ResolveInputs(v *viper.Viper) (Inputs, error) {
Expand All @@ -154,11 +155,16 @@ func (h *handler) ValidateInputs(in Inputs) error {
return nil
}

func (h *handler) Execute(in Inputs) error {
func (h *handler) Execute(ctx context.Context, in Inputs) error {
if !h.validated {
return fmt.Errorf("inputs not validated")
}

h.execCtx = ctx
if err := h.initWorkflowRegistryClient(); err != nil {
return err
}

h.displayDetails()

if in.WorkflowOwnerLabel == "" {
Expand Down Expand Up @@ -191,7 +197,7 @@ func (h *handler) Execute(in Inputs) error {

ui.Dim(fmt.Sprintf("Starting linking: owner=%s, label=%s", in.WorkflowOwner, in.WorkflowOwnerLabel))

resp, err := h.callInitiateLinking(context.Background(), in)
resp, err := h.callInitiateLinking(h.execCtx, in)
if err != nil {
return err
}
Expand Down Expand Up @@ -296,10 +302,10 @@ func (h *handler) linkOwner(resp initiateLinkingResponse) error {
}

ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress)
if err := h.wrc.CanLinkOwner(ownerAddr, ts, proofBytes, sigBytes); err != nil {
if err := h.wrc.CanLinkOwner(h.execCtx, ownerAddr, ts, proofBytes, sigBytes); err != nil {
return fmt.Errorf("link request verification failed: %w", err)
}
txOut, err := h.wrc.LinkOwner(ts, proofBytes, sigBytes)
txOut, err := h.wrc.LinkOwner(h.execCtx, ts, proofBytes, sigBytes)
if err != nil {
return fmt.Errorf("LinkOwner failed: %w", err)
}
Expand Down Expand Up @@ -388,7 +394,7 @@ func (h *handler) checkIfAlreadyLinked() (bool, error) {
ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress)
ui.Dim("Checking existing registrations...")

linked, err := h.wrc.IsOwnerLinked(ownerAddr)
linked, err := h.wrc.IsOwnerLinked(h.execCtx, ownerAddr)
if err != nil {
return false, fmt.Errorf("failed to check owner link status: %w", err)
}
Expand Down
30 changes: 18 additions & 12 deletions cmd/account/unlink_key/unlink_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type handler struct {
stdin io.Reader
environmentSet *environments.EnvironmentSet
wrc *client.WorkflowRegistryV2Client
execCtx context.Context

validated bool

Expand All @@ -83,7 +84,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command {
if err := h.ValidateInputs(in); err != nil {
return err
}
return h.Execute(in)
return h.Execute(cmd.Context(), in)
},
}
settings.AddTxnTypeFlags(cmd)
Expand All @@ -92,28 +93,28 @@ func New(runtimeContext *runtime.Context) *cobra.Command {
}

func newHandler(ctx *runtime.Context, stdin io.Reader) *handler {
h := handler{
return &handler{
settings: ctx.Settings,
credentials: ctx.Credentials,
clientFactory: ctx.ClientFactory,
log: ctx.Logger,
environmentSet: ctx.EnvironmentSet,
stdin: stdin,
wg: sync.WaitGroup{},
wrcErr: nil,
}
}

func (h *handler) initWorkflowRegistryClient() error {
h.wg.Add(1)
go func() {
defer h.wg.Done()
wrc, err := h.clientFactory.NewWorkflowRegistryV2Client()
wrc, err := h.clientFactory.NewWorkflowRegistryV2Client(h.execCtx)
if err != nil {
h.wrcErr = fmt.Errorf("failed to create workflow registry client: %w", err)
return
}
h.wrc = wrc
}()

return &h
return nil
}

func (h *handler) ResolveInputs(v *viper.Viper) (Inputs, error) {
Expand All @@ -137,11 +138,16 @@ func (h *handler) ValidateInputs(in Inputs) error {
return nil
}

func (h *handler) Execute(in Inputs) error {
func (h *handler) Execute(ctx context.Context, in Inputs) error {
if !h.validated {
return fmt.Errorf("inputs not validated")
}

h.execCtx = ctx
if err := h.initWorkflowRegistryClient(); err != nil {
return err
}

h.displayDetails()

ui.Dim(fmt.Sprintf("Starting unlinking: owner=%s", in.WorkflowOwner))
Expand Down Expand Up @@ -181,7 +187,7 @@ func (h *handler) Execute(in Inputs) error {
}
}

resp, err := h.callInitiateUnlinking(context.Background(), in)
resp, err := h.callInitiateUnlinking(h.execCtx, in)
if err != nil {
return err
}
Expand Down Expand Up @@ -255,10 +261,10 @@ func (h *handler) unlinkOwner(owner string, resp initiateUnlinkingResponse) erro
}

addr := common.HexToAddress(owner)
if err := h.wrc.CanUnlinkOwner(addr, ts, sigBytes); err != nil {
if err := h.wrc.CanUnlinkOwner(h.execCtx, addr, ts, sigBytes); err != nil {
return fmt.Errorf("unlink request verification failed: %w", err)
}
txOut, err := h.wrc.UnlinkOwner(addr, ts, sigBytes)
txOut, err := h.wrc.UnlinkOwner(h.execCtx, addr, ts, sigBytes)
if err != nil {
return fmt.Errorf("UnlinkOwner failed: %w", err)
}
Expand Down Expand Up @@ -346,7 +352,7 @@ func (h *handler) unlinkOwner(owner string, resp initiateUnlinkingResponse) erro
func (h *handler) checkIfAlreadyLinked() (bool, error) {
ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress)

linked, err := h.wrc.IsOwnerLinked(ownerAddr)
linked, err := h.wrc.IsOwnerLinked(h.execCtx, ownerAddr)
if err != nil {
return false, fmt.Errorf("failed to check owner link status: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/client/client_factory.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"fmt"
"strings"

Expand All @@ -15,7 +16,7 @@ import (
)

type Factory interface {
NewWorkflowRegistryV2Client() (*WorkflowRegistryV2Client, error)
NewWorkflowRegistryV2Client(ctx context.Context) (*WorkflowRegistryV2Client, error)
GetTxType() TxType
GetSkipConfirmation() bool
}
Expand All @@ -32,7 +33,7 @@ func NewFactory(logger *zerolog.Logger, viper *viper.Viper) Factory {
}
}

func (f *factoryImpl) NewWorkflowRegistryV2Client() (*WorkflowRegistryV2Client, error) {
func (f *factoryImpl) NewWorkflowRegistryV2Client(ctx context.Context) (*WorkflowRegistryV2Client, error) {
environmentSet, err := environments.New()
if err != nil {
return nil, fmt.Errorf("failed to load environment details: %w", err)
Expand Down Expand Up @@ -60,7 +61,7 @@ func (f *factoryImpl) NewWorkflowRegistryV2Client() (*WorkflowRegistryV2Client,
txcConfig,
)

typeAndVersion, err := workflowRegistryV2Client.TypeAndVersion()
typeAndVersion, err := workflowRegistryV2Client.TypeAndVersion(ctx)
if err != nil {
return workflowRegistryV2Client, fmt.Errorf("failed to get type and version of workflow registry contract at %s: %w", environmentSet.WorkflowRegistryAddress, err)
}
Expand Down
6 changes: 4 additions & 2 deletions cmd/client/eth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,16 @@ func readSethConfigFromFile(configPath string) (*seth.Config, error) {
}

func getChainID(rpcURL string) (uint64, error) {
client, err := rpc.DialContext(context.Background(), rpcURL)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
client, err := rpc.DialContext(ctx, rpcURL)
if err != nil {
return 0, err
}
defer client.Close()

var chainID string
err = client.CallContext(context.Background(), &chainID, "eth_chainId")
err = client.CallContext(ctx, &chainID, "eth_chainId")
if err != nil {
return 0, err
}
Expand Down
11 changes: 7 additions & 4 deletions cmd/client/tx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -119,7 +120,7 @@ type RawTx struct {
// return txOpts, nil
//}

func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) {
func (c *TxClient) executeTransactionByTxType(ctx context.Context, txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) {
switch c.config.TxType {
case Regular:
simulateTx, err := txFn(cmdCommon.SimTransactOpts())
Expand All @@ -138,7 +139,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts)
Value: simulateTx.Value(),
Data: simulateTx.Data(),
}
estimatedGas, gasErr := c.EthClient.Client.EstimateGas(c.EthClient.Context, msg)
estimatedGas, gasErr := c.EthClient.Client.EstimateGas(ctx, msg)
if gasErr != nil {
c.Logger.Warn().Err(gasErr).Msg("Failed to estimate gas usage")
}
Expand All @@ -159,7 +160,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts)

// Calculate and print total cost for sending the transaction on-chain
if gasErr == nil {
gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(c.EthClient.Context)
gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(ctx)
if gasPriceErr != nil {
c.Logger.Warn().Err(gasPriceErr).Msg("Failed to fetch gas price")
} else {
Expand Down Expand Up @@ -189,7 +190,9 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts)
spinner := ui.NewSpinner()
spinner.Start("Submitting transaction...")

decodedTx, err := c.EthClient.Decode(txFn(c.EthClient.NewTXOpts()))
txOpts := c.EthClient.NewTXOpts()
txOpts.Context = ctx
decodedTx, err := c.EthClient.Decode(txFn(txOpts))
if err != nil {
spinner.Stop()
return TxOutput{Type: Regular}, err
Expand Down
Loading
Loading