Skip to content

Commit

Permalink
Add prompt support via HF tokenizer_config.json
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 22, 2024
1 parent debd64a commit e49f436
Show file tree
Hide file tree
Showing 23 changed files with 550 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import picocli.CommandLine.*;

@Command(name = "chat", description = "Interact with the specified model")
Expand All @@ -54,16 +55,6 @@ public class ChatCommand extends BaseCommand {
defaultValue = ".9")
protected Float topp;

@Option(
names = {"--override-prompt-start"},
description = "Override of prompt instruction format before the prompt (example: [INST] )")
protected String promptStart;

@Option(
names = {"--override-prompt-end"},
description = "Override of prompt instruction format after the prompt (example: [/INST] )")
protected String promptEnd;

@Override
public void run() {
AbstractModel m = loadModel(
Expand All @@ -74,7 +65,13 @@ public void run() {
Optional.ofNullable(modelQuantization),
Optional.ofNullable(threadCount));

if (m.promptSupport().isEmpty()) {
System.err.println("This model does not support chat prompting");
System.exit(1);
}

UUID session = UUID.randomUUID();
PromptSupport promptSupport = m.promptSupport().get();
PrintWriter out = System.console().writer();

out.println("Chatting with " + model + "...\n");
Expand All @@ -91,14 +88,19 @@ public void run() {
break;
}

String wrappedPrompt = wrap(m, prompt, first ? Optional.ofNullable(systemPrompt) : Optional.empty());
PromptSupport.Builder builder = promptSupport.newBuilder();
if (first && systemPrompt != null) {
builder.addSystemMessage(systemPrompt);
}
builder.addUserMessage(prompt);
String builtPrompt = builder.build();

m.generate(
session,
wrappedPrompt,
prompt,
builtPrompt,
temperature,
Integer.MAX_VALUE,
true,
false,
makeOutHandler());

first = false;
Expand All @@ -120,13 +122,4 @@ protected BiConsumer<String, Float> makeOutHandler()

return outCallback;
}

protected String wrap(AbstractModel model, String prompt, Optional<String> systemPrompt) {
if (promptStart == null && promptEnd == null) {
return model.wrapPrompt(prompt, systemPrompt);
}

return promptStart + prompt + promptEnd;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ public Response generate(@NotNull GenerateParams params) {
UUID sessionId = params.sessionId == null ? UUID.randomUUID() : params.sessionId;
StreamingOutput so = os -> model.generate(
sessionId,
model.wrapPrompt(params.prompt, Optional.empty()),
"",
params.prompt,
0.7f,
Integer.MAX_VALUE,
false,
Expand Down
6 changes: 6 additions & 0 deletions jlama-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,11 @@
<artifactId>jctools-core</artifactId>
<version>4.0.1</version>
</dependency>

<dependency>
<groupId>com.hubspot.jinjava</groupId>
<artifactId>jinjava</artifactId>
<version>2.7.2</version>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.github.tjake.jlama.safetensors.Config;
import com.github.tjake.jlama.safetensors.DType;
import com.github.tjake.jlama.safetensors.WeightLoader;
import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.KvBufferCache;
Expand Down Expand Up @@ -150,8 +151,8 @@ public Tokenizer getTokenizer() {
return tokenizer;
}

public String wrapPrompt(String prompt, Optional<String> systemPrompt) {
return prompt;
public Optional<PromptSupport> promptSupport() {
return tokenizer.promptSupport();
}

public AbstractTensor makeTensor(int... shape) {
Expand Down Expand Up @@ -268,12 +269,12 @@ public int sample(AbstractTensor output, float temperature, float uniformSample,
public void generate(
UUID sessionId,
String prompt,
String cleanPrompt,
float temperature,
int ntokens,
boolean useEOS,
BiConsumer<String, Float> onTokenWithTimings) {
long[] encoded = tokenizer.encode(prompt);
System.out.println("COPY: " + tokenizer.decode(encoded));
Preconditions.checkArgument(encoded.length < c.contextLength);

AbstractTensor kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window
Expand All @@ -285,12 +286,9 @@ public void generate(
if (ntokens > c.contextLength) ntokens = c.contextLength;

try (AbstractTensor logits = makeTensor(c.vocabularySize)) {

int[] promptTokens = new int[useEOS ? (1 + encoded.length + 1) : (1 + encoded.length)];

promptTokens[0] = c.bosToken;
for (int i = 1; i <= encoded.length; i++) promptTokens[i] = Ints.checkedCast(encoded[i - 1]);

int promptLength = encoded.length;

if (useEOS) {
Expand Down Expand Up @@ -322,7 +320,7 @@ public void generate(
}

start = System.currentTimeMillis();
for (int i = startPos + promptTokens.length - 1; i < ntokens; i++) {
for (int i = startPos + promptTokens.length; i < ntokens; i++) {
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ default AbstractTensor batchInputsToEmbeddings(int[] inputTokens, int startPos)

VectorMath.pfor(1, inputTokens.length, i -> {
AbstractTensor ti = inputTokenToEmbedding(inputTokens[i], startPos + i);

tb.copyFrom(ti, 0, i * ti.shape().sparseLength(), ti.shape().sparseLength());

ti.close();
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,11 @@
* Used to define a function that generates tokens from a prompt
*/
public interface Generator {
default void generate(
UUID session,
String prompt,
float temperature,
int ntokens,
boolean useEOS,
BiConsumer<String, Float> onTokenWithTimings) {
generate(session, prompt, null, temperature, ntokens, useEOS, onTokenWithTimings);
}

void generate(
UUID session,
String prompt,
String cleanPrompt,
float temperature,
int ntokens,
boolean useEOS,
BiConsumer<String, Float> onTokenWithTimings);

String wrapPrompt(String prompt, Optional<String> systemPrompt);
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@
public class GemmaModel extends LlamaModel {
private static final Logger logger = LoggerFactory.getLogger(GemmaModel.class);

private final float embeddingScalingFactor;

public GemmaModel(
Config config,
WeightLoader weights,
Tokenizer tokenizer,
DType workingDType,
DType workingQType,
Optional<DType> modelQType) {
super(config, weights, tokenizer, workingDType, workingQType, modelQType);
this(InferenceType.FULL_GENERATION, config, weights, tokenizer, workingDType, workingQType, modelQType);
}

public GemmaModel(
Expand All @@ -56,6 +58,7 @@ public GemmaModel(
DType workingQType,
Optional<DType> modelQType) {
super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType);
this.embeddingScalingFactor = (float) Math.pow(c.embeddingLength, 0.5);
}

private AbstractTensor wte;
Expand Down Expand Up @@ -125,7 +128,7 @@ protected EmbedInput loadInputWeights() {
// This is important for Gemma, but not for Llama
TensorOperationsProvider.get()
.scale(
(float) Math.pow(c.embeddingLength, 0.5),
embeddingScalingFactor,
embedding,
c.embeddingSegmentStart(),
c.embeddingSegmentLength());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,6 @@ public AbstractTensor getOutputLogitsWeights() {
};
}

@Override
public String wrapPrompt(String prompt, Optional<String> systemPrompt) {
StringBuilder b = new StringBuilder();
b.append("[INST] ");
if (systemPrompt.isPresent()) {
b.append("<<SYS>> \n").append(systemPrompt.get()).append("\n<</SYS>> \n\n");
}
b.append(prompt).append(" [/INST]");

return b.toString();
}

@Override
protected AbstractTensor maybeQuantize(AbstractTensor t) {
Preconditions.checkArgument(t.dims() == 2, "Unexpected shape");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,10 @@
public class LlamaTokenizer extends BPETokenizer {
static final String SPIECE_UNDERLINE = "▁";

private static BiMap<Integer, Integer> alteredBytes; // Codepoint and Token mapping needed for legacy mode

static {
// https://github.com/openai/gpt-2/blob/master/src/encoder.py#L19
alteredBytes = HashBiMap.create();
int i = 0;
for (int c = 0; c < 256; c++) {
if ((c < '!' || c > '~') && (c < '¡' || c > '¬') && (c < '®' || c > 'ÿ')) {
int codepoint = (i++ + 256);
alteredBytes.put(c, codepoint);
}
}
}

private final int byteFallbackEncodingOffset;

public LlamaTokenizer(Path modelRoot) {
super(modelRoot);

this.byteFallbackEncodingOffset = 3;
}

Expand All @@ -65,17 +50,17 @@ protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {

@Override
protected String preProcess(String sentence) {
if (!model.isLegacy()) {
if (!sentence.isEmpty() && !sentence.startsWith(" ")) {
sentence = " " + sentence;
}
return sentence.replace(" ", SPIECE_UNDERLINE);
} else {
return sentence.codePoints()
if (model.normalizer() != null)
sentence = model.normalizer().normalize(sentence);

if (model.isLegacy() && !model.byteFallback) {
sentence = sentence.codePoints()
.map(c -> alteredBytes.getOrDefault(c, c))
.mapToObj(Character::toString)
.collect(Collectors.joining());
}

return sentence;
}

@Override
Expand All @@ -90,8 +75,8 @@ protected String postProcessToken(String decoded) {
decoded = decoded.replaceAll("</?s>", "");
decoded = decoded.replaceAll(SPIECE_UNDERLINE, " ");

if (model.isLegacy()) {
return decoded.codePoints()
if (model.isLegacy() && !model.byteFallback) {
decoded = decoded.codePoints()
.map(c -> alteredBytes.inverse().getOrDefault(c, c))
.mapToObj(Character::toString)
.collect(Collectors.joining());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,4 @@ public MistralModel(
Optional<DType> modelQType) {
super(inferenceType, config, weights, tokenizer, workingDType, workingQType, modelQType);
}

@Override
public String wrapPrompt(String prompt, Optional<String> systemPrompt) {
StringBuilder b = new StringBuilder();

if (systemPrompt.isPresent()) {
b.append(systemPrompt.get()).append("\n\n");
}

b.append("[INST] ").append(prompt).append(" [/INST]");

return b.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,51 @@ public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException {

TokenizerModel model = om.treeToValue(rootNode.get("model"), TokenizerModel.class);

if (rootNode.has("added_tokens") && rootNode.get("added_tokens") != null) {
List<Map<String, Object>> addedTokens = om.convertValue(rootNode.get("added_tokens"), List.class);
model.setAddedTokens(addedTokens);
}

if (rootNode.has("pre_tokenizer") && rootNode.get("pre_tokenizer") != null)
model.setPreTokenizer(om.treeToValue(rootNode.get("pre_tokenizer"), TokenizerModel.PreTokenizer.class));

if (rootNode.has("normalizer") && rootNode.get("normalizer") != null)
model.setNormalizer(om.treeToValue(rootNode.get("normalizer"), TokenizerModel.Normalizer.class));

File tokenizerConfigJson = modelRoot.resolve("tokenizer_config.json").toFile();
if (tokenizerConfigJson.exists()) {
JsonNode configNode = om.readTree(tokenizerConfigJson);
if (configNode.has("legacy"))
model.setLegacy(configNode.get("legacy").asBoolean());

if (configNode.has("chat_template")) {
JsonNode chatTemplateNode = configNode.get("chat_template");
Map<String, String> promptTemplates = new HashMap<>();
if (chatTemplateNode.isTextual()) {
promptTemplates.put("default", chatTemplateNode.asText());
} else if (chatTemplateNode.isArray()){
List<Map<String, String>> chatTemplates = om.convertValue(chatTemplateNode, List.class);
for (Map<String, String> chatTemplate : chatTemplates) {
if (chatTemplate.containsKey("name") && chatTemplate.containsKey("template")) {
promptTemplates.put(chatTemplate.get("name"), chatTemplate.get("template"));
} else {
throw new IllegalArgumentException("Invalid chat_template format");
}
}
} else {
throw new IllegalArgumentException("Invalid chat_template format");
}

model.setPromptTemplates(promptTemplates);
}

if (configNode.has("eos_token")) {
model.setEosToken(configNode.get("eos_token").asText());
}

if (configNode.has("bos_token")) {
model.setBosToken(configNode.get("bos_token").asText());
}
}

return model;
Expand Down Expand Up @@ -218,6 +255,9 @@ public static Path quantizeModel(
Files.copy(modelRoot.resolve("config.json"), qPath.resolve("config.json"));
Files.copy(modelRoot.resolve("tokenizer.json"), qPath.resolve("tokenizer.json"));

if (Files.exists(modelRoot.resolve("tokenizer_config.json")))
Files.copy(modelRoot.resolve("tokenizer_config.json"), qPath.resolve("tokenizer_config.json"));

try (RandomAccessFile raf =
new RandomAccessFile(qPath.resolve("model.safetensors").toFile(), "rw")) {
FileChannel chan = raf.getChannel();
Expand Down
Loading

0 comments on commit e49f436

Please sign in to comment.