diff --git a/src/internal/task/queue.go b/src/internal/task/queue.go index c86bc596cb..4b0a22bf9c 100644 --- a/src/internal/task/queue.go +++ b/src/internal/task/queue.go @@ -68,8 +68,8 @@ func (s *Stack) Pop() *Task { t := s.top if t != nil { s.top = t.Next + t.Next = nil } - t.Next = nil return t } diff --git a/src/sync/cond.go b/src/sync/cond.go new file mode 100644 index 0000000000..e392bc6e2c --- /dev/null +++ b/src/sync/cond.go @@ -0,0 +1,80 @@ +package sync + +import "internal/task" + +type Cond struct { + L Locker + + unlocking *earlySignal + blocked task.Stack +} + +// earlySignal is a type used to implement a stack for signalling waiters while they are unlocking. +type earlySignal struct { + next *earlySignal + + signaled bool +} + +func (c *Cond) trySignal() bool { + // Pop a blocked task off of the stack, and schedule it if applicable. + t := c.blocked.Pop() + if t != nil { + scheduleTask(t) + return true + } + + // If there any tasks which are currently unlocking, signal one. + if c.unlocking != nil { + c.unlocking.signaled = true + c.unlocking = c.unlocking.next + return true + } + + // There was nothing to signal. + return false +} + +func (c *Cond) Signal() { + c.trySignal() +} + +func (c *Cond) Broadcast() { + // Signal everything. + for c.trySignal() { + } +} + +func (c *Cond) Wait() { + // Add an earlySignal frame to the stack so we can be signalled while unlocking. + early := earlySignal{ + next: c.unlocking, + } + c.unlocking = &early + + // Temporarily unlock L. + c.L.Unlock() + + // Re-acquire the lock before returning. + defer c.L.Lock() + + // If we were signaled while unlocking, immediately complete. + if early.signaled { + return + } + + // Remove the earlySignal frame. + prev := c.unlocking + for prev != nil && prev.next != &early { + prev = prev.next + } + if prev != nil { + prev.next = early.next + } else { + c.unlocking = early.next + } + + // Wait for a signal. + c.blocked.Push(task.Current()) + task.Pause() +} diff --git a/src/sync/mutex.go b/src/sync/mutex.go index 4e89285f93..acc02de3a3 100644 --- a/src/sync/mutex.go +++ b/src/sync/mutex.go @@ -1,16 +1,29 @@ package sync +import ( + "internal/task" + _ "unsafe" +) + // These mutexes assume there is only one thread of operation: no goroutines, // interrupts or anything else. type Mutex struct { - locked bool + locked bool + blocked task.Stack } +//go:linkname scheduleTask runtime.runqueuePushBack +func scheduleTask(*task.Task) + func (m *Mutex) Lock() { if m.locked { - panic("todo: block on locked mutex") + // Push self onto stack of blocked tasks, and wait to be resumed. + m.blocked.Push(task.Current()) + task.Pause() + return } + m.locked = true } @@ -18,7 +31,13 @@ func (m *Mutex) Unlock() { if !m.locked { panic("sync: unlock of unlocked Mutex") } - m.locked = false + + // Wake up a blocked task, if applicable. + if t := m.blocked.Pop(); t != nil { + scheduleTask(t) + } else { + m.locked = false + } } type RWMutex struct { @@ -50,3 +69,8 @@ func (rw *RWMutex) RUnlock() { rw.m.Unlock() } } + +type Locker interface { + Lock() + Unlock() +} diff --git a/src/sync/waitgroup.go b/src/sync/waitgroup.go new file mode 100644 index 0000000000..72ef24c809 --- /dev/null +++ b/src/sync/waitgroup.go @@ -0,0 +1,54 @@ +package sync + +import "internal/task" + +type WaitGroup struct { + counter uint + waiters task.Stack +} + +func (wg *WaitGroup) Add(delta int) { + if delta > 0 { + // Check for overflow. + if uint(delta) > (^uint(0))-wg.counter { + panic("sync: WaitGroup counter overflowed") + } + + // Add to the counter. + wg.counter += uint(delta) + } else { + // Check for underflow. + if uint(-delta) > wg.counter { + panic("sync: negative WaitGroup counter") + } + + // Subtract from the counter. + wg.counter -= uint(-delta) + + // If the counter is zero, everything is done and the waiters should be resumed. + // This code assumes that the waiters cannot wake up until after this function returns. + // In the current implementation, this is always correct. + if wg.counter == 0 { + for t := wg.waiters.Pop(); t != nil; t = wg.waiters.Pop() { + scheduleTask(t) + } + } + } +} + +func (wg *WaitGroup) Done() { + wg.Add(-1) +} + +func (wg *WaitGroup) Wait() { + if wg.counter == 0 { + // Everything already finished. + return + } + + // Push the current goroutine onto the waiter stack. + wg.waiters.Push(task.Current()) + + // Pause until the waiters are awoken by Add/Done. + task.Pause() +} diff --git a/testdata/channel.go b/testdata/channel.go index db5a86317b..fa4c131f82 100644 --- a/testdata/channel.go +++ b/testdata/channel.go @@ -2,45 +2,17 @@ package main import ( "runtime" + "sync" "time" ) -// waitGroup is a small type reimplementing some of the behavior of sync.WaitGroup -type waitGroup uint - -func (wg *waitGroup) wait() { - n := 0 - for *wg != 0 { - // pause and wait to be rescheduled - runtime.Gosched() - - if n > 100 { - // if something is using the sleep queue, this may be necessary - time.Sleep(time.Millisecond) - } - - n++ - } -} - -func (wg *waitGroup) add(n uint) { - *wg += waitGroup(n) -} - -func (wg *waitGroup) done() { - if *wg == 0 { - panic("wait group underflow") - } - *wg-- -} - -var wg waitGroup +var wg sync.WaitGroup func main() { ch := make(chan int) println("len, cap of channel:", len(ch), cap(ch), ch == nil) - wg.add(1) + wg.Add(1) go sender(ch) n, ok := <-ch @@ -50,7 +22,7 @@ func main() { println("received num:", n) } - wg.wait() + wg.Wait() n, ok = <-ch println("recv from closed channel:", n, ok) @@ -66,55 +38,55 @@ func main() { // Test bigger values ch2 := make(chan complex128) - wg.add(1) + wg.Add(1) go sendComplex(ch2) println("complex128:", <-ch2) - wg.wait() + wg.Wait() // Test multi-sender. ch = make(chan int) - wg.add(3) + wg.Add(3) go fastsender(ch, 10) go fastsender(ch, 23) go fastsender(ch, 40) slowreceiver(ch) - wg.wait() + wg.Wait() // Test multi-receiver. ch = make(chan int) - wg.add(3) + wg.Add(3) go fastreceiver(ch) go fastreceiver(ch) go fastreceiver(ch) slowsender(ch) - wg.wait() + wg.Wait() // Test iterator style channel. ch = make(chan int) - wg.add(1) + wg.Add(1) go iterator(ch, 100) sum := 0 for i := range ch { sum += i } - wg.wait() + wg.Wait() println("sum(100):", sum) // Test simple selects. go selectDeadlock() // cannot use waitGroup here - never terminates - wg.add(1) + wg.Add(1) go selectNoOp() - wg.wait() + wg.Wait() // Test select with a single send operation (transformed into chan send). ch = make(chan int) - wg.add(1) + wg.Add(1) go fastreceiver(ch) select { case ch <- 5: } close(ch) - wg.wait() + wg.Wait() println("did send one") // Test select with a single recv operation (transformed into chan recv). @@ -125,11 +97,11 @@ func main() { // Test select recv with channel that has one entry. ch = make(chan int) - wg.add(1) + wg.Add(1) go func(ch chan int) { runtime.Gosched() ch <- 55 - wg.done() + wg.Done() }(ch) select { case make(chan int) <- 3: @@ -139,7 +111,7 @@ func main() { case n := <-make(chan int): println("unreachable:", n) } - wg.wait() + wg.Wait() // Test select recv with closed channel. close(ch) @@ -154,7 +126,7 @@ func main() { // Test select send. ch = make(chan int) - wg.add(1) + wg.Add(1) go fastreceiver(ch) select { case ch <- 235: @@ -163,7 +135,7 @@ func main() { println("unreachable:", n) } close(ch) - wg.wait() + wg.Wait() // test non-concurrent buffered channels ch = make(chan int, 2) @@ -181,7 +153,7 @@ func main() { println("closed buffered channel recieve:", <-ch) // test using buffered channels as regular channels with special properties - wg.add(6) + wg.Add(6) ch = make(chan int, 2) go send(ch) go send(ch) @@ -189,7 +161,7 @@ func main() { go send(ch) go receive(ch) go receive(ch) - wg.wait() + wg.Wait() close(ch) var count int for range ch { @@ -202,19 +174,19 @@ func main() { sch1 := make(chan int) sch2 := make(chan int) sch3 := make(chan int) - wg.add(3) + wg.Add(3) go func() { - defer wg.done() + defer wg.Done() time.Sleep(time.Millisecond) sch1 <- 1 }() go func() { - defer wg.done() + defer wg.Done() time.Sleep(time.Millisecond) sch2 <- 2 }() go func() { - defer wg.done() + defer wg.Done() // merge sch2 and sch3 into ch for i := 0; i < 2; i++ { var v int @@ -238,18 +210,18 @@ func main() { sum += v } } - wg.wait() + wg.Wait() println("blocking select sum:", sum) } func send(ch chan<- int) { ch <- 1 - wg.done() + wg.Done() } func receive(ch <-chan int) { <-ch - wg.done() + wg.Done() } func sender(ch chan int) { @@ -261,18 +233,18 @@ func sender(ch chan int) { ch <- i } close(ch) - wg.done() + wg.Done() } func sendComplex(ch chan complex128) { ch <- 7 + 10.5i - wg.done() + wg.Done() } func fastsender(ch chan int, n int) { ch <- n ch <- n + 1 - wg.done() + wg.Done() } func slowreceiver(ch chan int) { @@ -298,7 +270,7 @@ func fastreceiver(ch chan int) { sum += n } println("sum:", sum) - wg.done() + wg.Done() } func iterator(ch chan int, top int) { @@ -306,7 +278,7 @@ func iterator(ch chan int, top int) { ch <- i } close(ch) - wg.done() + wg.Done() } func selectDeadlock() { @@ -321,5 +293,5 @@ func selectNoOp() { default: } println("after no-op") - wg.done() + wg.Done() } diff --git a/testdata/coroutines.go b/testdata/coroutines.go index 77e14d0e0a..49fdfc2870 100644 --- a/testdata/coroutines.go +++ b/testdata/coroutines.go @@ -1,6 +1,9 @@ package main -import "time" +import ( + "sync" + "time" +) func main() { println("main 1") @@ -51,6 +54,29 @@ func main() { println("closure go call result:", x) time.Sleep(2 * time.Millisecond) + + var m sync.Mutex + m.Lock() + println("pre-acquired mutex") + go acquire(&m) + time.Sleep(2 * time.Millisecond) + println("releasing mutex") + m.Unlock() + time.Sleep(2 * time.Millisecond) + m.Lock() + println("re-acquired mutex") + m.Unlock() + println("done") + + time.Sleep(2 * time.Millisecond) +} + +func acquire(m *sync.Mutex) { + m.Lock() + println("acquired mutex from goroutine") + time.Sleep(2 * time.Millisecond) + m.Unlock() + println("released mutex from goroutine") } func sub() { diff --git a/testdata/coroutines.txt b/testdata/coroutines.txt index e296f8e0ce..1e29558afc 100644 --- a/testdata/coroutines.txt +++ b/testdata/coroutines.txt @@ -14,3 +14,9 @@ async interface method call slept inside func pointer 8 slept inside closure, with value: 20 8 closure go call result: 1 +pre-acquired mutex +releasing mutex +acquired mutex from goroutine +released mutex from goroutine +re-acquired mutex +done