Skip to content

Commit

Permalink
Fix bert model, improve chat cmd, rm js libs
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 19, 2024
1 parent 83b9f86 commit debd64a
Show file tree
Hide file tree
Showing 13 changed files with 56 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ public class ChatCommand extends BaseCommand {
defaultValue = ".9")
protected Float topp;

@Option(
names = {"--override-prompt-start"},
description = "Override of prompt instruction format before the prompt (example: [INST] )")
protected String promptStart;

@Option(
names = {"--override-prompt-end"},
description = "Override of prompt instruction format after the prompt (example: [/INST] )")
protected String promptEnd;

@Override
public void run() {
AbstractModel m = loadModel(
Expand All @@ -67,20 +77,21 @@ public void run() {
UUID session = UUID.randomUUID();
PrintWriter out = System.console().writer();

out.println("Chatting with " + model + "...\n\n");
out.println("Chatting with " + model + "...\n");
out.flush();
Scanner sc = new Scanner(System.in);
boolean first = true;
while (true) {
out.print("You: ");
out.print("\nYou: ");
out.flush();
String prompt = sc.nextLine();
out.println();
out.flush();
if (prompt.isEmpty()) {
break;
}
String wrappedPrompt = m.wrapPrompt(prompt, first ? Optional.ofNullable(systemPrompt) : Optional.empty());

String wrappedPrompt = wrap(m, prompt, first ? Optional.ofNullable(systemPrompt) : Optional.empty());
m.generate(
session,
wrappedPrompt,
Expand Down Expand Up @@ -110,4 +121,12 @@ protected BiConsumer<String, Float> makeOutHandler()
return outCallback;
}

protected String wrap(AbstractModel model, String prompt, Optional<String> systemPrompt) {
if (promptStart == null && promptEnd == null) {
return model.wrapPrompt(prompt, systemPrompt);
}

return promptStart + prompt + promptEnd;
}

}
3 changes: 0 additions & 3 deletions jlama-cli/src/main/resources/background.js

This file was deleted.

16 changes: 5 additions & 11 deletions jlama-cli/src/main/resources/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@
<head>
<title>Chat with Jlama</title>
<link rel="shortcut icon" type="image/x-icon" href="favicon.ico">
<link href="resources/bootstrap.min.css" rel="stylesheet"
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous">
<script src="resources/bootstrap.bundle.min.js"
integrity="sha384-HwwvtgBNo3bZJJLYd8oVXjrBZt8cqVSpeBNS5n7C8IVInixGAoxmnlMuBnhbgrkm"
crossorigin="anonymous"></script>
<script src="resources/marked.min.js"
integrity="sha384-dZulhREgb+hCgQMhZ2VG0l37VQj5pJBq2w0h7Jn3tdMn36aXNepF1+FMLBB4O649"
crossorigin="anonymous"></script>
<script src="resources/purify.min.js" integrity="sha256-QigBQMy2be3IqJD2ezKJUJ5gycSmyYlRHj2VGBuITpU="
crossorigin="anonymous"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.3.3/css/bootstrap.min.css" rel="stylesheet"
crossorigin="anonymous">
<script src="https://cdnjs.cloudflare.com/ajax/libs/bootstrap/5.3.3/js/bootstrap.bundle.min.js" integrity="sha512-7Pi/otdlbbCR+LnW+F7PwFcSDJOuUJB3OxtEHbg4vSMvzvJjde4Po1v4BR9Gdc9aXNUNFVUY+SK51wWT8WF0Gg==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/marked/13.0.0/marked.min.js" integrity="sha512-NlNyxz9EmQt8NOeczUXqghpfmRIHlCfE5qRKftWYA44tf8sveWGZhSHxVtGtyHlmqdt89f66F26aWi+kTDz8RQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/3.1.5/purify.min.js" integrity="sha512-JatFEe90fJU2nrgf27fUz2hWRvdYrSlTEV8esFuqCtfiqWN8phkS1fUl/xCfYyrLDQcNf3YyS0V9hG7U4RHNmQ==" crossorigin="anonymous" referrerpolicy="no-referrer"></script>
<link rel="stylesheet" href="chat.css">
</head>

Expand Down Expand Up @@ -42,7 +37,6 @@ <h1>Chat with Jlama</h1>

<script src="api.js"></script>
<script src="chat.js"></script>

</body>

</html>

This file was deleted.

6 changes: 0 additions & 6 deletions jlama-cli/src/main/resources/resources/bootstrap.min.css

This file was deleted.

15 changes: 0 additions & 15 deletions jlama-cli/src/main/resources/resources/marked.min.js

This file was deleted.

3 changes: 0 additions & 3 deletions jlama-cli/src/main/resources/resources/purify.min.js

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ protected AbstractModel(
this.kvBufferCache = new KvBufferCache(this);

// FIXME: This is a hack to support Avoid Q8F32 evals
if (modelDType == DType.F32 && workingMemoryQType != DType.F32) {
if (modelDType == DType.F32 && workingMemoryQType != DType.F32 && modelQType.isEmpty()) {
workingMemoryQType = DType.F32;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.util.Arrays;
import java.util.Optional;
import java.util.stream.StreamSupport;

public class BertModel extends AbstractModel {

Expand All @@ -37,7 +41,7 @@ public BertModel(
DType workingDType,
DType workingQType,
Optional<DType> modelQType) {
super(InferenceType.FULL_GENERATION, c, w, tokenizer, workingDType, workingQType, modelQType);
super(InferenceType.FORWARD_PASS, c, w, tokenizer, workingDType, workingQType, modelQType);
}

public BertModel(
Expand Down Expand Up @@ -65,7 +69,7 @@ protected EmbedInput loadInputWeights() {

for (int i = 0; i < c.embeddingLength; i++) {
float v = we.get(inputToken, i) + wte.get(0, i) + wpe.get(position, i);
embedding.set(v, i);
embedding.set(v, 0, i);
}

AbstractTensor lnemb = inputLayerNorm.forward(embedding);
Expand Down Expand Up @@ -124,29 +128,26 @@ protected SampleOutput loadOutputWeights() {
}

public float[] embed(String input) {
long[] encoded = tokenizer.encode(input);
int[] encoded = Arrays.stream(tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();
Preconditions.checkArgument(encoded.length < c.contextLength);

AbstractTensor kvmem =
makeTensor(c.getNumberOfLayers(), encoded.length, 2, c.embeddingLength); // 2 for key and value

int promptLength = encoded.length;
float avgp = 1.0f / promptLength;

float[] outputEmbedding = new float[c.embeddingLength];

for (int i = 0; i < promptLength; i++) {
int next = (int) encoded[i];
AbstractTensor output = forward(next, i, kvmem);
try (AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), 2, encoded.length, c.embeddingLength)) { // 2 for key and value

// Average Pooling
for (int ii = 0; ii < c.embeddingLength; ii++) outputEmbedding[ii] += output.get(ii) * avgp;
int promptLength = encoded.length;
float avgp = 1.0f / promptLength;

output.close();
}
AbstractTensor r = batchForward(encoded, 0, kvmem);
for (int i = 0; i < promptLength; i++) {
AbstractTensor output = r.slice(i);

VectorMath.l2normalize(outputEmbedding);
kvmem.close();
// Average Pooling
for (int ii = 0; ii < c.embeddingLength; ii++)
outputEmbedding[ii] += output.get(0, ii) * avgp;
}
r.close();
VectorMath.l2normalize(outputEmbedding);
}
return outputEmbedding;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
import com.github.tjake.jlama.tensor.TensorShape;
import com.google.common.base.Preconditions;

/**
* Used to define a function that maps input tokens to embeddings
*/
public interface EmbedInput {
AbstractTensor inputTokenToEmbedding(int inputToken, int position);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import java.util.UUID;
import java.util.function.BiConsumer;

/**
* Used to define a function that generates tokens from a prompt
*/
public interface Generator {
default void generate(
UUID session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.tensor.AbstractTensor;


public interface SampleOutput {

LayerNorm getOutputLayerNorm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public void TinyLlamaRun() throws Exception {

@Test
public void BertRun() throws Exception {
String modelPrefix = "models/e5-small-v2";
String modelPrefix = "../models/e5-small-v2";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

try (RandomAccessFile sc = new RandomAccessFile(modelPrefix + "/model.safetensors", "r")) {
Expand Down

0 comments on commit debd64a

Please sign in to comment.