Skip to content

Commit

Permalink
AVX 256, MOE and Distributed kinda done
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 4, 2024
1 parent 5aa01f4 commit e2ea5fb
Show file tree
Hide file tree
Showing 13 changed files with 413 additions and 194 deletions.
166 changes: 90 additions & 76 deletions jlama-core/src/main/java/com/github/tjake/jlama/model/MoEBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,91 +70,105 @@ public MoEBlock(

@Override
public AbstractTensor forward(AbstractTensor lnemb, Optional<Consumer<List<AbstractTensor>>> tensorReducer) {
int batchSize = lnemb.shape().first();

int hiddenLength = model.c.hiddenLength;
AbstractTensor result = model.makeTensor(model.c.embeddingLength);
AbstractTensor result = model.makeTensor(batchSize, model.c.embeddingLength);

try (AbstractTensor buf = model.makeTensor(hiddenLength);
AbstractTensor buf2 = model.makeTensor(hiddenLength);
AbstractTensor moeResult = model.makeTensor(model.c.embeddingLength)) {
try (AbstractTensor buf = model.makeTensor(1, hiddenLength);
AbstractTensor buf2 = model.makeTensor(1, hiddenLength);
AbstractTensor moeResult = model.makeTensor(1, model.c.embeddingLength)) {

// Apply each experts gate to the input
VectorMath.pfor(0, numberOfExperts, i -> {
expertResults.set(
TensorOperationsProvider.get()
.dotProduct(
lnemb,
moeGateWeight.slice(true, i),
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength()),
i);
});

tensorReducer.ifPresent(func -> {
func.accept(Collections.singletonList(expertResults));
});

// Pick the top experts for this token
VectorMath.softMax(expertResults, 0, numberOfExperts);
topk(expertResults);

// Apply the selected experts to the input
for (int i = 0; i < numberOfExpertsPerToken; i++) {
batchWeights[0] = fullyConnectedWeights[selectedExperts[i]];
batchWeights[1] = upProjectionWeights[selectedExperts[i]];
AbstractTensor projectionWeight = projectionWeights[selectedExperts[i]];
batchResults[0] = buf;
batchResults[1] = buf2;

VectorMath.pchunk(0, hiddenLength, (chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductBatchChunk(
batchResults,
lnemb,
batchWeights,
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength(),
chunkStart,
chunkSize);
for (int b = 0; b < batchSize; b++) {
AbstractTensor lnembSlice = lnemb.slice(true, b);
// Apply each experts gate to the input
VectorMath.pfor(0, numberOfExperts, i -> {
expertResults.set(
TensorOperationsProvider.get()
.dotProduct(
lnembSlice,
moeGateWeight.slice(true, i),
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength()),
0, i);
});

tensorReducer.ifPresent(func -> {
tmpTensors1.clear();
tmpTensors1.add(buf);
tmpTensors1.add(buf2);
func.accept(tmpTensors1);
func.accept(Collections.singletonList(expertResults));
});

VectorMath.pfor(0, hiddenLength, iv -> {
float w1 = buf.get(iv);
float w1a = ActivationFunction.eval(activationFunction, w1);
buf.set(w1a, iv);
});
// Pick the top experts for this token
VectorMath.softMax(expertResults, 0, numberOfExperts);
topk(expertResults);

TensorOperationsProvider.get().maccumulate(buf, buf2, 0, hiddenLength);
// Apply the selected experts to the input
for (int i = 0; i < numberOfExpertsPerToken; i++) {
batchWeights[0] = fullyConnectedWeights[selectedExperts[i]];
batchWeights[1] = upProjectionWeights[selectedExperts[i]];
AbstractTensor projectionWeight = projectionWeights[selectedExperts[i]];
batchResults[0] = buf;
batchResults[1] = buf2;

// matmul the projection and sum into result
VectorMath.pchunk(
model.c.embeddingSegmentStart(), model.c.embeddingSegmentLength(), (chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductChunk(
moeResult, buf, projectionWeight, 0, hiddenLength, chunkStart, chunkSize);
});

if (i == 0) {
result.copyFrom(
moeResult,
moeResult.getOffset(model.c.embeddingSegmentStart()),
result.getOffset(model.c.embeddingSegmentStart()),
model.c.embeddingSegmentLength());
} else {
TensorOperationsProvider.get()
.accumulate(
result,
moeResult,
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength());
VectorMath.pchunk(0, hiddenLength, (chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductBatchChunk(
batchResults,
lnembSlice,
batchWeights,
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength(),
chunkStart,
chunkSize);
});

tensorReducer.ifPresent(func -> {
tmpTensors1.clear();
tmpTensors1.add(buf);
tmpTensors1.add(buf2);
func.accept(tmpTensors1);
});

VectorMath.pfor(0, hiddenLength, iv -> {
float w1 = buf.get(0, iv);
float w1a = ActivationFunction.eval(activationFunction, w1);
buf.set(w1a, 0, iv);
});

TensorOperationsProvider.get().maccumulate(buf, buf2, 0, hiddenLength);

// matmul the projection and sum into result
try (AbstractTensor bufq = model.maybeQuantize(buf)) {
VectorMath.pchunk(
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength(),
(chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductChunk(
moeResult,
bufq,
projectionWeight,
0,
hiddenLength,
chunkStart,
chunkSize);
});
}

if (i == 0) {
result.copyFrom(
moeResult,
moeResult.getOffset(0, model.c.embeddingSegmentStart()),
result.getOffset(b, model.c.embeddingSegmentStart()),
model.c.embeddingSegmentLength());
} else {
TensorOperationsProvider.get()
.accumulate(
result.slice(b),
moeResult,
model.c.embeddingSegmentStart(),
model.c.embeddingSegmentLength());
}
}
}

Expand All @@ -170,11 +184,11 @@ private int[] topk(FloatBufferTensor probs) {
for (int i = numberOfExpertsPerToken; i < length; i++) {
int min = 0;
for (int j = 1; j < numberOfExpertsPerToken; j++) {
if (probs.get(selectedExperts[j]) < probs.get(selectedExperts[min])) {
if (probs.get(0, selectedExperts[j]) < probs.get(0, selectedExperts[min])) {
min = j;
}
}
if (probs.get(i) > probs.get(selectedExperts[min])) {
if (probs.get(0, i) > probs.get(0, selectedExperts[min])) {
selectedExperts[min] = i;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ protected EmbedInput loadInputWeights() {
final AbstractTensor wpe = weights.load("wpe.weight");

return (inputToken, position) -> {
AbstractTensor embedding = makeTensor(c.embeddingLength);
AbstractTensor embedding = makeTensor(1, c.embeddingLength);

for (int i = 0; i < c.embeddingLength; i++) {
for (int i = c.embeddingSegmentStart(); i < c.embeddingSegmentLength(); i++) {
float v = wte.get(inputToken, i) + wpe.get(position, i);
embedding.set(v, 0, i);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ protected EmbedInput loadInputWeights() {
.quantize(workingDType); // Don't quantize this, it's used for the embedding layer

return (inputToken, position) -> {
AbstractTensor embedding = makeTensor(c.embeddingLength);
AbstractTensor embedding = makeTensor(1, c.embeddingLength);
embedding.copyFrom(
wte,
wte.getOffset(inputToken, c.embeddingSegmentStart()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public final int getOffset(int... pdims)
switch (pdims.length)
{
case 1:
return sparseLength * pdims[0];
return sparseLength * pdims[0] - sparseOffset;
case 2:
return sparseLength * pdims[0] + pdims[1] - sparseOffset; //Most common case
case 3:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public void batchDotProduct(AbstractTensor result, AbstractTensor a, AbstractTen
Preconditions.checkArgument(a.dims() == 2 && b.dims() == 2 && result.dims() == 2);
Preconditions.checkArgument(a.shape().dim(0) == result.shape().dim(0), "BAD M");
Preconditions.checkArgument(b.shape().dim(0) == result.shape().dim(1), "BAD N");
Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K");
//Preconditions.checkArgument(a.shape().dim(1) == b.shape().dim(1), "BAD K");

int M = a.shape().dim(0);
int N = rowChunkSize; //b.shape().dim(0);
Expand Down
2 changes: 1 addition & 1 deletion jlama-native/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
<env key="LIB_DIR" value="${nativeLibOnlyDir}" />
<env key="OBJ_DIR" value="${nativeObjsOnlyDir}" />
<env key="JNI_PLATFORM" value="${jni.platform}" />
<env key="CFLAGS" value="-O3 -mavx512f -march=native -std=c11 -Werror -Wno-attributes -fPIC -fno-omit-frame-pointer -Wunused-variable" />
<env key="CFLAGS" value="-g -O3 -mavx2 -march=native -Werror -Wno-attributes -fPIC -fno-omit-frame-pointer -Wunused-variable" />
<env key="LDFLAGS" value="-shared" />
<env key="LIB_NAME" value="${nativeLibName}" />
<env KEY="LIB_EXT" value="so" />
Expand Down
Loading

0 comments on commit e2ea5fb

Please sign in to comment.