Skip to content

Commit

Permalink
override tscert.TailscaledTransport with muxing transport
Browse files Browse the repository at this point in the history
This provides an http.RoundTripper implementation that dynamically
routes requests to the correct tsnet server's LocalAPI based on the
ClientHelloInfo in the context.

Updates #19
Updates #53
Updates #66
  • Loading branch information
willnorris committed Jun 7, 2024
1 parent 5cc2140 commit dd09a6e
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions module.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@ import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"strconv"
"strings"
"sync"

"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/caddyserver/certmagic"
"github.com/tailscale/tscert"
"go.uber.org/zap"
"tailscale.com/client/tailscale"
"tailscale.com/hostinfo"
"tailscale.com/tsnet"
)
Expand All @@ -36,7 +39,8 @@ func init() {
// Caddy uses tscert to get certificates for Tailscale hostnames.
// Update the tscert dialer to dial the LocalAPI of the correct tsnet node,
// rather than just always dialing the local tailscaled.
tscert.TailscaledDialer = localAPIDialer
//tscert.TailscaledDialer = localAPIDialer
tscert.TailscaledTransport = &tsnetMuxTransport{}

Check failure on line 43 in module.go

View workflow job for this annotation

GitHub Actions / lint

undefined: tscert.TailscaledTransport) (typecheck)

Check failure on line 43 in module.go

View workflow job for this annotation

GitHub Actions / lint

undefined: tscert.TailscaledTransport (typecheck)

Check failure on line 43 in module.go

View workflow job for this annotation

GitHub Actions / tests

undefined: tscert.TailscaledTransport
hostinfo.SetApp("caddy")
}

Expand Down Expand Up @@ -317,40 +321,53 @@ func (t *tsnetServerListener) Close() error {
return err
}

// localAPIDialer finds the node that matches the requested certificate in ctx
// and dials that node's local API.
// If no matching node is found, the default dialer is used,
// which tries to connect to a local tailscaled on the machine.
func localAPIDialer(ctx context.Context, network, addr string) (net.Conn, error) {
if addr != "local-tailscaled.sock:80" {
return nil, fmt.Errorf("unexpected URL address %q", addr)
}
// localAPITransport is an [http.RoundTripper] that sends requests to a [tailscale.LocalClient]'s LocalAPI.
type localAPITransport struct {
*tailscale.LocalClient
}

func (t *localAPITransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.DoLocalRequest(req)
}

// tsnetMuxTransport is an [http.RoundTripper] that sends requests to the LocalAPI
// for the tsnet server that matches the ClientHelloInfo server name.
// If no tsnet server matches, a default Transport is used.
type tsnetMuxTransport struct {
defaultTransport *http.Transport
defaultTransportOnce sync.Once
}

func (t *tsnetMuxTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := req.Context()
var rt http.RoundTripper

clientHello, ok := ctx.Value(certmagic.ClientHelloInfoCtxKey).(*tls.ClientHelloInfo)
if !ok || clientHello == nil {
return tscert.DialLocalAPI(ctx, network, addr)
}

var tn *tailscaleNode
nodes.Range(func(key, value any) bool {
if n, ok := value.(*tailscaleNode); ok && n != nil {
for _, d := range n.CertDomains() {
// Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard
// for the built-in Tailscale cert manager, so we do so here as well.
if certmagic.MatchWildcard(clientHello.ServerName, d) {
tn = n
return false
if ok && clientHello != nil {
nodes.Range(func(key, value any) bool {
if n, ok := value.(*tailscaleNode); ok && n != nil {
for _, d := range n.CertDomains() {
// Tailscale doesn't do wildcard certs, but caddy uses MatchWildcard
// for the built-in Tailscale cert manager, so we do so here as well.
if certmagic.MatchWildcard(clientHello.ServerName, d) {
if lc, err := n.LocalClient(); err == nil {
rt = &localAPITransport{lc}
}
return false
}
}
}
}
return true
})

if tn != nil {
if lc, err := tn.LocalClient(); err == nil {
return lc.Dial(ctx, network, addr)
}
return true
})
}

return tscert.DialLocalAPI(ctx, network, addr)
if rt == nil {
t.defaultTransportOnce.Do(func() {
t.defaultTransport = &http.Transport{
DialContext: tscert.TailscaledDialer,
}
})
rt = t.defaultTransport
}
return rt.RoundTrip(req)
}

0 comments on commit dd09a6e

Please sign in to comment.