Skip to content
This repository has been archived by the owner on Jan 31, 2024. It is now read-only.

Commit

Permalink
add Accept0RTT to Config callback to decide if 0-RTT should be accepted
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed May 29, 2023
1 parent 2796c8a commit 3be08c2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 28 deletions.
6 changes: 6 additions & 0 deletions common.go
Expand Up @@ -726,13 +726,19 @@ type ExtraConfig struct {
//
// It has no meaning on the client.
GetAppDataForSessionTicket func() []byte

// The Accept0RTT callback is called when the client offers 0-RTT.
// The server then has to decide if it wants to accept or reject 0-RTT.
// It is only used for servers.
Accept0RTT func(appData []byte) bool
}

// Clone clones.
func (c *ExtraConfig) Clone() *ExtraConfig {
return &ExtraConfig{
Enable0RTT: c.Enable0RTT,
GetAppDataForSessionTicket: c.GetAppDataForSessionTicket,
Accept0RTT: c.Accept0RTT,
}
}

Expand Down
59 changes: 33 additions & 26 deletions handshake_server_tls13.go
Expand Up @@ -22,23 +22,24 @@ import (
const maxClientPSKIdentities = 5

type serverHandshakeStateTLS13 struct {
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
earlyData bool
c *Conn
ctx context.Context
clientHello *clientHelloMsg
hello *serverHelloMsg
encryptedExtensions *encryptedExtensionsMsg
sentDummyCCS bool
usingPSK bool
suite *cipherSuiteTLS13
cert *Certificate
sigAlg SignatureScheme
earlySecret []byte
sharedKey []byte
handshakeSecret []byte
masterSecret []byte
trafficSecret []byte // client_application_traffic_secret_0
transcript hash.Hash
clientFinished []byte
earlyData bool
}

func (hs *serverHandshakeStateTLS13) handshake() error {
Expand Down Expand Up @@ -90,6 +91,7 @@ func (hs *serverHandshakeStateTLS13) processClientHello() error {
c := hs.c

hs.hello = new(serverHelloMsg)
hs.encryptedExtensions = new(encryptedExtensionsMsg)

// TLS 1.3 froze the ServerHello.legacy_version field, and uses
// supported_versions instead. See RFC 8446, sections 4.1.3 and 4.2.1.
Expand Down Expand Up @@ -273,9 +275,16 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
continue
}

if hs.clientHello.earlyData && sessionState.maxEarlyData == 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
if hs.clientHello.earlyData {
if sessionState.maxEarlyData == 0 {
c.sendAlert(alertUnsupportedExtension)
return errors.New("tls: client sent unexpected early data")
}

if c.extraConfig != nil && c.extraConfig.Enable0RTT &&
c.extraConfig.Accept0RTT != nil && c.extraConfig.Accept0RTT(sessionState.appData) {
hs.encryptedExtensions.earlyData = true
}
}

createdAt := time.Unix(int64(sessionState.createdAt), 0)
Expand Down Expand Up @@ -326,7 +335,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
return errors.New("tls: invalid PSK binder")
}

if c.quic != nil && hs.clientHello.earlyData && i == 0 &&
if c.quic != nil && hs.clientHello.earlyData && hs.encryptedExtensions.earlyData && i == 0 &&
sessionState.maxEarlyData > 0 && sessionState.cipherSuite == hs.suite.id {
hs.earlyData = true

Expand Down Expand Up @@ -595,25 +604,23 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
return err
}

encryptedExtensions := new(encryptedExtensionsMsg)

selectedProto, err := negotiateALPN(c.config.NextProtos, hs.clientHello.alpnProtocols, c.quic != nil)
if err != nil {
c.sendAlert(alertNoApplicationProtocol)
return err
}
encryptedExtensions.alpnProtocol = selectedProto
hs.encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto

if c.quic != nil {
p, err := c.quicGetTransportParameters()
if err != nil {
return err
}
encryptedExtensions.quicTransportParameters = p
hs.encryptedExtensions.quicTransportParameters = p
}

if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
if _, err := hs.c.writeHandshakeRecord(hs.encryptedExtensions, hs.transcript); err != nil {
return err
}

Expand Down
9 changes: 7 additions & 2 deletions tls_test.go
Expand Up @@ -851,18 +851,23 @@ func TestCloneNilConfig(t *testing.T) {
}

func TestExtraConfigCloneFuncField(t *testing.T) {
const expectedCount = 1
const expectedCount = 2
called := 0

c1 := ExtraConfig{
GetAppDataForSessionTicket: func() []byte {
called |= 1
return nil
},
Accept0RTT: func([]byte) bool {
called |= 1 << 1
return true
},
}

c2 := c1.Clone()
c2.GetAppDataForSessionTicket()
c2.Accept0RTT(nil)
if called != (1<<expectedCount)-1 {
t.Fatalf("expected %d calls but saw calls %b", expectedCount, called)
}
Expand All @@ -880,7 +885,7 @@ func TestExtraConfigCloneNonFuncFields(t *testing.T) {
switch fn := typ.Field(i).Name; fn {
case "Enable0RTT":
f.Set(reflect.ValueOf(true))
case "GetAppDataForSessionTicket":
case "GetAppDataForSessionTicket", "Accept0RTT":
// DeepEqual can't compare functions. If you add a
// function field to this list, you must also change
// TestCloneFuncFields to ensure that the func field is
Expand Down

0 comments on commit 3be08c2

Please sign in to comment.