Skip to content

Commit

Permalink
Clean up dead code and fix sparse handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 16, 2024
1 parent fde3fb8 commit ead0f23
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 1,422 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected AbstractTensor(DType dType, TensorShape shape, boolean cacheSlices) {
this.shape = shape;
this.requiresOffHeapTensor = TensorOperationsProvider.get().requiresOffHeapTensor();
this.sliceCache = cacheSlices ? new AbstractTensor[shape.first()] : null;
this.stride = shape.first() > 1 && dims() == 2 ? getOffset(1, 0) : 0;
this.stride = shape.first() > 1 && dims() == 2 ? getOffset(1, shape.sparseOffset()) : 0;
}

/** Create a new tensor with the given shape of the same Tensor implementation */
Expand Down

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions jlama-native/src/main/c/vector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void __attribute__((noinline)) gemm(int m0, int m, int n0, int n, void (*gemmPtr
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
gemm(mp, m, n0, np, gemmPtr, params);
gemm(m0, m, np, n, gemmPtr, params);
gemm(m0, mp, np, n, gemmPtr, params);
}

#if defined(__ARM_NEON__)
Expand Down Expand Up @@ -223,7 +223,7 @@ void __attribute__((noinline)) gemm_q8_q4_128_arm(int m0, int m, int n0, int n,

for (int mi = 0; mi < RM; ++mi) {
for (int ni = 0; ni < RN; ++ni) {
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = vaddvq_f32(sums[mi][ni]);
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = vaddvq_f32(sums[mi][ni]);
}
}
}
Expand Down Expand Up @@ -360,7 +360,7 @@ void __attribute__((noinline)) gemm_q8_q4_128(int m0, int m, int n0, int n, int
dot += result[i];
}
//fprintf(stderr, "ii: %d, ni: %d, jj: %d, mi: %d, ldc: %d\n", ii, ni, jj, mi, params.ldc);
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -436,7 +436,6 @@ void __attribute__((noinline)) gemm_q8_q4_256(int m0, int m, int n0, int n, int
}
}


