In [1]:
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}

import Data.Bits             (Bits (bit, complement, popCount, shiftR, (.&.), (.|.)),
                              FiniteBits (finiteBitSize))
import Data.ByteArray.Hash   (FnvHash32 (..), fnv1Hash)
import Data.ByteString.Char8 (pack)
import Data.Char             (intToDigit)
import Data.Semigroup        ((<>))
import Data.Vector           (Vector, drop, singleton, take, (!), (//))
import Data.Word             (Word16, Word32)
import Numeric               (showIntAtBase)
import Prelude               hiding (drop, lookup, take)
import System.TimeIt         (timeIt)
import Text.Show.Pretty      (pPrint)

In [2]:
newtype Binary a = Binary a deriving (Enum, Ord, Real, Integral, Eq, Num, Bits, FiniteBits)

instance (FiniteBits a, Show a, Integral a) => Show (Binary a) where
    show (Binary a) = let
        str = showIntAtBase 2 intToDigit a ""
        size = finiteBitSize a
        in replicate (size - length str) '0' <> str

In [3]:
type Hash = Binary Word32
type Bitmap = Binary Word16
type Shift = Int

class Hashable a where
    hash :: a -> Hash

In [4]:
instance Hashable String where
    hash s = let
        FnvHash32 h = fnv1Hash (pack s)
        in Binary h

In [5]:
data HAMT key value
    = None
    | Leaf Hash key value
    | Many Bitmap (Vector (HAMT key value))
    | Full (Vector (HAMT key value))
    | Coll Hash (Vector (key, value))
    deriving (Show)

empty :: HAMT key value
empty = None

In [6]:
bitsPerSubkey :: Int
bitsPerSubkey = 4

subkeyMask :: Bitmap
subkeyMask = (bit bitsPerSubkey) - 1

fullMask :: Bitmap
fullMask = (bit (2^bitsPerSubkey)) - 1

subkey :: Hash -> Shift -> Int
subkey hash shift = fromIntegral $ (fromIntegral $ shiftR hash shift) .&. subkeyMask

maskIndex :: Bitmap -> Bitmap -> Int
maskIndex bitmap mask = popCount (bitmap .&. (mask - 1))

bitMask :: Hash -> Shift -> Bitmap
bitMask hash shift = bit (subkey hash shift)

In [7]:
insertAt :: Vector a -> Int -> a -> Vector a
insertAt vector index a = take index vector <> singleton a <> drop index vector

updateAt :: Vector a -> Int -> a -> Vector a
updateAt vector index a = vector // [(index, a)]

deleteAt :: Vector a -> Int -> Vector a
deleteAt vector index = take index vector <> drop (index+1) vector

In [8]:
insert :: (Hashable key, Eq key) => key -> value -> HAMT key value -> HAMT key value
insert key value hamt = insert' 0 (hash key) key value hamt

insert' :: Eq key => Shift -> Hash -> key -> value -> HAMT key value -> HAMT key value
insert' shift hash key value None = Leaf hash key value

insert' shift hash key value leaf@(Leaf leafHash leafKey leafValue)
    | hash == leafHash =
        if key == leafKey
        then Leaf hash key value
        else Coll hash (insertAt (singleton (leafKey, leafValue)) 0 (key, value))
    | otherwise = insert' shift hash key value (Many (bitMask leafHash shift) (singleton leaf))

insert' shift hash key value (Many bitmap vector)
    | bitmap .&. mask == 0 = let
        leaf = Leaf hash key value
        vector' = insertAt vector index leaf
        bitmap' = bitmap .|. mask
        in if bitmap' == fullMask
          then Full vector'
          else Many bitmap' vector'
    | otherwise = let
        subtree = vector ! index
        subtree' = insert' (shift+bitsPerSubkey) hash key value subtree
        vector' = updateAt vector index subtree'
        in Many bitmap vector'
    where
        mask = bitMask hash shift
        index = maskIndex bitmap mask

insert' shift hash key value (Full vector) =
    let
        subtree = vector ! index
        subtree' = insert' (shift+bitsPerSubkey) hash key value subtree
        vector' = updateAt vector index subtree'
    in Full vector'
    where
        index = subkey hash shift

insert' shift hash key value coll@(Coll collHash vector)
    | hash == collHash = Coll collHash (updateOrPrepend 0 (length vector) key value vector)
    | otherwise = insert' shift hash key value (Many (bitMask collHash shift) (singleton coll))
    where
        updateOrPrepend index len key value vector
            | index == len = insertAt vector 0 (key, value)
            | otherwise = let
                (currKey, _) = vector ! index
                in if currKey == key
                    then updateAt vector index (key, value)
                    else updateOrPrepend (index+1) len key value vector

In [9]:
fromList :: (Hashable key, Eq key) => [(key, value)] -> HAMT key value
fromList = foldr (uncurry insert) empty

In [10]:
lookup :: (Hashable key, Eq key) => key -> HAMT key value -> Maybe value
lookup key hamt = lookup' 0 (hash key) key hamt

lookup' :: Eq key => Shift -> Hash -> key -> HAMT key value -> Maybe value
lookup' shift hash key None = Nothing

lookup' shift hash key (Leaf leafHash leafKey leafValue)
    | hash == leafHash && key == leafKey = Just leafValue
    | otherwise = Nothing

lookup' shift hash key (Many bitmap vector)
    | bitmap .&. mask == 0 = Nothing
    | otherwise = lookup' (shift+bitsPerSubkey) hash key (vector ! index)
    where
        mask = bitMask hash shift
        index = maskIndex bitmap mask

lookup' shift hash key (Full vector) = lookup' (shift+bitsPerSubkey) hash key (vector ! index)
    where
        index = subkey hash shift

lookup' shift hash key (Coll collHash vector)
    | hash == collHash = findMatching 0 (length vector) key vector
    | otherwise = Nothing
    where
        findMatching index len key vector
            | index == len = Nothing
            | otherwise = let
                (currKey, currValue) = vector ! index
                in if currKey == key
                    then Just currValue
                    else findMatching (index+1) len key vector

In [11]:
fibSlow :: Int -> Int
fibSlow 0 = 1
fibSlow 1 = 1
fibSlow n = fibSlow (n-1) + fibSlow (n-2)

In [12]:
instance Hashable Int where
    hash int = Binary (fromIntegral int)

fib' :: HAMT Int Integer -> Int -> (Integer, HAMT Int Integer)
fib' table 0 = (1, insert 0 1 table)
fib' table 1 = (1, insert 1 1 table)
fib' table n = case lookup n table of
    Just i -> (i, table)
    Nothing -> let
        (i1, table')  = fib' table  (n-1)
        (i2, table'') = fib' table' (n-2)
        in (i1 + i2, insert n (i1 + i2) table'')

fibFast :: Int -> Integer
fibFast n = fst $ fib' empty n

In [13]:
delete :: (Hashable key, Eq key) => key -> HAMT key value -> HAMT key value
delete key hamt = delete' 0 (hash key) key hamt

delete' :: Eq key => Shift -> Hash -> key -> HAMT key value -> HAMT key value
delete' shift hash key None = None

delete' shift hash key leaf@(Leaf leafHash leafKey leafValue)
    | hash == leafHash && key == leafKey = None
    | otherwise = leaf

delete' shift hash key many@(Many bitmap vector)
    | bitmap .&. mask == 0 = many
    | otherwise = let
        subtree = vector ! index
        subtree' = delete' (shift+bitsPerSubkey) hash key subtree
        in case subtree' of
            None -> if length vector == 1
                then None
                else Many (bitmap .&. complement mask) (deleteAt vector index)
            Leaf{} -> if length vector == 1
                then subtree'
                else Many bitmap (updateAt vector index subtree')
            _ -> Many bitmap (updateAt vector index subtree')
    where
        mask = bitMask hash shift
        index = maskIndex bitmap mask

delete' shift hash key (Full vector) =
    let
        subtree = vector ! index
        subtree' = delete' (shift+bitsPerSubkey) hash key subtree
    in case subtree' of
        None -> Many (fullMask .&. complement mask) (deleteAt vector index)
        _ -> Full (updateAt vector index subtree')
    where
        mask = bitMask hash shift
        index = subkey hash shift

delete' shift hash key coll@(Coll collHash vector)
    | hash == collHash = let
        vector' = deleteMatching 0 (length vector) key vector
        in if length vector' == 1
            then (\(leafKey, leafValue) -> Leaf collHash leafKey leafValue) $ vector' ! 0
            else Coll collHash vector'
    | otherwise = coll
    where
        deleteMatching index len key vector
            | index == len = vector
            | otherwise = let
                (currKey, _) = vector ! index
                in if currKey == key
                    then deleteAt vector index
                    else deleteMatching (index+1) len key vector

In [14]:
main :: IO ()
main = do
    let example = fromList [("1", 1), ("10", 2), ("100", 3), ("1000", 4)]
    pPrint example
    print $ lookup "100" example
    timeIt $ print $ fibSlow 30
    timeIt $ print $ fibFast 30
    pPrint $ delete "1000" example
    pPrint $ delete "10" $ delete "1000" example

In [15]:
ls = fromList $ map (\i -> (show i, "")) $ [0..9] <> [18,19] <> [22,23] <> [26,27]

In [16]:
pPrint ls

Full
  [ Leaf 00011111011101101010110111100000 "27" ""
  , Leaf 00011111011101101010110111100001 "26" ""
  , Leaf 00100000011101101010111101010010 "18" ""
  , Leaf 00100000011101101010111101010011 "19" ""
  , Leaf 00011111011101101010110111100100 "23" ""
  , Leaf 00011111011101101010110111100101 "22" ""
  , Leaf 00000101000011000101110100100110 "9" ""
  , Leaf 00000101000011000101110100100111 "8" ""
  , Leaf 00000101000011000101110100101000 "7" ""
  , Leaf 00000101000011000101110100101001 "6" ""
  , Leaf 00000101000011000101110100101010 "5" ""
  , Leaf 00000101000011000101110100101011 "4" ""
  , Leaf 00000101000011000101110100101100 "3" ""
  , Leaf 00000101000011000101110100101101 "2" ""
  , Leaf 00000101000011000101110100101110 "1" ""
  , Leaf 00000101000011000101110100101111 "0" ""
  ]

In [17]:
pPrint $ delete "4" ls

Many
  1111011111111111
  [ Leaf 00011111011101101010110111100000 "27" ""
  , Leaf 00011111011101101010110111100001 "26" ""
  , Leaf 00100000011101101010111101010010 "18" ""
  , Leaf 00100000011101101010111101010011 "19" ""
  , Leaf 00011111011101101010110111100100 "23" ""
  , Leaf 00011111011101101010110111100101 "22" ""
  , Leaf 00000101000011000101110100100110 "9" ""
  , Leaf 00000101000011000101110100100111 "8" ""
  , Leaf 00000101000011000101110100101000 "7" ""
  , Leaf 00000101000011000101110100101001 "6" ""
  , Leaf 00000101000011000101110100101010 "5" ""
  , Leaf 00000101000011000101110100101100 "3" ""
  , Leaf 00000101000011000101110100101101 "2" ""
  , Leaf 00000101000011000101110100101110 "1" ""
  , Leaf 00000101000011000101110100101111 "0" ""
  ]

In [18]:
newtype Colliding = Colliding Int
    deriving (Eq, Show)

instance Hashable Colliding where
    hash value = 0

newtype CollidingHalf = CollidingHalf Int
    deriving (Eq, Show)

instance Hashable CollidingHalf where
    hash (CollidingHalf i) = fromIntegral $ i `rem` 2

In [19]:
ls = fromList $ map (\i -> (Colliding i, "")) [0,1]

pPrint ls
pPrint $ delete (Colliding 0) ls

Coll
  00000000000000000000000000000000
  [ ( Colliding 0 , "" ) , ( Colliding 1 , "" ) ]

Leaf 00000000000000000000000000000000 (Colliding 1) ""

In [21]:
ls = fromList $ map (\i -> (CollidingHalf i, "")) [0..3]

pPrint ls
pPrint $ delete (CollidingHalf 1) $ delete (CollidingHalf 0) ls

Many
  0000000000000011
  [ Coll
      00000000000000000000000000000000
      [ ( CollidingHalf 0 , "" ) , ( CollidingHalf 2 , "" ) ]
  , Coll
      00000000000000000000000000000001
      [ ( CollidingHalf 1 , "" ) , ( CollidingHalf 3 , "" ) ]
  ]

Many
  0000000000000011
  [ Leaf 00000000000000000000000000000000 (CollidingHalf 2) ""
  , Leaf 00000000000000000000000000000001 (CollidingHalf 3) ""
  ]