In [None]:
{-# LANGUAGE ScopedTypeVariables #-}

import Data.Hashable
import qualified Data.HashMap.Strict as HM

import qualified Data.BitVector.LittleEndian as BV (rank, select)
import Data.BitVector.LittleEndian hiding (rank, select)
import Data.Bits

import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

import Control.Monad.ST
import Data.Traversable
import Data.Foldable
import Data.STRef
import Data.Maybe (fromJust)

In [None]:
rank :: BitVector -> Word -> Word
rank bv w = BV.rank bv (w + 1)

select :: BitVector -> Word -> Maybe Word
select bv w = BV.select bv (w - 1)

In [None]:
getIndices :: BitVector -> V.Vector Word
getIndices bv = let
    count = popCount bv
    range = V.enumFromN 1 count
    Just is = traverse (select bv . fromIntegral) range
    in is

-- Takes a source vector and a vector of indices and
-- copies the values at those indices to a new vector
extract :: V.Vector a -> V.Vector Word -> V.Vector a
extract vector = V.map ((vector V.!) . fromIntegral)

pluck :: (Hashable k, Eq k) => HM.HashMap k v -> V.Vector k -> HM.HashMap k v
pluck = foldr HM.delete

In [None]:
step :: (Hashable k, Eq k) => HM.HashMap k v -> Int -> Double -> (BitVector, V.Vector v, HM.HashMap k v)
step hashmap level gamma = runST $ do
    let vectorSize = floor $ fromIntegral (HM.size hashmap) * gamma
    hashVector <- MV.replicate vectorSize False
    collisionVector <- MV.replicate vectorSize False
    keysVector <- MV.replicate vectorSize (undefined :: k)
    for_ (HM.keys hashmap) $ \key -> do
        let position = hashWithSalt level key `mod` vectorSize
        present <- MV.read hashVector position
        collision <- MV.read collisionVector position
        case (present, collision) of
            (False, False) -> do
                MV.write hashVector position True
                MV.write keysVector position key
            (True, False) -> do
                MV.write hashVector position False
                MV.write collisionVector position True
                -- MV.write keysVector position undefined
            (False, True) -> pure ()
            (True, True) -> error "this should never happen"
    bitVector <- fromBits <$> V.freeze hashVector
    finalKeys <- V.freeze keysVector
    let uniqueKeys = extract finalKeys (getIndices bitVector)
    let valuesVector = V.map (hashmap HM.!) uniqueKeys
    let leftover = pluck hashmap uniqueKeys
    pure (bitVector, valuesVector, leftover)

finalise :: (Hashable k, Eq k) => HM.HashMap k v -> (HM.HashMap k Int, V.Vector v)
finalise hashmap = let
    pairs = HM.toList hashmap
    valuesVector = V.fromList $ map snd pairs
    indices = HM.fromList $ zipWith (\(k,_) i -> (k,i)) pairs [1..]
    in (indices, valuesVector)

In [None]:
data MinimalPerfectHash k
    = MinimalPerfectHash
    { mphBitVectors :: V.Vector BitVector
    , mphLeftovers :: Maybe (HM.HashMap k Int)
    } deriving (Eq, Show)

generate :: forall k v. (Hashable k, Eq k) => HM.HashMap k v -> Int -> Double -> (MinimalPerfectHash k, V.Vector v)
generate hashmap maxLevel gamma = go hashmap 0 [] []
    where
        go :: HM.HashMap k v -> Int -> [BitVector] -> [V.Vector v] -> (MinimalPerfectHash k, V.Vector v)
        go hashmap currentLevel accBitVector accValues
            | HM.null hashmap = let
                mphBVs = V.fromList accBitVector
                values = V.concat accValues
                in (MinimalPerfectHash mphBVs Nothing, values)
            | currentLevel >= maxLevel = let
                (leftoverIndices, leftoverValues) = finalise hashmap
                mphBVs = V.fromList accBitVector
                values = V.concat (accValues ++ [leftoverValues])
                in (MinimalPerfectHash mphBVs (Just leftoverIndices), values)
            | otherwise = let
                (bitVector, values, remaining) = step hashmap currentLevel gamma
                accBitVector' = accBitVector ++ [bitVector]
                accValues' = accValues ++ [values]
                in go remaining (currentLevel + 1) accBitVector' accValues'

In [None]:
query :: forall k v. (Hashable k, Eq k) => MinimalPerfectHash k -> V.Vector v -> k -> v
query mph values key = go 0 0
    where
        go :: Int -> Int -> v
        go currentLevel currentRank
            | currentLevel >= V.length (mphBitVectors mph) = case mphLeftovers mph of
                Just leftovers -> let
                    ranked = leftovers HM.! key
                    in values V.! (currentRank + ranked - 1)
                Nothing -> error "key not in table"
            | otherwise = let
                bitVector = mphBitVectors mph V.! currentLevel
                hashed = hashWithSalt currentLevel key `mod` fromIntegral (dimension bitVector)
                present = testBit bitVector hashed
                in if present
                    then let
                        ranked = rank bitVector (fromIntegral hashed)
                        in values V.! (currentRank + fromIntegral ranked - 1)
                    else
                        go (currentLevel+1) (currentRank + popCount bitVector)

In [None]:
example = HM.fromList [("f", "foo"), ("b", "bar"), ("q", "quux")]
(mph, values) = generate example 2 1

query mph values "b"