-
Notifications
You must be signed in to change notification settings - Fork 62
/
relay.go
159 lines (136 loc) · 3.8 KB
/
relay.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// SPDX-License-Identifier: BSD-3-Clause
// Copyright (c) 2024, Unikraft GmbH and The KraftKit Authors.
// Licensed under the BSD-3-Clause License (the "License").
// You may not use this file except in compliance with the License.
package tunnel
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"time"
"kraftkit.sh/log"
)
// Relay relays TCP connections to a local listener to a remote host over TLS.
type Relay struct {
lAddr string
rAddr string
}
func (r *Relay) Up(ctx context.Context) error {
l, err := r.listenLocal(ctx)
if err != nil {
return err
}
defer func() { l.Close() }()
go func() { <-ctx.Done(); l.Close() }()
log.G(ctx).Info("Tunnelling ", l.Addr(), " to ", r.rAddr)
for {
conn, err := l.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("accepting incoming connection: %w", err)
}
c := r.newConnection(conn)
go c.handle(ctx)
}
}
// newConnection creates a new connection from the given net.Conn.
func (r *Relay) newConnection(conn net.Conn) *connection {
return &connection{
relay: r,
conn: conn,
}
}
func (r *Relay) dialRemote(ctx context.Context) (net.Conn, error) {
var d tls.Dialer
return d.DialContext(ctx, "tcp4", r.rAddr)
}
func (r *Relay) listenLocal(ctx context.Context) (net.Listener, error) {
var lc net.ListenConfig
return lc.Listen(ctx, "tcp4", r.lAddr)
}
// connection represents the server side of a connection to a local TCP socket.
type connection struct {
// relay is the relay on which the connection arrived.
relay *Relay
// conn is the underlying network connection.
conn net.Conn
}
// handle handles the client connection by relaying reads and writes from/to
// the remote host.
func (c *connection) handle(ctx context.Context) {
log.G(ctx).Info("Accepted client connection ", c.conn.RemoteAddr())
defer func() {
c.conn.Close()
log.G(ctx).Info("Closed client connection ", c.conn.RemoteAddr())
}()
rc, err := c.relay.dialRemote(ctx)
if err != nil {
log.G(ctx).WithError(err).Error("Failed to connect to remote host")
return
}
defer rc.Close()
// NOTE(antoineco): these calls are critical as they allow reads/writes to be
// later cancelled, because the deadline applies to all future and pending
// I/O and can be dynamically extended or reduced.
_ = rc.SetDeadline(noNetTimeout)
_ = rc.SetDeadline(noNetTimeout)
defer func() {
_ = c.conn.SetDeadline(immediateNetCancel)
}()
const bufSize = 32 * 1024 // same as io.Copy
writerDone := make(chan struct{})
go func() {
defer func() {
_ = rc.SetDeadline(immediateNetCancel)
writerDone <- struct{}{}
}()
writeBuf := make([]byte, bufSize)
for {
n, err := c.conn.Read(writeBuf)
if err != nil {
if !errors.Is(err, io.EOF) {
log.G(ctx).WithError(err).Error("Failed to read from client")
}
return
}
if _, err := rc.Write(writeBuf[:n]); err != nil {
log.G(ctx).WithError(err).Error("Failed to write to remote host")
return
}
}
}()
readBuf := make([]byte, bufSize)
for {
n, err := rc.Read(readBuf)
if err != nil {
// expected when the connection gets aborted by a deadline
if !isNetTimeoutError(err) {
log.G(ctx).WithError(err).Error("Failed to read from remote host")
}
break
}
if _, err := c.conn.Write(readBuf[:n]); err != nil {
log.G(ctx).WithError(err).Error("Failed to write to client")
break
}
}
<-writerDone
}
var (
// zero time value used to prevent network operations from timing out.
noNetTimeout = time.Time{}
// non-zero time far in the past used for immediate cancellation of network operations.
immediateNetCancel = time.Unix(1, 0)
)
// isNetTimeoutError reports whether err is a network timeout error.
func isNetTimeoutError(err error) bool {
if neterr := net.Error(nil); errors.As(err, &neterr) {
return neterr.Timeout()
}
return false
}