Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

control/controlbase: make the protocol version number selectable. #4370

Merged
merged 1 commit into from
Apr 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 6 additions & 4 deletions control/controlbase/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"tailscale.com/types/key"
)

const testProtocolVersion = 1

func TestMessageSize(t *testing.T) {
// This test is a regression guard against someone looking at
// maxCiphertextSize, going "huh, we could be more efficient if it
Expand Down Expand Up @@ -204,10 +206,10 @@ func TestConnStd(t *testing.T) {
serverErr := make(chan error, 1)
go func() {
var err error
c2, err = Server(context.Background(), s2, controlKey, nil)
c2, err = Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
serverErr <- err
}()
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public())
c1, err = Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
if err != nil {
s1.Close()
s2.Close()
Expand Down Expand Up @@ -396,11 +398,11 @@ func pairWithConns(t *testing.T, clientConn, serverConn net.Conn) (*Conn, *Conn)
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, controlKey, nil)
server, err = Server(context.Background(), serverConn, controlKey, testProtocolVersion, nil)
serverErr <- err
}()

client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public())
client, err := Client(context.Background(), clientConn, machineKey, controlKey.Public(), testProtocolVersion)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
Expand Down
37 changes: 25 additions & 12 deletions control/controlbase/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ const (
protocolName = "Noise_IK_25519_ChaChaPoly_BLAKE2s"
// protocolVersion is the version of the control protocol that
// Client will use when initiating a handshake.
protocolVersion uint16 = 1
//protocolVersion uint16 = 1
// protocolVersionPrefix is the name portion of the protocol
// name+version string that gets mixed into the handshake as a
// prologue.
Expand Down Expand Up @@ -66,7 +66,7 @@ type HandshakeContinuation func(context.Context, net.Conn) (*Conn, error)
// protocol switching. By splitting the handshake into an initial
// message and a continuation, we can embed the handshake initiation
// into the HTTP protocol switching request and avoid a bit of delay.
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (initialHandshake []byte, continueHandshake HandshakeContinuation, err error) {
var s symmetricState
s.Initialize()

Expand All @@ -78,7 +78,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
s.MixHash(controlKey.UntypedBytes())

// -> e, es, s, ss
init := mkInitiationMessage()
init := mkInitiationMessage(protocolVersion)
machineEphemeral := key.NewMachine()
machineEphemeralPub := machineEphemeral.Public()
copy(init.EphemeralPub(), machineEphemeralPub.UntypedBytes())
Expand All @@ -96,7 +96,7 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
s.EncryptAndHash(cipher, init.Tag(), nil) // empty message payload

cont := func(ctx context.Context, conn net.Conn) (*Conn, error) {
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey)
return continueClientHandshake(ctx, conn, &s, machineKey, machineEphemeral, controlKey, protocolVersion)
}
return init[:], cont, nil
}
Expand All @@ -107,8 +107,8 @@ func ClientDeferred(machineKey key.MachinePrivate, controlKey key.MachinePublic)
// This is a helper for when you don't need the fancy
// continuation-style handshake, and just want to synchronously
// upgrade a net.Conn to a secure transport.
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
init, cont, err := ClientDeferred(machineKey, controlKey)
func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
init, cont, err := ClientDeferred(machineKey, controlKey, protocolVersion)
if err != nil {
return nil, err
}
Expand All @@ -118,7 +118,7 @@ func Client(ctx context.Context, conn net.Conn, machineKey key.MachinePrivate, c
return cont(ctx, conn)
}

