Skip to content

Commit

Permalink
Gemm support for batch processing (#30)
Browse files Browse the repository at this point in the history
* F32 GEMM

* fix tests

* Adding matrix operations to other methods

* initial avx512 batch code

* Finish AVX512 Native Gemm

* AVX 256, MOE and Distributed kinda done

* use arm style api and 128 bit versions

* Clean up dead code and fix sparse handling

* Add missing files

* fix compile
  • Loading branch information
tjake committed Jun 16, 2024
1 parent 8d3b5ba commit 31cab51
Show file tree
Hide file tree
Showing 47 changed files with 4,143 additions and 2,982 deletions.
1 change: 0 additions & 1 deletion .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: Unit Test CI

on:
workflow_dispatch:
pull_request:
push:
paths:
- .github/workflows/unit-tests.yaml
Expand Down
5 changes: 3 additions & 2 deletions inlinerules.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
inline: [
"+com.github.tjake.jlama.tensor.operations.PanamaTensorOperations::dotProduct*",
"+com.github.tjake.jlama.tensor.operations.PanamaTensorOperations::quantize*",
"+com.github.tjake.jlama.tensor.*::getVector*",
"+com.github.tjake.jlama.tensor.*::intoTensor*"
"+com.github.tjake.jlama.tensor.operations.PanamaTensorOperations*::mpack*",
"+com.github.tjake.jlama.tensor.AbstractTensor::getOffset*",
"+com.github.tjake.jlama.tensor.*::getVector*"
]
}
]
Original file line number Diff line number Diff line change
Expand Up @@ -38,41 +38,45 @@ public static void pfor(int start, int end, IntConsumer action) {
public static void pchunk(int offset, int length, BiIntConsumer action) {
int splits = Math.min(length, TensorOperationsProvider.get().parallelSplitSize());
int chunkSize = length / splits;
int remainder = 0;

// Non optimal case, just run in parallel
if (splits == 1 || splits % length != 0) {
if (splits == 1) {
splits = length;
chunkSize = 1;
} else if (length % chunkSize != 0) {
remainder = length % chunkSize;
}

int fsplits = splits;
int fchunkSize = chunkSize;
int fremainder = remainder;

PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(0, fsplits)
.parallel()
.forEach(i -> action.accept(offset + (i * fchunkSize), fchunkSize)));
.forEach(i -> action.accept(offset + (i * fchunkSize), fremainder > 0 ? fchunkSize + fremainder : fchunkSize)));

}

