Skip to content

Commit

Permalink
Merge pull request #398 from projectdiscovery/maint-adaptive-group-wi…
Browse files Browse the repository at this point in the history
…th-sem

use semaphore instead of deprecated resizablechannel
  • Loading branch information
Mzack9999 committed Apr 30, 2024
2 parents 3dbe79d + f4bcedb commit 1e40ad2
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 20 deletions.
3 changes: 1 addition & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/cespare/xxhash v1.1.0
github.com/charmbracelet/glamour v0.6.0
github.com/docker/go-units v0.5.0
github.com/eapache/channels v1.1.0
github.com/fortytw2/leaktest v1.3.0
github.com/google/go-github/v30 v30.1.0
github.com/google/uuid v1.3.1
github.com/hdm/jarm-go v0.0.7
Expand Down Expand Up @@ -50,7 +50,6 @@ require (
github.com/dimchansky/utfbom v1.1.1 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect
github.com/eapache/queue v1.1.0 // indirect
github.com/fatih/color v1.15.0 // indirect
github.com/gaukas/godicttls v0.0.4 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY=
github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s=
github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY=
github.com/eapache/channels v1.1.0 h1:F1taHcn7/F0i8DYqKXJnyhJcVpp2kgFcNePxXtnyu4k=
github.com/eapache/channels v1.1.0/go.mod h1:jMm2qB5Ubtg9zLd+inMZd2/NUvXgzmWXsDaLyQIGfH0=
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/ebitengine/purego v0.4.0 h1:RQVuMIxQPQ5iCGEJvjQ17YOK+1tMKjVau2FUMvXH4HE=
github.com/ebitengine/purego v0.4.0/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ=
github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs=
github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
Expand Down
65 changes: 51 additions & 14 deletions sync/adaptivewaitgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,45 @@ import (
"context"
"errors"
"sync"
"sync/atomic"

"github.com/eapache/channels"
"github.com/projectdiscovery/utils/sync/semaphore"
)

type AdaptiveGroupOption func(*AdaptiveWaitGroup) error

type AdaptiveWaitGroup struct {
Size int
Size int
current *atomic.Int64

current *channels.ResizableChannel
wg sync.WaitGroup
sem *semaphore.Semaphore
wg sync.WaitGroup
mu sync.Mutex // Mutex to protect access to the Size and semaphore
}

// WithSize sets the initial size of the waitgroup ()
func WithSize(size int) AdaptiveGroupOption {
return func(wg *AdaptiveWaitGroup) error {
if size < 0 {
return errors.New("size must be positive")
if err := validateSize(size); err != nil {
return err
}
sem, err := semaphore.New(int64(size))
if err != nil {
return err
}
wg.sem = sem
wg.Size = size
return nil
}
}

func validateSize(size int) error {
if size < 1 {
return errors.New("size must be at least 1")
}
return nil
}

func New(options ...AdaptiveGroupOption) (*AdaptiveWaitGroup, error) {
wg := &AdaptiveWaitGroup{}
for _, option := range options {
Expand All @@ -37,9 +53,8 @@ func New(options ...AdaptiveGroupOption) (*AdaptiveWaitGroup, error) {
}
}

wg.current = channels.NewResizableChannel()
wg.current.Resize(channels.BufferCap(wg.Size))
wg.wg = sync.WaitGroup{}
wg.current = &atomic.Int64{}
return wg, nil
}

Expand All @@ -51,23 +66,45 @@ func (s *AdaptiveWaitGroup) AddWithContext(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
case s.current.In() <- struct{}{}:
break
default:
// Attempt to acquire a semaphore slot, handle error if acquisition fails
if err := s.sem.Acquire(ctx, 1); err != nil {
return err
}
}

// Safely add to the waitgroup only after acquiring the semaphore
s.wg.Add(1)
s.current.Add(1)
return nil
}

func (s *AdaptiveWaitGroup) Done() {
<-s.current.Out()
s.sem.Release(1)
s.wg.Done()
s.current.Add(-1)
}

func (s *AdaptiveWaitGroup) Wait() {
s.wg.Wait()
}

func (s *AdaptiveWaitGroup) Resize(size int) {
s.current.Resize(channels.BufferCap(size))
s.Size = int(s.current.Cap())
func (s *AdaptiveWaitGroup) Resize(ctx context.Context, size int) error {
s.mu.Lock()
defer s.mu.Unlock()

if err := validateSize(size); err != nil {
return err
}

// Resize the semaphore with the provided context and handle any errors
if err := s.sem.Resize(ctx, int64(size)); err != nil {
return err
}
s.Size = size
return nil
}

func (s *AdaptiveWaitGroup) Current() int {
return int(s.current.Load())
}
180 changes: 180 additions & 0 deletions sync/adaptivewaitgroup_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package sync

import (
"context"
"sync/atomic"
"testing"

"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/require"
)

// tests from https://github.com/remeh/sizedwaitgroup/blob/master/sizedwaitgroup_test.go

func TestWait(t *testing.T) {
swg, err := New(WithSize(10))
require.Nil(t, err)

var c uint32

for i := 0; i < 10000; i++ {
swg.Add()
go func(c *uint32) {
defer swg.Done()
atomic.AddUint32(c, 1)
}(&c)
}

swg.Wait()

if c != 10000 {
t.Fatalf("%d, not all routines have been executed.", c)
}
}

func TestThrottling(t *testing.T) {
var c atomic.Uint32

swg, err := New(WithSize(4))
require.Nil(t, err)

if swg.Current() != 0 {
t.Fatalf("the SizedWaitGroup should start with zero.")
}

for i := 0; i < 10000; i++ {
swg.Add()
go func() {
defer swg.Done()

c.Add(1)
require.False(t, swg.Current() > 5, "not the good amount of routines spawned.", swg.Current())
}()
}

swg.Wait()
}

func TestNoThrottling(t *testing.T) {
var c atomic.Int32
swg, err := New(WithSize(1))
require.Nil(t, err)

if swg.Current() != 0 {
t.Fatalf("the SizedWaitGroup should start with zero.")
}
for i := 0; i < 10000; i++ {
swg.Add()
go func() {
defer swg.Done()
c.Add(1)
}()
}
swg.Wait()
if c.Load() != 10000 {
t.Fatalf("%d, not all routines have been executed.", c.Load())
}
}

func TestAddWithContext(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.TODO())

swg, err := New(WithSize(1))
require.Nil(t, err)

if err := swg.AddWithContext(ctx); err != nil {
t.Fatalf("AddContext returned error: %v", err)
}

cancelFunc()
if err := swg.AddWithContext(ctx); err != context.Canceled {
t.Fatalf("AddContext returned non-context.Canceled error: %v", err)
}

}

func TestMultipleResizes(t *testing.T) {
var c atomic.Int32
swg, err := New(WithSize(2)) // Start with a size of 2
require.Nil(t, err)

for i := 0; i < 10000; i++ {
if i == 250 {
err := swg.Resize(context.TODO(), 5) // Increase size at 2500th iteration
require.Nil(t, err)
}
if i == 500 {
err := swg.Resize(context.TODO(), 1) // Decrease size at 5000th iteration
require.Nil(t, err)
}
if i == 750 {
err := swg.Resize(context.TODO(), 3) // Increase size again at 7500th iteration
require.Nil(t, err)
}

swg.Add()
go func() {
defer swg.Done()
c.Add(1)
}()
}

swg.Wait()
if c.Load() != 10000 {
t.Fatalf("%d, not all routines have been executed.", c.Load())
}
}

func Test_AdaptiveWaitGroup_Leak(t *testing.T) {
defer leaktest.Check(t)()

for j := 0; j < 1000; j++ {
wg, err := New(WithSize(10))
if err != nil {
t.Fatal(err)
}

for i := 0; i < 10000; i++ {
wg.Add()
go func(awg *AdaptiveWaitGroup) {
defer awg.Done()
}(wg)
}
wg.Wait()
}
}

func Test_AdaptiveWaitGroup_ContinuousResizeAndCheck(t *testing.T) {
defer leaktest.Check(t)() // Ensure no goroutines are leaked

var c atomic.Int32

wg, err := New(WithSize(1)) // Start with a size of 1
if err != nil {
t.Fatal(err)
}

// Perform continuous resizing and goroutine execution
for j := 0; j < 100; j++ {
for i := 0; i < 1000; i++ {
wg.Add()
go func(awg *AdaptiveWaitGroup) {
defer awg.Done()
c.Add(1)
}(wg)
}

// Increase or decrease size
newSize := (j % 10) + 1 // Cycle sizes between 1 and 10
err := wg.Resize(context.TODO(), newSize)
if err != nil {
t.Fatalf("Resize returned error: %v", err)
}

wg.Wait() // Wait at each step to ensure all routines finish before resizing again
}

if c.Load() != 100000 {
t.Fatalf("%d, not all routines have been executed.", c.Load())
}
}
23 changes: 23 additions & 0 deletions sync/semaphore/semaphore.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,29 @@ func (s *Semaphore) Vary(ctx context.Context, x int64) error {
}
}

func (s *Semaphore) Resize(ctx context.Context, newSize int64) error {
currentSize := s.currentSize.Load()
difference := newSize - currentSize

if difference == 0 {
return nil // No resizing needed if the new size is the same as the current size
}

if difference > 0 {
// Increase capacity
s.sem.Release(difference)
} else {
// Decrease capacity
err := s.sem.Acquire(ctx, -difference) // Acquire takes a positive number, so negate difference
if err != nil {
return err
}
}

s.currentSize.Store(newSize)
return nil
}

// Current size of the semaphore
func (s *Semaphore) Size() int64 {
return s.currentSize.Load()
Expand Down

0 comments on commit 1e40ad2

Please sign in to comment.