Skip to content

Commit

Permalink
Shadowsocks2022 Client Implementation Improvements (again)
Browse files Browse the repository at this point in the history
1. Replaced Mutex with RWMutex for reduced lock contention
2. Added per server session tracking of decryption cache and anti-replay window
  • Loading branch information
xiaokangwang committed Nov 21, 2023
1 parent bc27c9d commit 0a5e223
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 31 deletions.
126 changes: 107 additions & 19 deletions proxy/shadowsocks2022/client_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ import (

func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
session := &ClientUDPSession{
locker: &sync.Mutex{},
locker: &sync.RWMutex{},
conn: conn,
packetProcessor: packetProcessor,
sessionMap: make(map[string]*ClientUDPSessionConn),
sessionMapAlias: make(map[string]string),
}
session.ctx, session.finish = context.WithCancel(ctx)

Expand All @@ -29,19 +30,21 @@ func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetPro
}

type ClientUDPSession struct {
locker *sync.Mutex
locker *sync.RWMutex

conn io.ReadWriteCloser
packetProcessor UDPClientPacketProcessor
sessionMap map[string]*ClientUDPSessionConn

sessionMapAlias map[string]string

ctx context.Context
finish func()
}

func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
c.locker.Lock()
defer c.locker.Unlock()
c.locker.RLock()
defer c.locker.RUnlock()

state, ok := c.sessionMap[sessionID]
if !ok {
Expand All @@ -50,9 +53,37 @@ func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProce
return state.cachedProcessorState
}

func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
c.locker.RLock()
defer c.locker.RUnlock()

clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return nil
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return nil
}

if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
return nil
} else {
return serverState.cachedRecvProcessorState
}
}

func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
state, ok := c.sessionMapAlias[serverSessionID]
if !ok {
return ""
}
return state
}

func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.Lock()
defer c.locker.Unlock()
c.locker.RLock()
defer c.locker.RUnlock()

state, ok := c.sessionMap[sessionID]
if !ok {
Expand All @@ -61,6 +92,25 @@ func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacke
state.cachedProcessorState = cache
}

func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
c.locker.RLock()
defer c.locker.RUnlock()

clientSessionID := c.getCachedStateAlias(serverSessionID)
if clientSessionID == "" {
return
}
state, ok := c.sessionMap[clientSessionID]
if !ok {
return
}

if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
serverState.cachedRecvProcessorState = cache
return
}
}

func (c *ClientUDPSession) Close() error {
c.finish()
return c.conn.Close()
Expand Down Expand Up @@ -107,8 +157,9 @@ func (c *ClientUDPSession) KeepReading() {
}
}

c.locker.Lock()
c.locker.RLock()
session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
c.locker.RUnlock()
if ok {
select {
case session.readChan <- udpResp:
Expand All @@ -117,7 +168,6 @@ func (c *ClientUDPSession) KeepReading() {
} else {
newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
}
c.locker.Unlock()
}
}
}
Expand All @@ -132,27 +182,33 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
connctx, connfinish := context.WithCancel(c.ctx)

sessionConn := &ClientUDPSessionConn{
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16),
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
rxReplayDetector: replaydetector.New(128, ^uint64(0)),
sessionID: string(sessionID),
readChan: make(chan *UDPResponse, 16),
parent: c,
ctx: connctx,
finish: connfinish,
nextWritePacketID: 0,
trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
}
c.locker.Lock()
c.sessionMap[sessionConn.sessionID] = sessionConn
c.locker.Unlock()
return sessionConn, nil
}

type ClientUDPSessionServerTracker struct {
cachedRecvProcessorState UDPClientPacketProcessorCachedState
rxReplayDetector replaydetector.ReplayDetector
lastSeen time.Time
}

type ClientUDPSessionConn struct {
sessionID string
readChan chan *UDPResponse
parent *ClientUDPSession

nextWritePacketID uint64
rxReplayDetector replaydetector.ReplayDetector
nextWritePacketID uint64
trackedServerSessionID map[string]*ClientUDPSessionServerTracker

cachedProcessorState UDPClientPacketProcessorCachedState

Expand All @@ -161,7 +217,12 @@ type ClientUDPSessionConn struct {
}

func (c *ClientUDPSessionConn) Close() error {
c.parent.locker.Lock()
delete(c.parent.sessionMap, c.sessionID)
for k := range c.trackedServerSessionID {
delete(c.parent.sessionMapAlias, k)
}
c.parent.locker.Unlock()
c.finish()
return nil
}
Expand Down Expand Up @@ -195,14 +256,41 @@ func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err err
case resp := <-c.readChan:
n = copy(p, resp.Payload.Bytes())
resp.Payload.Release()
if accept, ok := c.rxReplayDetector.Check(resp.PacketID); ok {

var trackedState *ClientUDPSessionServerTracker
if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
expiredServerSessionID := make([]string, 0)
for key, value := range c.trackedServerSessionID {
if time.Since(value.lastSeen) > 125*time.Second {
expiredServerSessionID = append(expiredServerSessionID, key)
}
}
for _, key := range expiredServerSessionID {
delete(c.trackedServerSessionID, key)
}

state := &ClientUDPSessionServerTracker{
rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
}
c.trackedServerSessionID[string(resp.SessionID[:])] = state
c.parent.locker.RLock()
c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
c.parent.locker.RUnlock()
trackedState = state
} else {
trackedState = trackedStateReceived
}

if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
accept()
} else {
newError("misbehaving server: replayed packet").Base(err).WriteToLog()
continue
}
trackedState.lastSeen = time.Now()

addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
}
return
return n, addr, nil
}
}
2 changes: 2 additions & 0 deletions proxy/shadowsocks2022/ss2022.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ const (
type UDPClientPacketProcessorCachedStateContainer interface {
GetCachedState(sessionID string) UDPClientPacketProcessorCachedState
PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState)
GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState
PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState)
}

type UDPClientPacketProcessorCachedState interface{}
Expand Down
14 changes: 2 additions & 12 deletions proxy/shadowsocks2022/udp_aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,16 @@ func (p *AESUDPClientPacketProcessor) DecodeUDPResp(input []byte, resp *UDPRespo
resp.PacketID = separateHeaderStruct.PacketID
resp.SessionID = separateHeaderStruct.SessionID
{
// Since we need to decrypt the main packet to see client session id, we
// have no way to know where should this key be found in the cache indexed
// by client session id.
// Luckily, V2Ray's implementation of shadowsocks2022-aes will always generate
// server session id based on client session id to allow client map it back
// to client session id.
var generatedCacheBytes [8]byte
copy(generatedCacheBytes[:], separateHeaderStruct.SessionID[:])
generatedCacheBytes[7] = ^generatedCacheBytes[7]
cacheKey := string(separateHeaderBuffer.Bytes()[0:8])

receivedCacheInterface := cache.GetCachedState(cacheKey)
receivedCacheInterface := cache.GetCachedServerState(cacheKey)
cachedState := &cachedUDPState{}
if receivedCacheInterface != nil {
cachedState = receivedCacheInterface.(*cachedUDPState)
}

if cachedState.sessionRecvAEAD == nil {
cachedState.sessionRecvAEAD = p.mainPacketAEAD(separateHeaderBuffer.Bytes()[0:8])
cache.PutCachedState(cacheKey, cachedState)
cache.PutCachedServerState(cacheKey, cachedState)
}

mainPacketAEADMaterialized := cachedState.sessionRecvAEAD
Expand Down

0 comments on commit 0a5e223

Please sign in to comment.