Skip to content

Commit

Permalink
add context to agent client, handle cancellations, option to validate…
Browse files Browse the repository at this point in the history
… on establish
  • Loading branch information
michaeldwan committed Aug 13, 2021
1 parent d84357d commit 9f6e14e
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 51 deletions.
17 changes: 10 additions & 7 deletions cmd/fly_agent.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"log"
"os"
Expand Down Expand Up @@ -55,28 +56,30 @@ func runFlyAgentDaemonStart(ctx *cmdctx.CmdContext) error {
return nil
}

func runFlyAgentStart(ctx *cmdctx.CmdContext) error {
api := ctx.Client.API()
func runFlyAgentStart(cc *cmdctx.CmdContext) error {
api := cc.Client.API()
ctx := context.Background()

c, err := agent.DefaultClient(api)
if err == nil {
c.Kill()
c.Kill(ctx)
}

_, err = agent.Establish(api)
_, err = agent.Establish(ctx, api, true)
if err != nil {
fmt.Fprintf(os.Stderr, "can't start agent: %s", err)
}

return err
}

func runFlyAgentStop(ctx *cmdctx.CmdContext) error {
api := ctx.Client.API()
func runFlyAgentStop(cc *cmdctx.CmdContext) error {
api := cc.Client.API()
ctx := context.Background()

c, err := agent.DefaultClient(api)
if err == nil {
c.Kill()
c.Kill(ctx)
}

return err
Expand Down
35 changes: 18 additions & 17 deletions cmd/ssh_terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,39 @@ import (
"github.com/superfly/flyctl/terminal"
)

func runSSHConsole(ctx *cmdctx.CmdContext) error {
client := ctx.Client.API()
func runSSHConsole(cc *cmdctx.CmdContext) error {
client := cc.Client.API()
ctx := createCancellableContext()

terminal.Debugf("Retrieving app info for %s\n", ctx.AppName)
terminal.Debugf("Retrieving app info for %s\n", cc.AppName)

app, err := client.GetApp(ctx.AppName)
app, err := client.GetApp(cc.AppName)
if err != nil {
return fmt.Errorf("get app: %w", err)
}

agentclient, err := agent.Establish(client)
agentclient, err := agent.Establish(ctx, client, true)
if err != nil {
return fmt.Errorf("can't establish agent: %s\n", err)
}

dialer, err := agentclient.Dialer(&app.Organization)
dialer, err := agentclient.Dialer(ctx, &app.Organization)
if err != nil {
return fmt.Errorf("ssh: can't build tunnel for %s: %s\n", app.Organization.Slug, err)
}

if ctx.Config.GetBool("probe") {
if err = agentclient.Probe(&app.Organization); err != nil {
if cc.Config.GetBool("probe") {
if err = agentclient.Probe(ctx, &app.Organization); err != nil {
return fmt.Errorf("probe wireguard: %w", err)
}
}

var addr string

if ctx.Config.GetBool("select") {
instances, err := agentclient.Instances(&app.Organization, ctx.AppName)
if cc.Config.GetBool("select") {
instances, err := agentclient.Instances(ctx, &app.Organization, cc.AppName)
if err != nil {
return fmt.Errorf("look up %s: %w", ctx.AppName, err)
return fmt.Errorf("look up %s: %w", cc.AppName, err)
}

selected := 0
Expand All @@ -62,18 +63,18 @@ func runSSHConsole(ctx *cmdctx.CmdContext) error {
}

addr = fmt.Sprintf("[%s]", instances.Addresses[selected])
} else if len(ctx.Args) != 0 {
addr = ctx.Args[0]
} else if len(cc.Args) != 0 {
addr = cc.Args[0]
} else {
addr = fmt.Sprintf("%s.internal", ctx.AppName)
addr = fmt.Sprintf("%s.internal", cc.AppName)
}

return sshConnect(&SSHParams{
Ctx: ctx,
Ctx: cc,
Org: &app.Organization,
Dialer: dialer,
App: ctx.AppName,
Cmd: ctx.Config.GetString("command"),
App: cc.AppName,
Cmd: cc.Config.GetString("command"),
}, addr)
}

Expand Down
4 changes: 2 additions & 2 deletions internal/build/imgsrc/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ func newRemoteDockerClient(ctx context.Context, apiClient *api.Client, appName s
return errors.Wrap(err, "error fetching target app")
}

agentclient, err := agent.Establish(apiClient)
agentclient, err := agent.Establish(errCtx, apiClient, true)
if err != nil {
return errors.Wrap(err, "error establishing agent")
}

dialer, err := agentclient.Dialer(&app.Organization)
dialer, err := agentclient.Dialer(errCtx, &app.Organization)
if err != nil {
return errors.Wrapf(err, "error establishing wireguard connection for %s organization", app.Organization.Slug)
}
Expand Down
14 changes: 10 additions & 4 deletions pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (s *Server) handle(c net.Conn) {

func NewServer(path string, cmdCtx *cmdctx.CmdContext) (*Server, error) {
if c, err := NewClient(path); err == nil {
c.Kill()
c.Kill(context.Background())
}

if err := removeSocket(path); err != nil {
Expand Down Expand Up @@ -515,16 +515,22 @@ func captureWireguardConnErr(err error, org string) {
}

/// Establish starts the daemon if necessary and returns a client
func Establish(apiClient *api.Client) (*Client, error) {
func Establish(ctx context.Context, apiClient *api.Client, validate bool) (*Client, error) {
if validate {
if err := wireguard.PruneInvalidPeers(apiClient); err != nil {
return nil, err
}
}

c, err := DefaultClient(apiClient)
if err == nil {
_, err := c.Ping()
_, err := c.Ping(ctx)
if err == nil {
return c, nil
}
}

fmt.Println("command", os.Args[0])

return StartDaemon(apiClient, os.Args[0])
return StartDaemon(ctx, apiClient, os.Args[0])
}
72 changes: 53 additions & 19 deletions pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,26 +61,38 @@ func (c *Client) connect() (net.Conn, error) {
return conn, nil
}

func (c *Client) withConnection(f func(conn net.Conn) error) error {
conn, err := c.connect()
if err != nil {
func (c *Client) withConnection(ctx context.Context, f func(conn net.Conn) error) error {
errCh := make(chan error, 1)

go func() {
conn, err := c.connect()
if err != nil {
errCh <- err
}
defer conn.Close()

errCh <- f(conn)
}()

select {
case <-ctx.Done():
<-errCh
return ctx.Err()
case err := <-errCh:
return err
}
defer conn.Close()

return f(conn)
}

func (c *Client) Kill() error {
return c.withConnection(func(conn net.Conn) error {
func (c *Client) Kill(ctx context.Context) error {
return c.withConnection(ctx, func(conn net.Conn) error {
return writef(conn, "kill")
})
}

func (c *Client) Ping() (int, error) {
func (c *Client) Ping(ctx context.Context) (int, error) {
var pid int

err := c.withConnection(func(conn net.Conn) error {
err := c.withConnection(ctx, func(conn net.Conn) error {
writef(conn, "ping")

conn.SetReadDeadline(time.Now().Add(defaultTimeout))
Expand All @@ -106,8 +118,8 @@ func (c *Client) Ping() (int, error) {
return pid, err
}

func (c *Client) Establish(slug string) error {
return c.withConnection(func(conn net.Conn) error {
func (c *Client) Establish(ctx context.Context, slug string) error {
return c.withConnection(ctx, func(conn net.Conn) error {
writef(conn, "establish %s", slug)

// this goes out to the API; don't time it out aggressively
Expand All @@ -124,8 +136,30 @@ func (c *Client) Establish(slug string) error {
})
}

func (c *Client) Probe(o *api.Organization) error {
return c.withConnection(func(conn net.Conn) error {
func (c *Client) WaitForTunnel(ctx context.Context, o *api.Organization) error {
for {
err := c.Probe(ctx, o)
switch {
case err == nil:
return nil
case err == context.Canceled || err == context.DeadlineExceeded:
return err
case errors.Is(err, &ErrProbeFailed{}):
continue
}
}
}

type ErrProbeFailed struct {
Msg string
}

func (e *ErrProbeFailed) Error() string {
return fmt.Sprintf("probe failed: %s", e.Msg)
}

func (c *Client) Probe(ctx context.Context, o *api.Organization) error {
return c.withConnection(ctx, func(conn net.Conn) error {
writef(conn, "probe %s", o.Slug)

reply, err := read(conn)
Expand All @@ -134,17 +168,17 @@ func (c *Client) Probe(o *api.Organization) error {
}

if string(reply) != "ok" {
return fmt.Errorf("probe failed: %s", string(reply))
return &ErrProbeFailed{Msg: string(reply)}
}

return nil
})
}

func (c *Client) Instances(o *api.Organization, app string) (*Instances, error) {
func (c *Client) Instances(ctx context.Context, o *api.Organization, app string) (*Instances, error) {
var instances *Instances

err := c.withConnection(func(conn net.Conn) error {
err := c.withConnection(ctx, func(conn net.Conn) error {
writef(conn, "instances %s %s", o.Slug, app)

// this goes out to the network; don't time it out aggressively
Expand Down Expand Up @@ -180,8 +214,8 @@ type Dialer struct {
client *Client
}

func (c *Client) Dialer(o *api.Organization) (*Dialer, error) {
if err := c.Establish(o.Slug); err != nil {
func (c *Client) Dialer(ctx context.Context, o *api.Organization) (*Dialer, error) {
if err := c.Establish(ctx, o.Slug); err != nil {
return nil, err
}

Expand Down
5 changes: 3 additions & 2 deletions pkg/agent/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package agent

import (
"context"
"fmt"
"os/exec"
"syscall"
Expand All @@ -11,7 +12,7 @@ import (
"github.com/superfly/flyctl/api"
)

func StartDaemon(api *api.Client, command string) (*Client, error) {
func StartDaemon(ctx context.Context, api *api.Client, command string) (*Client, error) {
cmd := exec.Command(command, "agent", "daemon-start")
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Expand All @@ -29,7 +30,7 @@ func StartDaemon(api *api.Client, command string) (*Client, error) {

c, err := DefaultClient(api)
if err == nil {
_, err := c.Ping()
_, err := c.Ping(ctx)
if err == nil {
return c, nil
}
Expand Down

0 comments on commit 9f6e14e

Please sign in to comment.