func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic) (*Conn, error) {
func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricState, machineKey, machineEphemeral key.MachinePrivate, controlKey key.MachinePublic, protocolVersion uint16) (*Conn, error) {
// No matter what, this function can only run once per s. Ensure
// attempted reuse causes a panic.
defer func() {
Expand Down Expand Up @@ -193,13 +193,19 @@ func continueClientHandshake(ctx context.Context, conn net.Conn, s *symmetricSta
// Server initiates a control server handshake, returning the resulting
// control connection.
//
// maxSupportedVersion is the highest handshake version the server is
// willing to handshake with. The server will handshake with any
// version from 0 to maxSupportedVersion inclusive, the caller should
// inspect conn.Version() to determine what version of the handshake
// was executed.
//
// optionalInit can be the client's initial handshake message as
// returned by ClientDeferred, or nil in which case the initial
// message is read from conn.
//
// The context deadline, if any, covers the entire handshaking
// process.
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, optionalInit []byte) (*Conn, error) {
func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, maxSupportedVersion uint16, optionalInit []byte) (*Conn, error) {
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("setting conn deadline: %w", err)
Expand Down Expand Up @@ -239,9 +245,16 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o
} else if _, err := io.ReadFull(conn, init.Header()); err != nil {
return nil, err
}
if init.Version() != protocolVersion {
return nil, sendErr("unsupported protocol version")
// Currently, these versions exclusively indicate what the upper
// RPC protocol understands, the Noise handshake is exactly the
// same in all versions. If that ever changes, this check will
// need to become more complex to handle different kinds of
// handshake.
if init.Version() > maxSupportedVersion {
return nil, sendErr("unsupported handshake version")
}
// Just a rename to make it more obvious what the value is
clientVersion := init.Version()
if init.Type() != msgTypeInitiation {
return nil, sendErr("unexpected handshake message type")
}
Expand All @@ -257,7 +270,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o

// prologue. Can only do this once we at least think the client is
// handshaking using a supported version.
s.MixHash(protocolVersionPrologue(protocolVersion))
s.MixHash(protocolVersionPrologue(clientVersion))

// <- s
// ...
Expand Down Expand Up @@ -310,7 +323,7 @@ func Server(ctx context.Context, conn net.Conn, controlKey key.MachinePrivate, o

c := &Conn{
conn: conn,
version: protocolVersion,
version: clientVersion,
peer: machineKey,
handshakeHash: s.h,
tx: txState{
Expand Down
28 changes: 14 additions & 14 deletions control/controlbase/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ func TestHandshake(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey, nil)
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
serverErr <- err
}()

client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
Expand All @@ -42,8 +42,8 @@ func TestHandshake(t *testing.T) {
t.Fatal("client and server disagree on handshake hash")
}

if client.ProtocolVersion() != int(protocolVersion) {
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), protocolVersion)
if client.ProtocolVersion() != int(testProtocolVersion) {
t.Fatalf("client reporting wrong protocol version %d, want %d", client.ProtocolVersion(), testProtocolVersion)
}
if client.ProtocolVersion() != server.ProtocolVersion() {
t.Fatalf("peers disagree on protocol version, client=%d server=%d", client.ProtocolVersion(), server.ProtocolVersion())
Expand Down Expand Up @@ -78,11 +78,11 @@ func TestNoReuse(t *testing.T) {
)
go func() {
var err error
server, err = Server(context.Background(), serverConn, serverKey, nil)
server, err = Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
serverErr <- err
}()

client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
Expand Down Expand Up @@ -172,7 +172,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey, nil)
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
// If the server failed, we have to close the Conn to
// unblock the client.
if err != nil {
Expand All @@ -181,7 +181,7 @@ func TestTampering(t *testing.T) {
serverErr <- err
}()

_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err == nil {
t.Fatal("client connection succeeded despite tampering")
}
Expand All @@ -200,11 +200,11 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
_, err := Server(context.Background(), serverConn, serverKey, nil)
_, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
serverErr <- err
}()

_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
_, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err == nil {
t.Fatal("client connection succeeded despite tampering")
}
Expand All @@ -225,13 +225,13 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey, nil)
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
serverErr <- err
_, err = io.WriteString(server, strings.Repeat("a", 14))
serverErr <- err
}()

client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err != nil {
t.Fatalf("client handshake failed: %v", err)
}
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestTampering(t *testing.T) {
serverErr = make(chan error, 1)
)
go func() {
server, err := Server(context.Background(), serverConn, serverKey, nil)
server, err := Server(context.Background(), serverConn, serverKey, testProtocolVersion, nil)
serverErr <- err
var bs [100]byte
// The server needs a timeout if the tampering is hitting the length header.
Expand All @@ -281,7 +281,7 @@ func TestTampering(t *testing.T) {
}
}()

client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public())
client, err := Client(context.Background(), clientConn, clientKey, serverKey.Public(), testProtocolVersion)
if err != nil {
t.Fatalf("client handshake failed: %v", err)
}
Expand Down
10 changes: 5 additions & 5 deletions control/controlbase/interop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestInteropClient(t *testing.T) {
)

go func() {
server, err := Server(context.Background(), s2, controlKey, nil)
server, err := Server(context.Background(), s2, controlKey, testProtocolVersion, nil)
serverErr <- err
if err != nil {
return
Expand Down Expand Up @@ -77,7 +77,7 @@ func TestInteropServer(t *testing.T) {
)

go func() {
client, err := Client(context.Background(), s1, machineKey, controlKey.Public())
client, err := Client(context.Background(), s1, machineKey, controlKey.Public(), testProtocolVersion)
clientErr <- err
if err != nil {
return
Expand Down Expand Up @@ -121,11 +121,11 @@ func noiseExplorerClient(conn net.Conn, controlKey key.MachinePublic, machineKey
copy(mk.public_key[:], machineKey.Public().UntypedBytes())
var peerKey [32]byte
copy(peerKey[:], controlKey.UntypedBytes())
session := InitSession(true, protocolVersionPrologue(protocolVersion), mk, peerKey)
session := InitSession(true, protocolVersionPrologue(testProtocolVersion), mk, peerKey)

_, msg1 := SendMessage(&session, nil)
var hdr [initiationHeaderLen]byte
binary.BigEndian.PutUint16(hdr[:2], protocolVersion)
binary.BigEndian.PutUint16(hdr[:2], testProtocolVersion)
hdr[2] = msgTypeInitiation
binary.BigEndian.PutUint16(hdr[3:5], 96)
if _, err := conn.Write(hdr[:]); err != nil {
Expand Down Expand Up @@ -193,7 +193,7 @@ func noiseExplorerServer(conn net.Conn, controlKey key.MachinePrivate, wantMachi
var mk keypair
copy(mk.private_key[:], controlKey.UntypedBytes())
copy(mk.public_key[:], controlKey.Public().UntypedBytes())
session := InitSession(false, protocolVersionPrologue(protocolVersion), mk, [32]byte{})
session := InitSession(false, protocolVersionPrologue(testProtocolVersion), mk, [32]byte{})

var buf [1024]byte
if _, err := io.ReadFull(conn, buf[:101]); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions control/controlbase/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ const (
// 16b: message tag (authenticates the whole message)
type initiationMessage [101]byte

func mkInitiationMessage() initiationMessage {
func mkInitiationMessage(protocolVersion uint16) initiationMessage {
var ret initiationMessage
binary.BigEndian.PutUint16(ret[:2], uint16(protocolVersion))
binary.BigEndian.PutUint16(ret[:2], protocolVersion)
ret[2] = msgTypeInitiation
binary.BigEndian.PutUint16(ret[3:5], uint16(len(ret.Payload())))
return ret
Expand Down
9 changes: 8 additions & 1 deletion control/controlclient/noise.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/tls"
"fmt"
"math"
"net"
"net/http"
"net/url"
Expand All @@ -17,6 +18,7 @@ import (
"golang.org/x/net/http2"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/multierr"
)
Expand Down Expand Up @@ -146,7 +148,12 @@ func (nc *noiseClient) dial(_, _ string, _ *tls.Config) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey)
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
// Panic, because a test should have started failing several
// thousand version numbers before getting to this point.
panic("capability version is too high to fit in the wire protocol")
}
conn, err := controlhttp.Dial(ctx, nc.serverHost, nc.privKey, nc.serverPubKey, uint16(tailcfg.CurrentCapabilityVersion))
if err != nil {
return nil, err
}
Expand Down
28 changes: 28 additions & 0 deletions control/controlclient/noise_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package controlclient

import (
"math"
"testing"

"tailscale.com/tailcfg"
)

// maxAllowedNoiseVersion is the highest we expect the Tailscale
// capability version to ever get. It's a value close to 2^16, but
// with enough leeway that we get a very early warning that it's time
// to rework the wire protocol to allow larger versions, while still
// giving us headroom to bump this test and fix the build.
//
// Code elsewhere in the client will panic() if the tailcfg capability
// version exceeds 16 bits, so take a failure of this test seriously.
const maxAllowedNoiseVersion = math.MaxUint16 - 5000

func TestNoiseVersion(t *testing.T) {
if tailcfg.CurrentCapabilityVersion > maxAllowedNoiseVersion {
t.Fatalf("tailcfg.CurrentCapabilityVersion is %d, want <=%d", tailcfg.CurrentCapabilityVersion, maxAllowedNoiseVersion)
}
}