Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait for blockwise requests to complete when connecting to the cloud #2560

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 6 additions & 7 deletions communication/inc/dtls_message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ namespace particle
namespace protocol
{

class Protocol;

/**
* Please centralize this somewhere else!
*/
Expand Down Expand Up @@ -79,7 +81,6 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
int (*restore)(void* data, size_t max_length, uint8_t type, void* reserved);

uint32_t (*calculate_crc)(const uint8_t* data, uint32_t length);
void (*notify_client_messages_processed)(void* reserved);
};

private:
Expand All @@ -92,6 +93,7 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
mbedtls_pk_context pkey;
mbedtls_timing_delay_context timer;
Callbacks callbacks;
Protocol* protocol;
uint8_t* server_public;
uint16_t server_public_len;
uint32_t keys_checksum;
Expand Down Expand Up @@ -123,13 +125,14 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>
void reset_session();

public:
DTLSMessageChannel() :
explicit DTLSMessageChannel(Protocol* protocol) :
ssl_context(),
conf(),
clicert(),
pkey(),
timer(),
callbacks(),
protocol(protocol),
server_public(nullptr),
server_public_len(0),
keys_checksum(0),
Expand Down Expand Up @@ -167,11 +170,7 @@ class DTLSMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE>

virtual ProtocolError command(Command cmd, void* arg=nullptr) override;

virtual void notify_client_messages_processed() override {
if (callbacks.notify_client_messages_processed) {
callbacks.notify_client_messages_processed(nullptr);
}
}
virtual void notify_client_messages_processed() override;

virtual AppStateDescriptor cached_app_state_descriptor() const override;

Expand Down
7 changes: 6 additions & 1 deletion communication/inc/message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ struct MessageChannel : public Channel
*/
virtual void notify_client_messages_processed()=0;

/**
* Returns `true` if there are client messages being processed at the moment, or `false` otherwise.
*/
virtual bool has_pending_client_messages() const = 0;

/**
* Get a descriptor of the cached application state.
*
Expand All @@ -262,7 +267,7 @@ struct MessageChannel : public Channel
class AbstractMessageChannel : public MessageChannel
{
public:
void set_debug_enabled(bool enabled) override {
void set_debug_enabled(bool /* enabled */) override {
}
};

Expand Down
4 changes: 3 additions & 1 deletion communication/inc/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,12 @@ class Protocol

virtual int get_describe_data(spark_protocol_describe_data* data, void* reserved);

virtual int get_status(protocol_status* status) const = 0;
int get_status(protocol_status* status) const;

void notify_message_complete(message_id_t msg_id, CoAPCode::Enum responseCode);

virtual void notify_client_messages_processed(); // Declared as virtual for mocking in unit tests

/**
* Retrieves the next token.
*/
Expand Down
26 changes: 19 additions & 7 deletions communication/src/coap_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include "service_debug.h"

#include "communication_diagnostic.h"

#include <limits>
#include <utility>

namespace particle
{
Expand All @@ -52,8 +54,10 @@ class CoAPChannel : public T
}

public:
CoAPChannel(message_id_t msg_seed=0) : message_id(msg_seed)
{
template<typename... ArgsT>
explicit CoAPChannel(ArgsT&&... args) :
T(std::forward<ArgsT>(args)...),
message_id(0) {
}

/**
Expand Down Expand Up @@ -554,8 +558,15 @@ class CoAPReliableChannel : public T
DelegateChannel delegateChannel;

public:
template<typename... ArgsT>
explicit CoAPReliableChannel(ArgsT&&... args) :
CoAPReliableChannel(M(), std::forward<ArgsT>(args)...) {
}

CoAPReliableChannel(M m=0) : millis(m) {
template<typename... ArgsT>
explicit CoAPReliableChannel(M m, ArgsT&&... args) :
T(std::forward<ArgsT>(args)...),
millis(m) {
delegateChannel.init(this);
}

Expand Down Expand Up @@ -622,6 +633,11 @@ class CoAPReliableChannel : public T
return receive(msg, true);
}

bool has_pending_client_messages() const override
{
return client.has_messages();
}

/**
* Pulls messages from the message channel
*/
Expand All @@ -637,10 +653,6 @@ class CoAPReliableChannel : public T
return client.has_messages() || server.has_unacknowledged_requests();
}

bool has_unacknowledged_client_requests() const {
return client.has_messages();
}

