-
Notifications
You must be signed in to change notification settings - Fork 26
/
scp.go
156 lines (127 loc) · 3.76 KB
/
scp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package commands
import (
"errors"
"fmt"
"os"
"os/exec"
"strings"
"github.com/docker/machine/cli"
"github.com/docker/machine/libmachine/host"
"github.com/docker/machine/libmachine/log"
"github.com/docker/machine/libmachine/persist"
)
var (
errMalformedInput = errors.New("The input was malformed")
errWrongNumberArguments = errors.New("Improper number of arguments")
)
var (
// TODO: possibly move this to ssh package
baseSSHArgs = []string{
"-o", "IdentitiesOnly=yes",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "LogLevel=quiet", // suppress "Warning: Permanently added '[localhost]:2022' (ECDSA) to the list of known hosts."
}
hostLoader HostLoader
)
// TODO: Remove this hack in favor of better strategy. Currently the
// HostLoader interface wraps the loadHost() function for easier testing.
type HostLoader interface {
LoadHost(persist.Store, string) (*host.Host, error)
}
type ScpHostLoader struct{}
func (s *ScpHostLoader) LoadHost(store persist.Store, name string) (*host.Host, error) {
return loadHost(store, name)
}
func getInfoForScpArg(hostAndPath string, store persist.Store) (*host.Host, string, []string, error) {
// TODO: What to do about colon in filepath?
splitInfo := strings.Split(hostAndPath, ":")
// Host path. e.g. "/tmp/foo"
if len(splitInfo) == 1 {
return nil, splitInfo[0], nil, nil
}
// Remote path. e.g. "machinename:/usr/bin/cmatrix"
if len(splitInfo) == 2 {
path := splitInfo[1]
host, err := hostLoader.LoadHost(store, splitInfo[0])
if err != nil {
return nil, "", nil, fmt.Errorf("Error loading host: %s", err)
}
args := []string{
"-i",
host.Driver.GetSSHKeyPath(),
}
return host, path, args, nil
}
return nil, "", nil, errMalformedInput
}
func generateLocationArg(host *host.Host, path string) (string, error) {
locationPrefix := ""
if host != nil {
ip, err := host.Driver.GetIP()
if err != nil {
return "", err
}
locationPrefix = fmt.Sprintf("%s@%s:", host.Driver.GetSSHUsername(), ip)
}
return locationPrefix + path, nil
}
func getScpCmd(src, dest string, sshArgs []string, store persist.Store) (*exec.Cmd, error) {
cmdPath, err := exec.LookPath("scp")
if err != nil {
return nil, errors.New("Error: You must have a copy of the scp binary locally to use the scp feature.")
}
srcHost, srcPath, srcOpts, err := getInfoForScpArg(src, store)
if err != nil {
return nil, err
}
destHost, destPath, destOpts, err := getInfoForScpArg(dest, store)
if err != nil {
return nil, err
}
// Append needed -i / private key flags to command.
sshArgs = append(sshArgs, srcOpts...)
sshArgs = append(sshArgs, destOpts...)
// Append actual arguments for the scp command (i.e. docker@<ip>:/path)
locationArg, err := generateLocationArg(srcHost, srcPath)
if err != nil {
return nil, err
}
sshArgs = append(sshArgs, locationArg)
locationArg, err = generateLocationArg(destHost, destPath)
if err != nil {
return nil, err
}
sshArgs = append(sshArgs, locationArg)
cmd := exec.Command(cmdPath, sshArgs...)
log.Debug(*cmd)
return cmd, nil
}
func runCmdWithStdIo(cmd exec.Cmd) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
func cmdScp(c *cli.Context) error {
hostLoader = &ScpHostLoader{}
args := c.Args()
if len(args) != 2 {
cli.ShowCommandHelp(c, "scp")
return errWrongNumberArguments
}
// TODO: Check that "-3" flag is available in user's version of scp.
// It is on every system I've checked, but the manual mentioned it's "newer"
sshArgs := append(baseSSHArgs, "-3")
if c.Bool("recursive") {
sshArgs = append(sshArgs, "-r")
}
src := args[0]
dest := args[1]
store := getStore(c)
cmd, err := getScpCmd(src, dest, sshArgs, store)
if err != nil {
return err
}
return runCmdWithStdIo(*cmd)
}