Skip to content

Commit

Permalink
Merge pull request #189 from vladimirvivien/copy_to-command
Browse files Browse the repository at this point in the history
Implements support for copy_to() starlark function
  • Loading branch information
vladimirvivien committed Nov 6, 2020
2 parents 7e95157 + f6492f4 commit 3e9c3f5
Show file tree
Hide file tree
Showing 11 changed files with 769 additions and 121 deletions.
3 changes: 3 additions & 0 deletions ssh/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (agent *agent) Stop() error {

logrus.Debugf("stopping the ssh-agent with Pid: %s", agent.Pid)
p := echo.New().Env(agent.GetEnvVariables()).RunProc("ssh-agent -k")
logrus.Debugf("ssh-agent stopped: %s", p.Result())

return p.Err()
}
Expand All @@ -106,6 +107,7 @@ func StartAgent() (Agent, error) {
return nil, fmt.Errorf("ssh-agent not found")
}

logrus.Debugf("starting %s", sshAgentCmd)
p := e.RunProc(fmt.Sprintf("%s -s", sshAgentCmd))
if p.Err() != nil {
return nil, errors.Wrap(p.Err(), "failed to start ssh agent")
Expand All @@ -119,6 +121,7 @@ func StartAgent() (Agent, error) {
return nil, err
}

logrus.Debugf("ssh-agent started %v", agentInfo)
return agentFromInfo(agentInfo), nil
}

Expand Down
81 changes: 69 additions & 12 deletions ssh/scp.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ func CopyFrom(args SSHArgs, agent Agent, rootDir string, sourcePath string) erro
return err
}

sshCmd, err := makeSCPCmdStr(prog, args, sourcePath)
sshCmd, err := makeSCPCmdStr(prog, args)
if err != nil {
return fmt.Errorf("scp: failed to build command string: %s", err)
return fmt.Errorf("scp: copyFrom: failed to build command string: %s", err)
}

effectiveCmd := fmt.Sprintf(`%s %s`, sshCmd, targetPath)
logrus.Debug("scp: ", effectiveCmd)
effectiveCmd := fmt.Sprintf(`%s %s`, sshCmd, getCopyFromSourceTarget(args, sourcePath, targetPath))
logrus.Debugf("scp: copFrom: cmd: [%s]", effectiveCmd)

if agent != nil {
logrus.Debugf("Adding agent info: %s", agent.GetEnvVariables())
logrus.Debugf("scp: copyFrom: adding agent info: %s", agent.GetEnvVariables())
e = e.Env(agent.GetEnvVariables())
}

Expand All @@ -57,20 +57,69 @@ func CopyFrom(args SSHArgs, agent Agent, rootDir string, sourcePath string) erro
if err := wait.ExponentialBackoff(retries, func() (bool, error) {
p := e.RunProc(effectiveCmd)
if p.Err() != nil {
logrus.Warn(fmt.Sprintf("scp: failed to connect to %s: error '%s %s': retrying connection", args.Host, p.Err(), p.Result()))
logrus.Warn(fmt.Sprintf("scp: copyFrom: failed to connect to %s: '%s %s': retrying connection", args.Host, p.Err(), p.Result()))
return false, nil
}
return true, nil // worked
}); err != nil {
logrus.Debugf("scp failed after %d tries", maxRetries)
return fmt.Errorf("scp: failed after %d attempt(s): %s", maxRetries, err)
return fmt.Errorf("scp: copyFrom: failed after %d attempt(s): %s", maxRetries, err)
}

logrus.Debugf("scp: copied %s", sourcePath)
logrus.Debugf("scp: copyFrom: copied %s", sourcePath)
return nil
}

