Skip to content

Commit

Permalink
perf(fetch): increase files transfer efficiency (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
windvalley committed Jan 3, 2024
1 parent 84e4866 commit b96b055
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 120 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Expand Up @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [1.14.0]

### Added

Add flag `-z, --zip` for subcommand `fetch`.

### Changed

Improve the files transfer efficiency of the subcommand `fetch`.
The subcommand `fetch` no longer uses zip compression by default. If you want to continue using zip compression, you can add the `-z` flag to the command line.

## [1.13.0]

### Added
Expand Down
28 changes: 20 additions & 8 deletions internal/cmd/fetch.go
Expand Up @@ -33,9 +33,10 @@ import (
)

var (
srcFiles []string
localDstDir string
tmpDir string
srcFiles []string
localDstDir string
tmpDir string
enableZipFiles bool
)

// fetchCmd represents the fetch command
Expand All @@ -45,11 +46,14 @@ var fetchCmd = &cobra.Command{
Long: `
Copy files and dirs from target hosts to local.`,
Example: `
# Copy host1:/path/foo to local /tmp/backup/host1/path/foo.
$ gossh fetch host1 -f /path/foo -d /tmp/backup
Copy host1:/path/foo to local dir /tmp/backup/.
$ gossh fetch host1 -f /path/foo -d /tmp/backup -k
# Copy files and dirs from target hosts to local dir /tmp/backup/.
$ gossh fetch host[1-2] -f /path1/foo.txt,/path2/bar/ -d /tmp/backup
Copy files and dirs from target hosts to local dir /tmp/backup/.
$ gossh fetch host[1-2] -f /path1/foo.txt,/path2/bar/ -d /tmp/backup -k
Enable zip files feature (zip first, then fetch).
$ gossh fetch host[1-2] -f /path1/foo.txt,/path2/bar/ -d /tmp/backup -z -k
Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/fetch.md`,
PreRun: func(cmd *cobra.Command, args []string) {
Expand All @@ -66,7 +70,7 @@ Copy files and dirs from target hosts to local.`,
if tmpDir == "$HOME" {
tmpDir = path.Join("/home", configflags.Config.Auth.User)
}
task.SetFetchOptions(localDstDir, tmpDir)
task.SetFetchOptions(localDstDir, tmpDir, enableZipFiles)

task.Start()

Expand All @@ -86,4 +90,12 @@ func init() {
fetchCmd.Flags().StringVarP(&tmpDir, "tmp-dir", "t", "$HOME",
"directory of target hosts for storing temporary zip file",
)

fetchCmd.Flags().BoolVarP(
&enableZipFiles,
"zip",
"z",
false,
"enable zip files ('zip' must be installed on target hosts)",
)
}
14 changes: 9 additions & 5 deletions internal/pkg/sshtask/sshtask.go
Expand Up @@ -25,7 +25,6 @@ package sshtask
import (
"errors"
"fmt"
"io/ioutil"
"net"
"os"
"regexp"
Expand Down Expand Up @@ -212,9 +211,10 @@ func (t *Task) SetPushOptions(destPath string, allowOverwrite, enableZip bool) {
}

// SetFetchOptions ...
func (t *Task) SetFetchOptions(destPath, tmpDir string) {
func (t *Task) SetFetchOptions(destPath, tmpDir string, enableZipFiles bool) {
t.dstDir = destPath
t.tmpDir = tmpDir
t.enableZip = enableZipFiles
}

// RunSSH implements batchssh.Task
Expand All @@ -231,7 +231,11 @@ func (t *Task) RunSSH(host *batchssh.Host) (string, error) {
case PushTask:
return t.sshClient.PushFiles(host, t.pushFiles.files, t.pushFiles.zipFiles, t.dstDir, t.allowOverwrite, t.enableZip)
case FetchTask:
return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs)
hosts, err := t.getAllHosts()
if err != nil {
return "", err
}
return t.sshClient.FetchFiles(host, t.fetchFiles, t.dstDir, t.tmpDir, sudo, runAs, t.enableZip, len(hosts))
default:
return "", fmt.Errorf("unknown task type: %v", t.taskType)
}
Expand Down Expand Up @@ -703,7 +707,7 @@ func getDefaultPassword(auth *configflags.Auth) string {
if authFile != "" {
var passwordContent []byte

passwordContent, err := ioutil.ReadFile(authFile)
passwordContent, err := os.ReadFile(authFile)
if err != nil {
err = fmt.Errorf("read password file '%s' failed: %w", authFile, err)
}
Expand Down Expand Up @@ -766,7 +770,7 @@ func getSigners(keyfiles []string, passphrase string, authKind string) []ssh.Sig
}

func getSigner(keyfile, passphrase string) (ssh.Signer, string) {
buf, err := ioutil.ReadFile(keyfile)
buf, err := os.ReadFile(keyfile)
if err != nil {
return nil, fmt.Sprintf("read identity file '%s' failed: %s", keyfile, err)
}
Expand Down
129 changes: 24 additions & 105 deletions pkg/batchssh/batchssh.go
Expand Up @@ -28,7 +28,6 @@ import (
"io"
"net"
"os"
"path"
"path/filepath"
"strconv"
"strings"
Expand All @@ -39,7 +38,6 @@ import (
"golang.org/x/crypto/ssh"

"github.com/windvalley/gossh/pkg/log"
"github.com/windvalley/gossh/pkg/util"
)

const (
Expand Down Expand Up @@ -350,6 +348,8 @@ func (c *Client) FetchFiles(
dstDir, tmpDir string,
sudo bool,
runAs string,
enableZip bool,
hostCount int,
) (string, error) {
client, err := c.getClient(host)
if err != nil {
Expand All @@ -375,7 +375,7 @@ func (c *Client) FetchFiles(
continue
}

if !sudo {
if !sudo || !enableZip {
if err, ok := err1.(*sftp.StatusError); ok && err.Code == uint32(sftp.ErrSshFxPermissionDenied) {
noPermSrcFiles = append(noPermSrcFiles, f)
continue
Expand All @@ -402,73 +402,30 @@ func (c *Client) FetchFiles(
return "", err2
}

session, err := client.NewSession()
if err != nil {
return "", err
}
defer session.Close()

zippedFileTmpDir := path.Join(tmpDir, ".gossh-tmp-"+host.Host)
tmpZipFile := fmt.Sprintf("%s.%d", host.Host, time.Now().UnixMicro())
zippedFileFullpath := path.Join(zippedFileTmpDir, tmpZipFile)
_, err = c.executeCmd(
session,
fmt.Sprintf(
`if which zip &>/dev/null;then
sudo -u %s -H bash -c '[[ ! -d %s ]] && { mkdir -p %s;chmod 777 %s;};zip -r %s %s'
else
echo "need install 'zip' command"
exit 1
fi`,
runAs,
zippedFileTmpDir,
zippedFileTmpDir,
zippedFileTmpDir,
zippedFileFullpath,
strings.Join(validSrcFiles, " "),
),
host.Password,
)
if err != nil {
log.Debugf("zip %s of %s failed: %s", strings.Join(validSrcFiles, ","), host.Host, err)
return "", err
}

file, err := c.fetchZipFile(ftpC, zippedFileFullpath, dstDir)
if err == nil {
file.Close()
}
if err != nil {
log.Debugf("fetch zip file '%s' from %s failed: %s", zippedFileFullpath, host.Host, err)
return "", err
}

session2, err := client.NewSession()
if err != nil {
return "", err
}
defer session2.Close()

_, err = c.executeCmd(
session2,
fmt.Sprintf("sudo -u %s -H bash -c 'rm -f %s'", runAs, zippedFileFullpath),
host.Password,
)
if err != nil {
log.Debugf("remove '%s:%s' failed: %s", host.Host, zippedFileFullpath, err)
return "", err
if hostCount > 1 {
dstDir = filepath.Join(dstDir, host.Host)
err = os.MkdirAll(dstDir, os.ModePerm)
if err != nil {
log.Errorf("make local dir '%s' failed: %v", dstDir, err)
return "", err
}
log.Debugf("make local dir '%s'", dstDir)
}

finalDstDir := path.Join(dstDir, host.Host)
localZippedFileFullpath := path.Join(dstDir, tmpZipFile)
defer func() {
if err := os.Remove(localZippedFileFullpath); err != nil {
log.Debugf("remove '%s' failed: %s", localZippedFileFullpath, err)
if enableZip {
for _, f := range validSrcFiles {
err = c.fetchFileWithZip(client, ftpC, f, dstDir, tmpDir, runAs, host)
if err != nil {
return "", err
}
}
} else {
for _, f := range validSrcFiles {
err = c.fetchFileOrDir(ftpC, f, dstDir, host.Host)
if err != nil {
return "", err
}
}
}()
if err := util.Unzip(localZippedFileFullpath, finalDstDir); err != nil {
log.Debugf("unzip '%s' to '%s' failed: %s", localZippedFileFullpath, finalDstDir, err)
return "", err
}

hasOrHave := "has"
Expand Down Expand Up @@ -562,44 +519,6 @@ func (c *Client) executeCmd(session *ssh.Session, command, password string) (str
return outputStr, nil
}

func (c *Client) fetchZipFile(
ftpC *sftp.Client,
srcZipFile, dstDir string,
) (*sftp.File, error) {
homeDir := os.Getenv("HOME")
if strings.HasPrefix(dstDir, "~/") {
srcZipFile = strings.Replace(dstDir, "~", homeDir, 1)
}

srcZipFileName := filepath.Base(srcZipFile)
dstZipFile := path.Join(dstDir, srcZipFileName)

file, err := ftpC.Open(srcZipFile)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("'%s' not exist", srcZipFile)
}

if err, ok := err.(*sftp.StatusError); ok && err.Code == uint32(sftp.ErrSshFxPermissionDenied) {
return nil, fmt.Errorf("no permission to open '%s'", srcZipFile)
}

return nil, err
}

zipFile, err := os.Create(dstZipFile)
if err != nil {
return nil, fmt.Errorf("open local '%s' failed: %w", dstZipFile, err)
}

_, err = file.WriteTo(zipFile)
if err != nil {
return nil, err
}

return file, nil
}

func (c *Client) getClient(host *Host) (*ssh.Client, error) {
var (
client *ssh.Client
Expand Down

0 comments on commit b96b055

Please sign in to comment.