/
unix.go
224 lines (189 loc) · 5.91 KB
/
unix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
package netx
import (
"context"
"fmt"
"io"
"net"
"os"
"sync/atomic"
"syscall"
)
// SendUnixConn sends a file descriptor embedded in conn over the unix domain
// socket.
// On success conn is closed because the owner is now the process that received
// the file descriptor.
//
// conn must be a *net.TCPConn or similar (providing a File method) or the
// function will panic.
func SendUnixConn(socket *net.UnixConn, conn net.Conn) (err error) {
return sendUnixFileConn(socket, BaseConn(conn).(fileConn), conn)
}
// SendUnixPacketConn sends a file descriptor embedded in conn over the unix
// domain socket.
// On success conn is closed because the owner is now the process that received
// the file descriptor.
//
// conn must be a *net.UDPConn or similar (providing a File method) or the
// function will panic.
func SendUnixPacketConn(socket *net.UnixConn, conn net.PacketConn) (err error) {
return sendUnixFileConn(socket, BasePacketConn(conn).(fileConn), conn)
}
func sendUnixFileConn(socket *net.UnixConn, conn fileConn, close io.Closer) (err error) {
var f *os.File
if f, err = conn.File(); err != nil {
return
}
defer f.Close()
if err = SendUnixFile(socket, f); err != nil {
return
}
close.Close()
return
}
// SendUnixFile sends a file descriptor embedded in file over the unix domain
// socket.
// On success the file is closed because the owner is now the process that
// received the file descriptor.
func SendUnixFile(socket *net.UnixConn, file *os.File) (err error) {
var fds = [1]int{int(file.Fd())}
var oob = syscall.UnixRights(fds[:]...)
if _, _, err = socket.WriteMsgUnix(nil, oob, nil); err != nil {
return
}
file.Close()
return
}
// RecvUnixConn receives a network connection from a unix domain socket.
func RecvUnixConn(socket *net.UnixConn) (conn net.Conn, err error) {
var f *os.File
if f, err = RecvUnixFile(socket); err != nil {
return
}
defer f.Close()
return net.FileConn(f)
}
// RecvUnixPacketConn receives a packet oriented network connection from a unix
// domain socket.
func RecvUnixPacketConn(socket *net.UnixConn) (conn net.PacketConn, err error) {
var f *os.File
if f, err = RecvUnixFile(socket); err != nil {
return
}
defer f.Close()
return net.FilePacketConn(f)
}
// RecvUnixFile receives a file descriptor from a unix domain socket.
func RecvUnixFile(socket *net.UnixConn) (file *os.File, err error) {
var oob = make([]byte, syscall.CmsgSpace(4))
var oobn int
var msg []syscall.SocketControlMessage
var fds []int
if _, oobn, _, _, err = socket.ReadMsgUnix(nil, oob); err != nil {
return
} else if oobn == 0 {
err = io.EOF
return
}
if msg, err = syscall.ParseSocketControlMessage(oob); err != nil {
err = os.NewSyscallError("ParseSocketControlMessage", err)
return
}
if len(msg) != 1 {
err = fmt.Errorf("invalid number of socket control messages, expected 1 but found %d", len(msg))
return
}
if fds, err = syscall.ParseUnixRights(&msg[0]); err != nil {
err = os.NewSyscallError("ParseUnixRights", err)
return
}
if len(fds) != 1 {
for _, fd := range fds {
syscall.Close(fd)
}
err = fmt.Errorf("too many file descriptors found in a single control message, %d were closed", len(fds))
return
}
file = os.NewFile(uintptr(fds[0]), "")
return
}
// NewRecvUnixListener returns a new listener which accepts connection by
// reading file descriptors from a unix domain socket.
//
// The function doesn't make a copy of socket, so the returned listener should
// be considered the new owner of that object, which means closing the listener
// will actually close the original socket (and vice versa).
func NewRecvUnixListener(socket *net.UnixConn) *RecvUnixListener {
return &RecvUnixListener{*socket}
}
// RecvUnixListener is a listener which acceptes connections by reading file
// descriptors from a unix domain socket.
type RecvUnixListener struct {
socket net.UnixConn
}
// Accept receives a file descriptor from the listener's unix domain socket.
func (l *RecvUnixListener) Accept() (net.Conn, error) {
return RecvUnixConn(&l.socket)
}
// Addr returns the address of the listener's unix domain socket.
func (l *RecvUnixListener) Addr() net.Addr {
return l.socket.LocalAddr()
}
// Close closes the underlying unix domain socket.
func (l *RecvUnixListener) Close() error {
return l.socket.Close()
}
// UnixConn returns a pointer to the underlying unix domain socket.
func (l *RecvUnixListener) UnixConn() *net.UnixConn {
return &l.socket
}
// NewSendUnixHandler wraps handler so the connetions it receives will be sent
// back to socket when handler returns without closing them.
func NewSendUnixHandler(socket *net.UnixConn, handler Handler) *SendUnixHandler {
return &SendUnixHandler{
handler: handler,
socket: *socket,
}
}
// SendUnixHandler is a connection handler which sends the connections it
// handles back through a unix domain socket.
type SendUnixHandler struct {
handler Handler
socket net.UnixConn
}
// ServeConn satisfies the Handler interface.
func (h *SendUnixHandler) ServeConn(ctx context.Context, conn net.Conn) {
c := &sendUnixConn{Conn: conn}
h.handler.ServeConn(ctx, c)
if atomic.LoadUint32(&c.closed) == 0 {
if err := SendUnixConn(&h.socket, conn); err != nil {
panic(fmt.Errorf("sending connection back over unix domain socket: %s", err))
}
}
}
// UnixConn returns a pointer to the underlying unix domain socket.
func (h *SendUnixHandler) UnixConn() *net.UnixConn {
return &h.socket
}
type sendUnixConn struct {
net.Conn
closed uint32
}
func (c *sendUnixConn) Base() net.Conn {
return c.Conn
}
func (c *sendUnixConn) Close() (err error) {
atomic.StoreUint32(&c.closed, 1)
return c.Conn.Close()
}
func (c *sendUnixConn) Read(b []byte) (n int, err error) {
if n, err = c.Conn.Read(b); err != nil && !IsTemporary(err) {
atomic.StoreUint32(&c.closed, 1)
}
return
}
func (c *sendUnixConn) Write(b []byte) (n int, err error) {
if n, err = c.Conn.Write(b); err != nil && !IsTemporary(err) {
atomic.StoreUint32(&c.closed, 1)
}
return
}