-
Notifications
You must be signed in to change notification settings - Fork 4.5k
/
server.go
137 lines (118 loc) · 3.41 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package simple
import (
"bytes"
"context"
"sync"
"github.com/v2fly/v2ray-core/v5/common"
"github.com/v2fly/v2ray-core/v5/transport/internet/request"
)
func newServer(config *ServerConfig) request.SessionAssemblerServer {
return &simpleAssemblerServer{}
}
type simpleAssemblerServer struct {
sessions sync.Map
assembly request.TransportServerAssembly
}
func (s *simpleAssemblerServer) OnTransportServerAssemblyReady(assembly request.TransportServerAssembly) {
s.assembly = assembly
}
func (s *simpleAssemblerServer) OnRoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption,
) (resp request.Response, err error) {
connectionID := req.ConnectionTag
session := newSimpleAssemblerServerSession(ctx)
loadedSession, loaded := s.sessions.LoadOrStore(string(connectionID), session)
if loaded {
session = loadedSession.(*simpleAssemblerServerSession)
} else {
if err := s.assembly.SessionReceiver().OnNewSession(ctx, session); err != nil {
return request.Response{}, newError("failed to create new session").Base(err)
}
}
return session.OnRoundTrip(ctx, req, opts...)
}
func newSimpleAssemblerServerSession(ctx context.Context) *simpleAssemblerServerSession {
sessionCtx, finish := context.WithCancel(ctx)
return &simpleAssemblerServerSession{
readBuffer: bytes.NewBuffer(nil),
readChan: make(chan []byte, 16),
requestProcessed: make(chan struct{}),
writeLock: new(sync.Mutex),
writeBuffer: bytes.NewBuffer(nil),
maxWriteSize: 4096,
ctx: sessionCtx,
finish: finish,
}
}
type simpleAssemblerServerSession struct {
maxWriteSize int
readBuffer *bytes.Buffer
readChan chan []byte
requestProcessed chan struct{}
writeLock *sync.Mutex
writeBuffer *bytes.Buffer
ctx context.Context
finish func()
}
func (s *simpleAssemblerServerSession) Read(p []byte) (n int, err error) {
if s.readBuffer.Len() == 0 {
select {
case <-s.ctx.Done():
return 0, s.ctx.Err()
case data := <-s.readChan:
s.readBuffer.Write(data)
}
}
return s.readBuffer.Read(p)
}
func (s *simpleAssemblerServerSession) Write(p []byte) (n int, err error) {
s.writeLock.Lock()
n, err = s.writeBuffer.Write(p)
length := s.writeBuffer.Len()
s.writeLock.Unlock()
if err != nil {
return 0, err
}
if length > s.maxWriteSize {
select {
case <-s.requestProcessed:
case <-s.ctx.Done():
return 0, s.ctx.Err()
}
}
return
}
func (s *simpleAssemblerServerSession) Close() error {
s.finish()
return nil
}
func (s *simpleAssemblerServerSession) OnRoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption,
) (resp request.Response, err error) {
if req.Data != nil && len(req.Data) > 0 {
select {
case <-s.ctx.Done():
return request.Response{}, s.ctx.Err()
case s.readChan <- req.Data:
}
}
s.writeLock.Lock()
nextWrite := s.writeBuffer.Next(s.maxWriteSize)
data := make([]byte, len(nextWrite))
copy(data, nextWrite)
s.writeLock.Unlock()
select {
case s.requestProcessed <- struct{}{}:
case <-s.ctx.Done():
return request.Response{}, s.ctx.Err()
default:
}
return request.Response{Data: data}, nil
}
func init() {
common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
serverConfig, ok := config.(*ServerConfig)
if !ok {
return nil, newError("not a SimpleServerConfig")
}
return newServer(serverConfig), nil
}))
}