diff --git a/dtls/client.go b/dtls/client.go index 96ac267d..4fec20de 100644 --- a/dtls/client.go +++ b/dtls/client.go @@ -90,8 +90,8 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn { v, cfg.BlockwiseTransferTimeout, cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - return v.GetObservationRequest(token) + func(hash uint64) (*pool.Message, bool) { + return v.GetObservationRequest(hash) }, ) } diff --git a/dtls/server/server.go b/dtls/server/server.go index 1f879a67..56fb6beb 100644 --- a/dtls/server/server.go +++ b/dtls/server/server.go @@ -198,8 +198,8 @@ func (s *Server) createConn(connection *coapNet.Conn, inactivityMonitor udpClien v, s.cfg.BlockwiseTransferTimeout, s.cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - return v.GetObservationRequest(token) + func(hash uint64) (*pool.Message, bool) { + return v.GetObservationRequest(hash) }, ) } diff --git a/message/option.go b/message/option.go index d8b8dba8..c31e5c2c 100644 --- a/message/option.go +++ b/message/option.go @@ -47,6 +47,7 @@ type OptionID uint16 | 35 | x | x | - | | Proxy-Uri | string | 1-1034 | (none) | | 39 | x | x | - | | Proxy-Scheme | string | 1-255 | (none) | | 60 | | | x | | Size1 | uint | 0-4 | (none) | + | 292 | | | | x | Request-Tag | opaque | 0-8 | (none) | +-----+----+---+---+---+----------------+--------+--------+---------+ C=Critical, U=Unsafe, N=NoCacheKey, R=Repeatable */ @@ -73,6 +74,7 @@ const ( ProxyScheme OptionID = 39 Size1 OptionID = 60 NoResponse OptionID = 258 + RequestTag OptionID = 292 ) var optionIDToString = map[OptionID]string{ @@ -96,6 +98,7 @@ var optionIDToString = map[OptionID]string{ ProxyScheme: "ProxyScheme", Size1: "Size1", NoResponse: "NoResponse", + RequestTag: "RequestTag", } func (o OptionID) String() string { @@ -153,6 +156,7 @@ var CoapOptionDefs = map[OptionID]OptionDef{ ProxyScheme: {ValueFormat: ValueString, MinLen: 1, MaxLen: 255}, Size1: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 4}, NoResponse: {ValueFormat: ValueUint, MinLen: 0, MaxLen: 1}, + RequestTag: {ValueFormat: ValueOpaque, MinLen: 0, MaxLen: 8}, } // MediaType specifies the content format of a message. diff --git a/net/blockwise/blockwise.go b/net/blockwise/blockwise.go index 3c795e33..222e8048 100644 --- a/net/blockwise/blockwise.go +++ b/net/blockwise/blockwise.go @@ -5,7 +5,9 @@ import ( "context" "errors" "fmt" + "hash/crc64" "io" + "net" "time" "github.com/dsnet/golib/memfile" @@ -131,6 +133,9 @@ type Client interface { AcquireMessage(ctx context.Context) *pool.Message // return back the message to the pool for next use ReleaseMessage(m *pool.Message) + + // The remote address for determining the endpoint pair + RemoteAddr() net.Addr } type BlockWise[C Client] struct { @@ -138,7 +143,7 @@ type BlockWise[C Client] struct { receivingMessagesCache *cache.Cache[uint64, *messageGuard] sendingMessagesCache *cache.Cache[uint64, *pool.Message] errors func(error) - getSentRequestFromOutside func(token message.Token) (*pool.Message, bool) + getSentRequestFromOutside func(hash uint64) (*pool.Message, bool) expiration time.Duration } @@ -160,10 +165,10 @@ func New[C Client]( cc C, expiration time.Duration, errors func(error), - getSentRequestFromOutside func(token message.Token) (*pool.Message, bool), + getSentRequestFromOutside func(hash uint64) (*pool.Message, bool), ) *BlockWise[C] { if getSentRequestFromOutside == nil { - getSentRequestFromOutside = func(message.Token) (*pool.Message, bool) { return nil, false } + getSentRequestFromOutside = func(uint64) (*pool.Message, bool) { return nil, false } } return &BlockWise[C]{ cc: cc, @@ -214,11 +219,12 @@ func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do if !ok { expire = time.Now().Add(b.expiration) } - _, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil)) + matchableHash := generateMatchableHash(r.Options(), b.cc.RemoteAddr(), r.Code()) + _, loaded := b.sendingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(r, expire, nil)) if loaded { return nil, errors.New("invalid token") } - defer b.sendingMessagesCache.Delete(r.Token().Hash()) + defer b.sendingMessagesCache.Delete(matchableHash) if r.Body() == nil { return do(r) } @@ -282,9 +288,9 @@ func (b *BlockWise[C]) WriteMessage(request *pool.Message, maxSZX SZX, maxMessag if err != nil { return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err) } - + matchableHash := generateMatchableHash(request.Options(), b.cc.RemoteAddr(), request.Code()) w := newWriteRequestResponse(b.cc, request) - err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) + err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, matchableHash) if err != nil { return fmt.Errorf("cannot start writing request: %w", err) } @@ -333,8 +339,8 @@ func wantsToBeReceived(r *pool.Message) bool { return true } -func (b *BlockWise[C]) getSendingMessageCode(token uint64) (codes.Code, bool) { - v := b.sendingMessagesCache.Load(token) +func (b *BlockWise[C]) getSendingMessageCode(hash uint64) (codes.Code, bool) { + v := b.sendingMessagesCache.Load(hash) if v == nil { return codes.Empty, false } @@ -348,19 +354,20 @@ func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Messa } token := r.Token() + matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code()) + if len(token) == 0 { - err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next, matchableHash) if err != nil { b.sendEntityIncomplete(w, token) b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) } return } - tokenStr := token.Hash() - sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr) + sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(matchableHash) if !sendingMessageExist || wantsToBeReceived(r) { - err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next, matchableHash) if err != nil { b.sendEntityIncomplete(w, token) b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) @@ -369,17 +376,17 @@ func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Messa } more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode) if err != nil { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(matchableHash) b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err)) return } // For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs. if !more && sendingMessageCode > codes.DELETE { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(matchableHash) } } -func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error { +func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), rxHash uint64) error { startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true) if err != nil { return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err) @@ -411,7 +418,7 @@ func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C] return errP } } - return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) + return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, rxHash) } func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) { @@ -504,7 +511,8 @@ func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C } var sendMessage *pool.Message var more bool - b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { + matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code()) + b.sendingMessagesCache.LoadWithFunc(matchableHash, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block) if err != nil { err = fmt.Errorf("cannot create sending message: %w", err) @@ -529,7 +537,7 @@ func isObserveResponse(msg *pool.Message) bool { return msg.Code() >= codes.Created } -func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error { +func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32, rxHash uint64) error { payloadSize, err := w.Message().BodySize() if err != nil { return payloadSizeError(err) @@ -552,7 +560,7 @@ func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], if !ok { expire = time.Now().Add(b.expiration) } - el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil)) + el, loaded := b.sendingMessagesCache.LoadOrStore(rxHash, cache.NewElement(originalSendingMessage, expire, nil)) if loaded { defer b.cc.ReleaseMessage(originalSendingMessage) return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token()) @@ -560,8 +568,8 @@ func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], return nil } -func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message { - data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { +func (b *BlockWise[C]) getSentRequest(hash uint64) *pool.Message { + data, ok := b.sendingMessagesCache.LoadWithFunc(hash, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { if value == nil { return nil } @@ -576,7 +584,7 @@ func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message { if ok { return data.Data() } - globalRequest, ok := b.getSentRequestFromOutside(token) + globalRequest, ok := b.getSentRequestFromOutside(hash) if ok { return globalRequest } @@ -595,7 +603,8 @@ func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message) (message validUntil := time.Now().Add(b.expiration) // context of observation can be expired. bwSentRequest := b.cloneMessage(sentRequest) bwSentRequest.SetToken(token) - _, loaded := b.sendingMessagesCache.LoadOrStore(token.Hash(), cache.NewElement(bwSentRequest, validUntil, nil)) + matchableHash := generateMatchableHash(sentRequest.Options(), b.cc.RemoteAddr(), sentRequest.Code()) + _, loaded := b.sendingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(bwSentRequest, validUntil, nil)) if loaded { return nil, time.Time{}, errors.New("cannot process message: message with token already exist") } @@ -674,7 +683,7 @@ func copyToPayloadFromOffset(r *pool.Message, payloadFile *memfile.File, offset return payloadSize, nil } -func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Message, tokenStr uint64, validUntil time.Time) (*pool.Message, func(), error) { +func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Message, matchableHash uint64, validUntil time.Time) (*pool.Message, func(), error) { cannotLockError := func(err error) error { return fmt.Errorf("processReceivedMessage: cannot lock message: %w", err) } @@ -708,11 +717,11 @@ func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Messag return nil, nil, cannotLockError(errA) } appendToClose(mg) - element, loaded := b.receivingMessagesCache.LoadOrStore(tokenStr, cache.NewElement(mg, validUntil, func(d *messageGuard) { + element, loaded := b.receivingMessagesCache.LoadOrStore(matchableHash, cache.NewElement(mg, validUntil, func(d *messageGuard) { if d == nil { return } - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(matchableHash) })) // request was already stored in cache, silently if loaded { @@ -732,6 +741,38 @@ func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Messag return mg.Message, closeFn, nil } +/* +RFC9175 1.1: +Two request messages are said to be "matchable" if they occur between +the same endpoint pair, have the same code, and have the same set of +options, with the exception that elective NoCacheKey options and +options involved in block-wise transfer (Block1, Block2, and Request- +Tag) need not be the same. Two blockwise request operations are said +to be matchable if their request messages are matchable. + +This function concatenates the IDs and values of relevant options, the string representation of the remote address, +and the code of the message to generate a hash that can be used to match requests. +*/ +func generateMatchableHash(options message.Options, remoteAddr net.Addr, code codes.Code) uint64 { + input := make([]byte, 0, 512) + + for _, opt := range options { + switch opt.ID { + // Skip Blockwise Options and NoCacheKey Options + case message.Block1, message.Block2, message.Size1, message.Size2, message.RequestTag: + continue + } + input = append(input, byte(opt.ID)) + input = append(input, opt.Value...) + } + + input = append(input, []byte(remoteAddr.Network())...) + input = append(input, []byte(remoteAddr.String())...) + input = append(input, byte(code)) + + return crc64.Checksum(input, crc64.MakeTable(crc64.ISO)) +} + //nolint:gocyclo,gocognit func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error { token := r.Token() @@ -755,7 +796,8 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C if err != nil { return fmt.Errorf("cannot decode block option: %w", err) } - sentRequest := b.getSentRequest(token) + matchableHash := generateMatchableHash(r.Options(), w.Conn().RemoteAddr(), r.Code()) + sentRequest := b.getSentRequest(matchableHash) if sentRequest != nil { defer b.cc.ReleaseMessage(sentRequest) } @@ -770,9 +812,8 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C } } - tokenStr := token.Hash() var cachedReceivedMessageGuard *messageGuard - if e := b.receivingMessagesCache.Load(tokenStr); e != nil { + if e := b.receivingMessagesCache.Load(matchableHash); e != nil { cachedReceivedMessageGuard = e.Data() } if cachedReceivedMessageGuard == nil { @@ -783,7 +824,7 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C return nil } } - cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil) + cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, matchableHash, validUntil) if err != nil { return err } @@ -791,7 +832,7 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C defer func(err *error) { if *err != nil { - b.receivingMessagesCache.Delete(tokenStr) + b.receivingMessagesCache.Delete(matchableHash) } }(&err) payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage) @@ -805,12 +846,12 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C return fmt.Errorf("cannot copy data to payload: %w", err) } if !more { - b.receivingMessagesCache.Delete(tokenStr) + b.receivingMessagesCache.Delete(matchableHash) cachedReceivedMessage.Remove(blockType) cachedReceivedMessage.Remove(sizeType) cachedReceivedMessage.SetType(r.Type()) if !bytes.Equal(cachedReceivedMessage.Token(), token) { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(matchableHash) } _, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart) if errS != nil { diff --git a/net/blockwise/blockwise_test.go b/net/blockwise/blockwise_test.go index 70d173ff..39e7fb0b 100644 --- a/net/blockwise/blockwise_test.go +++ b/net/blockwise/blockwise_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net" "testing" "time" @@ -54,6 +55,10 @@ type testClient struct { p *pool.Pool } +func (c *testClient) RemoteAddr() net.Addr { + return &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} +} + func newTestClient() *testClient { return &testClient{ p: pool.New(100, 1024), diff --git a/net/client/client.go b/net/client/client.go index b2cdca9e..836fb3aa 100644 --- a/net/client/client.go +++ b/net/client/client.go @@ -105,8 +105,8 @@ func (c *Client[C]) Observe(ctx context.Context, path string, observeFunc func(r return c.DoObserve(req, observeFunc) } -func (c *Client[C]) GetObservationRequest(token message.Token) (*pool.Message, bool) { - return c.observationHandler.GetObservationRequest(token) +func (c *Client[C]) GetObservationRequest(hash uint64) (*pool.Message, bool) { + return c.observationHandler.GetObservationRequest(hash) } // NewPostRequest creates post request. diff --git a/net/observation/handler.go b/net/observation/handler.go index 0b539e90..734893f6 100644 --- a/net/observation/handler.go +++ b/net/observation/handler.go @@ -108,8 +108,8 @@ func (h *Handler[C]) GetObservation(key uint64) (*Observation[C], bool) { } // GetObservationRequest returns observation request for token -func (h *Handler[C]) GetObservationRequest(token message.Token) (*pool.Message, bool) { - obs, ok := h.GetObservation(token.Hash()) +func (h *Handler[C]) GetObservationRequest(hash uint64) (*pool.Message, bool) { + obs, ok := h.GetObservation(hash) if !ok { return nil, false } diff --git a/tcp/client.go b/tcp/client.go index 54349433..c0374e22 100644 --- a/tcp/client.go +++ b/tcp/client.go @@ -6,7 +6,6 @@ import ( "net" "time" - "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/pool" coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/net/blockwise" @@ -79,8 +78,8 @@ func Client(conn net.Conn, opts ...Option) *client.Conn { v, cfg.BlockwiseTransferTimeout, cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - return v.GetObservationRequest(token) + func(hash uint64) (*pool.Message, bool) { + return v.GetObservationRequest(hash) }, ) } diff --git a/tcp/server/server.go b/tcp/server/server.go index 08f6e079..3ce92a4a 100644 --- a/tcp/server/server.go +++ b/tcp/server/server.go @@ -193,7 +193,7 @@ func (s *Server) createConn(connection *coapNet.Conn, inactivityMonitor client.I cc, s.cfg.BlockwiseTransferTimeout, s.cfg.Errors, - func(message.Token) (*pool.Message, bool) { + func(uint64) (*pool.Message, bool) { return nil, false }, ) diff --git a/udp/client.go b/udp/client.go index 4a68b358..9a393294 100644 --- a/udp/client.go +++ b/udp/client.go @@ -6,7 +6,6 @@ import ( "net" "time" - "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/pool" coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/net/blockwise" @@ -78,8 +77,8 @@ func Client(conn *net.UDPConn, opts ...Option) *client.Conn { v, cfg.BlockwiseTransferTimeout, cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - return v.GetObservationRequest(token) + func(hash uint64) (*pool.Message, bool) { + return v.GetObservationRequest(hash) }, ) } diff --git a/udp/server/server.go b/udp/server/server.go index beffe0ea..d1773f99 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -278,12 +278,12 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) ( v, s.cfg.BlockwiseTransferTimeout, s.cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - msg, ok := v.GetObservationRequest(token) + func(hash uint64) (*pool.Message, bool) { + msg, ok := v.GetObservationRequest(hash) if ok { return msg, ok } - return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { + return s.multicastRequests.LoadWithFunc(hash, func(m *pool.Message) *pool.Message { msg := v.AcquireMessage(m.Context()) msg.ResetOptionsTo(m.Options()) msg.SetCode(m.Code())