diff --git a/src/toxcore/dht_new/packet.rs b/src/toxcore/dht_new/packet.rs index efcc17c27..2619ca60c 100644 --- a/src/toxcore/dht_new/packet.rs +++ b/src/toxcore/dht_new/packet.rs @@ -642,16 +642,11 @@ impl ToBytes for NatPingResponse { #[cfg(test)] mod test { use super::*; - use byteorder::{ByteOrder, BigEndian, WriteBytesExt}; use std::net::SocketAddr; use toxcore::dht_new::codec::*; -// use toxcore::dht_new::packet_kind::*; use quickcheck::{Arbitrary, Gen, quickcheck}; - const NAT_PING_REQUEST: PacketKind = PacketKind::PingRequest; - const NAT_PING_RESPONSE: PacketKind = PacketKind::PingResponse; - impl DhtPacket { pub fn new(shared_secret: &PrecomputedKey, pk: &PublicKey, dp: DhtPacketPayload) -> DhtPacket { let nonce = &gen_nonce(); @@ -698,24 +693,36 @@ mod test { } } - impl SendNodes { - /** - Create new `SendNodes`. Returns `None` if 0 or more than 4 nodes are - supplied. - - Created as a response to `GetNodes` request. - */ - pub fn with_nodes(request: &GetNodes, nodes: Vec) -> Option { - debug!(target: "SendNodes", "Creating SendNodes from GetNodes."); - trace!(target: "SendNodes", "With GetNodes: {:?}", request); - trace!("With nodes: {:?}", &nodes); - - if nodes.is_empty() || nodes.len() > 4 { - warn!(target: "SendNodes", "Wrong number of nodes supplied!"); - return None - } + impl PingRequest { + /// Create new ping request with a randomly generated `request id`. + pub fn new() -> Self { + trace!("Creating new Ping."); + PingRequest { id: random_u64() } + } + + /// An ID of the request / response. + pub fn id(&self) -> u64 { + self.id + } + } + + impl PingResponse { + /// An ID of the request / response. + pub fn id(&self) -> u64 { + self.id + } + } + + impl NatPingRequest { + /// Create new ping request with a randomly generated `request id`. + pub fn new() -> Self { + trace!("Creating new Ping."); + NatPingRequest { id: random_u64() } + } - Some(SendNodes { nodes: nodes, id: request.id }) + /// An ID of the request / response. + pub fn id(&self) -> u64 { + self.id } } @@ -725,6 +732,12 @@ mod test { } } + impl From for NatPingResponse { + fn from(p: NatPingRequest) -> Self { + NatPingResponse { id: p.id } + } + } + impl Arbitrary for DhtBase { fn arbitrary(g: &mut G) -> Self { let choice = g.gen_range(0, 2); @@ -741,17 +754,22 @@ mod test { let (pk, sk) = gen_keypair(); // "sender" keypair let (r_pk, _) = gen_keypair(); // receiver PK let precomputed = encrypt_precompute(&r_pk, &sk); + DhtPacket::new(&precomputed, &pk, DhtPacketPayload::arbitrary(g)) + } + } + impl Arbitrary for DhtPacketPayload { + fn arbitrary(g: &mut G) -> Self { let choice = g.gen_range(0, 4); match choice { 0 => - DhtPacket::new(&precomputed, &pk, DhtPacketPayload::PingRequest(PingRequest::arbitrary(g))), + DhtPacketPayload::PingRequest(PingRequest::arbitrary(g)), 1 => - DhtPacket::new(&precomputed, &pk, DhtPacketPayload::PingResponse(PingResponse::arbitrary(g))), + DhtPacketPayload::PingResponse(PingResponse::arbitrary(g)), 2 => - DhtPacket::new(&precomputed, &pk, DhtPacketPayload::GetNodes(GetNodes::arbitrary(g))), + DhtPacketPayload::GetNodes(GetNodes::arbitrary(g)), 3 => - DhtPacket::new(&precomputed, &pk, DhtPacketPayload::SendNodes(SendNodes::arbitrary(g))), + DhtPacketPayload::SendNodes(SendNodes::arbitrary(g)), _ => unreachable!("Arbitrary for DhtPacket - should not have happened!") } } @@ -762,136 +780,45 @@ mod test { let (pk, sk) = gen_keypair(); // "sender" keypair let (r_pk, _) = gen_keypair(); // receiver PK let precomputed = encrypt_precompute(&r_pk, &sk); + DhtRequest::new(&precomputed, &r_pk, &pk, DhtRequestPayload::arbitrary(g)) + } + } + impl Arbitrary for DhtRequestPayload { + fn arbitrary(g: &mut G) -> Self { let choice = g.gen_range(0, 2); if choice == 0 { - DhtRequest::new(&precomputed, &r_pk, &pk,DhtRequestPayload::NatPingRequest(NatPingRequest::arbitrary(g))) + DhtRequestPayload::NatPingRequest(NatPingRequest::arbitrary(g)) } else { - DhtRequest::new(&precomputed, &r_pk, &pk, DhtRequestPayload::NatPingResponse(NatPingResponse::arbitrary(g))) + DhtRequestPayload::NatPingResponse(NatPingResponse::arbitrary(g)) } } } - // PingRequest:: impl Arbitrary for PingRequest { fn arbitrary(_g: &mut G) -> Self { PingRequest::new() } } - - // PingResponse:: + impl Arbitrary for PingResponse { fn arbitrary(_g: &mut G) -> Self { PingRequest::new().into() } } - impl PingRequest { - /// Create new ping request with a randomly generated `request id`. - pub fn new() -> Self { - trace!("Creating new Ping."); - PingRequest { id: random_u64() } - } - - /// An ID of the request / response. - pub fn id(&self) -> u64 { - self.id - } - } - - impl PingResponse { - /// An ID of the request / response. - pub fn id(&self) -> u64 { - self.id - } - } - - // PingRequest:: impl Arbitrary for NatPingRequest { fn arbitrary(_g: &mut G) -> Self { NatPingRequest::new() } } - - // PingResponse:: + impl Arbitrary for NatPingResponse { fn arbitrary(_g: &mut G) -> Self { NatPingRequest::new().into() } } - impl NatPingRequest { - /// Create new ping request with a randomly generated `request id`. - pub fn new() -> Self { - trace!("Creating new Ping."); - NatPingRequest { id: random_u64() } - } - - /// An ID of the request / response. - pub fn id(&self) -> u64 { - self.id - } - } - - impl NatPingResponse { - /// An ID of the request / response. - pub fn id(&self) -> u64 { - self.id - } - } - - impl From for NatPingResponse { - fn from(p: NatPingRequest) -> Self { - NatPingResponse { id: p.id } - } - } - - macro_rules! tests_for_pings { - ($($p:ident $b_t:ident $f_t:ident)+) => ($( - - // ::to_bytes() - #[test] - fn $b_t() { - fn with_ping(p: $p) { - let mut _buf = [0; 1024]; - let pb = p.to_bytes((&mut _buf, 0)).ok().unwrap(); - assert_eq!(PING_SIZE, pb.1); - assert_eq!(PacketKind::$p as u8, pb.0[0]); - } - quickcheck(with_ping as fn($p)); - } - - // ::from_bytes() - #[test] - fn $f_t() { - fn with_bytes(bytes: Vec) { - if bytes.len() < PING_SIZE || - bytes[0] != PacketKind::$p as u8 { - assert!(!($p::from_bytes(&bytes)).is_done()); - } else { - let p = $p::from_bytes(&bytes).unwrap(); - // `id` should not differ - assert_eq!(p.1.id(), BigEndian::read_u64(&bytes[1..PING_SIZE])); - } - } - quickcheck(with_bytes as fn(Vec)); - - // just in case - let mut ping = vec![PacketKind::$p as u8]; - ping.write_u64::(random_u64()).unwrap(); - with_bytes(ping); - } - )+) - } - tests_for_pings!(PingRequest - packet_ping_req_to_bytes_test - packet_ping_req_from_bytes_test - PingResponse - packet_ping_resp_to_bytes_test - packet_ping_resp_from_bytes_test - ); - - // GetNodes:: impl Arbitrary for GetNodes { fn arbitrary(g: &mut G) -> Self { let mut a: [u8; PUBLICKEYBYTES] = [0; PUBLICKEYBYTES]; @@ -900,225 +827,64 @@ mod test { } } - impl Arbitrary for DhtPacketPayload { + impl Arbitrary for SendNodes { fn arbitrary(g: &mut G) -> Self { - let mut a: [u8; PUBLICKEYBYTES] = [0; PUBLICKEYBYTES]; - g.fill_bytes(&mut a); - DhtPacketPayload::GetNodes(GetNodes { pk: PublicKey(a), id: g.gen() }) - } - } - - // GetNodes::to_bytes() - #[test] - fn packet_get_nodes_to_bytes_test() { - fn with_gn(gn: GetNodes) { - let mut _buf = [0;1024]; - let g_bytes = gn.to_bytes((&mut _buf, 0)).ok().unwrap().0; - let PublicKey(pk_bytes) = gn.pk; - assert_eq!(&pk_bytes, &g_bytes[..PUBLICKEYBYTES]); - assert_eq!(gn.id, BigEndian::read_u64(&g_bytes[PUBLICKEYBYTES..])); - } - quickcheck(with_gn as fn(GetNodes)); - } - - /// Size of serialized [`GetNodes`](./struct.GetNodes.html) in bytes. - pub const GET_NODES_SIZE: usize = PUBLICKEYBYTES + 8; - - // GetNodes::from_bytes() - #[test] - fn packet_get_nodes_from_bytes_test() { - fn with_bytes(bytes: Vec) { - if bytes.len() < GET_NODES_SIZE { - assert!(!GetNodes::from_bytes(&bytes).is_done()); - } else { - let gn = GetNodes::from_bytes(&bytes).unwrap().1; - // ping_id as bytes should match "original" bytes - assert_eq!(BigEndian::read_u64(&bytes[PUBLICKEYBYTES..GET_NODES_SIZE]), gn.id); - - let PublicKey(ref pk) = gn.pk; - assert_eq!(pk, &bytes[..PUBLICKEYBYTES]); - } + let nodes = vec![Arbitrary::arbitrary(g); g.gen_range(1, 4)]; + let id = g.gen(); + SendNodes { nodes: nodes, id: id } } - quickcheck(with_bytes as fn(Vec)); } - // DhtPacketPayload::GetNodes::to_bytes() #[test] - fn dht_packet_get_nodes_to_bytes_test() { - fn with_gn(gn: DhtPacketPayload) { - let mut _buf = [0;1024]; - let g_bytes = gn.to_bytes((&mut _buf, 0)).ok().unwrap().0; - if let DhtPacketPayload::GetNodes(gp) = gn { - let PublicKey(pk_bytes) = gp.pk; - assert_eq!(&pk_bytes, &g_bytes[..PUBLICKEYBYTES]); - assert_eq!(gp.id, BigEndian::read_u64(&g_bytes[PUBLICKEYBYTES..])); - } + fn dht_packet_payload_check() { + fn with_payload(payload: DhtPacketPayload) { + let packet_kind = payload.kind(); + let mut buf = [0; MAX_DHT_PACKET_SIZE]; + let (_, len) = payload.to_bytes((&mut buf, 0)).ok().unwrap(); + let (_, decoded) = DhtPacketPayload::from_bytes(&buf[..len], packet_kind).unwrap(); + assert_eq!(decoded, payload); } - quickcheck(with_gn as fn(DhtPacketPayload)); + quickcheck(with_payload as fn(DhtPacketPayload)); } - // DhtPacketPayload::GetNodes::from_bytes() #[test] - fn dht_packet_get_nodes_from_bytes_test() { - fn with_bytes(bytes: Vec) { - if bytes.len() < GET_NODES_SIZE { - assert!(!GetNodes::from_bytes(&bytes).is_done()); - } else { - let gp = GetNodes::from_bytes(&bytes).unwrap().1; - // ping_id as bytes should match "original" bytes - assert_eq!(BigEndian::read_u64(&bytes[PUBLICKEYBYTES..GET_NODES_SIZE]), gp.id); - - let PublicKey(ref pk) = gp.pk; - assert_eq!(pk, &bytes[..PUBLICKEYBYTES]); - } - } - quickcheck(with_bytes as fn(Vec)); - } - - // SendNodes:: - impl Arbitrary for SendNodes { - fn arbitrary(g: &mut G) -> Self { - let nodes = vec![Arbitrary::arbitrary(g); g.gen_range(1,4)]; - let id = g.gen(); - SendNodes { nodes: nodes, id: id } + fn dht_request_payload_check() { + fn with_payload(payload: DhtRequestPayload) { + let mut buf = [0; MAX_DHT_PACKET_SIZE]; + let (_, len) = payload.to_bytes((&mut buf, 0)).ok().unwrap(); + let (_, decoded) = DhtRequestPayload::from_bytes(&buf[..len]).unwrap(); + assert_eq!(decoded, payload); } + quickcheck(with_payload as fn(DhtRequestPayload)); } - // SendNodes::to_bytes() #[test] - fn packet_send_nodes_to_bytes_test() { - // there should be at least 1 valid node; there can be up to 4 nodes - fn with_nodes(req: GetNodes, n1: PackedNode, n2: Option, - n3: Option, n4: Option) { - - let mut _buf = [0;1024]; - let mut nodes = vec![n1]; - if let Some(n) = n2 { nodes.push(n); } - if let Some(n) = n3 { nodes.push(n); } - if let Some(n) = n4 { nodes.push(n); } - let sn_bytes = SendNodes::with_nodes(&req, nodes.clone()) - .unwrap().to_bytes((&mut _buf, 0)).ok().unwrap().0; - - // number of nodes should match - assert_eq!(nodes.len(), sn_bytes[0] as usize); - - // bytes before current PackedNode in serialized SendNodes - // starts from `1` since first byte of serialized SendNodes is number of - // nodes - let mut len_before = 1; - for node in &nodes { - let mut _buf = [0; 1024]; - let cur_len = node.to_bytes((&mut _buf, 0)).ok().unwrap().1; - assert_eq!(&_buf[..cur_len], - &sn_bytes[len_before..(len_before + cur_len)]); - len_before += cur_len; - } - // ping id should be the same as in request - assert_eq!(req.id, BigEndian::read_u64(&sn_bytes[len_before..])); + fn dht_packet_check() { + fn with_packet(packet: DhtPacket) { + let mut buf = [0; MAX_DHT_PACKET_SIZE]; + let (_, len) = packet.to_bytes((&mut buf, 0)).ok().unwrap(); + let (_, decoded) = DhtPacket::from_bytes(&buf[..len]).unwrap(); + assert_eq!(decoded, packet); } - quickcheck(with_nodes as fn(GetNodes, PackedNode, Option, - Option, Option)); + quickcheck(with_packet as fn(DhtPacket)); } - // SendNodes::from_bytes() #[test] - fn packet_send_nodes_from_bytes_test() { - fn with_nodes(nodes: Vec, r_u64: u64) { - let mut bytes = vec![nodes.len() as u8]; - let mut _buf = [0; 1024]; - for node in &nodes { - let buf = node.to_bytes((&mut _buf, 0)).ok().unwrap(); - bytes.extend_from_slice(&buf.0[..buf.1]); - } - // and ping id - bytes.write_u64::(r_u64).unwrap(); - - if nodes.len() > 4 || nodes.is_empty() { - assert!(!SendNodes::from_bytes(&bytes).is_done()); - } else { - let nodes2 = SendNodes::from_bytes(&bytes).unwrap().1; - assert_eq!(&nodes, &nodes2.nodes); - assert_eq!(r_u64, nodes2.id); - } + fn dht_request_check() { + fn with_packet(packet: DhtRequest) { + let mut buf = [0; MAX_DHT_PACKET_SIZE]; + let (_, len) = packet.to_bytes((&mut buf, 0)).ok().unwrap(); + let (_, decoded) = DhtRequest::from_bytes(&buf[..len]).unwrap(); + assert_eq!(decoded, packet); } - quickcheck(with_nodes as fn(Vec, u64)); - } - - /** `NatPing` type byte for [`NatPingRequest`] and [`NatPingResponse`]. - [./struct.PingRequest.html] [./struct.PingResponse.html] - */ - pub const NAT_PING_TYPE: u8 = 0xfe; - - /** Length in bytes of NatPings when serialized into bytes. - */ - pub const NAT_PING_SIZE: usize = PING_SIZE + 1; - - macro_rules! impls_tests_for_nat_pings { - ($($np:ident $b_t:ident $f_t:ident)+) => ($( - // impl Arbitrary for $np { - // fn arbitrary(g: &mut G) -> Self { - // $np(Arbitrary::arbitrary(g)) - // } - // } - - #[test] - fn $b_t() { - fn with_np(p: $np) { - let mut _buf = [0; 1024]; - let pb = p.to_bytes((&mut _buf, 0)).ok().unwrap(); - assert_eq!(NAT_PING_SIZE, pb.1); - assert_eq!(NAT_PING_TYPE as u8, pb.0[0]); - if stringify!($np) == "NatPingRequest" { - assert_eq!(NAT_PING_REQUEST as u8, pb.0[1]); - } else { - assert_eq!(NAT_PING_RESPONSE as u8, pb.0[1]); - } - } - quickcheck(with_np as fn($np)); - } - - // ::from_bytes() - #[test] - fn $f_t() { - fn with_bytes(bytes: Vec) { - if bytes.len() < NAT_PING_SIZE || - bytes[0] != NAT_PING_TYPE as u8 { - assert!(!($np::from_bytes(&bytes)).is_done()); - } else { - let p = $np::from_bytes(&bytes).unwrap().1; - // `id` should not differ - assert_eq!(p.id(), BigEndian::read_u64(&bytes[2..NAT_PING_SIZE])); - } - } - quickcheck(with_bytes as fn(Vec)); - - // just in case - let ping_kind = match stringify!($np) { - "NatPingRequest" => NAT_PING_REQUEST as u8, - "NatPingResponse" => NAT_PING_RESPONSE as u8, - e => unreachable!("can not occur {:?}", e) - }; - let mut ping = vec![NAT_PING_TYPE, ping_kind]; - ping.write_u64::(random_u64()) - .unwrap(); - with_bytes(ping); - } - )+) + quickcheck(with_packet as fn(DhtRequest)); } - impls_tests_for_nat_pings!( - NatPingRequest - packet_nat_ping_req_to_bytes_test - packet_nat_ping_req_from_bytes_test - NatPingResponse - packet_nat_ping_resp_to_bytes_test - packet_nat_ping_resp_from_bytes_test - ); - #[test] fn dht_packet_encode_decode() { let (alice_pk, alice_sk) = gen_keypair(); let (bob_pk, bob_sk) = gen_keypair(); + let (_eve_pk, eve_sk) = gen_keypair(); let shared_secret = encrypt_precompute(&bob_pk, &alice_sk); let packed_node = PackedNode::new(false, SocketAddr::V4("5.6.7.8:12345".parse().unwrap()), &alice_pk); let test_payloads = vec![ @@ -1146,10 +912,112 @@ mod test { }; // packets should be equal assert_eq!(decoded_dht_packet, dht_packet); + // try to decode payload with eve's secret key + let decoded_payload = decoded_dht_packet.get_payload(&eve_sk); + assert!(decoded_payload.is_err()); // decode payload with bob's secret key let decoded_payload = decoded_dht_packet.get_payload(&bob_sk).unwrap(); // payloads should be equal assert_eq!(decoded_payload, payload); } } + + #[test] + fn dht_request_encode_decode() { + let (alice_pk, alice_sk) = gen_keypair(); + let (bob_pk, bob_sk) = gen_keypair(); + let (_eve_pk, eve_sk) = gen_keypair(); + let shared_secret = encrypt_precompute(&bob_pk, &alice_sk); + let test_payloads = vec![ + DhtRequestPayload::NatPingRequest(NatPingRequest { id: 42 }), + DhtRequestPayload::NatPingResponse(NatPingResponse { id: 42 }) + ]; + for payload in test_payloads { + // encode payload with shared secret + let dht_request = DhtRequest::new(&shared_secret, &bob_pk, &alice_pk, payload.clone()); + // create dht_base + let dht_base = DhtBase::DhtRequest(dht_request.clone()); + // serialize dht base to bytes + let mut buf = [0; MAX_DHT_PACKET_SIZE]; + let (_, size) = dht_base.to_bytes((&mut buf, 0)).unwrap(); + // deserialize dht base from bytes + let (_, decoded_dht_base) = DhtBase::from_bytes(&buf[..size]).unwrap(); + // bases should be equal + assert_eq!(decoded_dht_base, dht_base); + // get packet from base + let decoded_dht_request = match decoded_dht_base { + DhtBase::DhtRequest(decoded_dht_request) => decoded_dht_request, + _ => unreachable!("should be DhtRequest") + }; + // requests should be equal + assert_eq!(decoded_dht_request, dht_request); + // try to decode payload with eve's secret key + let decoded_payload = decoded_dht_request.get_payload(&eve_sk); + assert!(decoded_payload.is_err()); + // decode payload with bob's secret key + let decoded_payload = decoded_dht_request.get_payload(&bob_sk).unwrap(); + // payloads should be equal + assert_eq!(decoded_payload, payload); + } + } + + #[test] + fn dht_packet_decode_invalid() { + let (alice_pk, alice_sk) = gen_keypair(); + let (bob_pk, bob_sk) = gen_keypair(); + let shared_secret = encrypt_precompute(&bob_pk, &alice_sk); + let nonce = gen_nonce(); + // Try long invalid array + let invalid_payload = [42; 123]; + let invalid_payload_encoded = seal_precomputed(&invalid_payload, &nonce, &shared_secret); + let invalid_packet = DhtPacket { + packet_kind: PacketKind::PingRequest, + pk: alice_pk, + nonce: nonce, + payload: invalid_payload_encoded + }; + let decoded_payload = invalid_packet.get_payload(&bob_sk); + assert!(decoded_payload.is_err()); + // Try short incomplete + let invalid_payload = [0x00]; + let invalid_payload_encoded = seal_precomputed(&invalid_payload, &nonce, &shared_secret); + let invalid_packet = DhtPacket { + packet_kind: PacketKind::PingRequest, + pk: alice_pk, + nonce: nonce, + payload: invalid_payload_encoded + }; + let decoded_payload = invalid_packet.get_payload(&bob_sk); + assert!(decoded_payload.is_err()); + } + + #[test] + fn dht_request_decode_invalid() { + let (alice_pk, alice_sk) = gen_keypair(); + let (bob_pk, bob_sk) = gen_keypair(); + let shared_secret = encrypt_precompute(&bob_pk, &alice_sk); + let nonce = gen_nonce(); + // Try long invalid array + let invalid_payload = [42; 123]; + let invalid_payload_encoded = seal_precomputed(&invalid_payload, &nonce, &shared_secret); + let invalid_packet = DhtRequest { + rpk: bob_pk, + spk: alice_pk, + nonce: nonce, + payload: invalid_payload_encoded + }; + let decoded_payload = invalid_packet.get_payload(&bob_sk); + assert!(decoded_payload.is_err()); + // Try short incomplete + let invalid_payload = [0xfe]; + let invalid_payload_encoded = seal_precomputed(&invalid_payload, &nonce, &shared_secret); + let invalid_packet = DhtRequest { + rpk: bob_pk, + spk: alice_pk, + nonce: nonce, + payload: invalid_payload_encoded + }; + let decoded_payload = invalid_packet.get_payload(&bob_sk); + assert!(decoded_payload.is_err()); + } }