Skip to content

Commit

Permalink
create separate queue file
Browse files Browse the repository at this point in the history
  • Loading branch information
sgsullivan committed Aug 3, 2022
1 parent a58bd01 commit 42065d2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 36 deletions.
72 changes: 36 additions & 36 deletions befehl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"log"
"os"
"sync"
"sync/atomic"
"time"

"golang.org/x/crypto/ssh"
Expand All @@ -21,15 +20,6 @@ import (
"github.com/sgsullivan/befehl/helpers/waitgroup"
)

type queue struct {
count int64
}

func (q *queue) signifyComplete(total int) {
remaining := atomic.AddInt64(&q.count, -1)
color.Magenta(fmt.Sprintf("Remaining: %d / %d\n", remaining, total))
}

type Options struct {
PrivateKeyFile string
SshUser string
Expand Down Expand Up @@ -148,6 +138,17 @@ func (instance *Instance) buildHostLists(hostsFilePath string) ([]string, error)
return victims, scanner.Err()
}

func (instance *Instance) getSshClientConfig() *ssh.ClientConfig {
return &ssh.ClientConfig{
User: instance.getSshUser(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(instance.sshKey),
},
Timeout: time.Duration(10) * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
}

func (instance *Instance) executePayloadOnHosts(payload []byte, hostsFilePath string, routines int) error {
hostsList, err := instance.buildHostLists(hostsFilePath)
if err != nil {
Expand All @@ -157,28 +158,19 @@ func (instance *Instance) executePayloadOnHosts(payload []byte, hostsFilePath st
var wg sync.WaitGroup
hostCnt := len(hostsList)
wg.Add(hostCnt)
var sem = make(chan int, routines)
hostsChan := make(chan int, routines)
queue := new(queue).New(int64(hostCnt))

sshEntryUser := instance.getSshUser()
sshConfig := instance.getSshClientConfig()

sshConfig := &ssh.ClientConfig{
User: sshEntryUser,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(instance.sshKey),
},
Timeout: time.Duration(10) * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

queue := new(queue)
queue.count = int64(hostCnt)
for _, host := range hostsList {
host := host
sem <- 1
hostsChan <- 1
go func() {
instance.runPayload(&wg, host, payload, sshConfig)
<-sem
queue.signifyComplete(hostCnt)
<-hostsChan
remaining := queue.decrementCounter(hostCnt)
color.Magenta(fmt.Sprintf("Remaining: %d / %d\n", remaining, hostCnt))
}()
}

Expand Down Expand Up @@ -252,27 +244,35 @@ func (instance *Instance) runPayload(wg *sync.WaitGroup, host string, payload []
}
}

func (instance *Instance) logPayloadRun(host string, output string) error {
logDir := os.Getenv("HOME") + "/befehl/logs"
func (instance *Instance) getLogDir() string {
if instance.options.LogDir != "" {
logDir = instance.options.LogDir
return instance.options.LogDir
}
logFile := logDir + "/" + host
return os.Getenv("HOME") + "/befehl/logs"
}

func (instance *Instance) getLogFilePath(host string) string {
return instance.getLogDir() + "/" + host
}

func (instance *Instance) logPayloadRun(host string, output string) error {
logDir := instance.getLogDir()
logFilePath := instance.getLogFilePath(host)
if !filesystem.PathExists(logDir) {
if err := os.MkdirAll(logDir, os.FileMode(0700)); err != nil {
return fmt.Errorf("failed creating [%s]: %s", logDir, err)
}
}
f, err := os.Create(logFile)
logFile, err := os.Create(logFilePath)
if err != nil {
return fmt.Errorf("error creating [%s]: %s", logFile, err)
return fmt.Errorf("error creating [%s]: %s", logFilePath, err)
}
defer f.Close()
defer logFile.Close()

if _, err = f.WriteString(output); err != nil {
return fmt.Errorf("error writing to [%s]: %s", logFile, err)
if _, err = logFile.WriteString(output); err != nil {
return fmt.Errorf("error writing to [%s]: %s", logFilePath, err)
}

log.Printf("payload completed on %s! logfile at: %s\n", host, logFile)
log.Printf("payload completed on %s! logfile at: %s\n", host, logFilePath)
return nil
}
19 changes: 19 additions & 0 deletions queue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package befehl

import (
"sync/atomic"
)

type queue struct {
count int64
}

func (q *queue) New(hostCnt int64) *queue {
return &queue{
count: hostCnt,
}
}

func (q *queue) decrementCounter(total int) int64 {
return atomic.AddInt64(&q.count, -1)
}

0 comments on commit 42065d2

Please sign in to comment.