Skip to content

Commit

Permalink
Implement cli-args parsing for examples
Browse files Browse the repository at this point in the history
  • Loading branch information
ufownl committed Jun 28, 2024
1 parent f472366 commit 85b8e33
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 70 deletions.
16 changes: 13 additions & 3 deletions examples/ai_function.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
local args = require("argparse").parse(arg)
if args["help"] then
require("argparse").help(
"AI function demo.",
"resty ai_function.lua [options]"
)
return
end

-- Create a Gemma instance
local gemma, err = require("cgemma").new({
tokenizer = "tokenizer.spm",
model = "2b-it",
weights = "2b-it-sfp.sbs"
tokenizer = args["tokenizer"] or "tokenizer.spm",
model = args["model"] or "2b-it",
weights = args["weights"] or "2b-it-sfp.sbs",
weight_type = args["weight_type"]
})
if not gemma then
print("Opoos! ", err)
Expand Down
36 changes: 36 additions & 0 deletions examples/argparse.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
local _M = {}

function _M.parse(cli_args)
local args = {}
for i, v in ipairs(cli_args) do
if string.sub(v, 1, 2) == "--" then
if cli_args[i + 1] and string.sub(cli_args[i + 1], 1, 2) ~= "--" then
args[string.sub(v, 3)] = cli_args[i + 1]
else
args[string.sub(v, 3)] = true
end
end
end
return args
end

function _M.help(description, usage)
require("cgemma").info()
print()
print(description)
print()
print(string.format("Usage: %s", usage))
print()
print("Available options:")
print(" --tokenizer: Path of tokenizer model file. (default: tokenizer.spm)")
print(" --model: Model type (default: 2b-it)")
print(" 2b-it = 2B parameters, instruction-tuned")
print(" 7b-it = 7B parameters instruction-tuned")
print(" 9b-it = 9B parameters instruction-tuned")
print(" 27b-it = 27B parameters instruction-tuned")
print(" gr2b-it = griffin 2B parameters, instruction-tuned")
print(" --weights: Path of model weights file. (default: 2b-it-sfp.sbs)")
print(" --weight_type: Weight type (default: sfp)")
end

return _M
71 changes: 71 additions & 0 deletions examples/normal_mode.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
local args = require("argparse").parse(arg)
if args["help"] then
require("argparse").help(
"Normal mode chatbot demo.",
"resty normal_mode.lua [options]"
)
print(" --kv_cache: Path of KV cache file.")
return
end

-- Create a Gemma instance
local gemma, err = require("cgemma").new({
tokenizer = args["tokenizer"] or "tokenizer.spm",
model = args["model"] or "2b-it",
weights = args["weights"] or "2b-it-sfp.sbs",
weight_type = args["weight_type"]
})
if not gemma then
print("Opoos! ", err)
return
end

-- Create a chat session
local session, seed = gemma:session()
if not session then
print("Opoos! ", seed)
return
end

print("Random seed of session: ", seed)
while true do
if args["kv_cache"] then
-- Restore the previous session
local ok, err = session:load(args["kv_cache"])
if ok then
print("Previous conversation restored")
else
print("New conversation started")
end
else
print("New conversation started")
end

-- Multi-turn chat
while session:ready() do
io.write("> ")
local text = io.read()
if not text then
if args["kv_cache"] then
print("End of file, dumping current session ...")
-- Dump the current session
local ok, err = session:dump(args["kv_cache"])
if not ok then
print("Opoos! ", err)
return
end
end
print("Done")
return
end
local reply, err = session(text)
if not reply then
print("Opoos! ", err)
return
end
print("reply: ", reply)
end

print("Exceed the maximum number of tokens")
session:reset()
end
43 changes: 30 additions & 13 deletions examples/stream_mode.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
local args = require("argparse").parse(arg)
if args["help"] then
require("argparse").help(
"Stream mode chatbot demo.",
"resty stream_mode.lua [options]"
)
print(" --kv_cache: Path of KV cache file.")
return
end

-- Create a Gemma instance
local gemma, err = require("cgemma").new({
tokenizer = "tokenizer.spm",
model = "2b-it",
weights = "2b-it-sfp.sbs"
tokenizer = args["tokenizer"] or "tokenizer.spm",
model = args["model"] or "2b-it",
weights = args["weights"] or "2b-it-sfp.sbs",
weight_type = args["weight_type"]
})
if not gemma then
print("Opoos! ", err)
Expand All @@ -18,10 +29,14 @@ end

print("Random seed of session: ", seed)
while true do
-- Restore the previous session from "dump.bin"
local ok, err = session:load("dump.bin")
if ok then
print("Previous conversation restored")
if args["kv_cache"] then
-- Restore the previous session
local ok, err = session:load(args["kv_cache"])
if ok then
print("Previous conversation restored")
else
print("New conversation started")
end
else
print("New conversation started")
end
Expand All @@ -31,12 +46,14 @@ while true do
io.write("> ")
local text = io.read()
if not text then
print("End of file, dumping current session ...")
-- Dump the current session to "dump.bin"
local ok, err = session:dump("dump.bin")
if not ok then
print("Opoos! ", err)
return
if args["kv_cache"] then
print("End of file, dumping current session ...")
-- Dump the current session
local ok, err = session:dump(args["kv_cache"])
if not ok then
print("Opoos! ", err)
return
end
end
print("Done")
return
Expand Down
54 changes: 0 additions & 54 deletions examples/synopsis.lua

This file was deleted.

0 comments on commit 85b8e33

Please sign in to comment.