Skip to content

Commit

Permalink
refactor: optimize errors handler
Browse files Browse the repository at this point in the history
  • Loading branch information
windvalley committed Jan 10, 2024
1 parent d33dc07 commit 5c31ad0
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 58 deletions.
4 changes: 2 additions & 2 deletions internal/cmd/command.go
Expand Up @@ -73,7 +73,7 @@ Execute commands on target hosts.`,
Example: commandCmdExamples,
PreRun: func(cmd *cobra.Command, args []string) {
if errs := configflags.Config.Validate(); len(errs) != 0 {
util.CheckErr(errs)
util.PrintErrExit(errs)
}

if noSafeCheck {
Expand All @@ -87,7 +87,7 @@ Execute commands on target hosts.`,
}

if err := checkCommand(shellCommand, configflags.Config.Run.CommandBlacklist); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}
}
},
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/fetch.go
Expand Up @@ -58,7 +58,7 @@ Copy files and dirs from target hosts to local.`,
Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/fetch.md`,
PreRun: func(cmd *cobra.Command, args []string) {
if errs := configflags.Config.Validate(); len(errs) != 0 {
util.CheckErr(errs)
util.PrintErrExit(errs)
}
},
Run: func(cmd *cobra.Command, args []string) {
Expand Down
10 changes: 5 additions & 5 deletions internal/cmd/push.go
Expand Up @@ -64,14 +64,14 @@ Copy local files and dirs to target hosts.`,
Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/push.md`,
PreRun: func(cmd *cobra.Command, args []string) {
if errs := configflags.Config.Validate(); len(errs) != 0 {
util.CheckErr(errs)
util.PrintErrExit(errs)
}

if len(files) != 0 {
for _, f := range files {
_, err := os.Stat(f)
if err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}
}
}
Expand All @@ -88,7 +88,7 @@ Copy local files and dirs to target hosts.`,

workDir, err := os.Getwd()
if err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}

for _, f := range files {
Expand All @@ -97,12 +97,12 @@ Copy local files and dirs to target hosts.`,
zipFile := path.Join(workDir, zipName)

if err := util.Zip(strings.TrimSuffix(f, string(os.PathSeparator)), zipFile); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}

stat, err := os.Stat(zipFile)
if err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}
//nolint:gomnd
log.Debugf("zip file '%s' size: %d MB", zipFile, stat.Size()/1024/1024)
Expand Down
13 changes: 8 additions & 5 deletions internal/cmd/root.go
Expand Up @@ -90,7 +90,9 @@ func initConfig() {
} else {
// Find home directory.
home, err := os.UserHomeDir()
util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}

// Search the default configuration file.
viper.AddConfigPath(".")
Expand All @@ -99,21 +101,22 @@ func initConfig() {
viper.SetConfigName(".gossh")
}

viper.AutomaticEnv() // read in environment variables that match
// Read in environment variables that match.
viper.AutomaticEnv()

// If a config file is found, read it in.
_ = viper.ReadInConfig()

if err := viper.BindPFlags(rootCmd.PersistentFlags()); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}

if err := viper.Unmarshal(&configflags.Config); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}

if err := configflags.Config.Complete(); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}
}

Expand Down
6 changes: 3 additions & 3 deletions internal/cmd/script.go
Expand Up @@ -57,11 +57,11 @@ Execute a local shell script on target hosts.`,
Find more examples at: https://github.com/windvalley/gossh/blob/main/docs/script.md`,
PreRun: func(cmd *cobra.Command, args []string) {
if errs := configflags.Config.Validate(); len(errs) != 0 {
util.CheckErr(errs)
util.PrintErrExit(errs)
}

if scriptFile != "" && !util.FileExists(scriptFile) {
util.CheckErr(fmt.Sprintf("script '%s' not found", scriptFile))
util.PrintErrExit(fmt.Sprintf("script '%s' not found", scriptFile))
}

if noSafeCheck {
Expand All @@ -75,7 +75,7 @@ Execute a local shell script on target hosts.`,
}

if err := checkScript(scriptFile, configflags.Config.Run.CommandBlacklist); err != nil {
util.CheckErr(err)
util.PrintErrExit(err)
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/vault/decrypt.go
Expand Up @@ -55,7 +55,7 @@ Decrypt content encrypted by vault.`,
}

if !aes.IsAES256CipherText(args[0]) {
util.CheckErr(fmt.Sprintf("'%s' is not vault encrypted content", args[0]))
util.PrintErrExit(fmt.Sprintf("'%s' is not vault encrypted content", args[0]))
}

return nil
Expand All @@ -65,8 +65,8 @@ Decrypt content encrypted by vault.`,
plainText, err := aes.AES256Decode(args[0], vaultPass)
if err != nil {
err = fmt.Errorf("decrypt failed: %w", err)
util.PrintErrExit(err)
}
util.CheckErr(err)

fmt.Printf("\n%s\n", plainText)
},
Expand Down
8 changes: 4 additions & 4 deletions internal/cmd/vault/decrypt_file.go
Expand Up @@ -35,8 +35,6 @@ import (
var deOutputFile string

// decryptFileCmd represents the vault decrypt-file command
//
//nolint:dupl
var decryptFileCmd = &cobra.Command{
Use: "decrypt-file FILENAME",
Short: "Decrypt vault encrypted file",
Expand Down Expand Up @@ -66,7 +64,7 @@ Decrypt vault encrypted file.`,
}

