Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional node validation to routing table #435

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions eth/p2p/discoveryv5/protocol.nim
Expand Up @@ -153,6 +153,7 @@ proc addNode*(d: Protocol, node: Node): bool =
##
## Returns true only when `Node` was added as a new entry to a bucket in the
## routing table.

if d.routingTable.addNode(node) == Added:
return true
else:
Expand Down Expand Up @@ -888,7 +889,8 @@ proc newProtocol*(privKey: PrivateKey,
bindIp = IPv4_any(),
enrAutoUpdate = false,
tableIpLimits = DefaultTableIpLimits,
rng = newRng()):
rng = newRng(),
nodeValidator = none(NodeValidator)):
Protocol =
# TODO: Tried adding bindPort = udpPort as parameter but that gave
# "Error: internal error: environment misses: udpPort" in nim-beacon-chain.
Expand Down Expand Up @@ -927,7 +929,8 @@ proc newProtocol*(privKey: PrivateKey,
bootstrapRecords: @bootstrapRecords,
ipVote: IpVote.init(),
enrAutoUpdate: enrAutoUpdate,
routingTable: RoutingTable.init(node, DefaultBitsPerHop, tableIpLimits, rng),
routingTable: RoutingTable.init(node, DefaultBitsPerHop, tableIpLimits, rng,
nodeValidator = nodeValidator),
rng: rng)

template listeningAddress*(p: Protocol): Address =
Expand Down
13 changes: 11 additions & 2 deletions eth/p2p/discoveryv5/routing_table.nim
Expand Up @@ -22,6 +22,7 @@ type
DistanceProc* = proc(a, b: NodeId): NodeId {.raises: [Defect], gcsafe, noSideEffect.}
LogDistanceProc* = proc(a, b: NodeId): uint16 {.raises: [Defect], gcsafe, noSideEffect.}
IdAtDistanceProc* = proc (id: NodeId, dist: uint16): NodeId {.raises: [Defect], gcsafe, noSideEffect.}
NodeValidator* = proc(node: Node): bool {.gcsafe, raises: [Defect].}

DistanceCalculator* = object
calculateDistance*: DistanceProc
Expand All @@ -43,6 +44,7 @@ type
## replacement caches.
distanceCalculator: DistanceCalculator
rng: ref BrHmacDrbgContext
nodeValidator: Option[NodeValidator] ## Optional validation of nodes

KBucket = ref object
istart, iend: NodeId ## Range of NodeIds this KBucket covers. This is not a
Expand Down Expand Up @@ -91,6 +93,7 @@ type
ReplacementAdded
ReplacementExisting
NoAddress
Invalid

# xor distance functions
func distance*(a, b: NodeId): Uint256 =
Expand Down Expand Up @@ -261,7 +264,7 @@ proc computeSharedPrefixBits(nodes: openarray[NodeId]): int =

proc init*(T: type RoutingTable, localNode: Node, bitsPerHop = DefaultBitsPerHop,
ipLimits = DefaultTableIpLimits, rng: ref BrHmacDrbgContext,
distanceCalculator = XorDistanceCalculator): T =
distanceCalculator = XorDistanceCalculator, nodeValidator = none(NodeValidator)): T =
## Initialize the routing table for provided `Node` and bitsPerHop value.
## `bitsPerHop` is default set to 5 as recommended by original Kademlia paper.
RoutingTable(
Expand All @@ -270,7 +273,8 @@ proc init*(T: type RoutingTable, localNode: Node, bitsPerHop = DefaultBitsPerHop
bitsPerHop: bitsPerHop,
ipLimits: IpLimits(limit: ipLimits.tableIpLimit),
distanceCalculator: distanceCalculator,
rng: rng)
rng: rng,
nodeValidator: nodeValidator)

proc splitBucket(r: var RoutingTable, index: int) =
let bucket = r.buckets[index]
Expand Down Expand Up @@ -329,6 +333,11 @@ proc addNode*(r: var RoutingTable, n: Node): NodeStatus =
## When the IP of the node has reached the IP limits for the bucket or the
## total routing table, the node will not be added to the bucket, nor its
## replacement cache.

## If we have the optional validator set, only add
## node if it passes validation
if r.nodeValidator.isSome() and not r.nodeValidator.get()(n):
return Invalid

# Don't allow nodes without an address field in the ENR to be added.
# This could also be reworked by having another Node type that always has an
Expand Down
28 changes: 28 additions & 0 deletions tests/p2p/test_discoveryv5.nim
Expand Up @@ -387,6 +387,34 @@ procSuite "Discovery v5 Tests":

await mainNode.closeWait()
await lookupNode.closeWait()

asyncTest "Discovery with node validator":
let validPort = 20302

proc validator(node: Node): bool {.gcsafe, raises: [Defect].} =
# Simple validation on UDP port
let tr = node.record.toTypedRecord.get()
return tr.udp.isSome and tr.udp.get() == validPort

let
lookupNode = newProtocol(PrivateKey.random(rng[]), some(ValidIpAddress.init("127.0.0.1")),
some(Port(validPort)), some(Port(validPort)), bindPort = Port(validPort),
rng = rng, nodeValidator = some(validator.NodeValidator))
validNode1 = generateNode(PrivateKey.random(rng[]), port = validPort)
validNode2 = generateNode(PrivateKey.random(rng[]), port = validPort)
invalidNode1 = generateNode(PrivateKey.random(rng[]), port = validPort - 1)

check:
lookupNode.addNode(validNode1)
lookupNode.addNode(validNode2)
lookupNode.addNode(invalidNode1) == false

let discovered = lookupNode.randomNodes(10)
check:
discovered.len == 2
discovered.contains(validNode1)
discovered.contains(validNode2)
discovered.contains(invalidNode1) == false

asyncTest "Random nodes with enr field filter":
let
Expand Down
28 changes: 28 additions & 0 deletions tests/p2p/test_routing_table.nim
Expand Up @@ -562,3 +562,31 @@ suite "Routing Table Tests":
# there may be more than one node at provided distance
check len(neighboursAtLogDist) >= 1
check neighboursAtLogDist.contains(n)

test "Node validation in routing table":
let validPort = 20302

proc validator(node: Node): bool {.gcsafe, raises: [Defect].} =
# Simple validation on UDP port
let tr = node.record.toTypedRecord.get()
return tr.udp.isSome and tr.udp.get() == validPort

let local = generateNode(PrivateKey.random(rng[]))
var table = RoutingTable.init(local, 1, ipLimits, rng = rng,
distanceCalculator = customDistanceCalculator,
nodeValidator = some(validator.NodeValidator)
)

let
validNode = generateNode(PrivateKey.random(rng[]), port = validPort)
invalidNode = generateNode(PrivateKey.random(rng[]), port = validPort + 1)

# Valid nodes get added
check:
table.addNode(validNode) == Added
table.len == 1

# Invalid nodes get rejected
check:
table.addNode(invalidNode) == Invalid
table.len == 1