Skip to content

Commit

Permalink
public Fire method should return error not panic
Browse files Browse the repository at this point in the history
  • Loading branch information
sgsullivan committed Aug 3, 2022
1 parent 172d4f4 commit ac43774
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
66 changes: 40 additions & 26 deletions befehl.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ func New(options *Options) *Instance {
}
}

func (instance *Instance) Fire(targets, payload string, routines int) {
bytePayload := system.ReadFileUnsafe(payload)
if instance.sshKey != nil {
instance.populateSshKey()
func (instance *Instance) Fire(targets, payload string, routines int) error {
if bytePayload, readFileErr := system.ReadFile(payload); readFileErr == nil {
if instance.sshKey != nil {
if err := instance.populateSshKey(); err != nil {
return err
}
}
return instance.fireTorpedos(bytePayload, targets, routines)
} else {
return readFileErr
}
instance.fireTorpedos(bytePayload, targets, routines)
}

func (instance *Instance) getPrivKeyFile() string {
Expand Down Expand Up @@ -102,27 +107,26 @@ func (instance *Instance) populateSshKeyUnencrypted(rawKey []byte) error {
return nil
}

func (instance *Instance) populateSshKey() {
func (instance *Instance) populateSshKey() error {
privKeyFile := instance.getPrivKeyFile()

rawKey := system.ReadFileUnsafe(privKeyFile)
privKeyBytes, _ := pem.Decode(rawKey)
if rawKey, readFileError := system.ReadFile(privKeyFile); readFileError == nil {
privKeyBytes, _ := pem.Decode(rawKey)

if x509.IsEncryptedPEMBlock(privKeyBytes) {
if err := instance.populateSshKeyEncrypted(privKeyBytes); err != nil {
panic(err)
if x509.IsEncryptedPEMBlock(privKeyBytes) {
return instance.populateSshKeyEncrypted(privKeyBytes)
} else {
return instance.populateSshKeyUnencrypted(rawKey)
}
} else {
if err := instance.populateSshKeyUnencrypted(rawKey); err != nil {
panic(err)
}
return readFileError
}
}

func (instance *Instance) fireTorpedos(payload []byte, targets string, routines int) {
func (instance *Instance) fireTorpedos(payload []byte, targets string, routines int) error {
file, err := os.Open(targets)
if err != nil {
panic(err)
return err
}
defer file.Close()

Expand All @@ -136,7 +140,7 @@ func (instance *Instance) fireTorpedos(payload []byte, targets string, routines
}

if err := scanner.Err(); err != nil {
panic(err)
return err
}

var wg sync.WaitGroup
Expand Down Expand Up @@ -170,9 +174,10 @@ func (instance *Instance) fireTorpedos(payload []byte, targets string, routines
}

if waitgroup.WgTimeout(&wg, time.Duration(time.Duration(1800)*time.Second)) {
panic("hit timeout waiting for all routines to finish")
return fmt.Errorf("hit timeout waiting for all routines to finish")
}
color.Green("All routines completed!\n")
return nil
}

func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []byte, sshConfig *ssh.ClientConfig) {
Expand All @@ -184,7 +189,9 @@ func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []
if err != nil {
uhoh := fmt.Sprintf("ssh.Dial() to %s failed: %s\n", host, err)
color.Red(uhoh)
instance.logPayloadRun(host, uhoh)
if err := instance.logPayloadRun(host, uhoh); err != nil {
panic(err)
}
return
}
defer conn.Close()
Expand All @@ -194,7 +201,9 @@ func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []
if err != nil {
uhoh := fmt.Sprintf("ssh.NewSession() to %s failed: %s\n", host, err)
color.Red(uhoh)
instance.logPayloadRun(host, uhoh)
if err := instance.logPayloadRun(host, uhoh); err != nil {
panic(err)
}
return
}
defer session.Close()
Expand All @@ -214,7 +223,9 @@ func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []
if err := session.RequestPty("xterm", 24, 80, modes); err != nil {
uhoh := fmt.Sprintf("session.RequestPty() to %s failed: %s\n", host, err)
color.Red(uhoh)
instance.logPayloadRun(host, uhoh)
if err := instance.logPayloadRun(host, uhoh); err != nil {
panic(err)
}
return
}

Expand All @@ -226,29 +237,32 @@ func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []
}

cmdOutput := stdout.String() + stderr.String() + "\n" + sessionRunAttempt
instance.logPayloadRun(host, cmdOutput)
if err := instance.logPayloadRun(host, cmdOutput); err != nil {
panic(err)
}
}

func (instance *Instance) logPayloadRun(host string, output string) {
func (instance *Instance) logPayloadRun(host string, output string) error {
logDir := os.Getenv("HOME") + "/befehl/logs"
if instance.options.LogDir != "" {
logDir = instance.options.LogDir
}
logFile := logDir + "/" + host
if !system.PathExists(logDir) {
if err := os.MkdirAll(logDir, os.FileMode(0700)); err != nil {
panic(fmt.Sprintf("Failed creating [%s]: %s\n", logDir, err))
return fmt.Errorf("failed creating [%s]: %s", logDir, err)
}
}
f, err := os.Create(logFile)
if err != nil {
panic(fmt.Sprintf("Error creating [%s]: %s", logFile, err))
return fmt.Errorf("error creating [%s]: %s", logFile, err)
}
defer f.Close()

if _, err = f.WriteString(output); err != nil {
panic(fmt.Sprintf("Error writing to [%s]: %s", logFile, err))
return fmt.Errorf("error writing to [%s]: %s", logFile, err)
}

log.Printf("payload completed on %s! logfile at: %s\n", host, logFile)
return nil
}
4 changes: 3 additions & 1 deletion cmd/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ sshuser = "eingeben"
LogDir: Config.GetString("general.logdir"),
})

instance.Fire(hosts, payload, routines)
if err := instance.Fire(hosts, payload, routines); err != nil {
panic(err)
}
},
}

Expand Down
6 changes: 3 additions & 3 deletions helpers/system/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"os"
)

func ReadFileUnsafe(file string) []byte {
func ReadFile(file string) ([]byte, error) {
read, err := ioutil.ReadFile(file)
if err != nil {
panic(err)
return nil, err
}
return read
return read, nil
}

func PathExists(path string) bool {
Expand Down

0 comments on commit ac43774

Please sign in to comment.