Skip to content

Commit

Permalink
Merge pull request #150 from scrapli/refactor/housekeeping
Browse files Browse the repository at this point in the history
refactor/channel-and-misc-improvements
  • Loading branch information
carlmontanari committed Aug 26, 2023
2 parents 3689f12 + 7d74411 commit 7cb3012
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 82 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
unit-test:
runs-on: ${{ matrix.os }}
strategy:
max-parallel: 8
max-parallel: 10
matrix:
os:
- ubuntu-latest
Expand All @@ -19,19 +19,20 @@ jobs:
- "1.17"
- "1.18"
- "1.19"
- "1.20"
steps:
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
fetch-depth: 1
- name: set up go ${{ matrix.go-version }}
uses: actions/setup-go@v4
with:
go-version: ${{ matrix.go-version }}
- name: lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.52
version: v1.54
args: --timeout 5m
- name: install gotestsum
run: go install gotest.tools/gotestsum@latest
Expand All @@ -45,7 +46,7 @@ jobs:
os:
- ubuntu-latest
go-version:
- "1.19"
- "1.20"
runtime:
- "docker"
needs:
Expand All @@ -54,7 +55,7 @@ jobs:
- name: checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
fetch-depth: 1
- name: set up go ${{ matrix.go-version }}
uses: actions/setup-go@v4
with:
Expand Down
4 changes: 2 additions & 2 deletions .golangci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ issues:
text: "package-comments"

run:
go: '1.19'
go: '1.20'
skip-dirs:
- private

output:
uniq-by-line: false

service:
golangci-lint-version: 1.52.x
golangci-lint-version: 1.54.x
95 changes: 88 additions & 7 deletions channel/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package channel