if !util.FileExists(args[0]) {
util.CheckErr(fmt.Sprintf("file '%s' not found", args[0]))
util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0]))
}

return nil
Expand All @@ -77,7 +75,9 @@ Decrypt vault encrypted file.`,
file := args[0]

content, err := decryptFile(file, vaultPass)
util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}

handleOutput(content, file, deOutputFile)

Expand Down
4 changes: 2 additions & 2 deletions internal/cmd/vault/encrypt.go
Expand Up @@ -61,14 +61,14 @@ Encrypt sensitive content.`,
plainPassword, err := getPlainPassword(args)
if err != nil {
err = fmt.Errorf("get plaintext to be encrypted failed: %s", err)
util.PrintErrExit(err)
}
util.CheckErr(err)

encryptContent, err := aes.AES256Encode(plainPassword, vaultPass)
if err != nil {
err = fmt.Errorf("encrypt failed: %w", err)
util.PrintErrExit(err)
}
util.CheckErr(err)

fmt.Printf("\n%s\n", encryptContent)
},
Expand Down
10 changes: 7 additions & 3 deletions internal/cmd/vault/encrypt_file.go
Expand Up @@ -67,7 +67,7 @@ Encrypt a file.`,
}

if !util.FileExists(args[0]) {
util.CheckErr(fmt.Sprintf("file '%s' not found", args[0]))
util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0]))
}

return nil
Expand All @@ -78,7 +78,9 @@ Encrypt a file.`,
file := args[0]

content, err := encryptFile(file, vaultPass)
util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}

handleOutput(content, file, outputFile)

Expand Down Expand Up @@ -108,7 +110,9 @@ func handleOutput(content, originalFile, newFile string) {
err = writeContentToOriFile(originalFile, content)
}

util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}
}

