/
gateway.go
136 lines (121 loc) · 3.33 KB
/
gateway.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package gateway
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strings"
"text/template"
"github.com/zllovesuki/t/multiplexer"
"github.com/zllovesuki/t/multiplexer/alpn"
"github.com/zllovesuki/t/profiler"
"github.com/zllovesuki/t/server"
"github.com/zllovesuki/t/shared"
"go.uber.org/zap"
)
type GatewayConfig struct {
Logger *zap.Logger
Multiplexer *server.Server
Listener net.Listener
RootDomain string
GatewayPort int
}
type Gateway struct {
GatewayConfig
apexServer *apexServer
apexAcceptor *httpAccepter
httpTunnelAcceptor *httpAccepter
}
func New(conf GatewayConfig) (*Gateway, error) {
md, err := template.New("content").Parse(tmpl)
if err != nil {
return nil, fmt.Errorf("reading markdown for apex template: %w", err)
}
idx, err := template.New("index").Parse(index)
if err != nil {
return nil, fmt.Errorf("reading index for apex template: %w", err)
}
d := conf.RootDomain
if conf.GatewayPort != 443 {
d = fmt.Sprintf("%s:%d", d, conf.GatewayPort)
}
return &Gateway{
GatewayConfig: conf,
apexAcceptor: &httpAccepter{
parent: conf.Listener,
ch: make(chan net.Conn, 1024),
},
httpTunnelAcceptor: &httpAccepter{
parent: conf.Listener,
ch: make(chan net.Conn, 1024),
},
apexServer: &apexServer{
clientPort: conf.GatewayPort,
hostname: conf.RootDomain,
host: d,
mdTmpl: md,
indexTmpl: idx,
},
}, nil
}
func (g *Gateway) Start(ctx context.Context) {
go http.Serve(g.apexAcceptor, g.apexServer.Handler())
go http.Serve(g.httpTunnelAcceptor, g.httpHandler())
for {
conn, err := g.Listener.Accept()
if err != nil {
g.Logger.Error("accepting gateway connection", zap.Error(err))
return
}
tconn := conn.(*tls.Conn)
go g.handleConnection(ctx, tconn)
}
}
func (g *Gateway) handleConnection(ctx context.Context, conn *tls.Conn) {
cs := conn.ConnectionState()
switch cs.ServerName {
case g.RootDomain:
// route to main page
g.apexAcceptor.ch <- conn
default:
// maybe tunnel it
switch cs.NegotiatedProtocol {
case alpn.Unknown.String(), alpn.HTTP.String():
g.Logger.Debug("forward http connection")
profiler.GatewayReqsType.WithLabelValues("http").Inc()
g.httpTunnelAcceptor.ch <- conn
case alpn.Raw.String():
g.Logger.Debug("forward raw connection")
profiler.GatewayReqsType.WithLabelValues("raw").Inc()
_, err := g.Multiplexer.Forward(ctx, conn, g.link(cs.ServerName, cs.NegotiatedProtocol))
if errors.Is(err, multiplexer.ErrDestinationNotFound) {
profiler.GatewayRequests.WithLabelValues("not_found", "forward").Add(1)
conn.Close()
return
}
if err != nil {
g.Logger.Error("establish raw link error", zap.Error(err))
conn.Close()
}
case alpn.Multiplexer.String():
g.Logger.Warn("received alpn proposal for multiplexer on gateway")
profiler.GatewayReqsType.WithLabelValues("multiplexer").Inc()
conn.Close()
default:
g.Logger.Warn("unknown alpn proposal", zap.String("proposal", cs.NegotiatedProtocol))
profiler.GatewayReqsType.WithLabelValues("error").Inc()
conn.Close()
}
}
}
func (g *Gateway) link(sni, proto string) multiplexer.Link {
parts := strings.SplitN(sni, ".", 2)
clientID := shared.PeerHash(parts[0])
return multiplexer.Link{
Source: g.Multiplexer.PeerID(),
Destination: clientID,
ALPN: alpn.ReverseMap[proto],
}
}