Skip to content

Commit

Permalink
Finish initial arm support and remove sse2neon
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 16, 2024
1 parent 35ffb08 commit 4c75d68
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 9,444 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public void batchDotProduct(
case Q4 -> switch (vectorType) {
case AVX_256 -> new GemmerI8Q4_256(K, a, b, result, aColumnOffset, bColumnOffset);
case AVX_512 -> new GemmerI8Q4_512(K, a, b, result, aColumnOffset, bColumnOffset);
case ARM_128 -> new GemmerI8Q4_arm(K, a, b, result, aColumnOffset, bColumnOffset);
default -> throw new UnsupportedOperationException(vectorType.name());
};
default -> throw new UnsupportedOperationException(
Expand Down Expand Up @@ -678,6 +679,130 @@ protected BiIntConsumer initMatmul1x4() {
}
}

private class GemmerI8Q4_arm extends Gemmer {
final BiIntConsumer matmul1x1;
final BiIntConsumer matmul1x4;
final BiIntConsumer matmul3x4;
final BiIntConsumer matmul4x1;

final Q8ByteBufferTensor a;
final Q4ByteBufferTensor b;

GemmerI8Q4_arm(
int k, AbstractTensor ta, AbstractTensor tb, AbstractTensor c, int aColumnOffset, int bColumnOffset) {
super(k, ta, tb, c, aColumnOffset, bColumnOffset);

this.a = (Q8ByteBufferTensor) ta;
this.b = (Q4ByteBufferTensor) tb;

this.matmul1x1 = initMatmul1x1();
this.matmul1x4 = null;
this.matmul3x4 = null;
this.matmul4x1 = null;
}

@Override
protected int pickKernel(int m0, int m, int n0, int n) {
short mc, nc;
/*if (m - m0 >= 3 && n - n0 >= 4)
{
mc = 3;
nc = 4;
kernel(m0, m, 3, n0, n, 4, matmul3x4);
}
else if (m - m0 >= 4 && n - n0 >= 1)
{
mc = 4;
nc = 1;
kernel(m0, m, 4, n0, n, 1, matmul4x1);
}
else if (m - m0 >= 1 && n - n0 >= 4)
{
mc = 1;
nc = 4;
kernel(m0, m, 1, n0, n, 4, matmul1x4);
}
else*/
{
mc = 1;
nc = 1;
kernel(m0, m, 1, n0, n, 1, matmul1x1);
}

return (mc << 4) | nc;
}

protected BiIntConsumer initMatmul1x1() {
return (i, j) -> {
final int blockSize = Q8ByteBufferTensor.BLOCK_SIZE;
final int blocksNeeded = k / Q8ByteBufferTensor.BLOCK_SIZE;

int aoffset = aColumnOffset;
int boffset = bColumnOffset;

FloatVector acc = FloatVector.zero(FloatVector.SPECIES_128);

// First take the scaling factors of both tensors and multiply them in SIMD
for (int bi = 0; bi < blocksNeeded; bi += FloatVector.SPECIES_128.length()) {
final var ablock = a.getBlockF()
.getVector(FloatVector.SPECIES_128, i, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * aoffset));
final var bblock = b.getBlockF()
.getVector(FloatVector.SPECIES_128, j, (int) (Q8ByteBufferTensor.I_BLOCK_SIZE * boffset));

final var scales = ablock.mul(bblock);
// Now for each scalar fetch the corresponding block of data and dot product them
for (int k = 0;
k < FloatVector.SPECIES_128.length();
k++, aoffset += blockSize, boffset += blockSize) {
var scale = FloatVector.broadcast(FloatVector.SPECIES_128, scales.lane(k));

var ab0 = a.getVector(ByteVector.SPECIES_128, i, aoffset);
var ab1 = a.getVector(ByteVector.SPECIES_128, i, aoffset + 16);

var af0 = ab0.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
var af1 = ab0.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);
var af2 = ab1.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
var af3 = ab1.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 1);

// Make 8 bytes -> 16 4bit -> 16 bytes -> 16 32F
var bf0 = b.getVector(ByteVector.SPECIES_64, j, boffset);
var bf1 = b.getVector(ByteVector.SPECIES_64, j, boffset + 16);

// Convert the first 4 bits into bytes
var low = bf0.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
var high = bf0.lanewise(VectorOperators.ASHR, Q4_BYTE_SHIFT_64)
.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64)
.sub(Q4_BYTE_SUB_64);

var low0 = low.castShape(ShortVector.SPECIES_128, 0);
var high0 = high.castShape(ShortVector.SPECIES_128, 0);

var nlow = bf1.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64).sub(Q4_BYTE_SUB_64);
var nhigh = bf1.lanewise(VectorOperators.ASHR, Q4_BYTE_SHIFT_64)
.lanewise(VectorOperators.AND, Q4_BYTE_MASK_64)
.sub(Q4_BYTE_SUB_64);

var low2 = nlow.castShape(ShortVector.SPECIES_128, 0);
var high2 = nhigh.castShape(ShortVector.SPECIES_128, 0);

ShortVector tacc = ShortVector.zero(ShortVector.SPECIES_128);
tacc = tacc.add(af0.mul(low0));
tacc = tacc.add(af1.mul(low2));

tacc = tacc.add(af2.mul(high0));
tacc = tacc.add(af3.mul(high2));

acc = acc.add(tacc.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 0).mul(scale));
acc = acc.add(tacc.convertShape(VectorOperators.S2F, FloatVector.SPECIES_128, 1).mul(scale));
}
}

c.set(acc.reduceLanes(VectorOperators.ADD), i, j);
};
}
}


private class GemmerI8Q4_256 extends Gemmer {
final BiIntConsumer matmul1x1;
final BiIntConsumer matmul1x4;
Expand Down
Loading

0 comments on commit 4c75d68

Please sign in to comment.