Skip to content

Commit

Permalink
transport&improved ssh/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kellerza committed May 25, 2021
1 parent 618d547 commit 1072d21
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 169 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ containerlab
.vscode/
.DS_Store
__rd*
tests/out
tests/out

284 changes: 202 additions & 82 deletions clab/config/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,125 +10,245 @@ import (
"golang.org/x/crypto/ssh"
)

type Session struct {
type SshSession struct {
In io.Reader
Out io.WriteCloser
Session *ssh.Session
}

func NewSession(username, password, host string) (*Session, error) {

sshConfig := &ssh.ClientConfig{
User: username,
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// The reply the execute command and the prompt.
type SshReply struct{ result, prompt string }

// SshTransport setting needs to be set before calling Connect()
// SshTransport implement the Transport interface
type SshTransport struct {
// Channel used to read. Can use Expect to Write & read wit timeout
in chan SshReply
// SSH Session
ses *SshSession
// Contains the first read after connecting
BootMsg SshReply

// SSH parameters used in connect
// defualt: 22
Port int
// SSH Options
// required!
SshConfig *ssh.ClientConfig
// Character to split the incoming stream (#/$/>)
// default: #
PromptChar string
// Prompt parsing function. Default return the last line of the #
// default: DefaultPrompParse
PromptParse func(in *string) *SshReply
}

connection, err := ssh.Dial("tcp", host, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %s", err)
}
session, err := connection.NewSession()
if err != nil {
return nil, err
}
sshIn, err := session.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("session stdout: %s", err)
}
sshOut, err := session.StdinPipe()
if err != nil {
return nil, fmt.Errorf("session stdin: %s", err)
// This is the default prompt parse function used by SSH transport
func DefaultPrompParse(in *string) *SshReply {
n := strings.LastIndex(*in, "\n")
res := (*in)[:n]
n = strings.LastIndex(res, "\n")
if n < 0 {
n = 0
}

if err := session.Shell(); err != nil {
session.Close()
return nil, fmt.Errorf("session shell: %s", err)
return &SshReply{
result: (*in)[:n],
prompt: (*in)[n:] + "#",
}

return &Session{
Session: session,
In: sshIn,
Out: sshOut,
}, nil
}

func (ses *Session) Close() {
log.Debugf("Closing sesison")
ses.Session.Close()
}

func (ses *Session) Expect(send, expect string, timeout int) string {
rChan := make(chan string)
// The channel does
func (t *SshTransport) InChannel() {
// Ensure we have one working channel
t.in = make(chan SshReply)

// setup a buffered string channel
go func() {
buf := make([]byte, 1024)
n, err := ses.In.Read(buf) //this reads the ssh terminal
tmpStr := ""
tmpS := ""
n, err := t.ses.In.Read(buf) //this reads the ssh terminal
if err == nil {
tmpStr = string(buf[:n])
tmpS = string(buf[:n])
}
for (err == nil) && (!strings.Contains(tmpStr, expect)) {
n, err = ses.In.Read(buf)
tmpStr += string(buf[:n])
for err == nil {

if strings.Contains(tmpS, "#") {
parts := strings.Split(tmpS, "#")
li := len(parts) - 1
for i := 0; i < li; i++ {
t.in <- *t.PromptParse(&parts[i])
}
tmpS = parts[li]
}
n, err = t.ses.In.Read(buf)
tmpS += string(buf[:n])
}
log.Debugf("In Channel closing: %v", err)
t.in <- SshReply{
result: tmpS,
prompt: "",
}
rChan <- tmpStr
}()

time.Sleep(10 * time.Millisecond)
t.BootMsg = t.Run("", 15)
log.Infof("%s\n", t.BootMsg.result)
log.Debugf("%s\n", t.BootMsg.prompt)
}

if send != "" {
ses.Write(send)
// Run a single command and wait for the reply
func (t *SshTransport) Run(command string, timeout int) SshReply {
if command != "" {
t.ses.Writeln(command)
}

// Read from the channel with a timeout
select {
case ret := <-rChan:
case ret := <-t.in:
if ret.result != "" {
rr := strings.Trim(ret.result, " \n")

if strings.HasPrefix(rr, command) {
rr = rr[len(command):]
fmt.Println(rr)
} else {
log.Errorf("'%s' != '%s'\n--", rr, command)
if !strings.Contains(rr, command) {
log.Errorln("YY")
t.Run("", 10)
}
}
}
return ret
case <-time.After(time.Duration(timeout) * time.Second):
log.Warnf("timeout waiting for %s", expect)
log.Warnf("timeout waiting for prompt: %s", command)
}
return ""
return SshReply{}
}

func (ses *Session) Write(command string) (int, error) {
returnCode, err := ses.Out.Write([]byte(command + "\r"))
return returnCode, err
// Write a config snippet (a set of commands)
// Session NEEDS to be configurable for other kinds
// Part of the Transport interface
func (t *SshTransport) Write(snip *ConfigSnippet) error {
t.Run("/configure global", 2)
t.Run("discard", 2)

c, b := 0, 0
for _, l := range snip.Lines() {
l = strings.TrimSpace(l)
if l == "" || strings.HasPrefix(l, "#") {
continue
}
c += 1
b += len(l)
t.Run(l, 3)
}

// Commit
commit := t.Run("commit", 10)
//commit += t.Run("", 10)
log.Infof("COMMIT %s - %d lines %d bytes\n%s", snip, c, b, commit)
return nil
}

// send multiple config to a device
func SendConfig(cs []*ConfigSnippet) error {
host := fmt.Sprintf("%s:22", cs[0].TargetNode.LongName)
// Connect to a host
// Part of the Transport interface
func (t *SshTransport) Connect(host string) error {
// Assign Default Values
if t.PromptParse == nil {
t.PromptParse = DefaultPrompParse
}
if t.PromptChar == "" {
t.PromptChar = "#"
}
if t.Port == 0 {
t.Port = 22
}
if t.SshConfig == nil {
return fmt.Errorf("require auth credentials in SshConfig")
}

ses, err := NewSession("admin", "admin", host)
if err != nil {
// Start some client config
host = fmt.Sprintf("%s:%d", host, t.Port)
//sshConfig := &ssh.ClientConfig{}
//SshConfigWithUserNamePassword(sshConfig, "admin", "admin")

ses_, err := NewSshSession(host, t.SshConfig)
if err != nil || ses_ == nil {
return fmt.Errorf("cannot connect to %s: %s", host, err)
}
defer ses.Close()
t.ses = ses_

log.Infof("Connected to %s\n", host)
t.InChannel()
//Read to first prompt
ses.Expect("", "#", 1)
// Enter config mode
ses.Expect("/configure global", "#", 10)
ses.Expect("discard", "#", 10)

for _, snip := range cs {
for _, l := range snip.Config {
l = strings.TrimSpace(l)
if l == "" || strings.HasPrefix(l, "#") {
continue
}
ses.Expect(l, "#", 3)
// fmt.Write("((%s))", res)
}
return nil
}

// Commit
commit := ses.Expect("commit", "commit", 10)
commit += ses.Expect("", "#", 10)
log.Infof("COMMIT %s\n%s", snip, commit)
// Close the Session and channels
// Part of the Transport interface
func (t *SshTransport) Close() {
// if t.in != nil {
// close(t.in)
// t.in = nil
// }
//t.ses.Close()
}

// Add a basic username & password to a config.
// Will initilize the config if required
func SshConfigWithUserNamePassword(config *ssh.ClientConfig, username, password string) {
if config == nil {
config = &ssh.ClientConfig{}
}
config.User = username
if config.Auth == nil {
config.Auth = []ssh.AuthMethod{}
}
config.Auth = append(config.Auth, ssh.Password(password))
config.HostKeyCallback = ssh.InsecureIgnoreHostKey()
}

return nil
// Create a new SSH session (Dial, open in/out pipes and start the shell)
// pass the authntication details in sshConfig
func NewSshSession(host string, sshConfig *ssh.ClientConfig) (*SshSession, error) {
if !strings.Contains(host, ":") {
return nil, fmt.Errorf("include the port in the host: %s", host)
}

connection, err := ssh.Dial("tcp", host, sshConfig)
if err != nil {
return nil, fmt.Errorf("failed to connect: %s", err)
}
session, err := connection.NewSession()
if err != nil {
return nil, err
}
sshIn, err := session.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("session stdout: %s", err)
}
sshOut, err := session.StdinPipe()
if err != nil {
return nil, fmt.Errorf("session stdin: %s", err)
}
if err := session.Shell(); err != nil {
session.Close()
return nil, fmt.Errorf("session shell: %s", err)
}

return &SshSession{
Session: session,
In: sshIn,
Out: sshOut,
}, nil
}

func (ses *SshSession) Writeln(command string) (int, error) {
return ses.Out.Write([]byte(command + "\r"))
}

func (ses *SshSession) Close() {
log.Debugf("Closing sesison")
ses.Session.Close()
}
Loading

0 comments on commit 1072d21

Please sign in to comment.