diff --git a/ascii/ascii.go b/ascii/ascii.go index 17a1a60..9770ac2 100644 --- a/ascii/ascii.go +++ b/ascii/ascii.go @@ -54,3 +54,22 @@ func hasMore32(x, n uint32) bool { func unsafeString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } + +var lower = [256]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, + 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, + 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, +} diff --git a/ascii/ascii_test.go b/ascii/ascii_test.go index cb0a239..396ff3b 100644 --- a/ascii/ascii_test.go +++ b/ascii/ascii_test.go @@ -1,6 +1,8 @@ package ascii import ( + "bytes" + "fmt" "strings" "testing" "unicode/utf8" @@ -160,6 +162,37 @@ func TestHasSuffixFold(t *testing.T) { } } +func TestEqualFoldASCII(t *testing.T) { + pairs := [...][2]byte{ + {0, ' '}, + {'@', '`'}, + {'[', '{'}, + {'_', 127}, + } + + for _, pair := range pairs { + t.Run(fmt.Sprintf("0x%02x=0x%02x", pair[0], pair[1]), func(t *testing.T) { + for i := 1; i <= 256; i++ { + a := bytes.Repeat([]byte{'x'}, i) + b := bytes.Repeat([]byte{'X'}, i) + + if !EqualFold(a, b) { + t.Errorf("%q does not match %q", a, b) + break + } + + a[0] = pair[0] + b[0] = pair[1] + + if EqualFold(a, b) { + t.Errorf("%q matches %q", a, b) + break + } + } + }) + } +} + func TestEqualFold(t *testing.T) { // Only test valid UTF-8 otherwise ToUpper/ToLower will convert invalid // characters to UTF-8 placeholders, which breaks the case-insensitive diff --git a/ascii/equal_fold.go b/ascii/equal_fold.go index c1dd49b..ace1811 100644 --- a/ascii/equal_fold.go +++ b/ascii/equal_fold.go @@ -1,9 +1,7 @@ //go:generate go run equal_fold_asm.go -out equal_fold_amd64.s -stubs equal_fold_amd64.go package ascii -import ( - "unsafe" -) +import "unsafe" // EqualFold is a version of bytes.EqualFold designed to work on ASCII input // instead of UTF-8. @@ -35,13 +33,15 @@ func EqualFoldString(a, b string) bool { n := uintptr(len(a)) p := *(*unsafe.Pointer)(unsafe.Pointer(&a)) q := *(*unsafe.Pointer)(unsafe.Pointer(&b)) + c := byte(0) + // Pre-check to avoid the other tests that would all evaluate to false. // For very small strings, this helps reduce the processing overhead. if n >= 8 { // If there is more than 32 bytes to copy, use the AVX optimized version, // otherwise the overhead of the function call tends to be greater than // looping 2 or 3 times over 8 bytes. - if n > 32 && asm.equalFoldAVX2 != nil { + if n >= 32 && asm.equalFoldAVX2 != nil { if asm.equalFoldAVX2((*byte)(p), (*byte)(q), n) == 0 { return false } @@ -50,11 +50,17 @@ func EqualFoldString(a, b string) bool { q = unsafe.Pointer(uintptr(q) + k) n -= k } - - for n > 8 { - const mask = 0xDFDFDFDFDFDFDFDF - - if (*(*uint64)(p) & mask) != (*(*uint64)(q) & mask) { + for n >= 8 { + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 0))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 0))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 1))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 1))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 2))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 2))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 3))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 3))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 4))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 4))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 5))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 5))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 6))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 6))] + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 7))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 7))] + + if c != 0 { return false } @@ -62,39 +68,33 @@ func EqualFoldString(a, b string) bool { q = unsafe.Pointer(uintptr(q) + 8) n -= 8 } - - if n == 8 { - const mask = 0xDFDFDFDFDFDFDFDF - return (*(*uint64)(p) & mask) == (*(*uint64)(q) & mask) - } - } - - if n > 4 { - const mask = 0xDFDFDFDF - - if (*(*uint32)(p) & mask) != (*(*uint32)(q) & mask) { - return false - } - - p = unsafe.Pointer(uintptr(p) + 4) - q = unsafe.Pointer(uintptr(q) + 4) - n -= 4 } switch n { + case 7: + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 6))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 6))] + fallthrough + case 6: + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 5))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 5))] + fallthrough + case 5: + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 4))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 4))] + fallthrough case 4: - return (*(*uint32)(p) & 0xDFDFDFDF) == (*(*uint32)(q) & 0xDFDFDFDF) + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 3))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 3))] + fallthrough case 3: - x := uint32(*(*uint16)(p)) | uint32(*(*uint8)(unsafe.Pointer(uintptr(p) + 2)))<<16 - y := uint32(*(*uint16)(q)) | uint32(*(*uint8)(unsafe.Pointer(uintptr(q) + 2)))<<16 - return (x & 0xDFDFDF) == (y & 0xDFDFDF) + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 2))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 2))] + fallthrough case 2: - return (*(*uint16)(p) & 0xDFDF) == (*(*uint16)(q) & 0xDFDF) + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 1))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 1))] + fallthrough case 1: - return (*(*uint8)(p) & 0xDF) == (*(*uint8)(q) & 0xDF) - default: - return true + c |= lower[*(*uint8)(unsafe.Pointer(uintptr(p) + 0))] ^ lower[*(*uint8)(unsafe.Pointer(uintptr(q) + 0))] } + + return c == 0 + } func HasPrefixFoldString(s, prefix string) bool { diff --git a/ascii/equal_fold_amd64.s b/ascii/equal_fold_amd64.s index 30f0182..b100c3e 100644 --- a/ascii/equal_fold_amd64.s +++ b/ascii/equal_fold_amd64.s @@ -5,60 +5,99 @@ // func equalFoldAVX2(a *byte, b *byte, n uintptr) int // Requires: AVX, AVX2, SSE4.1 TEXT ·equalFoldAVX2(SB), NOSPLIT, $0-32 - MOVQ a+0(FP), AX - MOVQ b+8(FP), CX - MOVQ n+16(FP), DX - SHRQ $0x04, DX - MOVQ $0x0000000000000000, BX - MOVQ $0xdfdfdfdfdfdfdfdf, BP - PINSRQ $0x00, BP, X0 - PINSRQ $0x01, BP, X0 - VPBROADCASTQ X0, Y1 + MOVQ a+0(FP), CX + MOVQ b+8(FP), DX + MOVQ n+16(FP), BX + XORQ AX, AX + SHRQ $0x04, BX + XORQ SI, SI + MOVB $0x20, DI + PINSRB $0x00, DI, X6 + VPBROADCASTB X6, Y6 + MOVB $0x1f, DI + PINSRB $0x00, DI, X7 + VPBROADCASTB X7, Y7 + MOVB $0x9a, DI + PINSRB $0x00, DI, X8 + VPBROADCASTB X8, Y8 + MOVB $0x01, DI + PINSRB $0x00, DI, X9 + VPBROADCASTB X9, Y9 loop64: - CMPQ DX, $0x04 - JL loop32 - VPAND (AX), Y1, Y2 - VPAND (CX), Y1, Y3 - VPCMPEQB Y3, Y2, Y2 - VPAND 32(AX), Y1, Y3 - VPAND 32(CX), Y1, Y4 + CMPQ BX, $0x04 + JB cmp32 + VMOVDQU (CX)(AX*1), Y0 + VMOVDQU 32(CX)(AX*1), Y3 + VMOVDQU (DX)(AX*1), Y1 + VMOVDQU 32(DX)(AX*1), Y4 + VXORPD Y0, Y1, Y1 + VPCMPEQB Y6, Y1, Y2 + VORPD Y6, Y0, Y0 + VPADDB Y7, Y0, Y0 + VPCMPGTB Y0, Y8, Y0 + VPAND Y2, Y0, Y0 + VPAND Y9, Y0, Y0 + VPSLLW $0x05, Y0, Y0 + VPCMPEQB Y1, Y0, Y0 + VXORPD Y3, Y4, Y4 + VPCMPEQB Y6, Y4, Y5 + VORPD Y6, Y3, Y3 + VPADDB Y7, Y3, Y3 + VPCMPGTB Y3, Y8, Y3 + VPAND Y5, Y3, Y3 + VPAND Y9, Y3, Y3 + VPSLLW $0x05, Y3, Y3 VPCMPEQB Y4, Y3, Y3 - VPAND Y3, Y2, Y2 - VPMOVMSKB Y2, BP - CMPL BP, $0xffffffff - JNE done + VPAND Y3, Y0, Y0 ADDQ $0x40, AX - ADDQ $0x40, CX - SUBQ $0x04, DX + SUBQ $0x04, BX + VPMOVMSKB Y0, DI + CMPL DI, $0xffffffff + JNE done JMP loop64 -loop32: - CMPQ DX, $0x02 - JL loop16 - VPAND (AX), Y1, Y2 - VPAND (CX), Y1, Y3 - VPCMPEQB Y3, Y2, Y2 - VPMOVMSKB Y2, BP - CMPL BP, $0xffffffff - JNE done +cmp32: + CMPQ BX, $0x02 + JB cmp16 + VMOVDQU (CX)(AX*1), Y0 + VMOVDQU (DX)(AX*1), Y1 + VXORPD Y0, Y1, Y1 + VPCMPEQB Y6, Y1, Y2 + VORPD Y6, Y0, Y0 + VPADDB Y7, Y0, Y0 + VPCMPGTB Y0, Y8, Y0 + VPAND Y2, Y0, Y0 + VPAND Y9, Y0, Y0 + VPSLLW $0x05, Y0, Y0 + VPCMPEQB Y1, Y0, Y0 ADDQ $0x20, AX - ADDQ $0x20, CX - SUBQ $0x02, DX + SUBQ $0x02, BX + VPMOVMSKB Y0, DI + CMPL DI, $0xffffffff + JNE done -loop16: - CMPQ DX, $0x00 - JE equal - VPAND (AX), X0, X1 - VPAND (CX), X0, X0 - VPCMPEQB X0, X1, X1 - VPMOVMSKB X1, BP - CMPL BP, $0x0000ffff +cmp16: + CMPQ BX, $0x01 + JB equal + VMOVDQU (CX)(AX*1), X0 + VMOVDQU (DX)(AX*1), X1 + VXORPD X0, X1, X1 + VPCMPEQB X6, X1, X2 + VORPD X6, X0, X0 + VPADDB X7, X0, X0 + VPCMPGTB X0, X8, X0 + VPAND X2, X0, X0 + VPAND X9, X0, X0 + VPSLLW $0x05, X0, X0 + VPCMPEQB X1, X0, X0 + VPMOVMSKB X0, DI + CMPL DI, $0x0000ffff JNE done equal: - MOVQ $0x0000000000000001, BX + MOVQ $0x0000000000000001, SI done: - MOVQ BX, ret+24(FP) + MOVQ SI, ret+24(FP) RET diff --git a/ascii/equal_fold_asm.go b/ascii/equal_fold_asm.go index 65f0bce..5d55593 100644 --- a/ascii/equal_fold_asm.go +++ b/ascii/equal_fold_asm.go @@ -5,83 +5,88 @@ package main import ( . "github.com/mmcloughlin/avo/build" . "github.com/mmcloughlin/avo/operand" + . "github.com/mmcloughlin/avo/reg" ) func main() { TEXT("equalFoldAVX2", NOSPLIT, "func(a *byte, b *byte, n uintptr) int") Doc("Case-insensitive comparison of two ASCII strings (equality).") - p := Load(Param("a"), GP64()) - q := Load(Param("b"), GP64()) + i := GP64() + p := Mem{Base: Load(Param("a"), GP64()), Index: i, Scale: 1} + q := Mem{Base: Load(Param("b"), GP64()), Index: i, Scale: 1} n := Load(Param("n"), GP64()) + XORQ(i, i) SHRQ(Imm(4), n) // n /= 16 eq := GP64() - MOVQ(U64(0), eq) + XORQ(eq, eq) - mask64 := GP64() - MOVQ(U64(0xDFDFDFDFDFDFDFDF), mask64) + cmpk := GP32() + mask256 := [4]Register{} + mask128 := [4]Register{} - mask128 := XMM() - mask256 := YMM() - PINSRQ(Imm(0), mask64, mask128) - PINSRQ(Imm(1), mask64, mask128) - VPBROADCASTQ(mask128, mask256) + for i, b := range [4]byte{0x20, 0x1F, 0x9A, 0x01} { + y := YMM() + g := GP32() - cmpk := GP32() - xmm0 := XMM() - xmm1 := XMM() - ymm0 := YMM() - ymm1 := YMM() - ymm2 := YMM() - ymm3 := YMM() + MOVB(U8(b), g.As8()) + PINSRB(U8(0), g, y.AsX()) + VPBROADCASTB(y.AsX(), y) + + mask256[i] = y + mask128[i] = y.AsX() + } Label("loop64") CMPQ(n, Imm(4)) - JL(LabelRef("loop32")) + JB(LabelRef("cmp32")) - VPAND(Mem{Base: p}, mask256, ymm0) - VPAND(Mem{Base: q}, mask256, ymm1) - VPCMPEQB(ymm1, ymm0, ymm0) + VMOVDQU(p, Y0) + VMOVDQU(p.Offset(32), Y3) + VMOVDQU(q, Y1) + VMOVDQU(q.Offset(32), Y4) - VPAND((Mem{Base: p}).Offset(32), mask256, ymm2) - VPAND((Mem{Base: q}).Offset(32), mask256, ymm3) - VPCMPEQB(ymm3, ymm2, ymm2) + gen(Y0, Y1, Y2, mask256) + gen(Y3, Y4, Y5, mask256) + VPAND(Y3, Y0, Y0) // merge results together - VPAND(ymm2, ymm0, ymm0) - VPMOVMSKB(ymm0, cmpk) + ADDQ(Imm(64), i) + SUBQ(Imm(4), n) + + VPMOVMSKB(Y0, cmpk) CMPL(cmpk, U32(0xFFFFFFFF)) JNE(LabelRef("done")) - ADDQ(Imm(64), p) - ADDQ(Imm(64), q) - SUBQ(Imm(4), n) JMP(LabelRef("loop64")) - Label("loop32") + Label("cmp32") CMPQ(n, Imm(2)) - JL(LabelRef("loop16")) + JB(LabelRef("cmp16")) + + VMOVDQU(p, Y0) + VMOVDQU(q, Y1) - VPAND(Mem{Base: p}, mask256, ymm0) - VPAND(Mem{Base: q}, mask256, ymm1) - VPCMPEQB(ymm1, ymm0, ymm0) - VPMOVMSKB(ymm0, cmpk) + gen(Y0, Y1, Y2, mask256) + + ADDQ(Imm(32), i) + SUBQ(Imm(2), n) + + VPMOVMSKB(Y0, cmpk) CMPL(cmpk, U32(0xFFFFFFFF)) JNE(LabelRef("done")) - ADDQ(Imm(32), p) - ADDQ(Imm(32), q) - SUBQ(Imm(2), n) + Label("cmp16") + CMPQ(n, Imm(1)) + JB(LabelRef("equal")) - Label("loop16") - CMPQ(n, Imm(0)) - JE(LabelRef("equal")) + VMOVDQU(p, X0) + VMOVDQU(q, X1) - VPAND(Mem{Base: p}, mask128, xmm0) - VPAND(Mem{Base: q}, mask128, xmm1) - VPCMPEQB(xmm1, xmm0, xmm0) - VPMOVMSKB(xmm0, cmpk) - CMPL(cmpk, U32(0xFFFF)) + gen(X0, X1, X2, mask128) + + VPMOVMSKB(X0, cmpk) + CMPL(cmpk, U32(0x0000FFFF)) JNE(LabelRef("done")) Label("equal") @@ -92,3 +97,15 @@ func main() { RET() Generate() } + +func gen(v0, v1, v2 Register, mask [4]Register) { + VXORPD(v0, v1, v1) // calculate difference between v0 and v1 + VPCMPEQB(mask[0], v1, v2) // check if above difference is the 6th bit + VORPD(mask[0], v0, v0) // set the 6th bit for v0 + VPADDB(mask[1], v0, v0) // add 0x1f to each byte to set top bit for letters + VPCMPGTB(v0, mask[2], v0) // compare if not letter: v - 'a' < 'z' - 'a' + 1 + VPAND(v2, v0, v0) // combine 6th-bit difference with letter range + VPAND(mask[3], v0, v0) // merge test mask + VPSLLW(Imm(5), v0, v0) // shift into case bit position + VPCMPEQB(v1, v0, v0) // compare original difference with case-only difference +}