/**
* Pulls messages from the channel and stores it in a message store for
* reliable receipt and retransmission.
Expand Down
3 changes: 3 additions & 0 deletions communication/src/description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ ProtocolError Description::receiveAckOrRst(const Message& msg, int* descFlags) {
if (!reqQueue_.isEmpty()) {
CHECK_PROTOCOL(sendNextRequest(reqQueue_.takeFirst()));
}
if (!activeReq_.has_value() && reqQueue_.isEmpty()) {
proto_->notify_client_messages_processed();
}
*descFlags = flags;
}
} else {
Expand Down
6 changes: 6 additions & 0 deletions communication/src/description.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class Description {

ProtocolError serialize(Appender* appender, int descFlags);

bool hasPendingClientRequests() const;

void reset();

private:
Expand Down Expand Up @@ -96,6 +98,10 @@ class Description {
system_tick_t millis() const;
};

inline bool Description::hasPendingClientRequests() const {
return activeReq_.has_value() || !reqQueue_.isEmpty();
}

} // namespace protocol

} // namespace particle
4 changes: 4 additions & 0 deletions communication/src/dtls_message_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ ProtocolError DTLSMessageChannel::command(Command command, void* arg)
return NO_ERROR;
}

void DTLSMessageChannel::notify_client_messages_processed() {
protocol->notify_client_messages_processed();
}

AppStateDescriptor DTLSMessageChannel::cached_app_state_descriptor() const
{
return sessionPersist.app_state_descriptor();
Expand Down
3 changes: 0 additions & 3 deletions communication/src/dtls_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ void DTLSProtocol::init(const char *id,
channelCallbacks.save = callbacks.save;
channelCallbacks.restore = callbacks.restore;
}
if (offsetof(SparkCallbacks, notify_client_messages_processed) + sizeof(SparkCallbacks::notify_client_messages_processed) <= callbacks.size) {
channelCallbacks.notify_client_messages_processed = callbacks.notify_client_messages_processed;
}

// TODO: Ideally, the next token value should be stored in the session data
mbedtls_default_rng(nullptr, &next_token, sizeof(next_token));
Expand Down
16 changes: 5 additions & 11 deletions communication/src/dtls_protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ class DTLSProtocol : public Protocol

// todo - this a duplicate of LightSSLProtocol - factor out

DTLSProtocol() : Protocol(channel) {}
DTLSProtocol() :
Protocol(channel),
channel(this),
device_id() {
}

void init(const char *id,
const SparkKeys &keys,
Expand Down Expand Up @@ -120,16 +124,6 @@ class DTLSProtocol : public Protocol
}
}

int get_status(protocol_status* status) const override {
SPARK_ASSERT(status);
status->flags = 0;
if (channel.has_unacknowledged_client_requests()) {
status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES;
}
return NO_ERROR;
}


/**
* Ensures that all outstanding sent coap messages have been acknowledged.
*/
Expand Down
16 changes: 16 additions & 0 deletions communication/src/protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,22 @@ int Protocol::get_describe_data(spark_protocol_describe_data* data, void* reserv
return 0;
}

int Protocol::get_status(protocol_status* status) const {
SPARK_ASSERT(status);
status->flags = 0;
if (channel.has_pending_client_messages() || description.hasPendingClientRequests()) {
status->flags |= PROTOCOL_STATUS_HAS_PENDING_CLIENT_MESSAGES;
}
return ProtocolError::NO_ERROR;
}

void Protocol::notify_client_messages_processed() {
if (callbacks.notify_client_messages_processed && !channel.has_pending_client_messages() &&
!description.hasPendingClientRequests()) { // Ensure there's no pending blockwise requests
callbacks.notify_client_messages_processed(nullptr /* reserved */);
}
}

size_t Protocol::get_max_transmit_message_size() const
{
if (!max_transmit_message_size) {
Expand Down
6 changes: 3 additions & 3 deletions test/unit_tests/communication/coap_reliability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages
THEN("the callback is invoked only once")
{
Verify(Method(channelMock, notify_client_messages_processed)).Once();
REQUIRE_FALSE(channel.has_unacknowledged_client_requests());
REQUIRE_FALSE(channel.has_pending_client_messages());
}
}

Expand All @@ -1229,7 +1229,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages
THEN("the callback is not invoked")
{
Verify(Method(channelMock, notify_client_messages_processed)).Never();
REQUIRE(channel.has_unacknowledged_client_requests());
REQUIRE(channel.has_pending_client_messages());
}
}

Expand All @@ -1247,7 +1247,7 @@ SCENARIO("notify_client_messages_processed() is invoked when all client messages
THEN("the callback is invoked only once")
{
Verify(Method(channelMock, notify_client_messages_processed)).Once();
REQUIRE_FALSE(channel.has_unacknowledged_client_requests());
REQUIRE_FALSE(channel.has_pending_client_messages());
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions test/unit_tests/communication/description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,4 +541,39 @@ TEST_CASE("Description") {
CHECK(m.option(CoapOption::BLOCK2).toUInt() == BlockOption().index(0).more(true));
d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id()));
}

