-
Notifications
You must be signed in to change notification settings - Fork 623
/
server.go
200 lines (173 loc) 路 4.09 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
package server
import (
"context"
"errors"
"fmt"
"log/slog"
"net"
"os"
"sync"
"sync/atomic"
"time"
)
const (
BufferSize = 32
IntervalCheckParentPidMilliseconds = 100
)
var defaultLoggerPath atomic.Value
type ServerParams struct {
ListenIPAddress string
PortFilename string
ParentPid int
}
// Server is the core server
type Server struct {
// ctx is the context for the server. It is used to signal
// the server to shutdown
ctx context.Context
// cancel is the cancel function for the server
cancel context.CancelFunc
// listener is the underlying listener
listener net.Listener
// wg is the WaitGroup to wait for all connections to finish
// and for the serve goroutine to finish
wg sync.WaitGroup
// parentPid is the parent pid to watch and exit if it goes away
parentPid int
}
// NewServer creates a new server
func NewServer(
ctx context.Context,
params *ServerParams,
) (*Server, error) {
if params == nil {
return nil, errors.New("unconfigured params")
}
ctx, cancel := context.WithCancel(ctx)
listener, err := net.Listen("tcp", params.ListenIPAddress)
if err != nil {
cancel()
return nil, err
}
s := &Server{
ctx: ctx,
cancel: cancel,
listener: listener,
wg: sync.WaitGroup{},
parentPid: params.ParentPid,
}
port := s.listener.Addr().(*net.TCPAddr).Port
if err := writePortFile(params.PortFilename, port); err != nil {
slog.Error("failed to write port file", "error", err)
return nil, err
}
return s, nil
}
func (s *Server) loopCheckIfParentGone(pid int) bool {
for {
select {
case <-s.ctx.Done():
return false
case <-time.After(IntervalCheckParentPidMilliseconds * time.Millisecond):
}
parentpid := os.Getppid()
if parentpid != pid {
return true
}
}
}
func (s *Server) SetDefaultLoggerPath(path string) {
if path == "" {
return
}
defaultLoggerPath.Store(path)
}
// Serve starts the server
func (s *Server) Start() {
// watch for parent process exit in background (if specified)
if s.parentPid != 0 {
s.wg.Add(1)
go func() {
shouldExit := s.loopCheckIfParentGone(s.parentPid)
if shouldExit {
slog.Info("Parent process exited, terminating core process")
// Forcefully exit the server process because our controlling user process
// has exited so there is no need to sync uncommitted data.
// Exit code is arbitrary as parent process is gone.
os.Exit(1)
}
s.wg.Done()
}()
}
// run server in background
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.serve()
}()
}
func (s *Server) serve() {
slog.Info("server is running", "addr", s.listener.Addr())
// Run a separate goroutine to handle incoming connections
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.ctx.Done():
slog.Debug("server shutting down...")
return
default:
slog.Error("failed to accept conn.", "error", err)
}
} else {
s.wg.Add(1)
go func() {
nc := NewConnection(s.ctx, s.cancel, conn)
nc.HandleConnection()
s.wg.Done()
}()
}
}
}
// Wait waits for a signal to shutdown the server
func (s *Server) Wait() {
<-s.ctx.Done()
slog.Info("server is shutting down")
}
// Close closes the server
func (s *Server) Close() {
if err := s.listener.Close(); err != nil {
slog.Error("failed to Close listener", err)
}
s.wg.Wait()
slog.Info("server is closed")
}
func writePortFile(portFile string, port int) error {
tempFile := fmt.Sprintf("%s.tmp", portFile)
f, err := os.Create(tempFile)
if err != nil {
err = fmt.Errorf("fail create temp file: %w", err)
return err
}
if _, err = f.WriteString(fmt.Sprintf("sock=%d\n", port)); err != nil {
err = fmt.Errorf("fail write port: %w", err)
return err
}
if _, err = f.WriteString("EOF"); err != nil {
err = fmt.Errorf("fail write EOF: %w", err)
return err
}
if err = f.Sync(); err != nil {
err = fmt.Errorf("fail sync: %w", err)
return err
}
if err := f.Close(); err != nil {
err = fmt.Errorf("fail close: %w", err)
return err
}
if err = os.Rename(tempFile, portFile); err != nil {
err = fmt.Errorf("fail rename: %w", err)
return err
}
return nil
}