forked from ionorg/ion-sfu
/
datachannel.go
81 lines (65 loc) · 2.09 KB
/
datachannel.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
package sfu
import (
"context"
"github.com/pion/webrtc/v3"
)
type (
// Datachannel is a wrapper to define middlewares executed on defined label.
// The datachannels created will be negotiated on join to all peers that joins
// the SFU.
Datachannel struct {
Label string
middlewares []func(MessageProcessor) MessageProcessor
onMessage func(ctx context.Context, args ProcessArgs)
}
ProcessArgs struct {
Peer Peer
Message webrtc.DataChannelMessage
DataChannel *webrtc.DataChannel
}
Middlewares []func(MessageProcessor) MessageProcessor
MessageProcessor interface {
Process(ctx context.Context, args ProcessArgs)
}
ProcessFunc func(ctx context.Context, args ProcessArgs)
chainHandler struct {
middlewares Middlewares
Last MessageProcessor
current MessageProcessor
}
)
// Use adds the middlewares to the current Datachannel.
// The middlewares are going to be executed before the OnMessage event fires.
func (dc *Datachannel) Use(middlewares ...func(MessageProcessor) MessageProcessor) {
dc.middlewares = append(dc.middlewares, middlewares...)
}
// OnMessage sets the message callback for the datachannel, the event is fired
// after all the middlewares have processed the message.
func (dc *Datachannel) OnMessage(fn func(ctx context.Context, args ProcessArgs)) {
dc.onMessage = fn
}
func (p ProcessFunc) Process(ctx context.Context, args ProcessArgs) {
p(ctx, args)
}
func (mws Middlewares) Process(h MessageProcessor) MessageProcessor {
return &chainHandler{mws, h, chain(mws, h)}
}
func (mws Middlewares) ProcessFunc(h MessageProcessor) MessageProcessor {
return &chainHandler{mws, h, chain(mws, h)}
}
func newDCChain(m []func(p MessageProcessor) MessageProcessor) Middlewares {
return Middlewares(m)
}
func (c *chainHandler) Process(ctx context.Context, args ProcessArgs) {
c.current.Process(ctx, args)
}
func chain(mws []func(processor MessageProcessor) MessageProcessor, last MessageProcessor) MessageProcessor {
if len(mws) == 0 {
return last
}
h := mws[len(mws)-1](last)
for i := len(mws) - 2; i >= 0; i-- {
h = mws[i](h)
}
return h
}