Skip to content

Commit

Permalink
Handle error properly during port forwarding initialization. (#2550)
Browse files Browse the repository at this point in the history
  • Loading branch information
cedric-appdirect committed Jun 7, 2024
1 parent e4f3811 commit f4ce8c5
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions port_forwarding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package testcontainers

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -231,20 +232,22 @@ func (sshdC *sshdContainer) exposeHostPort(ctx context.Context, ports ...int) er
go pw.Forward(ctx)
}

var err error

// continue when all port forwarders have created the connection
for _, pfw := range sshdC.portForwarders {
<-pfw.connectionCreated
err = errors.Join(err, <-pfw.connectionCreated)
}

return nil
return err
}

type PortForwarder struct {
sshDAddr string
sshConfig *ssh.ClientConfig
remotePort int
localPort int
connectionCreated chan struct{} // used to signal that the connection has been created, so the caller can proceed
connectionCreated chan error // used to signal that the connection has been created, so the caller can proceed
terminateChan chan struct{} // used to signal that the connection has been terminated
}

Expand All @@ -254,7 +257,7 @@ func NewPortForwarder(sshDAddr string, sshConfig *ssh.ClientConfig, remotePort,
sshConfig: sshConfig,
remotePort: remotePort,
localPort: localPort,
connectionCreated: make(chan struct{}),
connectionCreated: make(chan error),
terminateChan: make(chan struct{}),
}
}
Expand All @@ -267,18 +270,22 @@ func (pf *PortForwarder) Close(ctx context.Context) {
func (pf *PortForwarder) Forward(ctx context.Context) error {
client, err := ssh.Dial("tcp", pf.sshDAddr, pf.sshConfig)
if err != nil {
return fmt.Errorf("error dialing ssh server: %w", err)
err = fmt.Errorf("error dialing ssh server: %w", err)
pf.connectionCreated <- err
return err
}
defer client.Close()

listener, err := client.Listen("tcp", fmt.Sprintf("localhost:%d", pf.remotePort))
if err != nil {
return fmt.Errorf("error listening on remote port: %w", err)
err = fmt.Errorf("error listening on remote port: %w", err)
pf.connectionCreated <- err
return err
}
defer listener.Close()

// signal that the connection has been created
pf.connectionCreated <- struct{}{}
pf.connectionCreated <- nil

// check if the context or the terminateChan has been closed
select {
Expand Down

0 comments on commit f4ce8c5

Please sign in to comment.