Permalink
Fetching contributors…
Cannot retrieve contributors at this time
696 lines (657 sloc) 20.3 KB
-- -*- lua -*-
--
-- SipHash is a family of hash functions, parameterized by the number of
-- rounds that run per 8-byte input block and the number of rounds that
-- run at the end. SipHash was designed to be good for short inputs and
-- to resist "hash flooding" attacks, where users attack a hash table
-- with many inputs that map to the same area of the hash table.
--
-- We provide implementations of SipHash with a normal scalar DynASM
-- backend as well as parallel SSE and AVX2 backends.
--
-- Because we use SipHash as a hash function for fixed-sized inputs, we
-- can simplify processing of the tail word. This simplification is
-- enabled unless a true value for "as_specified" is passed.
--
-- This implementation was based on the reference implementation. See
-- https://131002.net/siphash/ for more details.
module(..., package.seeall)
local bit = require("bit")
local dasm = require("dasm")
local ffi = require("ffi")
local S = require("syscall")
local lib = require("core.lib")
local debug = false
local cpuinfo = lib.readfile("/proc/cpuinfo", "*a")
assert(cpuinfo, "failed to read /proc/cpuinfo for hardware check")
local have_avx2 = cpuinfo:match("avx2")
local have_sse2 = cpuinfo:match("sse2")
|.arch x64
|.actionlist actions
__anchor = {}
local function finish (name, prototype, Dst)
local mcode, size = Dst:build()
table.insert(__anchor, mcode)
if debug then
print("mcode dump: "..name)
dasm.dump(mcode, size)
end
return ffi.cast(prototype, mcode)
end
function random_sip_hash_key()
return lib.random_bytes(16)
end
function sip_hash_key_from_seed(seed)
local res = ffi.new('uint8_t[16]')
ffi.cast('uint32_t*', res)[0] = assert(tonumber(seed))
return res
end
local function load_initial_state(key)
-- Initial state constants.
local state = ffi.new('uint64_t[4]',
{ 0x736f6d6570736575ULL,
0x646f72616e646f6dULL,
0x6c7967656e657261ULL,
0x7465646279746573ULL })
-- Mix key into state constants.
key = ffi.cast('uint64_t*', key or random_sip_hash_key())
state[0] = bit.bxor(state[0], key[0])
state[1] = bit.bxor(state[1], key[1])
state[2] = bit.bxor(state[2], key[0])
state[3] = bit.bxor(state[3], key[1])
return state
end
-- Scalar x86-64 backend for the SipHash implementation.
local fn_ptr_t = ffi.typeof("uint32_t (*)(void *)")
local function X86_64()
local asm = {}
local Dst
local allregs = {"rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi",
"r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15"}
local regnums = {}
for i,reg in ipairs(allregs) do regnums[reg] = i-1 end
-- Map value indexes from the program to registers. Don't allocate
-- to rdi; it's the input data.
local scratchregs = {"r8", "r9", "r10", "r11", "rax", "rcx", "rdx", "rsi"}
local function regnum(n) return assert(regnums[assert(scratchregs[n+1])]) end
function asm.init(key)
local initial_state = load_initial_state(key)
table.insert(__anchor, initial_state)
| mov64 Rq(regnum(0)), initial_state[0]
| mov64 Rq(regnum(1)), initial_state[1]
| mov64 Rq(regnum(2)), initial_state[2]
| mov64 Rq(regnum(3)), initial_state[3]
end
function asm.add(dst, other)
| add Rq(regnum(dst)), Rq(regnum(other))
end
function asm.shl(reg, bits)
| shl Rq(regnum(reg)), bits
end
function asm.ior(dst, other)
| or Rq(regnum(dst)), Rq(regnum(other))
end
function asm.xor(dst, other)
| xor Rq(regnum(dst)), Rq(regnum(other))
end
function asm.rol(reg, bits)
| rol Rq(regnum(reg)), bits
end
function asm.copy_argument(dst, arg)
-- For use by make_sip_hash_u64.
local arg = assert(({"rdi", "rsi"})[arg])
if regnum(dst) ~= regnums[arg] then
| mov Rq(regnum(dst)), Rq(regnums[arg])
end
end
function asm.load_u64_and_advance(dst)
| mov Rq(regnum(dst)), [rdi]
| add rdi, 8
end
function asm.load_u32_and_advance(dst)
-- Automatically zero-extending.
| mov Rd(regnum(dst)), [rdi]
| add rdi, 4
end
function asm.load_u16_and_advance(dst)
| movzx Rq(regnum(dst)), word [rdi]
| add rdi, 2
end
function asm.load_u8_and_advance(dst)
| movzx Rq(regnum(dst)), byte [rdi]
| add rdi, 1
end
function asm.load_imm8(dst, imm)
| xor Rq(regnum(dst)), Rq(regnum(dst))
if imm ~= 0 then
| mov Rb(regnum(dst)), imm
end
end
function asm.ret(reg)
if regnum(reg) ~= regnums["rax"] then
| mov rax, Rq(regnum(reg))
end
| shr rax, 32
-- The ctable needs the hash to not be 0xFFFFFFFF. If we did the
-- lshift there instead of here, we'd have to compensate for that
-- weird thing where LuaJIT's bitops can produce negative numbers.
| shl eax, 1
| ret
end
function asm.finish(name, Dst)
return finish(name.."_x1",
fn_ptr_t,
Dst)
end
function asm:assemble(name, gen)
Dst = dasm.new(actions)
gen(self)
return self.finish(name, Dst)
end
return asm
end
-- Parallel SSE backend for the SipHash implementation that can hash two
-- inputs at once.
local function SSE(stride)
local asm = {}
local Dst
function asm.init(key)
local initial_state = load_initial_state(key)
table.insert(__anchor, initial_state)
| mov64 rax, initial_state
| movddup xmm0, [rax]
| movddup xmm1, [rax+8]
| movddup xmm2, [rax+16]
| movddup xmm3, [rax+24]
end
function asm.add(dst, other)
| paddq xmm(dst), xmm(other)
end
function asm.shl(reg, bits)
| psllq xmm(reg), bits
end
function asm.ior(dst, other)
| por xmm(dst), xmm(other)
end
function asm.xor(dst, other)
| pxor xmm(dst), xmm(other)
end
function asm.rol(reg, bits)
| movupd xmm5, xmm(reg)
| psllq xmm5, bits
| psrlq xmm(reg), (64 - bits)
| por xmm(reg), xmm5
end
function asm.load_u64_and_advance(dst)
| movlpd xmm(dst), qword [rdi]
| movhpd xmm(dst), qword [rdi+stride]
| add rdi, 8
end
function asm.load_u32_and_advance(dst)
| movss xmm(dst), dword [rdi]
| pinsrd xmm(dst), dword [rdi + stride], 2
| add rdi, 4
end
function asm.load_u16_and_advance(dst)
| pxor xmm(dst), xmm(dst)
| pinsrw xmm(dst), word [rdi], 0
| pinsrw xmm(dst), word [rdi + stride], 4
| add rdi, 2
end
function asm.load_u8_and_advance(dst)
| pxor xmm(dst), xmm(dst)
| pinsrb xmm(dst), byte [rdi], 0
| pinsrb xmm(dst), byte [rdi + stride], 8
| add rdi, 1
end
function asm.load_imm8(dst, imm)
| xor rax, rax
if imm ~= 0 then
| mov al, imm
end
| pinsrq xmm(dst), rax, 0
| pinsrq xmm(dst), rax, 1
end
function asm.ret(reg)
-- Extract high 31 bits from each 64-bit value to uint32_t[2] in
-- rsi.
| pslld xmm(reg), 1
| pextrd dword [rsi], xmm(reg), 1
| pextrd dword [rsi+4], xmm(reg), 3
| ret
end
function asm:assemble(name, gen)
Dst = dasm.new(actions)
gen(self)
return finish(name.."_x2",
ffi.typeof("void (*)(uint8_t *, uint32_t *)"),
Dst)
end
return asm
end
-- Parallel AVX2 backend for the SipHash implementation that can hash
-- four inputs at once.
local function AVX2(stride)
local asm = {}
local Dst
function asm.init(key)
local initial_state = load_initial_state(key)
table.insert(__anchor, initial_state)
| vzeroupper
| mov64 rax, initial_state
| vbroadcastsd ymm0, qword [rax]
| vbroadcastsd ymm1, qword [rax+8]
| vbroadcastsd ymm2, qword [rax+16]
| vbroadcastsd ymm3, qword [rax+24]
end
function asm.add(dst, other)
| vpaddq ymm(dst), ymm(dst), ymm(other)
end
function asm.shl(reg, bits)
| vpsllq ymm(reg), ymm(reg), bits
end
function asm.ior(dst, other)
| vpor ymm(dst), ymm(dst), ymm(other)
end
function asm.xor(dst, other)
| vpxor ymm(dst), ymm(dst), ymm(other)
end
function asm.rol(reg, bits)
| vpsllq ymm5, ymm(reg), bits
| vpsrlq ymm(reg), ymm(reg), (64 - bits)
| vpor ymm(reg), ymm(reg), ymm5
end
function asm.load_u64_and_advance(dst)
-- TODO: Use a parallel load here.
assert(dst > 5)
| vpinsrq xmm(dst), xmm(dst), qword [rdi], 0
| vpinsrq xmm(dst), xmm(dst), qword [rdi+stride*1], 1
| vpinsrq xmm5, xmm5, qword [rdi+stride*2], 0
| vpinsrq xmm5, xmm5, qword [rdi+stride*3], 1
| vinsertf128 ymm(dst), ymm(dst), xmm5, 1
| add rdi, 8
end
function asm.load_u32_and_advance(dst)
assert(dst > 5)
| mov eax, [rdi]
| mov edx, [rdi+stride]
| vpinsrq xmm(dst), xmm(dst), rax, 0
| vpinsrq xmm(dst), xmm(dst), rdx, 1
| mov eax, [rdi+stride*2]
| mov edx, [rdi+stride*3]
| vpinsrq xmm5, xmm5, rax, 0
| vpinsrq xmm5, xmm5, rdx, 1
| vinsertf128 ymm(dst), ymm(dst), xmm5, 1
| add rdi, 4
end
function asm.load_u16_and_advance(dst)
assert(dst > 5)
| movzx rax, word [rdi]
| movzx rdx, word [rdi+stride]
| vpinsrq xmm(dst), xmm(dst), rax, 0
| vpinsrq xmm(dst), xmm(dst), rdx, 1
| movzx rax, word [rdi+stride*2]
| movzx rdx, word [rdi+stride*3]
| vpinsrq xmm5, xmm5, rax, 0
| vpinsrq xmm5, xmm5, rdx, 1
| vinsertf128 ymm(dst), ymm(dst), xmm5, 1
| add rdi, 2
end
function asm.load_u8_and_advance(dst)
assert(dst > 5)
| movzx rax, byte [rdi]
| movzx rdx, byte [rdi+stride]
| vpinsrq xmm(dst), xmm(dst), rax, 0
| vpinsrq xmm(dst), xmm(dst), rdx, 1
| movzx rax, byte [rdi+stride*2]
| movzx rdx, byte [rdi+stride*3]
| vpinsrq xmm5, xmm5, rax, 0
| vpinsrq xmm5, xmm5, rdx, 1
| vinsertf128 ymm(dst), ymm(dst), xmm5, 1
| add rdi, 1
end
function asm.load_imm8(dst, imm)
| xor rax, rax
if imm ~= 0 then
| mov al, imm
end
| vpinsrq xmm(dst), xmm(dst), rax, 0
| vpinsrq xmm(dst), xmm(dst), rax, 1
| vinsertf128 ymm(dst), ymm(dst), xmm(dst), 1
end
local function imm8_control(a, b, c, d)
assert(bit.band(bit.bor(a,b,c,d), 3) == bit.bor(a,b,c,d))
return a + b*4 + c*4*4 + d*4*4*4
end
function asm.ret(reg)
-- Extract high 32 bits from each 64-bit value. For ridiculous
-- reasons, the way we have to do this is to read in a set of
-- indexes into the 32-bit words from memory, then permute, then
-- write out.
local control = ffi.new('uint32_t[8]', 1, 3, 5, 7, 0, 0, 0, 0)
table.insert(__anchor, control)
| mov64 rax, control
| vmovdqu ymm5, [rax]
| vpermd ymm(reg), ymm5, ymm(reg)
| vpslld xmm(reg), xmm(reg), 1
| vmovdqu oword [rsi], xmm(reg)
| vzeroupper
| ret
end
function asm:assemble(name, gen)
Dst = dasm.new(actions)
gen(self)
return finish(name.."_x4",
ffi.typeof("void (*)(uint8_t *, uint32_t *)"),
Dst)
end
return asm
end
-- Portable backend for the SipHash implementation to use as a
-- reference.
local function Simulator()
local asm = {}
local input, output
local state
function asm.init(key)
local initial_state = load_initial_state(key)
state = ffi.new('uint64_t[8]')
for i=0,3 do state[i] = initial_state[i] end
end
function asm.add(dst, other)
state[dst] = state[dst] + state[other]
end
function asm.shl(reg, bits)
state[reg] = bit.lshift(state[reg], bits)
end
function asm.ior(dst, other)
state[dst] = bit.bor(state[dst], state[other])
end
function asm.xor(dst, other)
state[dst] = bit.bxor(state[dst], state[other])
end
function asm.rol(reg, bits)
state[reg] = bit.rol(state[reg], bits)
end
function asm.load_u64_and_advance(dst)
state[dst] = ffi.cast('uint64_t*', input)[0]
input = input + 8
end
function asm.load_u32_and_advance(dst)
state[dst] = ffi.cast('uint32_t*', input)[0]
input = input + 4
end
function asm.load_u16_and_advance(dst)
state[dst] = ffi.cast('uint16_t*', input)[0]
input = input + 2
end
function asm.load_u8_and_advance(dst)
state[dst] = input[0]
input = input + 1
end
function asm.load_imm8(dst, imm)
state[dst] = imm
end
function asm.ret(reg)
output = tonumber(bit.rshift(state[reg], 32))
output = ffi.new('uint32_t[1]', bit.lshift(output, 1))[0]
end
function asm:assemble(name, gen)
return function(ptr)
input, output = ffi.cast('uint8_t*', ptr), nil
gen(self)
input = nil
return output
end
end
return asm
end
local sip_hash_config = {
size={required=true}, stride={}, key={default=false},
c={default=2}, d={default=4}, as_specified={default=false}, width={default=1}
}
local function make_sip_hash(assembler, opts)
function siphash(asm)
-- Arguments:
-- rdi: packed keys as pointer
-- rsi: for parallel implementations, an output uint32_t[4]
-- Working registers:
-- Registers 0-3 map to SipHash variables v0-v3
-- Registers 6 and 7 used as scratch.
local function sipround()
asm.add(0, 1)
asm.rol(1, 13)
asm.xor(1, 0)
asm.rol(0, 32)
asm.add(2, 3)
asm.rol(3, 16)
asm.xor(3, 2)
asm.add(0, 3)
asm.rol(3, 21)
asm.xor(3, 0)
asm.add(2, 1)
asm.rol(1, 17)
asm.xor(1, 2)
asm.rol(2, 32)
end
local function process(input)
asm.xor(3, input)
for i=1,opts.c do sipround() end
asm.xor(0, input)
end
-- Initialization phase.
asm.init(opts.key)
-- Compression phase.
for i=1,opts.size/8 do
asm.load_u64_and_advance(6)
process(6)
end
-- Load tail word and process it.
if opts.as_specified then
asm.load_imm8(6, bit.band(opts.size, 0xff))
asm.shl(6, 56)
for i=1,opts.size%8 do
asm.load_u8_and_advance(7)
if i > 1 then asm.shl(7, (i - 1) * 8) end
asm.ior(6, 7)
end
process(6)
elseif opts.size%8 ~= 0 then
-- Fixed-size simplification: no need to add in size byte, we
-- can use different byte orders if it's more convenient, and
-- we don't have to do anything at all if the size is a
-- multiple of 8.
if opts.size%8 >= 4 then
asm.load_u32_and_advance(6)
else
asm.xor(6, 6)
end
if opts.size%4 >= 2 then
asm.load_u16_and_advance(7)
asm.shl(6, 16)
asm.ior(6, 7)
end
if opts.size%2 ~= 0 then
asm.load_u8_and_advance(7)
asm.shl(6, 8)
asm.ior(6, 7)
end
process(6)
end
-- Finalization.
asm.load_imm8(6, 0xff)
asm.xor(2, 6)
for i=1,opts.d do sipround() end
asm.xor(0, 1)
asm.xor(2, 3)
asm.xor(0, 2)
asm.ret(0)
end
opts = lib.parse(opts, sip_hash_config)
if not opts.stride then opts.stride = opts.size end
local asm = assembler(opts.stride)
return asm:assemble("siphash_"..opts.c.."_"..opts.d, siphash)
end
-- A special implementation to hash immediate values; requires our
-- fixed-size simplification.
function make_u64_hash(opts)
local opts = lib.deepcopy(opts)
opts.size = 8
assert(not opts.as_specified)
local function ImmX86_64()
local asm = X86_64()
function asm.load_u64_and_advance(dst)
asm.copy_argument(dst, 1)
end
function asm.finish(name, Dst)
return finish(name.."_u64",
ffi.typeof("uint32_t (*)(uint64_t)"),
Dst)
end
asm.load_u32_and_advance = error
asm.load_u16_and_advance = error
asm.load_u8_and_advance = error
return asm
end
return make_sip_hash(ImmX86_64, opts)
end
local function make_hash1(opts)
return make_sip_hash(X86_64, opts)
end
local function make_hash2(opts)
if have_sse2 then return make_sip_hash(SSE, opts) end
local hash = make_hash1(opts)
local stride = opts.stride or opts.size
return function(ptr, result)
ptr = ffi.cast('uint8_t*', ptr)
result = ffi.cast('uint32_t*', result)
result[0] = hash(ptr)
result[1] = hash(ptr + stride)
end
end
local function make_hash4(opts)
if have_avx2 then return make_sip_hash(AVX2, opts) end
local hash = make_hash2(opts)
local stride = opts.stride or opts.size
return function(ptr, result)
ptr = ffi.cast('uint8_t*', ptr)
result = ffi.cast('uint32_t*', result)
hash(ptr, result)
hash(ptr + stride*2, result + 2)
end
end
local function make_reference_sip_hash(opts)
return make_sip_hash(Simulator, opts)
end
make_hash = make_hash1
function make_multi_hash(opts)
local hash1,hash2,hash4 = make_hash1(opts),make_hash2(opts),make_hash4(opts)
local width = opts.width or 1
local stride = (opts.stride or opts.size)
if width == 1 then
return function(input, output) output[0] = hash1(input) end
end
if width == 2 then return hash2 end
if width == 4 then return hash4 end
if width % 4 == 0 then
return function(input, output)
input = ffi.cast('uint8_t*', input)
output = ffi.cast('uint32_t*', output)
for i=1,width,4 do
hash4(input, output)
input = input + stride*4
output = output + 4
end
end
end
return function(input, output)
input = ffi.cast('uint8_t*', input)
output = ffi.cast('uint32_t*', output)
for i=1,bit.rshift(width, 2) do
hash4(input, output)
input = input + stride*4
output = output + 4
end
if bit.band(width, 2) ~= 0 then
hash2(input, output)
input = input + stride*2
output = output + 2
end
if bit.band(width, 1) ~= 0 then
output[0] = hash1(input)
end
end
end
function selftest()
local function union(a, b)
local res = lib.deepcopy(a)
for k,v in pairs(b) do res[k] = v end
return res
end
local function test(opts)
local size = opts.size
local reference_hash = make_reference_sip_hash(opts)
local hash1 = make_hash(opts)
local test_input = lib.random_bytes(size)
local ref_result = reference_hash(test_input)
local function check(result, what)
if ref_result ~= result then
for k,v in pairs(opts) do print(k,v) end
error('got '..what..'='..result..'; expected '..ref_result)
end
end
check(hash1(test_input), 'scalar')
local zero_hash = reference_hash(ffi.new("uint8_t[?]", size))
for _,width in ipairs({1,2,4,8}) do
local mbuf = ffi.new("uint8_t[?]", size*width)
local result = ffi.new("uint32_t[?]", width)
local mhash = make_multi_hash(union(opts, {width=width}))
for i=0,width-1 do
ffi.fill(mbuf, size*width)
ffi.copy(mbuf+i*size, test_input, size)
mhash(mbuf, result)
for elt=0,width-1 do
if elt == i then
check(result[i], string.format('x%d[%d]', width, elt))
elseif result[elt] ~= zero_hash then
error('got hash(0)='..result[elt]..'; expected '..zero_hash)
end
end
end
end
end
io.stdout:write("selftest: ")
io.stdout:flush()
local opts = { key=random_sip_hash_key() }
for size=0,32 do
opts.size = size
io.stdout:write(".")
io.stdout:flush()
for c=0,2 do
opts.c = c
for d=0,4 do
opts.d = d
for _, as_specified in ipairs({true, false}) do
opts.as_specified = as_specified
test(opts)
end
end
end
end
-- We need a hash function for immediates as well; test that here.
opts.size, opts.as_specified = 8, false
local val = ffi.new('uint64_t[1]', 0x12345678ULL)
for c=0,2 do
opts.c = c
for d=0,4 do
io.stdout:write(".")
io.stdout:flush()
opts.d = d
val[0] = val[0] * 257
local hash_mem = make_hash(opts)
local hash_u64 = make_u64_hash(opts)
assert(hash_mem(val) == hash_u64(val[0]))
end
end
print("\nselftest ok")
end