diff --git a/cmd/server.go b/cmd/server.go index 30b05c1..fb0f2f4 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -3,7 +3,6 @@ package cmd import ( "context" "errors" - "fmt" "os" "strings" "time" @@ -110,8 +109,6 @@ var ( opts = append(opts, grpcsvr.WithSkipRedfishVersions(versions)) } - fmt.Println("maxWorkers", maxWorkers) - if err := grpcsvr.RunServer(ctx, logger, grpcServer, port, httpServer, opts...); err != nil { logger.Error(err, "error running server") os.Exit(1) diff --git a/go.mod b/go.mod index 0e1d67e..c947741 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/tinkerbell/pbnj go 1.20 require ( - github.com/adrianbrad/queue v1.2.1 github.com/bmc-toolbox/bmclib v0.5.7 github.com/bmc-toolbox/bmclib/v2 v2.0.1-0.20230515164712-2714c7479477 github.com/cristalhq/jwt/v3 v3.1.0 diff --git a/go.sum b/go.sum index 26c555c..a39eb31 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,6 @@ github.com/VictorLowther/simplexml v0.0.0-20180716164440-0bff93621230 h1:t95Grn2 github.com/VictorLowther/simplexml v0.0.0-20180716164440-0bff93621230/go.mod h1:t2EzW1qybnPDQ3LR/GgeF0GOzHUXT5IVMLP2gkW1cmc= github.com/VictorLowther/soap v0.0.0-20150314151524-8e36fca84b22 h1:a0MBqYm44o0NcthLKCljZHe1mxlN6oahCQHHThnSwB4= github.com/VictorLowther/soap v0.0.0-20150314151524-8e36fca84b22/go.mod h1:/B7V22rcz4860iDqstGvia/2+IYWXf3/JdQCVd/1D2A= -github.com/adrianbrad/queue v1.2.1 h1:CEVsjFQyuR0s5Hty/HJGWBZHsJ3KMmii0kEgLeam/mk= -github.com/adrianbrad/queue v1.2.1/go.mod h1:wYiPC/3MPbyT45QHLrPR4zcqJWPePubM1oEP/xTwhUs= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/grpc/rpc/bmc_test.go b/grpc/rpc/bmc_test.go index 82f03a2..b1c3af9 100644 --- a/grpc/rpc/bmc_test.go +++ b/grpc/rpc/bmc_test.go @@ -42,7 +42,7 @@ func setup() { Ctx: ctx, } - tr = taskrunner.NewRunner(repo, 100, 100, time.Second) + tr = taskrunner.NewRunner(repo, 100, time.Second) tr.Start(ctx) bmcService = BmcService{ TaskRunner: tr, diff --git a/grpc/rpc/machine.go b/grpc/rpc/machine.go index d09837b..69c1afa 100644 --- a/grpc/rpc/machine.go +++ b/grpc/rpc/machine.go @@ -54,7 +54,7 @@ func (m *MachineService) BootDevice(ctx context.Context, in *v1.DeviceRequest) ( defer cancel() return mbd.BootDeviceSet(taskCtx, in.BootDevice.String(), in.Persistent, in.EfiBoot) } - go m.TaskRunner.Execute(ctx, l, "setting boot device", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) + m.TaskRunner.Execute(ctx, l, "setting boot device", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.DeviceResponse{TaskId: taskID}, nil } @@ -64,14 +64,14 @@ func (m *MachineService) Power(ctx context.Context, in *v1.PowerRequest) (*v1.Po l := logging.ExtractLogr(ctx) taskID := xid.New().String() l = l.WithValues("taskID", taskID, "bmcIP", in.GetAuthn().GetDirectAuthn().GetHost().GetHost()) - /*l.Info( + l.Info( "start Power request", "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "powerAction", in.GetPowerAction().String(), "softTimeout", in.SoftTimeout, "OffDuration", in.OffDuration, - )*/ + ) execFunc := func(s chan string) (string, error) { mp, err := machine.NewPowerSetter( @@ -89,7 +89,7 @@ func (m *MachineService) Power(ctx context.Context, in *v1.PowerRequest) (*v1.Po defer cancel() return mp.PowerSet(taskCtx, in.PowerAction.String()) } - go m.TaskRunner.Execute(ctx, l, "power action: "+in.GetPowerAction().String(), taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) + m.TaskRunner.Execute(ctx, l, "power action: "+in.GetPowerAction().String(), taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.PowerResponse{TaskId: taskID}, nil } diff --git a/grpc/rpc/task_test.go b/grpc/rpc/task_test.go index 46f11ed..f9b4faa 100644 --- a/grpc/rpc/task_test.go +++ b/grpc/rpc/task_test.go @@ -34,9 +34,10 @@ func TestTaskFound(t *testing.T) { t.Fatalf("expected no error, got: %v", err) } - time.Sleep(time.Second) + time.Sleep(time.Second * 3) taskReq := &v1.StatusRequest{TaskId: resp.TaskId} taskResp, _ := taskService.Status(context.Background(), taskReq) + t.Logf("Got response: %+v", taskResp) if taskResp.Id != resp.TaskId { t.Fatalf("got: %+v", taskResp) } diff --git a/grpc/server.go b/grpc/server.go index c6d7ac7..82fc72d 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -2,11 +2,9 @@ package grpc import ( "context" - "fmt" "net" "os" "os/signal" - "syscall" "time" "github.com/go-logr/logr" @@ -42,9 +40,6 @@ type Server struct { maxWorkers int // workerIdleTimeout is the idle timeout for workers. If no tasks are received within the timeout, the worker will exit. workerIdleTimeout time.Duration - // maxIngestionWorkers is the maximum number of concurrent workers that will be allowed. - // These are the workers that handle ingesting tasks from RPC endpoints and writing them to the map of per Host ID queues. - maxIngestionWorkers int } // ServerOption for setting optional values. @@ -77,12 +72,6 @@ func WithWorkerIdleTimeout(t time.Duration) ServerOption { return func(args *Server) { args.workerIdleTimeout = t } } -// WithMaxIngestionWorkers sets the max number of concurrent workers that will be allowed. -// These are the workers that handle ingesting tasks from RPC endpoints and writing them to the map of per Host ID queues. -func WithMaxIngestionWorkers(max int) ServerOption { - return func(args *Server) { args.maxIngestionWorkers = max } -} - // RunServer registers all services and runs the server. func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, port string, httpServer *http.Server, opts ...ServerOption) error { ctx, cancel := context.WithCancel(ctx) @@ -97,19 +86,17 @@ func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, po } defaultServer := &Server{ - Actions: repo, - bmcTimeout: oob.DefaultBMCTimeout, - maxWorkers: 1000, - workerIdleTimeout: time.Second * 30, - maxIngestionWorkers: 1000, + Actions: repo, + bmcTimeout: oob.DefaultBMCTimeout, + maxWorkers: 1000, + workerIdleTimeout: time.Second * 30, } for _, opt := range opts { opt(defaultServer) } - fmt.Printf("maxWorkers: %d\n", defaultServer.maxWorkers) - tr := taskrunner.NewRunner(repo, defaultServer.maxIngestionWorkers, defaultServer.maxWorkers, defaultServer.workerIdleTimeout) + tr := taskrunner.NewRunner(repo, defaultServer.maxWorkers, defaultServer.workerIdleTimeout) tr.Start(ctx) ms := rpc.MachineService{ @@ -165,16 +152,6 @@ func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, po } }() - msgChan := make(chan os.Signal) - signal.Notify(msgChan, syscall.SIGUSR1) - go func() { - for range msgChan { - fmt.Println("======") - tr.Print() - fmt.Println("======") - } - }() - go func() { <-ctx.Done() log.Info("ctx cancelled, shutting down PBnJ") diff --git a/grpc/taskrunner/queue.go b/grpc/taskrunner/queue.go deleted file mode 100644 index 0f490bc..0000000 --- a/grpc/taskrunner/queue.go +++ /dev/null @@ -1,128 +0,0 @@ -package taskrunner - -import ( - "context" - - "github.com/adrianbrad/queue" - "github.com/go-logr/logr" -) - -type IngestQueue struct { - q *queue.Blocking[*Task] -} - -type Task struct { - ID string `json:"id"` - Host string `json:"host"` - Description string `json:"description"` - Action func(chan string) (string, error) `json:"-"` - Log logr.Logger `json:"-"` -} - -func NewFIFOChannelQueue() *IngestQueue { - return &IngestQueue{ - q: queue.NewBlocking([]*Task{}), - } -} - -// Enqueue inserts the item into the queue. -func (i *IngestQueue) Enqueue2(item Task) { - i.q.OfferWait(&item) -} - -// Dequeue removes the oldest element from the queue. FIFO. -func (i *IngestQueue) Dequeue2(ctx context.Context, tChan chan Task) { - for { - select { - case <-ctx.Done(): - return - default: - item, err := i.q.Get() - if err != nil { - continue - } - tChan <- *item - } - } -} - -func NewIngestQueue() *IngestQueue { - return &IngestQueue{ - q: queue.NewBlocking([]*Task{}), - } -} - -// Enqueue inserts the item into the queue. -func (i *IngestQueue) Enqueue(item Task) { - i.q.OfferWait(&item) -} - -// Dequeue removes the oldest element from the queue. FIFO. -func (i *IngestQueue) Dequeue() (Task, error) { - item, err := i.q.Get() - if err != nil { - return Task{}, err - } - - return *item, nil -} - -func (i *IngestQueue) Size() int { - return i.q.Size() -} - -func newHostQueue() *hostQueue { - return &hostQueue{ - q: queue.NewBlocking[host]([]host{}), - ch: make(chan host), - } -} - -type host string - -func (h host) String() string { - return string(h) -} - -type hostQueue struct { - q *queue.Blocking[host] - ch chan host -} - -// Enqueue inserts the item into the queue. -func (i *hostQueue) Enqueue(item host) { - i.q.OfferWait(item) -} - -// Dequeue removes the oldest element from the queue. FIFO. -func (i *hostQueue) Dequeue() (host, error) { - item, err := i.q.Get() - if err != nil { - return "", err - } - - return item, nil -} - -// Dequeue removes the oldest element from the queue. FIFO. -func (i *hostQueue) Dequeue2(ctx context.Context) <-chan host { - go func() { - for { - select { - case <-ctx.Done(): - return - default: - item, err := i.q.Get() - if err != nil { - continue - } - i.ch <- item - } - } - }() - return i.ch -} - -func (i *hostQueue) Size() int { - return i.q.Size() -} diff --git a/grpc/taskrunner/run.go b/grpc/taskrunner/run.go index c3585eb..2dc8d7e 100644 --- a/grpc/taskrunner/run.go +++ b/grpc/taskrunner/run.go @@ -13,36 +13,15 @@ type orchestrator struct { workers sync.Map manager *concurrencyManager workerIdleTimeout time.Duration - - ingestManager *concurrencyManager - - fifoQueue *hostQueue - fifoChan chan host - ingestionQueue *IngestQueue + fifoChan chan string // perIDQueue is a map of hostID to a channel of tasks. perIDQueue sync.Map - - //testing new stuff ingestChan chan Task } -func (r *Runner) Print() { - one := r.orchestrator.ingestionQueue.Size() - two := r.orchestrator.fifoQueue.Size() - var three int - r.orchestrator.perIDQueue.Range(func(key, value interface{}) bool { - three++ - return true - }) - fmt.Printf("ingestion queue size: %d\n", one) - fmt.Printf("fcfs queue size: %d\n", two) - fmt.Printf("perID queue size: %d\n", three) -} - // ingest take a task off the ingestion queue and puts it on the perID queue // and adds the host ID to the fcfs queue. func (r *Runner) ingest(ctx context.Context) { - //func (o *orchestrator) ingest(ctx context.Context) { // dequeue from ingestion queue // enqueue to perID queue // enqueue to fcfs queue @@ -81,7 +60,6 @@ func (r *Runner) orchestrate(ctx context.Context) { // 2. start workers for { time.Sleep(time.Second * 2) - // r.orchestrator.perIDQueue.Range(func(key, value interface{}) bool { r.orchestrator.workers.Range(func(key, value interface{}) bool { // if worker id exists in o.workers, then move on because the worker is already running. if value.(bool) { @@ -92,7 +70,10 @@ func (r *Runner) orchestrate(ctx context.Context) { r.orchestrator.manager.Wait() r.orchestrator.workers.Store(key.(string), true) - v, _ := r.orchestrator.perIDQueue.Load(key.(string)) + v, found := r.orchestrator.perIDQueue.Load(key.(string)) + if !found { + return false + } go r.worker(ctx, key.(string), v.(chan Task)) return true }) @@ -101,25 +82,20 @@ func (r *Runner) orchestrate(ctx context.Context) { func (r *Runner) worker(ctx context.Context, hostID string, q chan Task) { defer r.orchestrator.manager.Done() - defer func() { - r.orchestrator.workers.Range(func(key, value interface{}) bool { - if key.(string) == hostID { //nolint:forcetypeassert // good - r.orchestrator.workers.Delete(key.(string)) - return true //nolint:revive // this is needed to satisfy the func parameter - } - return true //nolint:revive // this is needed to satisfy the func parameter - }) - - }() + defer r.orchestrator.workers.Delete(hostID) for { select { case <-ctx.Done(): + // TODO: check queue length before returning maybe? + // For 175000 tasks, i found there would occasionally be 1 or 2 that didnt get processed. + // still seemed to be in the queue/chan. return case t := <-q: r.process(ctx, t.Log, t.Description, t.ID, t.Action) metrics.PerIDQueue.WithLabelValues(hostID).Dec() case <-time.After(r.orchestrator.workerIdleTimeout): + // TODO: check queue length returning maybe? return } } diff --git a/grpc/taskrunner/taskrunner.go b/grpc/taskrunner/taskrunner.go index bc88a4d..c15269e 100644 --- a/grpc/taskrunner/taskrunner.go +++ b/grpc/taskrunner/taskrunner.go @@ -2,7 +2,6 @@ package taskrunner import ( "context" - "fmt" "net" "net/url" "sync" @@ -25,6 +24,14 @@ type Runner struct { orchestrator *orchestrator } +type Task struct { + ID string `json:"id"` + Host string `json:"host"` + Description string `json:"description"` + Action func(chan string) (string, error) `json:"-"` + Log logr.Logger `json:"-"` +} + // NewRunner returns a task runner that manages tasks, workers, queues, and persistence. // // maxIngestionWorkers is the maximum number of concurrent workers that will be allowed. @@ -33,19 +40,15 @@ type Runner struct { // maxWorkers is the maximum number of concurrent workers that will be allowed to handle bmc tasks. // // workerIdleTimeout is the idle timeout for workers. If no tasks are received within the timeout, the worker will exit. -func NewRunner(repo repository.Actions, maxIngestionWorkers, maxWorkers int, workerIdleTimeout time.Duration) *Runner { - fmt.Println("NewRunner", maxIngestionWorkers, maxWorkers, workerIdleTimeout) +func NewRunner(repo repository.Actions, maxWorkers int, workerIdleTimeout time.Duration) *Runner { o := &orchestrator{ - workers: sync.Map{}, - fifoQueue: newHostQueue(), - fifoChan: make(chan host, 5000), - ingestionQueue: NewIngestQueue(), - ingestManager: newManager(maxIngestionWorkers), + workers: sync.Map{}, + fifoChan: make(chan string, 10000), // perIDQueue is a map of hostID to a channel of tasks. perIDQueue: sync.Map{}, manager: newManager(maxWorkers), workerIdleTimeout: workerIdleTimeout, - ingestChan: make(chan Task, 5000), + ingestChan: make(chan Task, 10000), } return &Runner{ @@ -96,7 +99,6 @@ func (r *Runner) Execute(_ context.Context, l logr.Logger, description, taskID, Log: l, } - //r.orchestrator.ingestionQueue.Enqueue(i) r.orchestrator.ingestChan <- i metrics.IngestionQueue.Inc() metrics.Ingested.Inc() @@ -155,7 +157,6 @@ func (r *Runner) process(ctx context.Context, logger logr.Logger, description, t cctx, done := context.WithCancel(ctx) defer done() go r.updateMessages(cctx, taskID, messagesChan) - //logger.Info("worker start") resultRecord := repository.Record{ State: "complete", @@ -179,7 +180,6 @@ func (r *Runner) process(ctx context.Context, logger logr.Logger, description, t if errors.As(err, &foundErr) { resultRecord.Error = foundErr.StructuredError() } - //logger.Error(err, "task completed with an error") } record, err := r.Repository.Get(taskID) if err != nil { @@ -192,10 +192,8 @@ func (r *Runner) process(ctx context.Context, logger logr.Logger, description, t record.Error = resultRecord.Error if err := r.Repository.Update(taskID, record); err != nil { - //logger.Error(err, "failed to update record") + logger.Error(err, "failed to update record") } - - //logger.Info("worker complete", "complete", true) } // Status returns the status record of a task. diff --git a/grpc/taskrunner/taskrunner_test.go b/grpc/taskrunner/taskrunner_test.go index eab9352..e299db8 100644 --- a/grpc/taskrunner/taskrunner_test.go +++ b/grpc/taskrunner/taskrunner_test.go @@ -26,7 +26,7 @@ func TestRoundTrip(t *testing.T) { defer s.Close() repo := &persistence.GoKV{Store: s, Ctx: ctx} logger := logr.Discard() - runner := NewRunner(repo, 100, 100, time.Second) + runner := NewRunner(repo, 100, time.Second) runner.Start(ctx) time.Sleep(time.Millisecond * 100) @@ -39,7 +39,7 @@ func TestRoundTrip(t *testing.T) { }) // must be min of 3 because we sleep 2 seconds in worker function to allow final status messages to be written - time.Sleep(500 * time.Millisecond) + time.Sleep(time.Second * 2) record, err := runner.Status(ctx, taskID) if err != nil { t.Fatal(err) diff --git a/pkg/http/http.go b/pkg/http/http.go index 6ce9167..b05158e 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -3,6 +3,7 @@ package http import ( "context" "net/http" + "time" "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -35,16 +36,21 @@ func (h *Server) init() { } func (h *Server) Run(ctx context.Context) error { - svr := &http.Server{Addr: h.address, Handler: h.mux} - svr.ListenAndServe() + svr := &http.Server{ + Addr: h.address, + Handler: h.mux, + // Mitigate Slowloris attacks. 20 seconds is based on Apache's recommended 20-40 + // recommendation. Hegel doesn't really have many headers so 20s should be plenty of time. + // https://en.wikipedia.org/wiki/Slowloris_(computer_security) + ReadHeaderTimeout: 20 * time.Second, + } go func() { <-ctx.Done() - svr.Shutdown(ctx) + _ = svr.Shutdown(ctx) }() return svr.ListenAndServe() - // return http.ListenAndServe(h.address, h.mux) //nolint:gosec // TODO: add handle timeouts } func NewServer(addr string) *Server {