Skip to content

Commit

Permalink
Implement checksum computation using DynASM
Browse files Browse the repository at this point in the history
  • Loading branch information
dpino committed Aug 23, 2018
1 parent 732eae9 commit f0c6011
Showing 1 changed file with 159 additions and 0 deletions.
159 changes: 159 additions & 0 deletions src/lib/newchecksum.dasl
@@ -0,0 +1,159 @@
module(..., package.seeall)

local dasm = require("dasm")
local lib = require("core.lib")
local ffi = require('ffi')
local C = ffi.C

-- DynASM prelude.

debug = false

|.arch x64
|.actionlist actions

__anchor = {}
mcode = {}
size = 0

function assemble (name, prototype, generator)
local Dst = dasm.new(actions)
generator(Dst)
local mcode, size = Dst:build()
table.insert(__anchor, mcode)
if debug then
print("mcode dump: "..name)
dasm.dump(mcode, size)
end
return ffi.cast(prototype, mcode)
end

local function gen_checksum ()
return function (Dst)
-- Prologue.
| push rbp
| mov rbp, rsp
-- Accumulative sum.
| xor rax, rax -- Clear out rax. Stores accumulated sum.
| xor r9, r9 -- Clear out r9. Stores value of array.
| xor r8, r8 -- Clear out r8. Stores array index.
| mov rcx, rsi -- Rsi (2nd argument; size). Assign rsi to rcx.
| 1:
| cmp rcx, 8 -- If index is less than 8.
| jl >2 -- Jump to branch '2'.
| mov r9, [rdi + r8] -- Fetch 64-bit from data + r8 into r9.
| add rax, r9 -- Sum acc with r9.
| adc rax, 0 -- Sum carry-bit into acc.
| sub rcx, 8 -- Decrease index by 8.
| add r8, 8 -- Next 64-bit.
| jmp <1 -- Go to beginning of loop.
| 2:
| cmp rcx, 4 -- If index is less than 4.
| jl >3 -- Jump to branch '3'.
| mov r9d, dword [rdi + r8] -- Fetch 32-bit from data + r8 into r9d.
| add rax, r9 -- Sum acc with r9. Accumulate carry.
| sub rcx, 4 -- Decrease index by 4.
| add r8, 4 -- Next 32-bit.
| jmp <2 -- Go to beginning of loop.
| 3:
| cmp rcx, 2 -- If index is less than 2.
| jl >4 -- Jump to branch '4'.
| movzx r9, word [rdi + r8] -- Fetch 16-bit from data + r8 into r9.
| add rax, r9 -- Sum acc with r9. Accumulate carry.
| sub rcx, 2 -- Decrease index by 2.
| add r8, 2 -- Next 16-bit.
| jmp <3 -- Go to beginning of loop.
| 4:
| cmp rcx, 1 -- If index is less than 1.
| jl >5 -- Jump to branch '5'.
| movzx r9, byte [rdi + r8] -- Fetch 8-bit from data + r8 into r9.
| add rax, r9 -- Sum acc with r9. Accumulate carry.
-- Fold 64-bit into 16-bit.
| 5:
| mov r9, rax -- Assign acc to r9.
| shr r9, 32 -- Shift r9 32-bit. Stores higher part of acc.
| and rax, 0x00000000ffffffff -- Clear out higher-part of rax. Stores lower part of acc.
| add eax, r9d -- 32-bit sum of acc and r9.
| adc eax, 0 -- Sum carry to acc.
| mov r9d, eax -- Repeat for 16-bit.
| shr r9d, 16
| and eax, 0x0000ffff
| add ax, r9w
| adc ax, 0
-- One's complement.
| not rax -- One-complement of rax.
| and rax, 0xffff -- Clear out higher part of rax.
-- Epilogue.
| 6:
| mov rsp, rbp
| pop rbp
-- Return.
| ret
end
end

local newchecksum = assemble("newchecksum", "uint32_t(*)(uint8_t*, uint32_t)", gen_checksum())

function selftest ()
require("lib.checksum_h")
local function create_packet (size)
local pkt = {
data = ffi.new("uint8_t[?]", size),
length = size
}
for i=0,size-1 do
pkt.data[i] = math.random(255)
end
return pkt
end
local function benchmark (fn, times)
local now = os.clock()
local temp
for i=1,times do
temp = fn()
end
local ret = {os.clock() - now, temp}
return ret[1]
end
local function hex (num)
return ("0x%.2x"):format(num)
end
local ntohs = lib.ntohs
print("selftest: newchecksum")

local size = 44
print("14.4M; "..size.." bytes")
local pkt = create_packet(size)
local times = 14.4*10^6
-- Verify checksum is correct.
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length))))
-- Benchmark for different architectures.
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times))
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times))
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times))
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times))

size = 550
print("2M; "..size.." bytes")
local pkt = create_packet(size)
local times = 2*10^6
-- Verify checksum is correct.
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length))))
-- Benchmark for different architectures.
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times))
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times))
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times))
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times))

size = 1500
print("1M; "..size.." bytes")
local pkt = create_packet(size)
local times = 1*10^6
-- Verify checksum is correct.
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length))))
-- Benchmark for different architectures.
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times))
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times))
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times))
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times))
end

0 comments on commit f0c6011

Please sign in to comment.