SECTION("notifies the protocol layer when all client requests have been processed") {
Mock<Protocol> proto(*d.protocol());
When(Method(proto, notify_client_messages_processed)).AlwaysReturn();
When(Method(cb, appendSystemInfo)).Do([](appender_fn append, void* arg, void* reserved) {
auto s = std::string(PROTOCOL_BUFFER_SIZE, 'a');
append(arg, (const uint8_t*)s.data(), s.size());
return true;
});
When(Method(cb, appendAppInfo)).Do([](appender_fn append, void* arg, void* reserved) {
auto s = std::string(BLOCK_SIZE, 'b');
append(arg, (const uint8_t*)s.data(), s.size());
return true;
});
CHECK(!d.get()->hasPendingClientRequests());
// Send a blockwise request to the server
d.get()->sendRequest(DescriptionType::DESCRIBE_SYSTEM);
CHECK(d.get()->hasPendingClientRequests());
// Receive and acknowledge the first block
auto m = d.receiveMessage();
d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id()));
// Send another request to the server (a regular one)
d.get()->sendRequest(DescriptionType::DESCRIBE_APPLICATION);
// Receive the second block of the first request
m = d.receiveMessage();
CHECK(d.get()->hasPendingClientRequests());
Verify(Method(proto, notify_client_messages_processed)).Never();
// Acknowledge the second block
d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id()));
CHECK(!d.get()->hasPendingClientRequests());
Verify(Method(proto, notify_client_messages_processed)).Once();
// Receive and acknowledge the second request
m = d.receiveMessage();
d.sendMessage(CoapMessage().type(CoapType::ACK).code(CoapCode::EMPTY).id(m.id()));
}
}
5 changes: 5 additions & 0 deletions test/unit_tests/communication/forward_message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ class ForwardMessageChannel : public MessageChannel
channel->notify_client_messages_processed();
}

virtual bool has_pending_client_messages() const override
{
return channel->has_pending_client_messages();
}

virtual AppStateDescriptor cached_app_state_descriptor() const override
{
return AppStateDescriptor();
Expand Down
7 changes: 6 additions & 1 deletion test/unit_tests/communication/util/coap_message_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ class CoapMessageChannel: public BufferMessageChannel<PROTOCOL_BUFFER_SIZE> {
// Returns true if there's a message received from the device
bool hasMessages() const;

// Reimplemented from AbstractMessageChannel
// Reimplemented from MessageChannel
ProtocolError send(Message& msg) override;
ProtocolError receive(Message& msg) override;
ProtocolError command(Command cmd, void* arg) override;
bool is_unreliable() override;
ProtocolError establish() override;
ProtocolError notify_established() override;
void notify_client_messages_processed() override;
bool has_pending_client_messages() const override;
AppStateDescriptor cached_app_state_descriptor() const override;
void reset() override;

Expand Down Expand Up @@ -112,6 +113,10 @@ inline ProtocolError CoapMessageChannel::notify_established() {
inline void CoapMessageChannel::notify_client_messages_processed() {
}

inline bool CoapMessageChannel::has_pending_client_messages() const {
return false;
}

inline AppStateDescriptor CoapMessageChannel::cached_app_state_descriptor() const {
return AppStateDescriptor();
}
Expand Down
5 changes: 0 additions & 5 deletions test/unit_tests/communication/util/protocol_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class ProtocolStub: public Protocol {
void init(const char* id, const SparkKeys& keys, const SparkCallbacks& cb, const SparkDescriptor& desc) override;
int command(ProtocolCommands::Enum cmd, uint32_t val, const void* data) override;
size_t build_hello(Message& msg, uint16_t flags) override;
int get_status(protocol_status* status) const override;

private:
DescriptorCallbacks desc_;
Expand Down Expand Up @@ -72,10 +71,6 @@ inline size_t ProtocolStub::build_hello(Message& msg, uint16_t flags) {
return 0;
}

inline int ProtocolStub::get_status(protocol_status* status) const {
return 0;
}

} // namespace test

} // namespace protocol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,5 @@ test(06_register_many_functions) {
}
Particle.connect();
waitUntil(Particle.connected);
delay(6000); // Give the system some time to send a blockwise Describe message
delay(3000);
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,5 @@ test(07_register_many_variables) {
}
Particle.connect();
waitUntil(Particle.connected);
delay(6000); // Give the system some time to send a blockwise Describe message
delay(3000);
}