import (
"bytes"
"fmt"
"regexp"
"sync"
Expand All @@ -21,9 +22,17 @@ type authPatterns struct {
passphrase *regexp.Regexp
}

type sshErrorMessagePatterns struct {
offeredOptions *regexp.Regexp
badConfig *regexp.Regexp
}

var (
authPatternsInstance *authPatterns //nolint:gochecknoglobals
authPatternsInstanceOnce sync.Once //nolint:gochecknoglobals

sshErrorMessagePatternsInstance *sshErrorMessagePatterns //nolint:gochecknoglobals
sshErrorMessagePatternsOnce sync.Once //nolint:gochecknoglobals
)

func getAuthPatterns() *authPatterns {
Expand All @@ -38,6 +47,17 @@ func getAuthPatterns() *authPatterns {
return authPatternsInstance
}

func getSSHErrorMessagePatterns() *sshErrorMessagePatterns {
sshErrorMessagePatternsOnce.Do(func() {
sshErrorMessagePatternsInstance = &sshErrorMessagePatterns{
offeredOptions: regexp.MustCompile(`(?im)their offer: ([a-z0-9\-,]*)`),
badConfig: regexp.MustCompile(`(?im)bad configuration option: ([a-z0-9+=,]*)`),
}
})

return sshErrorMessagePatternsInstance
}

func (c *Channel) authenticateSSH(p, pp []byte) *result {
pCount := 0

Expand All @@ -46,22 +66,23 @@ func (c *Channel) authenticateSSH(p, pp []byte) *result {
var b []byte

for {
nb, err := c.ReadUntilAnyPrompt(
[]*regexp.Regexp{c.PromptPattern, c.PasswordPattern, c.PassphrasePattern},
)
nb, err := c.Read()
if err != nil {
return &result{nil, err}
}

b = append(b, nb...)

err = c.sshMessageHandler(b)
if err != nil {
return &result{nil, err}
}

if c.PromptPattern.Match(b) {
return &result{b, nil}
}

if c.PasswordPattern.Match(b) { //nolint:nestif
b = []byte{}

pCount++

if pCount > passwordSeenMax {
Expand All @@ -80,9 +101,10 @@ func (c *Channel) authenticateSSH(p, pp []byte) *result {
if err != nil {
return &result{nil, err}
}
} else if c.PassphrasePattern.Match(b) {
b = []byte{}

// reset the buffer so we don't re-read things and so we can find the prompt (hopefully)
b = []byte{}
} else if c.PassphrasePattern.Match(b) {
ppCount++

if ppCount > passphraseSeenMax {
Expand All @@ -105,6 +127,8 @@ func (c *Channel) authenticateSSH(p, pp []byte) *result {
if err != nil {
return &result{nil, err}
}

b = []byte{}
}
}
}
Expand Down Expand Up @@ -225,3 +249,60 @@ func (c *Channel) AuthenticateTelnet(u, p []byte) ([]byte, error) {
)
}
}

func (c *Channel) sshMessageHandler(b []byte) error { //nolint:gocyclo
var errorMessage string

normalizedB := bytes.ToLower(b)

switch {
case bytes.Contains(normalizedB, []byte("host key verification failed")):
errorMessage = "host key verification failed"
case bytes.Contains(normalizedB, []byte("operation timed out")) ||
bytes.Contains(normalizedB, []byte("connection timed out")):
errorMessage = "timed out connecting to host"
case bytes.Contains(normalizedB, []byte("no route to host")):
errorMessage = "no route to host"
case bytes.Contains(normalizedB, []byte("no matching")):
switch {
case bytes.Contains(normalizedB, []byte("no matching host key")):
errorMessage = "no matching host key found for host"
case bytes.Contains(normalizedB, []byte("no matching key exchange")):
errorMessage = "no matching key exchange found for host"
case bytes.Contains(normalizedB, []byte("no matching cipher")):
errorMessage = "no matching cipher found for host"
}

patterns := getSSHErrorMessagePatterns()

theirOffer := patterns.offeredOptions.FindSubmatch(b)
if len(theirOffer) > 0 {
errorMessage += fmt.Sprintf(", their offer: %s", theirOffer[0])
}
case bytes.Contains(normalizedB, []byte("bad configuration")):
errorMessage = "bad ssh configuration option(s) for host"

patterns := getSSHErrorMessagePatterns()

badOption := patterns.offeredOptions.FindSubmatch(b)
if len(badOption) > 0 {
errorMessage += fmt.Sprintf(", bad configuration option: %s", badOption[0])
}
case bytes.Contains(normalizedB, []byte("warning: unprotected private key file")):
errorMessage = "permissions for private key are too open"
case bytes.Contains(normalizedB, []byte("could not resolve hostname")):
errorMessage = "could not resolve hostname"
case bytes.Contains(normalizedB, []byte("permission denied")):
errorMessage = "permission denied"
}

if errorMessage != "" {
return fmt.Errorf(
"%w: encountered error output during in channel ssh authentication, error: '%s'",
util.ErrConnectionError,
errorMessage,
)
}

return nil
}
61 changes: 32 additions & 29 deletions channel/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@ const (
// DefaultTimeoutOpsSeconds is the default time value for operations -- 60 seconds.
DefaultTimeoutOpsSeconds = 60
// DefaultReadDelayMicroSeconds is the default value for the delay between reads of the
// transport -- 100 microseconds. Going very low is likely to lead to very high cpu and not
// transport -- 250 microseconds. Going very low is likely to lead to very high cpu and not
// yield any recognizable gains, so be careful changing this!
DefaultReadDelayMicroSeconds = 250
// DefaultReturnChar is the character used to send an "enter" key to the device, "\n".
DefaultReturnChar = "\n"
// DefaultPromptSearchDepth -- is the default depth to search for the prompt in the received
// bytes.
DefaultPromptSearchDepth = 1_000
redacted = "redacted"
readDelayDivisor = 1_000

redacted = "redacted"
readDelayDivisor = 1_000
)

var (
Expand Down Expand Up @@ -136,39 +137,41 @@ func (c *Channel) Open() (reterr error) {

go c.read()

if !c.AuthBypass {
var b []byte
if c.AuthBypass {
c.l.Debug("auth bypass is enabled, skipping in channel auth check")

return nil
}

authData := c.t.InChannelAuthData()
var b []byte

switch authData.Type {
case transport.InChannelAuthSSH:
c.l.Debug("transport requests in channel ssh auth, starting...")
authData := c.t.InChannelAuthData()

b, err = c.AuthenticateSSH(
[]byte(authData.Password),
[]byte(authData.PrivateKeyPassPhrase),
)
if err != nil {
return err
}
case transport.InChannelAuthTelnet:
c.l.Debug("transport requests in channel telnet auth, starting...")
switch authData.Type {
case transport.InChannelAuthSSH:
c.l.Debug("transport requests in channel ssh auth, starting...")

b, err = c.AuthenticateTelnet([]byte(authData.User), []byte(authData.Password))
if err != nil {
return err
}
b, err = c.AuthenticateSSH(
[]byte(authData.Password),
[]byte(authData.PrivateKeyPassPhrase),
)
if err != nil {
return err
}
case transport.InChannelAuthTelnet:
c.l.Debug("transport requests in channel telnet auth, starting...")

if len(b) > 0 {
// requeue any buffer data we get during in channel authentication back onto the
// read buffer. mostly this should only be relevant for netconf where we need to
// read the server capabilities.
c.Q.Requeue(b)
b, err = c.AuthenticateTelnet([]byte(authData.User), []byte(authData.Password))
if err != nil {
return err
}
} else {
c.l.Debug("auth bypass is enabled, skipping in channel auth check")
}

if len(b) > 0 {
// requeue any buffer data we get during in channel authentication back onto the
// read buffer. mostly this should only be relevant for netconf where we need to
// read the server capabilities.
c.Q.Requeue(b)
}

return nil
Expand Down
5 changes: 4 additions & 1 deletion channel/getprompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ func (c *Channel) GetPrompt() ([]byte, error) {

b, err = c.ReadUntilPrompt()

cr <- &result{b: b, err: err}
// we already know the pattern is in the buf, we just want ot re to yoink it out without
// any newlines or extra stuff we read (which shouldn't happen outside the initial
// connection but...)
cr <- &result{b: c.PromptPattern.Find(b), err: err}
}()

timer := time.NewTimer(c.TimeoutOps)
Expand Down
Loading

0 comments on commit 7cb3012

Please sign in to comment.