Skip to content

Commit

Permalink
Merge pull request #93 from vladimirvivien/command-func-copy
Browse files Browse the repository at this point in the history
Starlark - Implement Command Function `copy_from()`
  • Loading branch information
vladimirvivien committed Jun 30, 2020
2 parents a57f337 + d17fad8 commit 7ea48e8
Show file tree
Hide file tree
Showing 17 changed files with 1,053 additions and 111 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/compile-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ jobs:

- name: test
run: |
sudo ufw allow 2222
sudo ufw allow 2424
sudo ufw allow 2200:2300/tcp
sudo ufw enable
sudo ufw status verbose
mkdir -p ~/.ssh
chmod 765 ~/.ssh
cp testing/keys/* ~/.ssh/
GO111MODULE=on go get sigs.k8s.io/kind@v0.7.0
GO111MODULE=on go test -timeout 600s -v ./...
GO111MODULE=on go test -timeout 600s -v -p 1 ./...
28 changes: 1 addition & 27 deletions ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,10 @@ import (
"path/filepath"
"strings"
"testing"

"github.com/sirupsen/logrus"
testcrashd "github.com/vmware-tanzu/crash-diagnostics/testing"
)

const (
testSSHPort = "2424"
)

func TestMain(m *testing.M) {
testcrashd.Init()

sshSvr := testcrashd.NewSSHServer("test-sshd-sshclient", testSSHPort)
logrus.Debug("Attempting to start SSH server")
if err := sshSvr.Start(); err != nil {
logrus.Error(err)
os.Exit(1)
}

testResult := m.Run()

logrus.Debug("Stopping SSH server...")
if err := sshSvr.Stop(); err != nil {
logrus.Error(err)
os.Exit(1)
}

os.Exit(testResult)
}
func TestSSHClient(t *testing.T) {
t.Skip("Skipping ssh client tests")
sshHost := fmt.Sprintf("127.0.0.1:%s", testSSHPort)
homeDir, err := os.UserHomeDir()
if err != nil {
Expand Down
39 changes: 39 additions & 0 deletions ssh/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2020 VMware, Inc. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package ssh

import (
"os"
"testing"

"github.com/sirupsen/logrus"

testcrashd "github.com/vmware-tanzu/crash-diagnostics/testing"
)

var (
testSSHPort = testcrashd.NextSSHPort()
testMaxRetries = 30
)

func TestMain(m *testing.M) {
testcrashd.Init()

sshSvr := testcrashd.NewSSHServer(testcrashd.NextSSHContainerName(), testSSHPort)
logrus.Debug("Attempting to start SSH server")
if err := sshSvr.Start(); err != nil {
logrus.Error(err)
os.Exit(1)
}

testResult := m.Run()

logrus.Debug("Stopping SSH server...")
if err := sshSvr.Stop(); err != nil {
logrus.Error(err)
os.Exit(1)
}

os.Exit(testResult)
}
113 changes: 113 additions & 0 deletions ssh/scp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2020 VMware, Inc. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package ssh

import (
"fmt"
"os"
"path/filepath"
"strings"
"time"

"github.com/sirupsen/logrus"
"github.com/vladimirvivien/echo"
"k8s.io/apimachinery/pkg/util/wait"
)

// CopyFrom copies one or more files using SCP from remote host
// and returns the paths of files that were successfully copied.
func CopyFrom(args SSHArgs, rootDir string, sourcePath string) error {
e := echo.New()
prog := e.Prog.Avail("scp")
if len(prog) == 0 {
return fmt.Errorf("scp program not found")
}

targetPath := filepath.Join(rootDir, sourcePath)
targetDir := filepath.Dir(targetPath)
pathDir, pathFile := filepath.Split(sourcePath)
if strings.Index(pathFile, "*") != -1 {
targetPath = filepath.Join(rootDir, pathDir)
targetDir = targetPath
}

if err := os.MkdirAll(targetDir, 0744); err != nil && !os.IsExist(err) {
return err
}

sshCmd, err := makeSCPCmdStr(prog, args, sourcePath)
if err != nil {
logrus.Debug()
}

effectiveCmd := fmt.Sprintf(`%s "%s"`, sshCmd, targetPath)
logrus.Debug("scp: ", effectiveCmd)

maxRetries := args.MaxRetries
if maxRetries == 0 {
maxRetries = 10
}
retries := wait.Backoff{Steps: maxRetries, Duration: time.Millisecond * 80, Jitter: 0.1}
if err := wait.ExponentialBackoff(retries, func() (bool, error) {
p := e.RunProc(effectiveCmd)
if p.Err() != nil {
logrus.Warn(fmt.Sprintf("scp: failed to connect to %s: error '%s %s': retrying connection", args.Host, p.Err(), p.Result()))
return false, nil
}
return true, nil // worked
}); err != nil {
logrus.Debugf("scp failed after %d tries", maxRetries)
return fmt.Errorf("scp: failed after %d attempt(s): %s", maxRetries, err)
}

logrus.Debugf("scp: copied %s", sourcePath)
return nil
}

func makeSCPCmdStr(progName string, args SSHArgs, sourcePath string) (string, error) {
if args.User == "" {
return "", fmt.Errorf("scp: user is required")
}
if args.Host == "" {
return "", fmt.Errorf("scp: host is required")
}

if args.ProxyJump != nil {
if args.ProxyJump.User == "" || args.ProxyJump.Host == "" {
return "", fmt.Errorf("scp: jump user and host are required")
}
}

scpCmdPrefix := func() string {
return fmt.Sprintf("%s -rpq -o StrictHostKeyChecking=no", progName)
}

pkPath := func() string {
if args.PrivateKeyPath != "" {
return fmt.Sprintf("-i %s", args.PrivateKeyPath)
}
return ""
}

port := func() string {
if args.Port == "" {
return "-P 22"
}
return fmt.Sprintf("-P %s", args.Port)
}

proxyJump := func() string {
if args.ProxyJump != nil {
return fmt.Sprintf("-J %s@%s", args.ProxyJump.User, args.ProxyJump.Host)
}
return ""
}
// build command as
// scp -i <pkpath> -P <port> -J <proxyjump> user@host:path
cmd := fmt.Sprintf(
`%s %s %s %s %s@%s:%s`,
scpCmdPrefix(), pkPath(), port(), proxyJump(), args.User, args.Host, sourcePath,
)
return cmd, nil
}
155 changes: 155 additions & 0 deletions ssh/scp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) 2020 VMware, Inc. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package ssh

import (
"io/ioutil"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
)

func TestCopy(t *testing.T) {
homeDir, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}

usr, err := user.Current()
if err != nil {
t.Fatal(err)
}
pkPath := filepath.Join(homeDir, ".ssh/id_rsa")
sshArgs := SSHArgs{User: usr.Username, PrivateKeyPath: pkPath, Host: "127.0.0.1", Port: testSSHPort, MaxRetries: testMaxRetries}
tests := []struct {
name string
sshArgs SSHArgs
rootDir string
remoteFiles map[string]string
srcFile string
fileContent string
}{
{
name: "copy single file",
sshArgs: sshArgs,
rootDir: "/tmp/crashd",
remoteFiles: map[string]string{"foo.txt": "FooBar"},
srcFile: "foo.txt",
fileContent: "FooBar",
},
{
name: "copy single file in dir",
sshArgs: sshArgs,
rootDir: "/tmp/crashd",
remoteFiles: map[string]string{"foo/bar.txt": "FooBar"},
srcFile: "foo/bar.txt",
fileContent: "FooBar",
},
{
name: "copy dir",
sshArgs: sshArgs,
rootDir: "/tmp/crashd",
remoteFiles: map[string]string{"bar/foo.csv": "FooBar", "bar/bar.txt": "BarBar"},
srcFile: "bar/",
fileContent: "FooBar",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
defer func() {
for file, _ := range test.remoteFiles {
RemoveTestSSHFile(t, test.sshArgs, file)
}

if err := os.RemoveAll(test.rootDir); err != nil {
t.Fatal(err)
}
}()

// setup remote files
for file, content := range test.remoteFiles {
MakeTestSSHFile(t, test.sshArgs, file, content)
}

if err := CopyFrom(test.sshArgs, test.rootDir, test.srcFile); err != nil {
t.Fatal(err)
}

expectedPath := filepath.Join(test.rootDir, test.srcFile)
finfo, err := os.Stat(expectedPath)
if err != nil {
t.Fatal(err)
}

if finfo.IsDir() {
finfos, err := ioutil.ReadDir(expectedPath)
if err != nil {
t.Fatal(err)
}
if len(finfos) < len(test.remoteFiles) {
t.Errorf("expecting %d copied files, got %d", len(finfos), len(test.remoteFiles))
}
} else {
if getTestFileContent(t, expectedPath) != test.fileContent {
t.Error("unexpected file content")
}
}

})
}
}

func TestMakeSCPCmdStr(t *testing.T) {
tests := []struct {
name string
args SSHArgs
cmdStr string
source string
shouldFail bool
}{
{
name: "user and host",
args: SSHArgs{User: "sshuser", Host: "local.host"},
source: "/tmp/any",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -P 22 sshuser@local.host:/tmp/any",
},
{
name: "user host and pkpath",
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path"},
source: "/foo/bar",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 sshuser@local.host:/foo/bar",
},
{
name: "user host pkpath and proxy",
args: SSHArgs{User: "sshuser", Host: "local.host", PrivateKeyPath: "/pk/path", ProxyJump: &ProxyJumpArgs{User: "juser", Host: "jhost"}},
source: "userFile",
cmdStr: "scp -rpq -o StrictHostKeyChecking=no -i /pk/path -P 22 -J juser@jhost sshuser@local.host:userFile",
},
{
name: "missing host",
args: SSHArgs{User: "sshuser"},
shouldFail: true,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := makeSCPCmdStr("scp", test.args, test.source)
if err != nil && !test.shouldFail {
t.Fatal(err)
}
cmdFields := strings.Fields(test.cmdStr)
resultFields := strings.Fields(result)

for i := range cmdFields {
if cmdFields[i] != resultFields[i] {
t.Fatalf("unexpected command string element: %s vs. %s", cmdFields, resultFields)
}
}
})
}
}

0 comments on commit 7ea48e8

Please sign in to comment.