Skip to content

Commit

Permalink
Wrap tsnet Listener with tsnetServerListener
Browse files Browse the repository at this point in the history
Everytime we get a listener we increment the refernce on a caddy
UsagePool. With this change, wrap the listener in a custom type that
will decrement the UsagePool when we close the listener. When all
listeners are closed we will call Destruct on the tsnet Server which
will Close the server and should cleanup potential references in the
tailscale control plane

Signed-off-by: Connor Kelly <connor.r.kelly@gmail.com>
  • Loading branch information
clly committed May 1, 2024
1 parent ffef646 commit 9fc8b14
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
1 change: 1 addition & 0 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func (t *TSApp) Start() error {
}

func (t *TSApp) Stop() error {
app = nil
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion caddyfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func init() {
}

func parseApp(d *caddyfile.Dispenser, _ any) (any, error) {
app := &TSApp{
app = &TSApp{
Servers: make(map[string]TSServer),
}
if !d.Next() {
Expand Down
31 changes: 30 additions & 1 deletion module.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ var (
)

func init() {
app = &TSApp{
Servers: map[string]TSServer{},
}
caddy.RegisterModule(TailscaleAuth{})
httpcaddyfile.RegisterHandlerDirective("tailscale_auth", parseCaddyfile)
caddy.RegisterNetwork("tailscale", getPlainListener)
Expand All @@ -48,7 +51,10 @@ func getPlainListener(_ context.Context, _ string, addr string, _ net.ListenConf
network = "tcp"
}

return s.Listen(network, ":"+port)
ln := &tsnetServerDestructor{
Server: s.Server,
}
return ln.Listen(network, ":"+port)
}

func getTLSListener(_ context.Context, _ string, addr string, _ net.ListenConfig) (any, error) {
Expand Down Expand Up @@ -258,3 +264,26 @@ type tsnetServerDestructor struct {
func (t tsnetServerDestructor) Destruct() error {
return t.Close()
}

func (t *tsnetServerDestructor) Listen(network string, addr string) (net.Listener, error) {
ln, err := t.Server.Listen(network, addr)
if err != nil {
return nil, err
}
serverListener := &tsnetServerListener{
hostname: t.Hostname,
Listener: ln,
}
return serverListener, nil
}

type tsnetServerListener struct {
hostname string
net.Listener
}

func (t *tsnetServerListener) Close() error {
fmt.Println("Delete", t.hostname)
_, err := servers.Delete(t.hostname)
return err
}
29 changes: 29 additions & 0 deletions module_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package tscaddy

import (
"errors"
"fmt"
"net"
"strings"
"testing"
)
Expand Down Expand Up @@ -55,3 +57,30 @@ func Test_GetAuthKey(t *testing.T) {
})
}
}

func Test_Listen(t *testing.T) {
svr, err := getServer("", "testhost")
if err != nil {
t.Fatal("failed to get server", err)
}

ln, err := svr.Listen("tcp", ":80")
if err != nil {
t.Fatal("failed to listen", err)
}
count, exists := servers.References("testhost")
if !exists && count != 1 {
t.Fatal("reference doesn't exist")
}
ln.Close()

count, exists = servers.References("testhost")
if exists && count != 0 {
t.Fatal("reference exists when it shouldn't")
}

err = svr.Close()
if !errors.Is(err, net.ErrClosed) {
t.Fatal("unexpected error", err)
}
}

0 comments on commit 9fc8b14

Please sign in to comment.