/
http.go
131 lines (117 loc) · 3.71 KB
/
http.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
package config
import (
"context"
"crypto/tls"
"net"
"net/http"
"sync"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/tripper"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
// NewHTTPTransport creates a new http transport. If CA or CAFile is set, the transport will
// add the CA to system cert pool.
func NewHTTPTransport(src Source) *http.Transport {
var (
lock sync.Mutex
tlsConfig *tls.Config
)
update := func(ctx context.Context, cfg *Config) {
rootCAs, err := cryptutil.GetCertPool(cfg.Options.CA, cfg.Options.CAFile)
if err == nil {
lock.Lock()
tlsConfig = &tls.Config{
RootCAs: rootCAs,
MinVersion: tls.VersionTLS12,
}
lock.Unlock()
} else {
log.Error(ctx).Err(err).Msg("config: error getting cert pool")
}
}
src.OnConfigChange(context.Background(), update)
update(context.Background(), src.GetConfig())
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
lock.Lock()
d := &tls.Dialer{
Config: tlsConfig,
}
lock.Unlock()
return d.DialContext(ctx, network, addr)
}
transport.ForceAttemptHTTP2 = true
return transport
}
// NewPolicyHTTPTransport creates a new http RoundTripper for a policy.
func NewPolicyHTTPTransport(options *Options, policy *Policy, disableHTTP2 bool) http.RoundTripper {
transport := http.DefaultTransport.(interface {
Clone() *http.Transport
}).Clone()
c := tripper.NewChain()
// according to the docs:
//
// Programs that must disable HTTP/2 can do so by setting Transport.TLSNextProto (for clients) or
// Server.TLSNextProto (for servers) to a non-nil, empty map.
//
if disableHTTP2 {
transport.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{}
transport.ForceAttemptHTTP2 = false
}
var tlsClientConfig tls.Config
var isCustomClientConfig bool
if policy.TLSSkipVerify {
tlsClientConfig.InsecureSkipVerify = true
isCustomClientConfig = true
}
if options.CA != "" || options.CAFile != "" {
rootCAs, err := cryptutil.GetCertPool(options.CA, options.CAFile)
if err == nil {
tlsClientConfig.RootCAs = rootCAs
tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true
} else {
log.Error(context.TODO()).Err(err).Msg("config: error getting ca cert pool")
}
}
if policy.TLSCustomCA != "" || policy.TLSCustomCAFile != "" {
rootCAs, err := cryptutil.GetCertPool(policy.TLSCustomCA, policy.TLSCustomCAFile)
if err == nil {
tlsClientConfig.RootCAs = rootCAs
tlsClientConfig.MinVersion = tls.VersionTLS12
isCustomClientConfig = true
} else {
log.Error(context.TODO()).Err(err).Msg("config: error getting custom ca cert pool")
}
}
if policy.ClientCertificate != nil {
tlsClientConfig.Certificates = []tls.Certificate{*policy.ClientCertificate}
isCustomClientConfig = true
}
if policy.TLSServerName != "" {
tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true
}
if policy.TLSUpstreamServerName != "" {
tlsClientConfig.ServerName = policy.TLSUpstreamServerName
isCustomClientConfig = true
}
// We avoid setting a custom client config unless we have to as
// if TLSClientConfig is nil, the default configuration is used.
if isCustomClientConfig {
transport.DialTLSContext = nil
transport.TLSClientConfig = &tlsClientConfig
}
return c.Then(transport)
}
// GetTLSClientTransport returns http transport accounting for custom CAs from config
func GetTLSClientTransport(cfg *Config) (*http.Transport, error) {
tlsConfig, err := cfg.GetTLSClientConfig()
if err != nil {
return nil, err
}
return &http.Transport{
TLSClientConfig: tlsConfig,
ForceAttemptHTTP2: true,
}, nil
}