for (int mi = 0; mi < RM; ++mi) {
for (int ni = 0; ni < RN; ++ni) {
__attribute__((aligned(16))) float result[8];
Expand All @@ -446,8 +445,10 @@ void __attribute__((noinline)) gemm_q8_q4_256(int m0, int m, int n0, int n, int
for(int i = 0; i < 8; ++i) {
dot += result[i];
}
//fprintf(stderr, "ii: %d, ni: %d, jj: %d, mi: %d, ldc: %d\n", ii, ni, jj, mi, params.ldc);
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
//int idx = (params.ldc * (ii + mi)) + (jj + ni);
//if (idx > params.roffset)
// fprintf(stderr, "ii: %d, ni: %d, jj: %d, mi: %d, ldc: %d, idx: %d, lim: %d\n", ii, ni, jj, mi, params.ldc, idx, params.roffset);
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -533,7 +534,7 @@ void __attribute__((noinline)) gemm_q8_q4_512(int m0, int m, int n0, int n, int
for (int ni = 0; ni < RN; ++ni) {
// Horizontal sum of the vector to get dot product
float dot = _mm512_reduce_add_ps(_mm512_castps256_ps512(sums[mi][ni]));
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -566,6 +567,8 @@ void gemm_q8_q4(int flags, const float * restrict af, const char * restrict a, i
.ldc = ldc
};

//fprintf(stderr, "m: %d, n0: %d, n: %d, k: %d, lda: %d, ldaf: %d, ldb: %d, ldbf: %d, ldc: %d\n", m, n0, n, k, lda, ldaf, ldb, ldbf, ldc);

#if !defined(__ARM_NEON__)
((flags & HAS_AVX2) != 0)
? gemm(0, m, n0, n0 + n, gemm_q8_q4_512, p)
Expand Down Expand Up @@ -628,7 +631,7 @@ void gemm_f32_128(int m0, int m, int n0, int n, int RM, int RN, struct gemm_para
for(int i = 0; i < 4; ++i) {
dot += result[i];
}
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -681,7 +684,7 @@ void gemm_f32_256(int m0, int m, int n0, int n, int RM, int RN, struct gemm_para
for(int i = 0; i < 8; ++i) {
dot += result[i];
}
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -727,7 +730,7 @@ void gemm_f32_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_para
for (int ni = 0; ni < RN; ++ni) {
// Horizontal sum of the vector to get dot product
float r = _mm512_reduce_add_ps(sums[mi][ni]);
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = r;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
}
}
}
Expand Down Expand Up @@ -912,7 +915,7 @@ void gemm_f32_q4_128(int m0, int m, int n0, int n, int RM, int RN, struct gemm_p
for(int i = 0; i < 4; ++i) {
dot += result[i];
}
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -1023,7 +1026,7 @@ void gemm_f32_q4_256(int m0, int m, int n0, int n, int RM, int RN, struct gemm_p
}
//if (params.roffset > 0)
// fprintf(stderr, "ii: %d, ni: %d, jj: %d, mi: %d, ldc: %d, roffset: %d\n", ii, ni, jj, mi, params.ldc, params.roffset);
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = dot;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = dot;
}
}
}
Expand Down Expand Up @@ -1106,7 +1109,7 @@ void gemm_f32_q4_512(int m0, int m, int n0, int n, int RM, int RN, struct gemm_p
for (int ni = 0; ni < RN; ++ni) {
// Horizontal sum of the vector to get dot product
float r = _mm512_reduce_add_ps(sums[mi][ni]);
params.r[(params.ldc * (ii + mi)) + (jj + ni)] = r;
params.r[(params.ldc * (ii + mi)) + (jj + ni) - params.roffset] = r;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public String name() {
}

private void checkLib() {
NativeSimd.dot_product_f32$MH();
NativeSimd.gemm_f32$MH();
}

@Override
Expand Down Expand Up @@ -107,14 +107,14 @@ public void batchDotProduct(AbstractTensor result, AbstractTensor at, AbstractTe
bt.getMemorySegment(),
bt.getOffset(0, bColumnOffset),
result.getMemorySegment(),
result.getOffset(0, aColumnOffset),
result.shape().sparseOffset(),
M,
bRowOffset,
N,
K,
at.getOffset(1, 0),
bt.getOffset(1, 0),
result.getOffset(1, 0));
at.getStride(),
bt.getStride(),
result.getStride());
break;
case Q4:
Q4ByteBufferTensor b = (Q4ByteBufferTensor) bt;
Expand All @@ -126,15 +126,15 @@ public void batchDotProduct(AbstractTensor result, AbstractTensor at, AbstractTe
b.getMemorySegment(),
b.getMemorySegmentOffset(b.getOffset(0, bColumnOffset)),
result.getMemorySegment(),
result.getOffset(0, aColumnOffset),
result.shape().sparseOffset(),
M,
bRowOffset,
N,
K,
at.getOffset(1, 0),
b.getMemorySegmentOffset(b.getOffset(1, 0)),
b.getBlockF().getOffset(1, 0),
result.getOffset(1, 0));
at.getStride(),
b.getMemorySegmentOffset(b.getStride()),
b.getBlockF().getStride(),
result.getStride());
break;
default:
throw new UnsupportedOperationException(at.dType().name() + " " + bt.dType().name());
Expand All @@ -155,7 +155,7 @@ public void batchDotProduct(AbstractTensor result, AbstractTensor at, AbstractTe
b.getMemorySegment(),
b.getMemorySegmentOffset(b.getOffset(0, bColumnOffset)),
result.getMemorySegment(),
result.getOffset(0, aColumnOffset),
result.shape().sparseOffset(),
M,
bRowOffset,
N,
Expand Down Expand Up @@ -214,13 +214,13 @@ public void dotProductBatchChunk(
rb,
b[0].getOffset(0, columnOffset),
ra,
r[0].getOffset(0, columnOffset),
r[0].shape().sparseOffset(),
M,
bRowOffset,
N,
K,
a.getOffset(1, 0),
b[0].getOffset(1, 0),
a.getStride(),
b[0].getStride(),
r[0].getStride());
break;
case Q4:
Expand All @@ -239,14 +239,14 @@ public void dotProductBatchChunk(
rb,
b[0].getMemorySegmentOffset(b[0].getOffset(0, columnOffset)),
ra,
r[0].getOffset(0, columnOffset),
r[0].shape().sparseOffset(),
M,
bRowOffset,
N,
K,
a.getOffset(1, 0),
b[0].getMemorySegmentOffset(b[0].getOffset(1, 0)),
bt.getBlockF().getOffset(1, 0),
a.getStride(),
b[0].getMemorySegmentOffset(b[0].getStride()),
bt.getBlockF().getStride(),
r[0].getStride());
break;
default:
Expand Down Expand Up @@ -275,15 +275,15 @@ public void dotProductBatchChunk(
rb,
bt.getMemorySegmentOffset(bt.getOffset(0, columnOffset)),
ra,
r[0].getOffset(0, columnOffset),
r[0].shape().sparseOffset(),
M,
bRowOffset,
N,
K,
a.getOffset(1, 0),
at.getBlockF().getOffset(1, 0),
bt.getMemorySegmentOffset(bt.getOffset(1, 0)),
bt.getBlockF().getOffset(1, 0),
a.getStride(),
at.getBlockF().getStride(),
bt.getMemorySegmentOffset(bt.getStride()),
bt.getBlockF().getStride(),
r[0].getStride());
break;
default:
Expand Down

This file was deleted.

Loading

0 comments on commit ead0f23

Please sign in to comment.