Skip to content

Commit

Permalink
Add run-cli.sh script; modify readme
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 2, 2023
1 parent 7dd7aa9 commit 56365b9
Show file tree
Hide file tree
Showing 15 changed files with 112 additions and 36 deletions.
22 changes: 10 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,21 @@ Helpful for anyone who wants to understand how LLMs work, or wants to use LLMs i
CPU based inference needs to be pushed to the limit to see if it can be a viable alternative to GPU based inference.

## How to use
As of now the best way to use this is to look at the [TestModels](https://github.com/tjake/Jlama/blob/main/src/test/java/com/github/tjake/jlama/models/TestModels.java) Unit Tests.
Jlama includes a cli tool to run models via the `run-cli.sh` command.
Before you do that first download one or more models from huggingface.

Use the `download_hf_models.sh` script in the data directory to download models from huggingface.
Use the `download-hf-models.sh` script in the data directory to download models from huggingface.

```shell
cd data
./download_hf_model.sh gpt2-medium
./download_hf_model.sh -a XXXXXXXX meta-llama/Llama-2-7b-chat-hf
./download_hf_model.sh intfloat/e5-small-v2
./download-hf-model.sh gpt2-medium
./download-hf-model.sh -a XXXXXXXX meta-llama/Llama-2-7b-chat-hf
./download-hf-model.sh intfloat/e5-small-v2
```
Then run the tests with:

Then run the cli:
```shell
cd ..
./mvnw package -DskipTests
./mvnw test -Dtest=TestModels#GPT2Run
./mvnw test -Dtest=TestModels#LlamaRun
./mvnw test -Dtest=TestModels#BertRun
./run-cli.sh complete -p "The best part of waking up is " -t 0.7 models/Llama-2-7b-chat-hf
./run-cli.sh chat -p "Tell me a joke about cats." models/Llama-2-7b-chat-hf
```
## Caveats

Expand Down
11 changes: 11 additions & 0 deletions conf/logback.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>

<root level="info">
<appender-ref ref="STDOUT" />
</root>
</configuration>
4 changes: 3 additions & 1 deletion models/download_hf_model.sh → download-hf-model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ if [ "$#" -ne 1 ]; then
exit 1
fi

cd models

MODEL="$1"
HTTP_HEADER=
if [ ! -z $HF_ACCESS_TOKEN ]; then
Expand Down Expand Up @@ -108,4 +110,4 @@ done
echo "Downloading ${MODEL_DIR}/tokenizer.model (if exists)"
eval curl -fL --progress-bar $HTTP_HEADER "https://huggingface.co/${MODEL}/resolve/main/tokenizer.model" -o "${MODEL_DIR}/tokenizer.model"

echo "Done!"
echo "Done! Model in ./models/$MODEL_DIR"
6 changes: 6 additions & 0 deletions jlama-cli/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
<artifactId>jlama-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.github.tjake</groupId>
<artifactId>jlama-native</artifactId>
<version>${project.version}</version>
<classifier>${jni.classifier}</classifier>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport.ModelType;
import com.github.tjake.jlama.safetensors.SafeTensorSupport;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;
import picocli.CommandLine;
import picocli.CommandLine.*;

@Command(name="jlama", helpCommand = true)
@Command(name="jlama")
public class JlamaCli implements Runnable {

@Parameters(index = "0", description = "The model location")
protected File model;
static {
System.setProperty("java.util.concurrent.ForkJoinPool.common.parallelism", "" + Math.max(4, Runtime.getRuntime().availableProcessors() / 2));
System.setProperty("jdk.incubator.vector.VECTOR_ACCESS_OOB_CHECK", "0");
TensorOperationsProvider.get();
}

@Option(names = { "-h", "--help" }, usageHelp = true, hidden = true)
boolean helpRequested = false;
Expand All @@ -36,8 +40,7 @@ public static void main(String[] args) {
}

@Override
public void run()
{
public void run() {

}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.github.tjake.jlama.cli.commands;

import java.io.File;

import com.github.tjake.jlama.cli.JlamaCli;
import picocli.CommandLine;

public class BaseCommand extends JlamaCli {
@CommandLine.Parameters(index = "0", arity = "1", description = "The model location")
protected File model;
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ public class ChatCommand extends ModelBaseCommand {
public void run() {
AbstractModel m = loadModel(model);

m.generate(m.wrapPrompt(prompt, Optional.of(systemPrompt)), temperature, tokens, true, makeOutHandler());
m.generate(m.wrapPrompt(prompt, Optional.of(systemPrompt)), prompt, temperature, tokens, true, makeOutHandler());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.github.tjake.jlama.model.AbstractModel;
import picocli.CommandLine.*;

@Command(name = "complete", description = "Completes a prompt using the specified model")
@Command(name = "complete", description = "Completes a prompt using the specified model", mixinStandardHelpOptions = true)
public class CompleteCommand extends ModelBaseCommand {

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.tjake.jlama.cli.JlamaCli;
import com.github.tjake.jlama.model.AbstractModel;
import com.github.tjake.jlama.model.ModelSupport;
import com.github.tjake.jlama.safetensors.Config;
Expand All @@ -21,7 +20,7 @@
import com.github.tjake.jlama.safetensors.WeightLoader;
import picocli.CommandLine.*;

public class ModelBaseCommand extends JlamaCli {
public class ModelBaseCommand extends BaseCommand {

private static final ObjectMapper om = new ObjectMapper()
.configure(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES, false)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package com.github.tjake.jlama.cli.commands;

import com.github.tjake.jlama.cli.JlamaCli;
import picocli.CommandLine;

@CommandLine.Command(name = "serve", description = "Starts a rest api for interacting with this model")
public class ServeCommand extends JlamaCli {
public class ServeCommand extends BaseCommand {
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public abstract class AbstractModel {
private static final ThreadLocal<AbstractTensor[]> tmpArray = new ThreadLocal<>();
private static final ThreadLocal<AbstractTensor[]> tmpArray2 = new ThreadLocal<>();


protected AbstractModel(Config c, WeightLoader w, Tokenizer t)
{
this.c = c;
Expand Down Expand Up @@ -143,6 +142,10 @@ protected int sample(AbstractTensor output, float temperature, float uniformSamp
}

public void generate(String prompt, float temperature, int ntokens, boolean useEOS, BiConsumer<String, Float> onTokenWithTimings) {
generate(prompt, null, temperature, ntokens, useEOS, onTokenWithTimings);
}

public void generate(String prompt, String cleanPrompt, float temperature, int ntokens, boolean useEOS, BiConsumer<String, Float> onTokenWithTimings) {
long[] encoded = tokenizer.encode(prompt);
Preconditions.checkArgument(encoded.length < c.contextLength);

Expand Down Expand Up @@ -171,7 +174,7 @@ public void generate(String prompt, float temperature, int ntokens, boolean useE
AbstractTensor batch[] = batchForward(promptTokens, 0, kvmem);

long promptBatchTime = System.currentTimeMillis() - start;
logger.info("{} prompt tokens in {}ms {} tokens/sec", promptLength, promptBatchTime, Math.round((((double)promptBatchTime)/(double)promptLength)));
logger.debug("{} prompt tokens in {}ms {} tokens/sec", promptLength, promptBatchTime, Math.round((((double)promptBatchTime)/(double)promptLength)));

int tokensGenerated = 0;
AbstractTensor last = batch[batch.length - 1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public LlamaModel(Config config, WeightLoader weights, Tokenizer tokenizer) {

DType qType = DType.Q4;

logger.info("Loading model with {} quantization...", qType);
logger.info("Quantizing model with {} - Please hold...", qType);

//LLama doesn't use bias, will optimize this away later
this.noBias = makeTensor(c.hiddenLength);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.github.tjake.jlama.tensor.operations;


import com.github.tjake.jlama.tensor.FloatBufferTensor;
import com.github.tjake.jlama.util.MachineSpec;
import com.github.tjake.jlama.util.RuntimeSupport;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -36,12 +36,14 @@ private TensorOperations pickFastestImplementaion() {

TensorOperations pick = null;

try {
Class<? extends TensorOperations> nativeClazz = (Class<? extends TensorOperations>) Class.forName("com.github.tjake.jlama.tensor.operations.NativeTensorOperations");
pick = nativeClazz.getConstructor().newInstance();
//This should break of no shared lib found
} catch (Throwable t) {
logger.info("Error loading native operations", t);
if (RuntimeSupport.isLinux()) {
try {
Class<? extends TensorOperations> nativeClazz = (Class<? extends TensorOperations>) Class.forName("com.github.tjake.jlama.tensor.operations.NativeTensorOperations");
pick = nativeClazz.getConstructor().newInstance();
//This should break of no shared lib found
} catch (Throwable t) {
logger.warn("Error loading native operations", t);
}
}

if (pick == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static void init() {
opTypes.add(new NaiveTensorOperations());
opTypes.add(new PanamaTensorOperations());

if (globalOps instanceof NaiveTensorOperations) {
if (globalOps instanceof NativeTensorOperations) {
opTypes.add(new NativeTensorOperations());
opTypes.add(new NativeTensorOperations(0));

Expand Down
42 changes: 42 additions & 0 deletions run-cli.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/bin/bash

# Function to extract the major version of Java
get_java_major_version() {
local version=$(java -version 2>&1 | awk -F '"' '/version/ {print $2}')
echo ${version%%.*}
}

# Verify Java version is JDK 21
JAVA_MAJOR_VERSION=$(get_java_major_version)

if [[ "$JAVA_MAJOR_VERSION" != "21" ]]; then
echo "Error: JDK 21 is required to run this application."
exit 1
fi

# Define the path of the relative JAR
JLAMA_RELATIVE_JAR="./jlama-cli/target/jlama-cli.jar"
# Path to the logback.xml
LOGBACK_CONFIG="./conf/logback.xml"


JLAMA_JVM_ARGS="--add-modules=jdk.incubator.vector --add-exports java.base/sun.nio.ch=ALL-UNNAMED --enable-preview --enable-native-access=ALL-UNNAMED \
-XX:+PreserveFramePointer -XX:+UnlockDiagnosticVMOptions -XX:CompilerDirectivesFile=./inlinerules.json -XX:+AlignVector"

# Check if PREINSTALLED_JAR environment variable is set
if [[ -z "$JLAMA_PREINSTALLED_JAR" ]]; then
# If the relative JAR doesn't exist, build it
if [[ ! -f $JLAMA_RELATIVE_JAR ]]; then
echo "The JAR $JLAMA_RELATIVE_JAR is missing. Attempting to build..."
./mvnw clean package -DskipTests
if [[ $? -ne 0 ]]; then
echo "Error building the JAR. Please check your build setup."
exit 1
fi
fi
# Run the JAR in a relative directory
java $JLAMA_JVM_ARGS -Dlogback.configurationFile=$LOGBACK_CONFIG -jar $JLAMA_RELATIVE_JAR "$@"
else
# If PREINSTALLED_JAR is set, run the JAR specified by the variable
java $JLAMA_JVM_ARGS -jar $JLAMA_PREINSTALLED_JAR "$@"
fi

0 comments on commit 56365b9

Please sign in to comment.