public static void softMax(FloatBufferTensor x) {
int offset = 0;
long size = x.size();
public static void softMax(AbstractTensor x, int offset, int length) {
long size = offset + length;

// find max value (for numerical stability)
float max_val = x.get(offset);
float max_val = x.get(0, offset);
for (int i = offset + 1; i < size; i++) {
if (x.get(i) > max_val) {
max_val = x.get(i);
if (x.get(0, i) > max_val) {
max_val = x.get(0, i);
}
}
// exp and sum
float sum = 0.0f;
for (int i = offset; i < size; i++) {
x.set((float) StrictMath.exp(x.get(i) - max_val), i);
sum += x.get(i);
x.set((float) StrictMath.exp(x.get(0, i) - max_val), 0, i);
sum += x.get(0, i);
}
// normalize
for (int i = 0; i < size; i++) {
x.set(x.get(i) / sum, i);
x.set(x.get(0, i) / sum, 0, i);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,50 +200,52 @@ public AbstractTensor forward(
}

protected AbstractTensor batchForward(int[] token_ids, int startPos, AbstractTensor kvbuf) {
AbstractTensor last = null;
for (int i = 0; i < token_ids.length; i++) {
if (last != null) last.close();

last = forward(token_ids[i], startPos + i, kvbuf);
AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
for (int i = c.layerStart(); i < c.layerEnd(); i++) {
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; // reference so we can free
embedding = transformerBlocks[i].forward(embedding, startPos, kvlayer, Optional.empty(), Optional.empty());
ref.close();
}

return last;
return embedding;
}

public int sample(AbstractTensor output, float temperature, float uniformSample, AbstractTensor logits) {
try (AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output)) {
AtomicReference<Double> maxv = new AtomicReference<>(Double.NEGATIVE_INFINITY);
AtomicInteger maxi = new AtomicInteger(Integer.MIN_VALUE);
//AtomicReference<Double> maxv = new AtomicReference<>(Double.NEGATIVE_INFINITY);
//AtomicInteger maxi = new AtomicInteger(Integer.MIN_VALUE);

// This is a mix of argmax and sampling with softmax
VectorMath.pfor(0, c.vocabularySize, i -> {
float v = TensorOperationsProvider.get()
.dotProduct(
embedding, sampleOutput.getOutputLogitsWeights().slice(i), c.embeddingLength);
logits.set(v, i);
maxv.getAndUpdate(x -> {
if (v > x) {
maxi.set(i);
return (double) v;
}
return x;
});
VectorMath.pchunk(0, c.vocabularySize, (chunkStart, chunkSize) -> {
TensorOperationsProvider.get().dotProductChunk(logits, embedding, sampleOutput.getOutputLogitsWeights(), 0, c.embeddingLength, chunkStart, chunkSize);
});

int maxi = Integer.MIN_VALUE;
double maxv = Double.NEGATIVE_INFINITY;
for (int i = 0; i < c.vocabularySize; i++) {
float v = logits.get(0, i);
if (v > maxv) {
maxi = i;
maxv = v;
}
}

if (temperature == 0.0) {
return maxi.get();
return maxi;
}

float sum = 0;
for (int i = 0; i < c.vocabularySize; i++) {
float v = (float) Math.exp((logits.get(i) - maxv.get()) / temperature);
float v = (float) Math.exp((logits.get(0, i) - maxv) / temperature);
sum += v;
logits.set(v, i);
logits.set(v, 0, i);
}

float acc = 0;
for (int i = 0; i < c.vocabularySize; i++) {
float v = logits.get(i) / sum;
float v = logits.get(0, i) / sum;
acc += v;
if (acc >= uniformSample) return i;
}
Expand All @@ -265,7 +267,7 @@ public void generate(

if (ntokens > c.contextLength) ntokens = c.contextLength;

AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), ntokens, 2, c.kvLength); // k and v are last 2 dims
AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), 2, ntokens, c.kvLength); // k and v for context window
AbstractTensor logits = makeTensor(c.vocabularySize);

int[] promptTokens = new int[useEOS ? (1 + encoded.length + 1) : (1 + encoded.length)];
Expand All @@ -287,39 +289,43 @@ public void generate(
AbstractTensor last = batchForward(promptTokens, 0, kvmem);

long promptBatchTime = System.currentTimeMillis() - start;
float avgTime = Math.round((((double) promptBatchTime) / (double) promptLength));
logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, avgTime);
float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength));
logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken);

float genMsPerToken = 0;
int tokensGenerated = 0;
int next = sample(last, temperature, ThreadLocalRandom.current().nextFloat(), logits);
int next = sample(last.slice(promptTokens.length - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits);
last.close();
try {
String c = tokenizer.decode(next);
onTokenWithTimings.accept(c, avgTime);
onTokenWithTimings.accept(c, batchMsPerToken);
} catch (Exception e) {
logger.error("Failed to decode token {}", next, e);
}

start = System.currentTimeMillis();
for (int i = promptTokens.length - 1; i < ntokens; i++) {
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;
next = sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits);

if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);

output.close();
// Model may tell us it's done
if (next == c.eosToken) break;

try {
String c = tokenizer.decode(next);
onTokenWithTimings.accept(c, (System.currentTimeMillis() - start) / (float) (i + 1));
genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated);
onTokenWithTimings.accept(c, genMsPerToken);
} catch (Exception e) {
logger.error("Failed to decode token {}", next, e);
}
}

long end = System.currentTimeMillis();
System.out.printf(
"\n\nelapsed: %ds, %fms per token\n",
TimeUnit.MILLISECONDS.toSeconds(end - start), ((end - start) / (float) tokensGenerated));
"\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n",
TimeUnit.MILLISECONDS.toSeconds(end - start), batchMsPerToken, genMsPerToken);
}
}
Loading

0 comments on commit 31cab51

Please sign in to comment.