Skip to content

Commit

Permalink
feature: Allow RemoteIdentity to be nil
Browse files Browse the repository at this point in the history
Now it's allowed to create a session without bultin validation
of remote party by it's public key. So a custom validation
could be performed within OnConnect() function using method
GetRemoteIdentity().
  • Loading branch information
xaionaro committed Mar 14, 2020
1 parent 7568f8e commit 841a7b2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 20 deletions.
4 changes: 4 additions & 0 deletions crypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"github.com/xaionaro-go/slice"
)

var (
emptyIV = make([]byte, 24)
)

func decrypt(key []byte, iv []byte, dst, src []byte) {
slice.SetZeros(dst)
chacha20.XORKeyStream(dst, src, iv, key)
Expand Down
23 changes: 19 additions & 4 deletions identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestNewIdentity(t *testing.T) {
assert.Error(t, err)
}

func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail bool, psk0, psk1 []byte) {
func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, remoteIsKnown, shouldFail bool, psk0, psk1 []byte) {
identity0, identity1, conn0, conn1 := testPair(t)

defer conn0.Close()
Expand Down Expand Up @@ -97,6 +97,10 @@ func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail b
wg.Add(1)
go func() {
defer wg.Done()
identity1 := identity1
if !remoteIsKnown {
identity1 = nil
}
keys0, err0 = identity0.MutualConfirmationOfIdentity(
ctx,
identity1,
Expand All @@ -111,6 +115,10 @@ func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail b
wg.Add(1)
go func() {
defer wg.Done()
identity0 := identity0
if !remoteIsKnown {
identity0 = nil
}
keys1, err1 = identity1.MutualConfirmationOfIdentity(
ctx,
identity0,
Expand Down Expand Up @@ -139,14 +147,14 @@ func testIdentityMutualConfirmationOfIdentityWithPSKs(t *testing.T, shouldFail b
}

func TestIdentityMutualConfirmationOfIdentityWithoutPSK(t *testing.T) {
testIdentityMutualConfirmationOfIdentityWithPSKs(t, false, nil, nil)
testIdentityMutualConfirmationOfIdentityWithPSKs(t, true, false, nil, nil)
}

func TestIdentityMutualConfirmationOfIdentityWithPSK(t *testing.T) {
psk := make([]byte, 64)
rand.Read(psk)

testIdentityMutualConfirmationOfIdentityWithPSKs(t, false, psk, psk)
testIdentityMutualConfirmationOfIdentityWithPSKs(t, true, false, psk, psk)
}

func TestIdentityMutualConfirmationOfIdentityWithWrongPSK(t *testing.T) {
Expand All @@ -157,5 +165,12 @@ func TestIdentityMutualConfirmationOfIdentityWithWrongPSK(t *testing.T) {
psk0[63] = 0
psk1[63] = 1

testIdentityMutualConfirmationOfIdentityWithPSKs(t, true, psk0, psk1)
testIdentityMutualConfirmationOfIdentityWithPSKs(t, true, true, psk0, psk1)
}

func TestIdentityMutualConfirmationOfIdentityByPSK(t *testing.T) {
psk := make([]byte, 64)
rand.Read(psk)

testIdentityMutualConfirmationOfIdentityWithPSKs(t, false, false, psk, psk)
}
29 changes: 22 additions & 7 deletions key_exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,11 @@ func (kx *keyExchanger) parseAndCheck(msg *keySeedUpdateMessage, b []byte) (err
if len(b) > keySeedUpdateMessageSignedSize {
kx.messenger.sess.debugf("[kx] ignored the tail of length %v", keySeedUpdateMessageSignedSize-len(b))
}
if err = kx.remoteIdentity.VerifySignature(signature, msgBytes); err != nil {
kx.messenger.sess.debugf("[kx] ignoring the message due to the wrong signature: %v", err)
return
if kx.remoteIdentity != nil {
if err = kx.remoteIdentity.VerifySignature(signature, msgBytes); err != nil {
kx.messenger.sess.debugf("[kx] ignoring the message from %+v due to the wrong signature: %v", kx.remoteIdentity, err)
return
}
}

err = binary.Read(bytes.NewBuffer(msgBytes), binaryOrderType, msg)
Expand All @@ -350,7 +352,7 @@ func (kx *keyExchanger) parseAndCheck(msg *keySeedUpdateMessage, b []byte) (err
}

var zeroKey [curve25519PublicKeySize]byte
if bytes.Compare(msg.PublicKey[:], zeroKey[:]) == 0 {
if bytes.Compare(msg.KXPublicKey[:], zeroKey[:]) == 0 {
err = newErrInvalidPublicKey()
kx.errFunc(err)
return
Expand All @@ -377,6 +379,18 @@ func (kx *keyExchanger) Handle(b []byte) (err error) {
}
defer func() { err = wrapError(err) }()

if kx.remoteIdentity == nil {
kx.messenger.sess.debugf("[kx] setting the remote identity to %+v", msg.KXPublicKey[:])
kx.remoteIdentity, err = NewRemoteIdentityFromPublicKey(msg.IdentityPublicKey[:])
if err != nil {
kx.errFunc(wrapError(err))
return
}
kx.messenger.sess.lockDo(func() {
kx.messenger.sess.remoteIdentity = kx.remoteIdentity
})
}

if kx.remoteSessionID == nil {
kx.remoteSessionID = &msg.SessionID
kx.messenger.sess.setRemoteSessionID(kx.remoteSessionID)
Expand All @@ -385,12 +399,12 @@ func (kx *keyExchanger) Handle(b []byte) (err error) {
nextRemoteHasChanged := true
kx.keyLocker.LockDo(func() {
if kx.nextRemotePublicKey != nil &&
bytes.Compare((*kx.nextRemotePublicKey)[:], msg.PublicKey[:]) == 0 {
bytes.Compare((*kx.nextRemotePublicKey)[:], msg.KXPublicKey[:]) == 0 {
nextRemoteHasChanged = false
return
}
kx.prevRemotePublicKey = kx.nextRemotePublicKey
kx.nextRemotePublicKey = &msg.PublicKey
kx.nextRemotePublicKey = &msg.KXPublicKey
})
if !nextRemoteHasChanged {
//kx.errFunc(newErrRemoteKeyHasNotChanged())
Expand Down Expand Up @@ -644,9 +658,10 @@ func (kx *keyExchanger) sendPublicKey(isAnswer bool) error {
}
kx.messenger.sess.debugf("[kx] kx.sendPublicKey(isAnswer: %v)", isAnswer)
msg := &keySeedUpdateMessage{}
copy(msg.IdentityPublicKey[:], kx.localIdentity.Keys.Public)
msg.SessionID = kx.messenger.sess.id
kx.keyLocker.RLockDo(func() {
copy(msg.PublicKey[:], (*kx.nextLocalPublicKey)[:])
copy(msg.KXPublicKey[:], (*kx.nextLocalPublicKey)[:])
})
msg.Flags.SetIsAnswer(isAnswer)
msg.AnswersMode = kx.options.AnswersMode
Expand Down
9 changes: 5 additions & 4 deletions key_seed_update_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ type keySeedUpdateMessageSigned struct {
}

type keySeedUpdateMessage struct {
SessionID SessionID
PublicKey [curve25519PublicKeySize]byte
AnswersMode KeyExchangeAnswersMode
Flags keySeedUpdateMessageFlags
SessionID SessionID
IdentityPublicKey [PublicKeySize]byte
KXPublicKey [curve25519PublicKeySize]byte
AnswersMode KeyExchangeAnswersMode
Flags keySeedUpdateMessageFlags
}

type keySeedUpdateMessageFlags uint8
Expand Down
19 changes: 14 additions & 5 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ type Session struct {
delayedWriteBufLocker spinlock.Locker
delayedSenderTimer *time.Timer
delayedSenderTimerLocker spinlock.Locker
sendDelayedNowChan chan *SendInfo
sendDelayedCond *sync.Cond
sendDelayedCondLocker sync.Mutex
sendDelayedNowChan chan *SendInfo
sendDelayedCond *sync.Cond
sendDelayedCondLocker sync.Mutex

lastSendInfoSendID uint64

Expand Down Expand Up @@ -962,7 +962,7 @@ func (sess *Session) decryptPacketIDBytes(decrypted *buffer, encrypted []byte) (
}

packetIDBytes = decrypted.Bytes[:len(encrypted)]
decrypt(sess.auxCipherKey, sess.identity.Keys.Public[:ivSize], packetIDBytes, encrypted)
decrypt(sess.auxCipherKey, emptyIV, packetIDBytes, encrypted)
decrypted.Offset += uint(len(encrypted))
sess.debugf("decrypted the PacketID from %v to %v using key %v",
encrypted, packetIDBytes, sess.auxCipherKey)
Expand Down Expand Up @@ -1600,7 +1600,7 @@ func (sess *Session) sendMessages(
if sess.auxCipherKey == nil {
copy(encryptedBytes[:len(containerHdr.PacketID)], containerHdr.PacketID[:]) // copying the plain IV
} else {
encrypt(sess.auxCipherKey, sess.remoteIdentity.Keys.Public[:ivSize], encryptedBytes[:len(containerHdr.PacketID)], containerHdr.PacketID[:])
encrypt(sess.auxCipherKey, emptyIV, encryptedBytes[:len(containerHdr.PacketID)], containerHdr.PacketID[:])
}
sess.ifDebug(func() {
if len(encryptedBytes) >= 200 {
Expand Down Expand Up @@ -1660,6 +1660,15 @@ func (sess *Session) error(err error) {
sess.eventHandler.Error(sess, err)
}

// GetRemoteIdentity returns the remote identity.
// It's not a copy, don't modify the content.
func (sess *Session) GetRemoteIdentity() (result *Identity) {
sess.rLockDo(func() {
result = sess.remoteIdentity
})
return
}

func (sess *Session) startKeyExchange() {
switch sess.setState(SessionStateKeyExchanging, SessionStateClosing, SessionStateClosed) {
case SessionStateKeyExchanging, SessionStateClosing, SessionStateClosed:
Expand Down

0 comments on commit 841a7b2

Please sign in to comment.