Skip to content

Commit 93594e4

Browse files
committed
refactor: merge transport serializer and deserializer into Transport class
This allows state that is shared between both directions to be encapsulated into a single object. Specifically the v2 transport protocol introduced by BIP324 has sending state (the encryption keys) that depends on received messages (the DH key exchange). Having a single object for both means it can hide logic from callers related to that key exchange and other interactions.
1 parent 23f3f40 commit 93594e4

File tree

4 files changed

+37
-42
lines changed

4 files changed

+37
-42
lines changed

Diff for: src/net.cpp

+10-11
Original file line numberDiff line numberDiff line change
@@ -681,16 +681,16 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
681681
nRecvBytes += msg_bytes.size();
682682
while (msg_bytes.size() > 0) {
683683
// absorb network data
684-
int handled = m_deserializer->Read(msg_bytes);
684+
int handled = m_transport->Read(msg_bytes);
685685
if (handled < 0) {
686686
// Serious header problem, disconnect from the peer.
687687
return false;
688688
}
689689

690-
if (m_deserializer->Complete()) {
690+
if (m_transport->Complete()) {
691691
// decompose a transport agnostic CNetMessage from the deserializer
692692
bool reject_message{false};
693-
CNetMessage msg = m_deserializer->GetMessage(time, reject_message);
693+
CNetMessage msg = m_transport->GetMessage(time, reject_message);
694694
if (reject_message) {
695695
// Message deserialization failed. Drop the message but don't disconnect the peer.
696696
// store the size of the corrupt message
@@ -717,7 +717,7 @@ bool CNode::ReceiveMsgBytes(Span<const uint8_t> msg_bytes, bool& complete)
717717
return true;
718718
}
719719

720-
int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
720+
int V1Transport::readHeader(Span<const uint8_t> msg_bytes)
721721
{
722722
// copy data to temporary parsing buffer
723723
unsigned int nRemaining = CMessageHeader::HEADER_SIZE - nHdrPos;
@@ -757,7 +757,7 @@ int V1TransportDeserializer::readHeader(Span<const uint8_t> msg_bytes)
757757
return nCopy;
758758
}
759759

760-
int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
760+
int V1Transport::readData(Span<const uint8_t> msg_bytes)
761761
{
762762
unsigned int nRemaining = hdr.nMessageSize - nDataPos;
763763
unsigned int nCopy = std::min<unsigned int>(nRemaining, msg_bytes.size());
@@ -774,15 +774,15 @@ int V1TransportDeserializer::readData(Span<const uint8_t> msg_bytes)
774774
return nCopy;
775775
}
776776

777-
const uint256& V1TransportDeserializer::GetMessageHash() const
777+
const uint256& V1Transport::GetMessageHash() const
778778
{
779779
assert(Complete());
780780
if (data_hash.IsNull())
781781
hasher.Finalize(data_hash);
782782
return data_hash;
783783
}
784784

785-
CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, bool& reject_message)
785+
CNetMessage V1Transport::GetMessage(const std::chrono::microseconds time, bool& reject_message)
786786
{
787787
// Initialize out parameter
788788
reject_message = false;
@@ -819,7 +819,7 @@ CNetMessage V1TransportDeserializer::GetMessage(const std::chrono::microseconds
819819
return msg;
820820
}
821821

822-
void V1TransportSerializer::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
822+
void V1Transport::prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const
823823
{
824824
// create dbl-sha256 checksum
825825
uint256 hash = Hash(msg.data);
@@ -2822,8 +2822,7 @@ CNode::CNode(NodeId idIn,
28222822
ConnectionType conn_type_in,
28232823
bool inbound_onion,
28242824
CNodeOptions&& node_opts)
2825-
: m_deserializer{std::make_unique<V1TransportDeserializer>(V1TransportDeserializer(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION))},
2826-
m_serializer{std::make_unique<V1TransportSerializer>(V1TransportSerializer())},
2825+
: m_transport{std::make_unique<V1Transport>(Params(), idIn, SER_NETWORK, INIT_PROTO_VERSION)},
28272826
m_permission_flags{node_opts.permission_flags},
28282827
m_sock{sock},
28292828
m_connected{GetTime<std::chrono::seconds>()},
@@ -2908,7 +2907,7 @@ void CConnman::PushMessage(CNode* pnode, CSerializedNetMsg&& msg)
29082907

29092908
// make sure we use the appropriate network transport format
29102909
std::vector<unsigned char> serializedHeader;
2911-
pnode->m_serializer->prepareForTransport(msg, serializedHeader);
2910+
pnode->m_transport->prepareForTransport(msg, serializedHeader);
29122911
size_t nTotalSize = nMessageSize + serializedHeader.size();
29132912

29142913
size_t nBytesSent = 0;

Diff for: src/net.h

+18-23
Original file line numberDiff line numberDiff line change
@@ -253,24 +253,31 @@ class CNetMessage {
253253
}
254254
};
255255

256-
/** The TransportDeserializer takes care of holding and deserializing the
257-
* network receive buffer. It can deserialize the network buffer into a
258-
* transport protocol agnostic CNetMessage (message type & payload)
259-
*/
260-
class TransportDeserializer {
256+
/** The Transport converts one connection's sent messages to wire bytes, and received bytes back. */
257+
class Transport {
261258
public:
259+
virtual ~Transport() {}
260+
261+
// 1. Receiver side functions, for decoding bytes received on the wire into transport protocol
262+
// agnostic CNetMessage (message type & payload) objects. Callers must guarantee that none of
263+
// these functions are called concurrently w.r.t. one another.
264+
262265
// returns true if the current deserialization is complete
263266
virtual bool Complete() const = 0;
264-
// set the serialization context version
267+
// set the deserialization context version
265268
virtual void SetVersion(int version) = 0;
266269
/** read and deserialize data, advances msg_bytes data pointer */
267270
virtual int Read(Span<const uint8_t>& msg_bytes) = 0;
268271
// decomposes a message from the context
269272
virtual CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) = 0;
270-
virtual ~TransportDeserializer() {}
273+
274+
// 2. Sending side functions:
275+
276+
// prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
277+
virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
271278
};
272279

273-
class V1TransportDeserializer final : public TransportDeserializer
280+
class V1Transport final : public Transport
274281
{
275282
private:
276283
const CChainParams& m_chain_params;
@@ -300,7 +307,7 @@ class V1TransportDeserializer final : public TransportDeserializer
300307
}
301308

302309
public:
303-
V1TransportDeserializer(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
310+
V1Transport(const CChainParams& chain_params, const NodeId node_id, int nTypeIn, int nVersionIn)
304311
: m_chain_params(chain_params),
305312
m_node_id(node_id),
306313
hdrbuf(nTypeIn, nVersionIn),
@@ -331,19 +338,7 @@ class V1TransportDeserializer final : public TransportDeserializer
331338
return ret;
332339
}
333340
CNetMessage GetMessage(std::chrono::microseconds time, bool& reject_message) override;
334-
};
335341

336-
/** The TransportSerializer prepares messages for the network transport
337-
*/
338-
class TransportSerializer {
339-
public:
340-
// prepare message for transport (header construction, error-correction computation, payload encryption, etc.)
341-
virtual void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const = 0;
342-
virtual ~TransportSerializer() {}
343-
};
344-
345-
class V1TransportSerializer : public TransportSerializer {
346-
public:
347342
void prepareForTransport(CSerializedNetMsg& msg, std::vector<unsigned char>& header) const override;
348343
};
349344

@@ -359,8 +354,8 @@ struct CNodeOptions
359354
class CNode
360355
{
361356
public:
362-
const std::unique_ptr<TransportDeserializer> m_deserializer; // Used only by SocketHandler thread
363-
const std::unique_ptr<const TransportSerializer> m_serializer;
357+
/** Transport serializer/deserializer. The receive side functions are only called under cs_vRecv. */
358+
const std::unique_ptr<Transport> m_transport;
364359

365360
const NetPermissionFlags m_permission_flags;
366361

Diff for: src/test/fuzz/p2p_transport_serialization.cpp

+8-7
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ void initialize_p2p_transport_serialization()
2424

2525
FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serialization)
2626
{
27-
// Construct deserializer, with a dummy NodeId
28-
V1TransportDeserializer deserializer{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
29-
V1TransportSerializer serializer{};
27+
// Construct transports for both sides, with dummy NodeIds.
28+
V1Transport recv_transport{Params(), NodeId{0}, SER_NETWORK, INIT_PROTO_VERSION};
29+
V1Transport send_transport{Params(), NodeId{1}, SER_NETWORK, INIT_PROTO_VERSION};
30+
3031
FuzzedDataProvider fuzzed_data_provider{buffer.data(), buffer.size()};
3132

3233
auto checksum_assist = fuzzed_data_provider.ConsumeBool();
@@ -63,22 +64,22 @@ FUZZ_TARGET(p2p_transport_serialization, .init = initialize_p2p_transport_serial
6364
mutable_msg_bytes.insert(mutable_msg_bytes.end(), payload_bytes.begin(), payload_bytes.end());
6465
Span<const uint8_t> msg_bytes{mutable_msg_bytes};
6566
while (msg_bytes.size() > 0) {
66-
const int handled = deserializer.Read(msg_bytes);
67+
const int handled = recv_transport.Read(msg_bytes);
6768
if (handled < 0) {
6869
break;
6970
}
70-
if (deserializer.Complete()) {
71+
if (recv_transport.Complete()) {
7172
const std::chrono::microseconds m_time{std::numeric_limits<int64_t>::max()};
7273
bool reject_message{false};
73-
CNetMessage msg = deserializer.GetMessage(m_time, reject_message);
74+
CNetMessage msg = recv_transport.GetMessage(m_time, reject_message);
7475
assert(msg.m_type.size() <= CMessageHeader::COMMAND_SIZE);
7576
assert(msg.m_raw_message_size <= mutable_msg_bytes.size());
7677
assert(msg.m_raw_message_size == CMessageHeader::HEADER_SIZE + msg.m_message_size);
7778
assert(msg.m_time == m_time);
7879

7980
std::vector<unsigned char> header;
8081
auto msg2 = CNetMsgMaker{msg.m_recv.GetVersion()}.Make(msg.m_type, Span{msg.m_recv});
81-
serializer.prepareForTransport(msg2, header);
82+
send_transport.prepareForTransport(msg2, header);
8283
}
8384
}
8485
}

Diff for: src/test/util/net.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ void ConnmanTestMsg::NodeReceiveMsgBytes(CNode& node, Span<const uint8_t> msg_by
7373
bool ConnmanTestMsg::ReceiveMsgFrom(CNode& node, CSerializedNetMsg& ser_msg) const
7474
{
7575
std::vector<uint8_t> ser_msg_header;
76-
node.m_serializer->prepareForTransport(ser_msg, ser_msg_header);
76+
node.m_transport->prepareForTransport(ser_msg, ser_msg_header);
7777

7878
bool complete;
7979
NodeReceiveMsgBytes(node, ser_msg_header, complete);

0 commit comments

Comments
 (0)