Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement checksum computation using DynASM
- Loading branch information
Showing
1 changed file
with
159 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
module(..., package.seeall) | ||
|
||
local dasm = require("dasm") | ||
local lib = require("core.lib") | ||
local ffi = require('ffi') | ||
local C = ffi.C | ||
|
||
-- DynASM prelude. | ||
|
||
debug = false | ||
|
||
|.arch x64 | ||
|.actionlist actions | ||
|
||
__anchor = {} | ||
mcode = {} | ||
size = 0 | ||
|
||
function assemble (name, prototype, generator) | ||
local Dst = dasm.new(actions) | ||
generator(Dst) | ||
local mcode, size = Dst:build() | ||
table.insert(__anchor, mcode) | ||
if debug then | ||
print("mcode dump: "..name) | ||
dasm.dump(mcode, size) | ||
end | ||
return ffi.cast(prototype, mcode) | ||
end | ||
|
||
local function gen_checksum () | ||
return function (Dst) | ||
-- Prologue. | ||
| push rbp | ||
| mov rbp, rsp | ||
-- Accumulative sum. | ||
| xor rax, rax -- Clear out rax. Stores accumulated sum. | ||
| xor r9, r9 -- Clear out r9. Stores value of array. | ||
| xor r8, r8 -- Clear out r8. Stores array index. | ||
| mov rcx, rsi -- Rsi (2nd argument; size). Assign rsi to rcx. | ||
| 1: | ||
| cmp rcx, 8 -- If index is less than 8. | ||
| jl >2 -- Jump to branch '2'. | ||
| mov r9, [rdi + r8] -- Fetch 64-bit from data + r8 into r9. | ||
| add rax, r9 -- Sum acc with r9. | ||
| adc rax, 0 -- Sum carry-bit into acc. | ||
| sub rcx, 8 -- Decrease index by 8. | ||
| add r8, 8 -- Next 64-bit. | ||
| jmp <1 -- Go to beginning of loop. | ||
| 2: | ||
| cmp rcx, 4 -- If index is less than 4. | ||
| jl >3 -- Jump to branch '3'. | ||
| mov r9d, dword [rdi + r8] -- Fetch 32-bit from data + r8 into r9d. | ||
| add rax, r9 -- Sum acc with r9. Accumulate carry. | ||
| sub rcx, 4 -- Decrease index by 4. | ||
| add r8, 4 -- Next 32-bit. | ||
| jmp <2 -- Go to beginning of loop. | ||
| 3: | ||
| cmp rcx, 2 -- If index is less than 2. | ||
| jl >4 -- Jump to branch '4'. | ||
| movzx r9, word [rdi + r8] -- Fetch 16-bit from data + r8 into r9. | ||
| add rax, r9 -- Sum acc with r9. Accumulate carry. | ||
| sub rcx, 2 -- Decrease index by 2. | ||
| add r8, 2 -- Next 16-bit. | ||
| jmp <3 -- Go to beginning of loop. | ||
| 4: | ||
| cmp rcx, 1 -- If index is less than 1. | ||
| jl >5 -- Jump to branch '5'. | ||
| movzx r9, byte [rdi + r8] -- Fetch 8-bit from data + r8 into r9. | ||
| add rax, r9 -- Sum acc with r9. Accumulate carry. | ||
-- Fold 64-bit into 16-bit. | ||
| 5: | ||
| mov r9, rax -- Assign acc to r9. | ||
| shr r9, 32 -- Shift r9 32-bit. Stores higher part of acc. | ||
| and rax, 0x00000000ffffffff -- Clear out higher-part of rax. Stores lower part of acc. | ||
| add eax, r9d -- 32-bit sum of acc and r9. | ||
| adc eax, 0 -- Sum carry to acc. | ||
| mov r9d, eax -- Repeat for 16-bit. | ||
| shr r9d, 16 | ||
| and eax, 0x0000ffff | ||
| add ax, r9w | ||
| adc ax, 0 | ||
-- One's complement. | ||
| not rax -- One-complement of rax. | ||
| and rax, 0xffff -- Clear out higher part of rax. | ||
-- Epilogue. | ||
| 6: | ||
| mov rsp, rbp | ||
| pop rbp | ||
-- Return. | ||
| ret | ||
end | ||
end | ||
|
||
local newchecksum = assemble("newchecksum", "uint32_t(*)(uint8_t*, uint32_t)", gen_checksum()) | ||
|
||
function selftest () | ||
require("lib.checksum_h") | ||
local function create_packet (size) | ||
local pkt = { | ||
data = ffi.new("uint8_t[?]", size), | ||
length = size | ||
} | ||
for i=0,size-1 do | ||
pkt.data[i] = math.random(255) | ||
end | ||
return pkt | ||
end | ||
local function benchmark (fn, times) | ||
local now = os.clock() | ||
local temp | ||
for i=1,times do | ||
temp = fn() | ||
end | ||
local ret = {os.clock() - now, temp} | ||
return ret[1] | ||
end | ||
local function hex (num) | ||
return ("0x%.2x"):format(num) | ||
end | ||
local ntohs = lib.ntohs | ||
print("selftest: newchecksum") | ||
|
||
local size = 44 | ||
print("14.4M; "..size.." bytes") | ||
local pkt = create_packet(size) | ||
local times = 14.4*10^6 | ||
-- Verify checksum is correct. | ||
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length)))) | ||
-- Benchmark for different architectures. | ||
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times)) | ||
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times)) | ||
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times)) | ||
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times)) | ||
|
||
size = 550 | ||
print("2M; "..size.." bytes") | ||
local pkt = create_packet(size) | ||
local times = 2*10^6 | ||
-- Verify checksum is correct. | ||
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length)))) | ||
-- Benchmark for different architectures. | ||
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times)) | ||
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times)) | ||
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times)) | ||
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times)) | ||
|
||
size = 1500 | ||
print("1M; "..size.." bytes") | ||
local pkt = create_packet(size) | ||
local times = 1*10^6 | ||
-- Verify checksum is correct. | ||
assert(hex(C.cksum_generic(pkt.data, pkt.length, 0)) == hex(ntohs(newchecksum(pkt.data, pkt.length)))) | ||
-- Benchmark for different architectures. | ||
print("Gen: ", benchmark(function() return C.cksum_generic(pkt.data, pkt.length, 0) end, times)) | ||
print("SSE2: ", benchmark(function() return C.cksum_sse2(pkt.data, pkt.length, 0) end, times)) | ||
print("AVX2: ", benchmark(function() return C.cksum_avx2(pkt.data, pkt.length, 0) end, times)) | ||
print("New: ", benchmark(function() return newchecksum(pkt.data, pkt.length) end, times)) | ||
end |