-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #93 from vladimirvivien/command-func-copy
Starlark - Implement Command Function `copy_from()`
- Loading branch information
Showing
17 changed files
with
1,053 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) 2020 VMware, Inc. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package ssh | ||
|
||
import ( | ||
"os" | ||
"testing" | ||
|
||
"github.com/sirupsen/logrus" | ||
|
||
testcrashd "github.com/vmware-tanzu/crash-diagnostics/testing" | ||
) | ||
|
||
var ( | ||
testSSHPort = testcrashd.NextSSHPort() | ||
testMaxRetries = 30 | ||
) | ||
|
||
func TestMain(m *testing.M) { | ||
testcrashd.Init() | ||
|
||
sshSvr := testcrashd.NewSSHServer(testcrashd.NextSSHContainerName(), testSSHPort) | ||
logrus.Debug("Attempting to start SSH server") | ||
if err := sshSvr.Start(); err != nil { | ||
logrus.Error(err) | ||
os.Exit(1) | ||
} | ||
|
||
testResult := m.Run() | ||
|
||
logrus.Debug("Stopping SSH server...") | ||
if err := sshSvr.Stop(); err != nil { | ||
logrus.Error(err) | ||
os.Exit(1) | ||
} | ||
|
||
os.Exit(testResult) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright (c) 2020 VMware, Inc. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package ssh | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
"strings" | ||
"time" | ||
|
||
"github.com/sirupsen/logrus" | ||
"github.com/vladimirvivien/echo" | ||
"k8s.io/apimachinery/pkg/util/wait" | ||
) | ||
|
||
// CopyFrom copies one or more files using SCP from remote host | ||
// and returns the paths of files that were successfully copied. | ||
func CopyFrom(args SSHArgs, rootDir string, sourcePath string) error { | ||
e := echo.New() | ||
prog := e.Prog.Avail("scp") | ||
if len(prog) == 0 { | ||
return fmt.Errorf("scp program not found") | ||
} | ||
|
||
targetPath := filepath.Join(rootDir, sourcePath) | ||
targetDir := filepath.Dir(targetPath) | ||
pathDir, pathFile := filepath.Split(sourcePath) | ||
if strings.Index(pathFile, "*") != -1 { | ||
targetPath = filepath.Join(rootDir, pathDir) | ||
targetDir = targetPath | ||
} | ||
|
||
if err := os.MkdirAll(targetDir, 0744); err != nil && !os.IsExist(err) { | ||
return err | ||
} | ||
|
||
sshCmd, err := makeSCPCmdStr(prog, args, sourcePath) | ||
if err != nil { | ||
logrus.Debug() | ||
} | ||
|
||
effectiveCmd := fmt.Sprintf(`%s "%s"`, sshCmd, targetPath) | ||
logrus.Debug("scp: ", effectiveCmd) | ||
|
||
maxRetries := args.MaxRetries | ||
if maxRetries == 0 { | ||
maxRetries = 10 | ||
} | ||
retries := wait.Backoff{Steps: maxRetries, Duration: time.Millisecond * 80, Jitter: 0.1} | ||
if err := wait.ExponentialBackoff(retries, func() (bool, error) { | ||
p := e.RunProc(effectiveCmd) | ||
if p.Err() != nil { | ||
logrus.Warn(fmt.Sprintf("scp: failed to connect to %s: error '%s %s': retrying connection", args.Host, p.Err(), p.Result())) | ||
return false, nil | ||
} | ||
return true, nil // worked | ||
}); err != nil { | ||
logrus.Debugf("scp failed after %d tries", maxRetries) | ||
return fmt.Errorf("scp: failed after %d attempt(s): %s", maxRetries, err) | ||
} | ||
|
||
logrus.Debugf("scp: copied %s", sourcePath) | ||
return nil | ||
} | ||
|
||
func makeSCPCmdStr(progName string, args SSHArgs, sourcePath string) (string, error) { | ||
if args.User == "" { | ||
return "", fmt.Errorf("scp: user is required") | ||
} | ||
if args.Host == "" { | ||
return "", fmt.Errorf("scp: host is required") | ||
} | ||
|
||
if args.ProxyJump != nil { | ||
if args.ProxyJump.User == "" || args.ProxyJump.Host == "" { | ||
return "", fmt.Errorf("scp: jump user and host are required") | ||
} | ||
} | ||
|
||
scpCmdPrefix := func() string { | ||
return fmt.Sprintf("%s -rpq -o StrictHostKeyChecking=no", progName) | ||
} | ||
|
||
pkPath := func() string { | ||
if args.PrivateKeyPath != "" { | ||
return fmt.Sprintf("-i %s", args.PrivateKeyPath) | ||
} | ||
return "" | ||
} | ||
|
||
port := func() string { | ||
if args.Port == "" { | ||
return "-P 22" | ||
} | ||
return fmt.Sprintf("-P %s", args.Port) | ||
} | ||
|
||
proxyJump := func() string { | ||
if args.ProxyJump != nil { | ||
return fmt.Sprintf("-J %s@%s", args.ProxyJump.User, args.ProxyJump.Host) | ||
} | ||
return "" | ||
} | ||
// build command as | ||
// scp -i <pkpath> -P <port> -J <proxyjump> user@host:path | ||
cmd := fmt.Sprintf( | ||
`%s %s %s %s %s@%s:%s`, | ||
scpCmdPrefix(), pkPath(), port(), proxyJump(), args.User, args.Host, sourcePath, | ||
) | ||
return cmd, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
// Copyright (c) 2020 VMware, Inc. All Rights Reserved. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package ssh | ||
|
||
import ( | ||
"io/ioutil" | ||
"os" | ||
"os/user" | ||
"path/filepath" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestCopy(t *testing.T) { | ||
homeDir, err := os.UserHomeDir() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
usr, err := user.Current() | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
pkPath := filepath.Join(homeDir, ".ssh/id_rsa") | ||
sshArgs := SSHArgs{User: usr.Username, PrivateKeyPath: pkPath, Host: "127.0.0.1", Port: testSSHPort, MaxRetries: testMaxRetries} | ||
tests := []struct { | ||
name string | ||
sshArgs SSHArgs | ||
rootDir string | ||
remoteFiles map[string]string | ||
srcFile string | ||
fileContent string | ||
}{ | ||
{ | ||
name: "copy single file", | ||
sshArgs: sshArgs, | ||
rootDir: "/tmp/crashd", | ||
remoteFiles: map[string]string{"foo.txt": "FooBar"}, | ||
srcFile: "foo.txt", | ||
fileContent: "FooBar", | ||
}, | ||
{ | ||
name: "copy single file in dir", | ||
sshArgs: sshArgs, | ||
rootDir: "/tmp/crashd", | ||
remoteFiles: map[string]string{"foo/bar.txt": "FooBar"}, | ||
srcFile: "foo/bar.txt", | ||
fileContent: "FooBar", | ||
}, | ||
{ | ||
name: "copy dir", | ||
sshArgs: sshArgs, | ||
rootDir: "/tmp/crashd", | ||
remoteFiles: map[string]string{"bar/foo.csv": "FooBar", "bar/bar.txt": "BarBar"}, | ||
srcFile: "bar/", | ||
fileContent: "FooBar", | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
defer func() { | ||
for file, _ := range test.remoteFiles { | ||
RemoveTestSSHFile(t, test.sshArgs, file) | ||
} | ||
|
||
if err := os.RemoveAll(test.rootDir); err != nil { | ||
t.Fatal(err) | ||
} | ||
}() | ||
|
||
// setup remote files | ||
for file, content := range test.remoteFiles { | ||
MakeTestSSHFile(t, test.sshArgs, file, content) | ||
} | ||
|
||
if err := CopyFrom(test.sshArgs, test.rootDir, test.srcFile); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
expectedPath := filepath.Join(test.rootDir, test.srcFile) | ||
finfo, err := os.Stat(expectedPath) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if finfo.IsDir() { | ||
finfos, err := ioutil.ReadDir(expectedPath) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if len(finfos) < len(test.remoteFiles) { | ||
t.Errorf("expecting %d copied files, got %d", len(finfos), len(test.remoteFiles)) | ||
} | ||
} else { | ||
if getTestFileContent(t, expectedPath) != test.fileContent { | ||
t.Error("unexpected file content") | ||
} | ||
} | ||
|
||
}) | ||
} | ||
} | ||
|
||
func TestMakeSCPCmdStr(t *testing.T) { | ||
tests := []struct { | ||
name string | ||
args SSHArgs | ||
cmdStr string | ||
source string | ||
shouldFail bool | ||
}{ | ||
{ | ||
name: "user and host", | ||
args: SSHArgs{User: "sshuser", Host: "local.host"}, | ||
source: "/tmp/any", | ||
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -P 22 sshuser@local.host:/tmp/any", | ||
}, | ||
{ | ||
name: "user host and pkpath", | ||
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path"}, | ||
source: "/foo/bar", | ||
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 sshuser@local.host:/foo/bar", | ||
}, | ||
{ | ||
name: "user host pkpath and proxy", | ||
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path", ProxyJump: &ProxyJumpArgs{User: "juser", Host: "jhost"}}, | ||
source: "userFile", | ||
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 -J juser@jhost sshuser@local.host:userFile", | ||
}, | ||
{ | ||
name: "missing host", | ||
args: SSHArgs{User: "sshuser"}, | ||
shouldFail: true, | ||
}, | ||
} | ||
|
||
for _, test := range tests { | ||
t.Run(test.name, func(t *testing.T) { | ||
result, err := makeSCPCmdStr("scp", test.args, test.source) | ||
if err != nil && !test.shouldFail { | ||
t.Fatal(err) | ||
} | ||
cmdFields := strings.Fields(test.cmdStr) | ||
resultFields := strings.Fields(result) | ||
|
||
for i := range cmdFields { | ||
if cmdFields[i] != resultFields[i] { | ||
t.Fatalf("unexpected command string element: %s vs. %s", cmdFields, resultFields) | ||
} | ||
} | ||
}) | ||
} | ||
} |
Oops, something went wrong.