diff --git a/tools/cache_prompt.lua b/tools/cache_prompt.lua index 364e3f1..e104f1e 100644 --- a/tools/cache_prompt.lua +++ b/tools/cache_prompt.lua @@ -17,6 +17,7 @@ if args["help"] then print("Usage: cat /path/to/prompt.txt | resty cache_prompt.lua [options]") print() print("Available options:") + print(" --num_threads: Number of threads in scheduler. (default: hardware concurrency)") print(" --tokenizer: Path of tokenizer model file. (default: tokenizer.spm)") print(" --model: Model type (default: 2b-pt)") print(" 2b-it = 2B parameters, instruction-tuned") @@ -35,13 +36,21 @@ if args["help"] then return end +-- Create a scheduler instance +local sched, err = require("cgemma").scheduler(tonumber(args["num_threads"])) +if not sched then + print("Opoos! ", err) + return +end + print("Loading model ...") -- Create a Gemma instance local gemma, err = require("cgemma").new({ tokenizer = args["tokenizer"] or "tokenizer.spm", model = args["model"] or "2b-pt", weights = args["weights"] or "2b-it-sfp.sbs", - weight_type = args["weight_type"] + weight_type = args["weight_type"], + scheduler = sched }) if not gemma then print("Opoos! ", err)