func encryptFile(file, vaultPass string) (string, error) {
Expand Down
18 changes: 10 additions & 8 deletions internal/cmd/vault/vault.go
Expand Up @@ -103,7 +103,7 @@ func getVaultConfirmPassword() string {
prompt := "New Vault password: "
password, err := getConfirmPasswordFromPrompt(prompt)
if err != nil {
util.CheckErr(fmt.Sprintf("get vault password from terminal prompt failed: %s", err))
util.PrintErrExit(fmt.Sprintf("get vault password from terminal prompt failed: %s", err))
}

log.Debugf("Vault: confirmed vault password that from terminal prompt")
Expand All @@ -124,7 +124,7 @@ func GetVaultPassword() string {
for {
password, err = getPasswordFromPrompt(prompt)
if err != nil {
util.CheckErr(fmt.Sprintf("get vault password from terminal prompt '%s' failed: %s", prompt, err))
util.PrintErrExit(fmt.Sprintf("get vault password from terminal prompt '%s' failed: %s", prompt, err))
}
if password != "" {
break
Expand All @@ -142,20 +142,22 @@ func getVaultPasswordFromFile() string {
vaultPassFile := configflags.Config.Auth.VaultPassFile
if vaultPassFile != "" {
ok, err := isExectuable(vaultPassFile)
util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}

if ok {
bin := fmt.Sprintf("./%s", vaultPassFile)
out, err1 := exec.Command(bin).Output()
if err1 != nil {
util.CheckErr(fmt.Errorf(
util.PrintErrExit(fmt.Errorf(
"problem executing file '%s': %s, if this is not a executable file, "+
"remove the executable bit from the file", vaultPassFile, err1))
}

vaultPass := strings.TrimSpace(string(out))
if vaultPass == "" {
util.CheckErr(fmt.Sprintf(
util.PrintErrExit(fmt.Sprintf(
"problem executing file '%s': output cannot be empty, if this is not a script, "+
"remove the executable bit from the file", vaultPassFile))
}
Expand All @@ -168,16 +170,16 @@ func getVaultPasswordFromFile() string {
passwordContent, err := os.ReadFile(vaultPassFile)
if err != nil {
err = fmt.Errorf("read vault password file '%s' failed: %w", vaultPassFile, err)
util.PrintErrExit(err)
}
util.CheckErr(err)

vaultPass := strings.TrimSpace(string(passwordContent))
if vaultPass == "" {
util.CheckErr("vault password file cannot be empty")
util.PrintErrExit("vault password file cannot be empty")
}

if strings.HasPrefix(vaultPass, "#!/") {
util.CheckErr(fmt.Sprintf(
util.PrintErrExit(fmt.Sprintf(
"'%s' looks like a script file, please add the executable bit to this file",
vaultPassFile,
))
Expand Down
11 changes: 7 additions & 4 deletions internal/cmd/vault/view.go
Expand Up @@ -54,7 +54,7 @@ View vault encrypted file.`,
}

if !util.FileExists(args[0]) {
util.CheckErr(fmt.Sprintf("file '%s' not found", args[0]))
util.PrintErrExit(fmt.Sprintf("file '%s' not found", args[0]))
}

return nil
Expand All @@ -65,9 +65,12 @@ View vault encrypted file.`,
file := args[0]

decryptContent, err := decryptFile(file, vaultPass)
util.CheckErr(err)
if err != nil {
util.PrintErrExit(err)
}

err = util.LessContent(decryptContent)
util.CheckErr(err)
if err := util.LessContent(decryptContent); err != nil {
util.PrintErrExit(err)
}
},
}
2 changes: 1 addition & 1 deletion internal/pkg/aes/aes.go
Expand Up @@ -53,7 +53,7 @@ func AES256Encode(plainText, key string) (string, error) {
func AES256Decode(hexCipherText, key string) (string, error) {
defer func() {
if err := recover(); err != nil {
util.CheckErr("decryption failed: wrong vault password")
util.PrintErrExit("decryption failed: wrong vault password")
}
}()

Expand Down
11 changes: 6 additions & 5 deletions internal/pkg/sshtask/sshtask.go
Expand Up @@ -231,8 +231,9 @@ func (t *Task) batchRunSSH() {
t.err = errors.New("need flag '-d/--dest-path' or '-l/--hosts.list'")
} else {
if !util.DirExists(t.dstDir) {
err := os.MkdirAll(t.dstDir, os.ModePerm)
util.CheckErr(err)
if err := os.MkdirAll(t.dstDir, os.ModeDir); err != nil {
util.PrintErrExit(err)
}
}
}
}
Expand Down Expand Up @@ -694,8 +695,8 @@ func getDefaultPassword(auth *configflags.Auth) string {
passwordContent, err := os.ReadFile(authFile)
if err != nil {
err = fmt.Errorf("read password file '%s' failed: %w", authFile, err)
util.PrintErrExit(err)
}
util.CheckErr(err)

password = strings.TrimSpace(string(passwordContent))

Expand Down Expand Up @@ -784,8 +785,8 @@ func getPasswordFromPrompt(loginUser string) string {
passwordByte, err := term.ReadPassword(0)
if err != nil {
err = fmt.Errorf("get password from terminal failed: %s", err)
util.PrintErrExit(err)
}
util.CheckErr(err)

password := string(passwordByte)

Expand All @@ -803,7 +804,7 @@ func getRealPass(pass string, host, objectType string) string {
realPass, err := aes.AES256Decode(pass, vaultPass)
if err != nil {
log.Debugf("Vault: decrypt %s for '%s' failed: %s", objectType, host, err)
util.CheckErr(err)
util.PrintErrExit(err)
}

log.Debugf("Vault: decrypt %s for '%s' success", objectType, host)
Expand Down
4 changes: 2 additions & 2 deletions pkg/util/cobra.go
Expand Up @@ -38,15 +38,15 @@ func CobraCheckErrWithHelp(cmd *cobra.Command, errMsg interface{}) {

fmt.Println()

CheckErr(errMsg)
PrintErrExit(errMsg)
}
}

// CobraMarkHiddenGlobalFlags that from params.
func CobraMarkHiddenGlobalFlags(command *cobra.Command, flags ...string) {
for _, v := range flags {
if err := command.Flags().MarkHidden(v); err != nil {
CheckErr(fmt.Sprintf("cannot mark hidden flag: %s", err))
PrintErrExit(fmt.Sprintf("cannot mark hidden flag: %s", err))
}
}
}
Expand Down

0 comments on commit 5c31ad0

Please sign in to comment.