/
sshclient.go
229 lines (205 loc) · 6.72 KB
/
sshclient.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
// Copyright 2020 Platform9 Systems Inc.
package ssh
// The content of this files are shamelessly copied from the SSH Provider code base of cctl
// the CCTL ssh-provider can't handle large files and hence this step was taken, perhaps
// the original source should have been modified.
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"os"
"github.com/pkg/sftp"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
// Client interface provides ways to run command and upload files to remote hosts
type Client interface {
// RunCommand executes the remote command returning the stdout, stderr and any error associated with it
RunCommand(cmd string) ([]byte, []byte, error)
// Uploadfile uploads the srcFile to remoteDestFilePath and changes the mode to the filemode
UploadFile(srcFilePath, remoteDstFilePath string, mode os.FileMode, cb func(read int64, total int64)) error
// Downloadfile downloads the remoteFile to localFile and changes the mode to the filemode
DownloadFile(remoteFile, localPath string, mode os.FileMode, cb func(read int64, total int64)) error
}
type client struct {
sshClient *ssh.Client
sftpClient *sftp.Client
proxyURL string
}
var (
SudoPassword string
)
const (
runAsSudo = true
)
// NewClient creates a new Client that can be used to perform action on a
// machine
func NewClient(host string, port int, username string, privateKey []byte, password, proxyURL string) (Client, error) {
authMethods := make([]ssh.AuthMethod, 1)
// give preferece to privateKey
if privateKey != nil {
signer, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, fmt.Errorf("error parsing private key: %s", err)
}
authMethods[0] = ssh.PublicKeys(signer)
} else {
authMethods[0] = ssh.Password(password)
}
sshConfig := &ssh.ClientConfig{
User: string(username),
Auth: authMethods,
// by default ignore host key checks
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
sshClient, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", host, port), sshConfig)
if err != nil {
return nil, fmt.Errorf("unable to dial %s:%d: %s", host, port, err)
}
sftpClient, err := sftp.NewClient(sshClient)
return &client{
sshClient: sshClient,
sftpClient: sftpClient,
proxyURL: proxyURL,
}, nil
}
// RunCommand runs a command on the machine and returns stdout and stderr
// separately
func (c *client) RunCommand(cmd string) ([]byte, []byte, error) {
session, err := c.sshClient.NewSession()
if err != nil {
return nil, nil, fmt.Errorf("unable to create session: %s", err)
}
stdOutPipe, err := session.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to pipe stdout: %s", err)
}
stdErrPipe, err := session.StderrPipe()
if err != nil {
return nil, nil, fmt.Errorf("unable to pipe stderr: %s", err)
}
// Prepend sudo if runAsSudo set to true
if runAsSudo {
// Prepend Sudo and add if Password is required to access Sudo
if SudoPassword != "" {
cmd = fmt.Sprintf("echo %s | sudo -S su ; sudo %s", SudoPassword, cmd)
} else {
cmd = fmt.Sprintf("sudo %s", cmd)
}
}
if c.proxyURL != "" {
cmd = fmt.Sprintf("https_proxy=%s %s", c.proxyURL, cmd)
}
err = session.Start(cmd)
if err != nil {
return nil, nil, fmt.Errorf("unable to run command: %s", err)
}
stdOut, err := ioutil.ReadAll(stdOutPipe)
stdErr, err := ioutil.ReadAll(stdErrPipe)
err = session.Wait()
if err != nil {
retError := err
switch err.(type) {
case *ssh.ExitError:
retError = fmt.Errorf("command %s failed: %s", cmd, err)
case *ssh.ExitMissingError:
retError = fmt.Errorf("command %s failed (no exit status): %s", cmd, err)
default:
retError = fmt.Errorf("command %s failed: %s", cmd, err)
}
zap.L().Debug("Error ", zap.String("stdout", string(stdOut)), zap.String("stderr", string(stdErr)))
return stdOut, stdErr, retError
}
return stdOut, stdErr, nil
}
// Upload writes a file to the machine
func (c *client) UploadFile(localFile string, remoteFilePath string, mode os.FileMode, cb func(read int64, total int64)) error {
// first check if the local file exists or not
localFp, err := os.Open(localFile)
if err != nil {
return fmt.Errorf("unable to read localFile: %s", err)
}
defer localFp.Close()
fInfo, err := localFp.Stat()
if err != nil {
return fmt.Errorf("Unable to find size of the file %s", localFile)
}
localFileReader := bufio.NewReader(localFp)
// create a progrssReader that will call the callback function after each read
progressReader := newProgressCBReader(fInfo.Size(), localFileReader, cb)
remoteFile, err := c.sftpClient.Create(remoteFilePath)
if err != nil {
return fmt.Errorf("unable to create file: %s", err)
}
defer remoteFile.Close()
// IMHO this function is misnomer, it actually writes to the remoteFile
_, err = remoteFile.ReadFrom(progressReader)
if err != nil {
// rmove the remote file since write failed and ignore the errors
// we can't do much about it anyways.
c.sftpClient.Remove(remoteFilePath)
return fmt.Errorf("write failed: %s, ", err)
}
err = remoteFile.Chmod(mode)
if err != nil {
return fmt.Errorf("chmod failed: %s", err)
}
return nil
}
// DownloadFile fetches a file from the remote machine
func (c *client) DownloadFile(remoteFile string, localFilePath string, mode os.FileMode, cb func(read int64, total int64)) error {
// check if remote file exists
remoteFP, err := c.sftpClient.Open(remoteFile)
if err != nil {
return fmt.Errorf("unable to read remoteFile: %s", err)
}
defer remoteFP.Close()
fInfo, err := remoteFP.Stat()
if err != nil {
return fmt.Errorf("unable to find size of remoteFile: %s", err)
}
remoteFileReader := bufio.NewReader(remoteFP)
progressReader := newProgressCBReader(fInfo.Size(), remoteFileReader, cb)
localFile, err := os.Create(localFilePath)
if err != nil {
return fmt.Errorf("unable to create local file: %s", err)
}
defer localFile.Close()
_, err = io.Copy(localFile, progressReader)
if err != nil {
os.Remove(localFilePath)
return fmt.Errorf("unable to copy data: %s", err)
}
err = localFile.Chmod(mode)
if err != nil {
os.Remove(localFilePath)
return fmt.Errorf("chmod failed: %s", err)
}
return nil
}
func newProgressCBReader(totalSize int64, orig io.Reader, cb func(read int64, total int64)) io.Reader {
progReader := &ProgressCBReader{
TotalSize: totalSize,
ReadCount: 0,
ProgressCB: cb,
OrigReader: orig,
}
return progReader
}
// ProgressCBReader implements a reader that can call back
// a function on regular interval to report progress
type ProgressCBReader struct {
TotalSize int64
ReadCount int64
ProgressCB func(read int64, total int64)
OrigReader io.Reader
}
func (r *ProgressCBReader) Read(p []byte) (int, error) {
read, err := r.OrigReader.Read(p)
r.ReadCount = r.ReadCount + int64(read)
if r.ProgressCB != nil {
r.ProgressCB(r.ReadCount, r.TotalSize)
}
return read, err
}