Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: Include Client Address in PSK Validation for Brute Force Detection #596

Closed
wants to merge 7 commits into from
3 changes: 2 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/pion/dtls/v2/pkg/crypto/elliptic"
Expand Down Expand Up @@ -219,7 +220,7 @@ var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P3

// PSKCallback is called once we have the remote's PSKIdentityHint.
// If the remote provided none it will be nil
type PSKCallback func([]byte) ([]byte, error)
type PSKCallback func([]byte, net.Addr) ([]byte, error)

// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
Expand Down
7 changes: 4 additions & 3 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"crypto/rsa"
"crypto/tls"
"errors"
"net"
"testing"

"github.com/pion/dtls/v2/pkg/crypto/selfsign"
Expand Down Expand Up @@ -47,7 +48,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, valid cipher suites": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand All @@ -56,7 +57,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, no PSK cipher suite": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand All @@ -66,7 +67,7 @@ func TestValidateConfig(t *testing.T) {
"PSK and Certificate, no non-PSK cipher suite": {
config: &Config{
CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8},
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, nil
},
Certificates: []tls.Certificate{cert},
Expand Down
24 changes: 12 additions & 12 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ func TestPSK(t *testing.T) {
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
if !bytes.Equal(test.ServerIdentity, hint) {
return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113
}
Expand All @@ -557,7 +557,7 @@ func TestPSK(t *testing.T) {
}()

config := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
if !bytes.Equal(clientIdentity, hint) {
return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint)
}
Expand Down Expand Up @@ -620,7 +620,7 @@ func TestPSKHintFail(t *testing.T) {
ca, cb := dpipe.Pipe()
go func() {
conf := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
Expand All @@ -632,7 +632,7 @@ func TestPSKHintFail(t *testing.T) {
}()

config := &Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return nil, pskRejected
},
PSKIdentityHint: []byte{},
Expand Down Expand Up @@ -1556,7 +1556,7 @@ func TestCertificateAndPSKServer(t *testing.T) {
go func() {
config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}
if test.ClientPSK {
config.PSK = func([]byte) ([]byte, error) {
config.PSK = func([]byte, net.Addr) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
}
config.PSKIdentityHint = []byte{0x00}
Expand All @@ -1569,7 +1569,7 @@ func TestCertificateAndPSKServer(t *testing.T) {

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256},
PSK: func([]byte) ([]byte, error) {
PSK: func([]byte, net.Addr) ([]byte, error) {
return []byte{0x00, 0x01, 0x02}, nil
},
}
Expand Down Expand Up @@ -1614,8 +1614,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and no certificate specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
Expand All @@ -1625,8 +1625,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and certificate specified",
ClientHasCertificate: true,
ServerHasCertificate: true,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: []byte{0x00},
ServerPSKIdentity: []byte{0x00},
WantClientError: errNoAvailablePSKCipherSuite,
Expand All @@ -1636,8 +1636,8 @@ func TestPSKConfiguration(t *testing.T) {
Name: "PSK and no identity specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte, net.Addr) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ClientPSKIdentity: nil,
ServerPSKIdentity: nil,
WantClientError: errPSKAndIdentityMustBeSetForClient,
Expand Down
2 changes: 1 addition & 1 deletion e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtls
defer cancel()

cfg := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
Expand Down
2 changes: 1 addition & 1 deletion examples/dial/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Server's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
32 changes: 31 additions & 1 deletion examples/dial/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"context"
"fmt"
"net"
"sync"
"time"

"github.com/pion/dtls/v2"
Expand All @@ -22,10 +23,39 @@
// Everything below is the pion-DTLS API! Thanks for using it ❤️.
//

// *************** Variables only used to implement a basic Brute Force Attack protection ***************
var attempts = make(map[string]int) // Map of attempts for each IP address

Check failure on line 27 in examples/dial/psk/main.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
var attemptsMutex sync.Mutex // Mutex for the map of attempts
var attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Server's hint: %s \n", hint)
// *************** Brute Force Attack protection ***************
// Check if the IP address is in the map, and the IP address has exceeded the limit
attemptsMutex.Lock()
defer attemptsMutex.Unlock()
// Here I implement a time cleaner for the map of attempts, every 5 minutes I will decrement by 1 the number of attempts for each IP address
if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) {
attemptsCleaner = time.Now()
for k, v := range attempts {
if v > 0 {
attempts[k]--
}
if attempts[k] == 0 {
delete(attempts, k)
}
}
}
// Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection)
if attempts[addr.(*net.UDPAddr).IP.String()] > 5 {

Check failure on line 52 in examples/dial/psk/main.go

View workflow job for this annotation

GitHub Actions / lint / Go

type assertion must be checked (forcetypeassert)
return nil, fmt.Errorf("too many attempts from this IP address")

Check failure on line 53 in examples/dial/psk/main.go

View workflow job for this annotation

GitHub Actions / lint / Go

err113: do not define dynamic errors, use wrapped static errors instead: "fmt.Errorf(\"too many attempts from this IP address\")" (goerr113)
}
// Here I increment the number of attempts for this IP address (Brute Force Attack protection)
attempts[addr.(*net.UDPAddr).IP.String()]++

Check failure on line 56 in examples/dial/psk/main.go

View workflow job for this annotation

GitHub Actions / lint / Go

type assertion must be checked (forcetypeassert)
// *************** END Brute Force Attack protection END ***************
// I return the PSK
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte("Pion DTLS Client"),
Expand Down
2 changes: 1 addition & 1 deletion examples/listen/cid/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
2 changes: 1 addition & 1 deletion examples/listen/psk/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func main() {

// Prepare the configuration of the DTLS connection
config := &dtls.Config{
PSK: func(hint []byte) ([]byte, error) {
PSK: func(hint []byte, addr net.Addr) ([]byte, error) {
fmt.Printf("Client's hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
Expand Down
4 changes: 2 additions & 2 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha
return flight5b, nil, nil
}

func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
func handleServerKeyExchange(c flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) {
var err error
if state.cipherSuite == nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite
}
if cfg.localPSKCallback != nil {
var psk []byte
if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil {
if psk, err = cfg.localPSKCallback(h.IdentityHint, c.(*Conn).rAddr); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.IdentityHint = h.IdentityHint
Expand Down
2 changes: 1 addition & 1 deletion flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handsh
var preMasterSecret []byte
if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey {
var psk []byte
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil {
if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint, c.(*Conn).rAddr); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
state.IdentityHint = clientKeyExchange.IdentityHint
Expand Down
Loading