func makeSCPCmdStr(progName string, args SSHArgs, sourcePath string) (string, error) {
// CopyTo copies one or more files using SCP from local machine to
// remote host.
func CopyTo(args SSHArgs, agent Agent, sourcePath, targetPath string) error {
e := echo.New()
prog := e.Prog.Avail("scp")
if len(prog) == 0 {
return fmt.Errorf("scp program not found")
}

if len(sourcePath) == 0 {
return fmt.Errorf("scp: copyTo: missing source path")
}

if len(targetPath) == 0 {
return fmt.Errorf("scp: copyTo: missing target path")
}

sshCmd, err := makeSCPCmdStr(prog, args)
if err != nil {
return fmt.Errorf("scp: copyTo: failed to build command string: %s", err)
}

effectiveCmd := fmt.Sprintf(`%s %s`, sshCmd, getCopyToSourceTarget(args, sourcePath, targetPath))
logrus.Debugf("scp: copyTo: cmd: [%s]", effectiveCmd)

if agent != nil {
logrus.Debugf("scp: adding agent info: %s", agent.GetEnvVariables())
e = e.Env(agent.GetEnvVariables())
}

maxRetries := args.MaxRetries
if maxRetries == 0 {
maxRetries = 10
}
retries := wait.Backoff{Steps: maxRetries, Duration: time.Millisecond * 80, Jitter: 0.1}
if err := wait.ExponentialBackoff(retries, func() (bool, error) {
p := e.RunProc(effectiveCmd)
if p.Err() != nil {
logrus.Warn(fmt.Sprintf("scp: failed to connect to %s: '%s %s': retrying connection", args.Host, p.Err(), p.Result()))
return false, nil
}
return true, nil // worked
}); err != nil {
return fmt.Errorf("scp: copyTo: failed after %d attempt(s): %s", maxRetries, err)
}

logrus.Debugf("scp: copyTo: copied %s -> %s", sourcePath, targetPath)
return nil
}

func makeSCPCmdStr(progName string, args SSHArgs) (string, error) {
if args.User == "" {
return "", fmt.Errorf("scp: user is required")
}
Expand Down Expand Up @@ -111,8 +160,16 @@ func makeSCPCmdStr(progName string, args SSHArgs, sourcePath string) (string, er
// build command as
// scp -i <pkpath> -P <port> -J <proxyjump> user@host:path
cmd := fmt.Sprintf(
`%s %s %s %s %s@%s:%s`,
scpCmdPrefix(), pkPath(), port(), proxyJump(), args.User, args.Host, sourcePath,
`%s %s %s %s`,
scpCmdPrefix(), pkPath(), port(), proxyJump(),
)
return cmd, nil
}

func getCopyFromSourceTarget(args SSHArgs, sourcePath, targetPath string) string {
return fmt.Sprintf("%s@%s:%s %s", args.User, args.Host, sourcePath, targetPath)
}

func getCopyToSourceTarget(args SSHArgs, sourcePath, targetPath string) string {
return fmt.Sprintf("%s %s@%s:%s", sourcePath, args.User, args.Host, targetPath)
}
171 changes: 116 additions & 55 deletions ssh/scp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)

func TestCopy(t *testing.T) {
func TestCopyFrom(t *testing.T) {
tests := []struct {
name string
sshArgs SSHArgs
Expand Down Expand Up @@ -45,13 +46,13 @@ func TestCopy(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
defer func() {
for file := range test.remoteFiles {
RemoveTestSSHFile(t, test.sshArgs, file)
RemoveRemoteTestSSHFile(t, test.sshArgs, file)
}
}()

// setup fake remote files
// setup fake files
for file, content := range test.remoteFiles {
MakeTestSSHFile(t, test.sshArgs, file, content)
MakeRemoteTestSSHFile(t, test.sshArgs, file, content)
}

if err := CopyFrom(test.sshArgs, nil, support.TmpDirRoot(), test.srcFile); err != nil {
Expand Down Expand Up @@ -82,54 +83,114 @@ func TestCopy(t *testing.T) {
}
}

//
//func TestMakeSCPCmdStr(t *testing.T) {
// tests := []struct {
// name string
// args SSHArgs
// cmdStr string
// source string
// shouldFail bool
// }{
// {
// name: "user and host",
// args: SSHArgs{User: "sshuser", Host: "local.host"},
// source: "/tmp/any",
// cmdStr: "scp -rpq -o StrictHostKeyChecking=no -P 22 sshuser@local.host:/tmp/any",
// },
// {
// name: "user host and pkpath",
// args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path"},
// source: "/foo/bar",
// cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 sshuser@local.host:/foo/bar",
// },
// {
// name: "user host pkpath and proxy",
// args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path", ProxyJump: &ProxyJumpArgs{User: "juser", Host: "jhost"}},
// source: "userFile",
// cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 -J juser@jhost sshuser@local.host:userFile",
// },
// {
// name: "missing host",
// args: SSHArgs{User: "sshuser"},
// shouldFail: true,
// },
// }
//
// for _, test := range tests {
// t.Run(test.name, func(t *testing.T) {
// result, err := makeSCPCmdStr("scp", test.args, test.source)
// if err != nil && !test.shouldFail {
// t.Fatal(err)
// }
// cmdFields := strings.Fields(test.cmdStr)
// resultFields := strings.Fields(result)
//
// for i := range cmdFields {
// if cmdFields[i] != resultFields[i] {
// t.Fatalf("unexpected command string element: %s vs. %s", cmdFields, resultFields)
// }
// }
// })
// }
//}
func TestCopyTo(t *testing.T) {
tests := []struct {
name string
sshArgs SSHArgs
localFiles map[string]string
file string
fileContent string
}{
{
name: "copy single file to remote",
sshArgs: testSSHArgs,
localFiles: map[string]string{"local-foo.txt": "FooBar"},
file: "local-foo.txt",
fileContent: "FooBar",
},
{
name: "copy single file in dir to remote",
sshArgs: testSSHArgs,
localFiles: map[string]string{"local-foo/local-bar.txt": "FooBar"},
file: "local-foo/local-bar.txt",
fileContent: "FooBar",
},
{
name: "copy dir entire dir to remote",
sshArgs: testSSHArgs,
localFiles: map[string]string{"local-bar/local-foo.csv": "FooBar", "local-bar/local-bar.txt": "BarBar"},
file: "local-bar/",
fileContent: "FooBar",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer func() {
for file := range test.localFiles {
RemoveLocalTestFile(t, filepath.Join(support.TmpDirRoot(), file))
RemoveRemoteTestSSHFile(t, test.sshArgs, file)
}
}()

// setup fake local files
for file, content := range test.localFiles {
MakeLocalTestFile(t, filepath.Join(support.TmpDirRoot(), file), content)
}

// create remote dir if needed
// setup remote dir if needed
MakeRemoteTestSSHDir(t, test.sshArgs, test.file)

sourceFile := filepath.Join(support.TmpDirRoot(), test.file)
if err := CopyTo(test.sshArgs, nil, sourceFile, test.file); err != nil {
t.Fatal(err)
}

// validate copied files/dir
AssertRemoteTestSSHFile(t, test.sshArgs, test.file)

})
}
}

func TestMakeSCPCmdStr(t *testing.T) {
tests := []struct {
name string
args SSHArgs
cmdStr string
source string
shouldFail bool
}{
{
name: "default",
args: SSHArgs{User: "sshuser", Host: "local.host"},
source: "/tmp/any",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -P 22",
},
{
name: "pkpath",
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path"},
source: "/foo/bar",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22",
},
{
name: "pkpath and proxy",
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path", ProxyJump: &ProxyJumpArgs{User: "juser", Host: "jhost"}},
source: "userFile",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 -J juser@jhost",
},
{
name: "missing host",
args: SSHArgs{User: "sshuser"},
shouldFail: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := makeSCPCmdStr("scp", test.args)
if err != nil && !test.shouldFail {
t.Fatal(err)
}
cmdFields := strings.Fields(test.cmdStr)
resultFields := strings.Fields(result)

for i := range cmdFields {
if cmdFields[i] != resultFields[i] {
t.Fatalf("unexpected command string element: %s vs. %s", cmdFields, resultFields)
}
}
})
}
}
Loading

0 comments on commit 3e9c3f5

Please sign in to comment.