This repository has been archived by the owner on Mar 11, 2020. It is now read-only.
forked from froozen/kademlia
/
Tree.hs
293 lines (259 loc) · 12 KB
/
Tree.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
{-|
Module : Network.Kademlia.Tree
Description : Implementation of the Node Storage Tree
Network.Kademlia.Tree implements the Node Storage Tree used to store
and look up the known nodes.
This module is designed to be used as a qualified import.
-}
module Network.Kademlia.Tree
( NodeTree
, create
, insert
, lookup
, delete
, handleTimeout
, pickupRandom
, findClosest
, extractId
, toView
, toList
, fold
) where
import Prelude hiding (lookup)
import Control.Monad.Random (evalRand)
import Data.Binary (Binary)
import qualified Data.List as L (delete, find, genericTake)
import GHC.Generics (Generic)
import System.Random (StdGen)
import System.Random.Shuffle (shuffleM)
import Network.Kademlia.Config (WithConfig, getConfig, k, bucketSize)
import Network.Kademlia.Types (ByteStruct, Node (..), Serialize (..),
fromByteStruct, sortByDistanceTo, toByteStruct)
data NodeTree i = NodeTree ByteStruct (NodeTreeElem i)
deriving (Generic)
data NodeTreeElem i = Split (NodeTreeElem i) (NodeTreeElem i)
| Bucket ([(Node i, Int)], [Node i])
deriving (Generic)
type NodeTreeFunction i a = Int -> Bool -> ([(Node i, Int)], [Node i]) -> WithConfig a
instance Binary i => Binary (NodeTree i)
instance Binary i => Binary (NodeTreeElem i)
-- | Modify the position in the tree where the supplied id would be
modifyAt :: (Serialize i) =>
NodeTree i -> i -> NodeTreeFunction i (NodeTreeElem i)
-> WithConfig (NodeTree i)
modifyAt (NodeTree idStruct treeElem) nid f = do
targetStruct <- toByteStruct nid
newElems <- go idStruct targetStruct 0 True treeElem
return $ NodeTree idStruct newElems
where -- This function is partial, but we know that there will alwasys be a
-- bucket at the end. Therefore, we don't have to check for empty
-- ByteStructs
--
-- Apply the function to the position of the bucket
go _ _ depth valid (Bucket b) = f depth valid b
-- If the bit is a 0, go left
go (i:is) (False:ts) depth valid (Split left right) = do
new <- go is ts (depth + 1) (valid && not i) left
return $ Split new right
-- Otherwise, continue to the right
go (i:is) (True:ts) depth valid (Split left right) = do
new <- go is ts (depth + 1) (valid && i) right
return $ Split left new
go _ _ _ _ _ = error "Fundamental error in @go@ function at 'modifyAt'"
-- | Modify and apply a function at the position in the tree where the
-- supplied id would be
bothAt :: (Serialize i) =>
NodeTree i -> i -> NodeTreeFunction i (NodeTreeElem i, a)
-> WithConfig (NodeTree i, a)
bothAt (NodeTree idStruct treeElem) nid f = do
targetStruct <- toByteStruct nid
(newElems, val) <- go idStruct targetStruct 0 True treeElem
return (NodeTree idStruct newElems, val)
where -- This function is partial, but we know that there will alwasys be a
-- bucket at the end. Therefore, we don't have to check for empty
-- ByteStructs
--
-- Apply the function to the position of the bucket
go _ _ depth valid (Bucket b) = f depth valid b
-- If the bit is a 0, go left
go (i:is) (False:ts) depth valid (Split left right) = do
(new, val) <- go is ts (depth + 1) (valid && not i) left
return (Split new right, val)
-- Otherwise, continue to the right
go (i:is) (True:ts) depth valid (Split left right) = do
(new, val) <- go is ts (depth + 1) (valid && i) right
return (Split left new, val)
go _ _ _ _ _ = error "Fundamental error in @go@ function in 'bothAt'"
-- | Apply a function to the bucket the supplied id would be located in
applyAt :: (Serialize i) => NodeTree i -> i -> NodeTreeFunction i a -> WithConfig a
applyAt (NodeTree idStruct treeElem) nid f = do
targetStruct <- toByteStruct nid
go idStruct targetStruct 0 True treeElem
where -- This function is partial for the same reason as in modifyAt
--
-- Apply the function
go _ _ depth valid (Bucket b) = f depth valid b
-- If the bit is a 0, go left
go (i:is) (False:ts) depth valid (Split left _) =
go is ts (depth + 1) (valid && not i) left
-- Otherwise, continue to the right
go (i:is) (True:ts) depth valid (Split _ right) =
go is ts (depth + 1) (valid && i) right
go _ _ _ _ _ = error "Fundamental error in @go@ function in 'applyAt'"
-- | Create a NodeTree corresponding to the id
create :: (Serialize i) => i -> WithConfig (NodeTree i)
create nid = NodeTree <$> (toByteStruct nid) <*> pure (Bucket ([], []))
-- | Lookup a node within a NodeTree
lookup :: (Serialize i, Eq i) => NodeTree i -> i -> WithConfig (Maybe (Node i))
lookup tree nid = applyAt tree nid f
where f _ _ = return . L.find (idMatches nid) . map fst . fst
-- | Delete a Node corresponding to a supplied Id from a NodeTree
delete :: (Serialize i, Eq i) => NodeTree i -> i -> WithConfig (NodeTree i)
delete tree nid = modifyAt tree nid f
where f _ _ (nodes, cache) =
let deleted = filter (not . idMatches nid . fst) $ nodes
in return $ Bucket (deleted, cache)
-- | Handle a timed out node by incrementing its timeoutCount and deleting it
-- if the count exceeds the limit. Also, return wether it's reasonable to ping
-- the node again.
handleTimeout :: (Serialize i, Eq i) => NodeTree i -> i -> WithConfig (NodeTree i, Bool)
handleTimeout tree nid = do
bucketSize <- bucketSize <$> getConfig
let f _ _ (nodes, cache) = return $ case L.find (idMatches nid . fst) nodes of
-- Delete a node that exceeded the limit. Don't contact it again
-- as it is now considered dead
Just x@(_, bs) | bs == bucketSize -> (Bucket (L.delete x $ nodes, cache), False)
-- Increment the timeoutCount
Just x@(n, timeoutCount) ->
(Bucket ((n, timeoutCount + 1) : L.delete x nodes, cache), True)
-- Don't contact an unknown node a second time
Nothing -> (Bucket (nodes, cache), False)
bothAt tree nid f
-- | Refresh the node corresponding to a supplied Id by placing it at the first
-- index of it's KBucket and reseting its timeoutCount, then return a Bucket
-- NodeTreeElem
refresh :: Eq i => Node i -> ([(Node i, Int)], [Node i]) -> NodeTreeElem i
refresh node (nodes, cache) =
Bucket (case L.find (idMatches (nodeId node) . fst) nodes of
Just x@(n, _) -> (n, 0) : L.delete x nodes
_ -> nodes
, cache)
-- | Insert a node into a NodeTree
insert :: (Serialize i, Eq i) => NodeTree i -> Node i -> WithConfig (NodeTree i)
insert tree node = do
k <- k <$> getConfig
bucketSize <- bucketSize <$> getConfig
let needsSplit depth valid (nodes, _) = do
maxDepth <- ((subtract 1) . length <$> toByteStruct (nodeId node))
return $
-- A new node will be inserted
node `notElem` map fst nodes &&
-- The bucket is full
length nodes >= k &&
-- The bucket may be split
(depth < 5 || valid) && depth <= maxDepth
doInsert _ _ b@(nodes, cache)
-- Refresh an already existing node
| node `elem` map fst nodes = return $ refresh node b
-- Simply insert the node, if the bucket isn't full
| length nodes < k = return $ Bucket ((node, 0):nodes, cache)
-- Move the node to the first spot, if it's already cached
| node `elem` cache = return $ Bucket (nodes, node : L.delete node cache)
-- Cache the node and drop older ones, if necessary
| otherwise = return $ Bucket (nodes, node : take bucketSize cache)
r <- applyAt tree (nodeId node) needsSplit
if r
-- Split the tree before inserting, when it makes sense
then let splitTree = split tree . nodeId $ node
in flip insert node =<< splitTree
-- Insert the node
else modifyAt tree (nodeId node) doInsert
-- | Split the KBucket the specified id would reside in into two and return a
-- Split NodeTreeElem
split :: (Serialize i) => NodeTree i -> i -> WithConfig (NodeTree i)
split tree splitId = modifyAt tree splitId g
where g depth _ (nodes, cache) = do
(leftNodes, rightNodes) <- splitBucket depth fst nodes
(leftCache, rightCache) <- splitBucket depth id cache
return $ Split
(Bucket (leftNodes, leftCache))
(Bucket (rightNodes, rightCache))
-- Recursivly split the nodes into two buckets
splitBucket _ _ [] = return ([], [])
splitBucket i f (n:ns) = do
bs <- toByteStruct . nodeId . f $ n
let bit = bs !! i
(left, right) <- splitBucket i f ns
return $ if bit
then (left, n:right)
else (n:left, right)
-- | Returns @n@ random nodes from @all \\ ignoredList@.
pickupRandom
:: (Eq i)
=> NodeTree i
-> Int
-> [Node i]
-> StdGen
-> [Node i]
pickupRandom _ 0 _ _ = []
pickupRandom tree n ignoreList randGen =
let treeList = toList tree
notIgnored = filter (`notElem` ignoreList) treeList
shuffledNodes = evalRand (shuffleM notIgnored) randGen
in L.genericTake n shuffledNodes
-- | Find the k closest Nodes to a given Id
findClosest
:: (Serialize i)
=> NodeTree i
-> i
-> Int
-> WithConfig [Node i]
findClosest (NodeTree idStruct treeElem) nid n = do
let
chooseClosest nodes = take n <$> (sortByDistanceTo nodes $ nid)
-- This function is partial for the same reason as in modifyAt
--
-- Take the n closest nodes
go _ _ (Bucket (nodes, _))
| length nodes <= n = return $ map fst nodes
| otherwise = chooseClosest $ map fst nodes
-- Take the closest nodes from the left child first, if those aren't
-- enough, take the rest from the right
go (_:is) (False:ts) (Split left right) = do
result <- go is ts left
if length result == n
then return result
else (result ++) <$> go is ts right
-- Take the closest nodes from the right child first, if those aren't
-- enough, take the rest from the left
go (_:is) (True:ts) (Split left right) = do
result <- go is ts right
if length result == n
then return result
else (result ++) <$> go is ts left
go _ _ _ = error "Fundamental error in @go@ function in 'findClosest'"
targetStruct <- toByteStruct nid
chooseClosest =<< go idStruct targetStruct treeElem
-- Extract original Id from NodeTree
extractId :: (Serialize i) => NodeTree i -> WithConfig i
extractId (NodeTree nid _) = fromByteStruct nid
-- | Helper function used for KBucket manipulation
idMatches :: (Eq i) => i -> Node i -> Bool
idMatches nid node = nid == nodeId node
-- | Turn the NodeTree into a list of buckets, ordered by distance to origin node
toView :: NodeTree i -> [[Node i]]
toView (NodeTree bs treeElems) = go bs treeElems []
where -- If the bit is 0, go left, then right
go (False:is) (Split left right) = go is left . go is right
-- Else go right first
go (True:is) (Split left right) = go is right . go is left
go _ (Split _ _ ) = error "toView: unexpected Split"
go _ (Bucket (b, _)) = (map fst b :)
-- | Turn the NodeTree into a list of nodes
toList :: NodeTree i -> [Node i]
toList = concat . toView
-- | Fold over the buckets
fold :: ([Node i] -> a -> a) -> a -> NodeTree i -> a
fold f start (NodeTree _ treeElems) = go start treeElems
where go a (Split left right) = let a' = go a left in go a' right
go a (Bucket b) = f (map fst . fst $ b) a