Skip to content

Commit

Permalink
Add num_threads argument for cache_prompt.lua script
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Jun 28, 2024
1 parent 903caa3 commit 3d5a0f6
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tools/cache_prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ if args["help"] then
print("Usage: cat /path/to/prompt.txt | resty cache_prompt.lua [options]")
print()
print("Available options:")
print(" --num_threads: Number of threads in scheduler. (default: hardware concurrency)")
print(" --tokenizer: Path of tokenizer model file. (default: tokenizer.spm)")
print(" --model: Model type (default: 2b-pt)")
print(" 2b-it = 2B parameters, instruction-tuned")
Expand All @@ -35,13 +36,21 @@ if args["help"] then
return
end

-- Create a scheduler instance
local sched, err = require("cgemma").scheduler(tonumber(args["num_threads"]))
if not sched then
print("Opoos! ", err)
return
end

print("Loading model ...")
-- Create a Gemma instance
local gemma, err = require("cgemma").new({
tokenizer = args["tokenizer"] or "tokenizer.spm",
model = args["model"] or "2b-pt",
weights = args["weights"] or "2b-it-sfp.sbs",
weight_type = args["weight_type"]
weight_type = args["weight_type"],
scheduler = sched
})
if not gemma then
print("Opoos! ", err)
Expand Down

0 comments on commit 3d5a0f6

Please sign in to comment.