Skip to content

Commit

Permalink
Spotless and javadoc clean
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Jun 23, 2024
1 parent 2876b93 commit 13cc8c3
Show file tree
Hide file tree
Showing 19 changed files with 123 additions and 154 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,12 @@
import com.diogonunes.jcolor.AnsiFormat;
import com.diogonunes.jcolor.Attribute;
import com.github.tjake.jlama.model.AbstractModel;

import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import java.io.PrintWriter;
import java.nio.charset.Charset;
import java.util.Optional;
import java.util.Scanner;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;

import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import picocli.CommandLine.*;

@Command(name = "chat", description = "Interact with the specified model")
Expand Down Expand Up @@ -95,20 +91,13 @@ public void run() {
builder.addUserMessage(prompt);
String builtPrompt = builder.build();

m.generate(
session,
builtPrompt,
temperature,
Integer.MAX_VALUE,
false,
makeOutHandler());
m.generate(session, builtPrompt, temperature, Integer.MAX_VALUE, false, makeOutHandler());

first = false;
}
}

protected BiConsumer<String, Float> makeOutHandler()
{
protected BiConsumer<String, Float> makeOutHandler() {
PrintWriter out;
BiConsumer<String, Float> outCallback;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@
package com.github.tjake.jlama.cli.serve;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.functions.Generator;
import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import java.io.IOException;
import java.util.Optional;
import java.util.UUID;
import javax.validation.constraints.NotNull;
import javax.ws.rs.*;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;

import com.github.tjake.jlama.safetensors.tokenizer.PromptSupport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -49,25 +46,23 @@ public Response generate(@NotNull GenerateParams params) {
UUID sessionId = params.sessionId == null ? UUID.randomUUID() : params.sessionId;

if (params.prompt == null) {
return Response.status(Response.Status.BAD_REQUEST).entity("prompt is required").build();
return Response.status(Response.Status.BAD_REQUEST)
.entity("prompt is required")
.build();
}

String prompt = params.prompt;
if (model.getTokenizer().promptSupport().isPresent()) {
PromptSupport.Builder builder = model.getTokenizer().promptSupport().get().newBuilder();
PromptSupport.Builder builder =
model.getTokenizer().promptSupport().get().newBuilder();
builder.addUserMessage(prompt);
prompt = builder.build();
}

final String finalPrompt = prompt;

StreamingOutput so = os -> model.generate(
sessionId,
finalPrompt,
0.7f,
Integer.MAX_VALUE,
false,
(s, timing) -> {
StreamingOutput so =
os -> model.generate(sessionId, finalPrompt, 0.7f, Integer.MAX_VALUE, false, (s, timing) -> {
try {
logger.info("'{}' took {}ms", s, timing);
os.write(om.writeValueAsBytes(new GenerateResponse(s, false)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ public static void pchunk(int offset, int length, BiIntConsumer action) {
PhysicalCoreExecutor.instance.get().execute(() -> IntStream.range(0, fsplits)
.parallel()
.forEach(i -> action.accept(
offset + (i * fchunkSize), fremainder > 0 && i == fsplits - 1 ? fchunkSize + fremainder : fchunkSize)));
offset + (i * fchunkSize),
fremainder > 0 && i == fsplits - 1 ? fchunkSize + fremainder : fchunkSize)));
}

public static void softMax(AbstractTensor x, int offset, int length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,8 @@ public enum ModelType {
}
}

public static AbstractModel loadModel(
File model,
DType workingMemoryType,
DType workingQuantizationType) {
return loadModel(
model,
null,
workingMemoryType,
workingQuantizationType,
Optional.empty(),
Optional.empty());
public static AbstractModel loadModel(File model, DType workingMemoryType, DType workingQuantizationType) {
return loadModel(model, null, workingMemoryType, workingQuantizationType, Optional.empty(), Optional.empty());
}

public static AbstractModel loadModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;

import java.util.Arrays;
import java.util.Optional;
import java.util.stream.StreamSupport;

public class BertModel extends AbstractModel {

Expand Down Expand Up @@ -128,11 +126,14 @@ protected SampleOutput loadOutputWeights() {
}

public float[] embed(String input) {
int[] encoded = Arrays.stream(tokenizer.encode(input)).mapToInt(Ints::checkedCast).toArray();
int[] encoded = Arrays.stream(tokenizer.encode(input))
.mapToInt(Ints::checkedCast)
.toArray();
Preconditions.checkArgument(encoded.length < c.contextLength);
float[] outputEmbedding = new float[c.embeddingLength];

try (AbstractTensor kvmem = makeTensor(c.getNumberOfLayers(), 2, encoded.length, c.embeddingLength)) { // 2 for key and value
try (AbstractTensor kvmem =
makeTensor(c.getNumberOfLayers(), 2, encoded.length, c.embeddingLength)) { // 2 for key and value

int promptLength = encoded.length;
float avgp = 1.0f / promptLength;
Expand All @@ -142,8 +143,7 @@ public float[] embed(String input) {
AbstractTensor output = r.slice(i);

// Average Pooling
for (int ii = 0; ii < c.embeddingLength; ii++)
outputEmbedding[ii] += output.get(0, ii) * avgp;
for (int ii = 0; ii < c.embeddingLength; ii++) outputEmbedding[ii] += output.get(0, ii) * avgp;
}
r.close();
VectorMath.l2normalize(outputEmbedding);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
package com.github.tjake.jlama.model.functions;

import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;

import java.util.Optional;
import java.util.UUID;
import java.util.function.BiConsumer;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import com.github.tjake.jlama.model.LayerNorm;
import com.github.tjake.jlama.tensor.AbstractTensor;


public interface SampleOutput {

LayerNorm getOutputLayerNorm();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ protected EmbedInput loadInputWeights() {

// This is important for Gemma, but not for Llama
TensorOperationsProvider.get()
.scale(
embeddingScalingFactor,
embedding,
c.embeddingSegmentStart(),
c.embeddingSegmentLength());
.scale(embeddingScalingFactor, embedding, c.embeddingSegmentStart(), c.embeddingSegmentLength());

return embedding;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
package com.github.tjake.jlama.model.llama;

import com.github.tjake.jlama.safetensors.tokenizer.BPETokenizer;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import java.nio.file.Path;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -50,11 +48,10 @@ protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {

@Override
protected String preProcess(String sentence) {
if (model.normalizer() != null)
sentence = model.normalizer().normalize(sentence);
if (model.normalizer() != null) sentence = model.normalizer().normalize(sentence);

if (model.isLegacy() && !model.byteFallback) {
sentence = sentence.codePoints()
sentence = sentence.codePoints()
.map(c -> alteredBytes.getOrDefault(c, c))
.mapToObj(Character::toString)
.collect(Collectors.joining());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public static TokenizerModel loadTokenizer(Path modelRoot) throws IOException {
Map<String, String> promptTemplates = new HashMap<>();
if (chatTemplateNode.isTextual()) {
promptTemplates.put("default", chatTemplateNode.asText());
} else if (chatTemplateNode.isArray()){
} else if (chatTemplateNode.isArray()) {
List<Map<String, String>> chatTemplates = om.convertValue(chatTemplateNode, List.class);
for (Map<String, String> chatTemplate : chatTemplates) {
if (chatTemplate.containsKey("name") && chatTemplate.containsKey("template")) {
Expand Down Expand Up @@ -309,7 +309,8 @@ public static File maybeDownloadModel(String modelDir, String fullModelName) thr
name = parts[1];
}

return maybeDownloadModel(modelDir, Optional.ofNullable(owner), name, Optional.empty(), Optional.empty(), Optional.empty());
return maybeDownloadModel(
modelDir, Optional.ofNullable(owner), name, Optional.empty(), Optional.empty(), Optional.empty());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@

import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.util.*;
import java.util.stream.Collectors;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -72,8 +71,7 @@ public List<String> tokenize(String sentence) {

if (sentence.isEmpty()) return Collections.emptyList();

if (model.preTokenizer() == null && model.addedTokenPattern() == null)
Collections.singletonList(sentence);
if (model.preTokenizer() == null && model.addedTokenPattern() == null) Collections.singletonList(sentence);

List<String> sentencePieces = new ArrayList<>();
if (model.addedTokenPattern() != null) {
Expand All @@ -82,15 +80,13 @@ public List<String> tokenize(String sentence) {
String[] pieces = model.addedTokenPattern().splitWithDelimiters(sentence, 0);
for (String piece : pieces) {
if (!piece.isEmpty()) {
if (model.addedTokens().containsKey(piece))
sentencePieces.add(piece);
if (model.addedTokens().containsKey(piece)) sentencePieces.add(piece);
else if (model.preTokenizer() != null)
sentencePieces.addAll(model.preTokenizer().pretokenize(piece));
else
sentencePieces.add(piece);
else sentencePieces.add(piece);
}
}
} else if (model.preTokenizer() != null){
} else if (model.preTokenizer() != null) {
sentencePieces.addAll(model.preTokenizer().pretokenize(sentence));
} else {
sentencePieces.add(sentence);
Expand Down Expand Up @@ -172,8 +168,7 @@ public long[] encode(String rawSentence) {
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens.set((int) bestIdx, bestId);
// delete token at position best_idx+1, shift the entire sequence back 1
for (int j = n; j > 0; j--)
tokens.remove((int) bestIdx + j);
for (int j = n; j > 0; j--) tokens.remove((int) bestIdx + j);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
/*
* Copyright 2024 T Jake Luciani
*
* The Jlama Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package com.github.tjake.jlama.safetensors.tokenizer;


import com.hubspot.jinjava.Jinjava;
import com.hubspot.jinjava.JinjavaConfig;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -13,8 +26,11 @@
* @see <a href="https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models">Chat Templating</a>
*/
public class PromptSupport {
//This matches the jinja config in huggingface
private static final Jinjava jinjava = new Jinjava(JinjavaConfig.newBuilder().withLstripBlocks(true).withTrimBlocks(true).build());
// This matches the jinja config in huggingface
private static final Jinjava jinjava = new Jinjava(JinjavaConfig.newBuilder()
.withLstripBlocks(true)
.withTrimBlocks(true)
.build());

private final TokenizerModel m;

Expand Down Expand Up @@ -105,12 +121,22 @@ public String build() {
throw new UnsupportedOperationException("Prompt templates are not available for this model");
}

String template = m.promptTemplates().map(t -> t.get(type.name().toLowerCase())).orElseThrow(() -> new UnsupportedOperationException("Prompt template not available for type: " + type));

return jinjava.render(template, Map.of("messages", messages,
"add_generation_prompt", addGenerationPrompt,
"eos_token", m.eosToken(),
"bos_token", m.bosToken()));
String template = m.promptTemplates()
.map(t -> t.get(type.name().toLowerCase()))
.orElseThrow(
() -> new UnsupportedOperationException("Prompt template not available for type: " + type));

return jinjava.render(
template,
Map.of(
"messages",
messages,
"add_generation_prompt",
addGenerationPrompt,
"eos_token",
m.eosToken(),
"bos_token",
m.bosToken()));
}
}
}
Loading

0 comments on commit 13cc8c3

Please sign in to comment.