diff --git a/examples/codegen/serve.cpp b/examples/codegen/serve.cpp index 1e06e6173..a7c45aeb4 100644 --- a/examples/codegen/serve.cpp +++ b/examples/codegen/serve.cpp @@ -8,61 +8,12 @@ #include // generators #include // streaming operators etc. -int main(int argc, char** argv) { - - gpt_params params; - params.model = "models/gpt-j-6B/ggml-model.bin"; - - if (gpt_params_parse(argc, argv, params) == false) { - return 1; - } - - if (params.seed < 0) { - params.seed = time(NULL); - } - - printf("%s: seed = %d\n", __func__, params.seed); - +/** + * This function serves requests for autocompletion from crow + * +*/ +crow::response serve_response(gpt_params params, gptj_model &model, gpt_vocab &vocab, const crow::request& req){ - crow::SimpleApp app; - - gpt_vocab vocab; - gptj_model model; - - int64_t t_load_us = 0; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!gptj_model_load(params.model, model, vocab)) { - fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); - return 1; - } - - t_load_us = ggml_time_us() - t_start_us; - } - - - CROW_ROUTE(app, "/")([](){ - return "Hello world"; - }); - - CROW_ROUTE(app, "/copilot_internal/v2/token")([](){ - //return "Hello world"; - - crow::json::wvalue response = {{"token","1"}, {"expires_at", static_cast(2600000000)}, {"refresh_in",900}}; - - crow::response res; - res.code = 200; - res.set_header("Content-Type", "application/json"); - res.body = response.dump(); - return res; - }); - - - CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::POST) - ([&model, &vocab, ¶ms](const crow::request& req) { crow::json::rvalue data = crow::json::load(req.body); if(!data.has("prompt") && !data.has("input_ids")){ @@ -97,8 +48,6 @@ int main(int argc, char** argv) { std::string suffix = ""; float temperature = 0.6; - data["model"].s(); - if(data.has("suffix")){ suffix = data["suffix"].s(); } @@ -223,6 +172,75 @@ int main(int argc, char** argv) { res.body = response.dump(); //ss.str(); return res; +} + +int main(int argc, char** argv) { + + gpt_params params; + params.model = "models/gpt-j-6B/ggml-model.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + printf("%s: seed = %d\n", __func__, params.seed); + + + crow::SimpleApp app; + + gpt_vocab vocab; + gptj_model model; + + int64_t t_load_us = 0; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!gptj_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + } + + + CROW_ROUTE(app, "/")([](){ + return "Hello world"; + }); + + CROW_ROUTE(app, "/copilot_internal/v2/token")([](){ + //return "Hello world"; + + crow::json::wvalue response = {{"token","1"}, {"expires_at", static_cast(2600000000)}, {"refresh_in",900}}; + + crow::response res; + res.code = 200; + res.set_header("Content-Type", "application/json"); + res.body = response.dump(); + return res; + }); + + + CROW_ROUTE(app, "/v1/completions").methods(crow::HTTPMethod::POST) + ([&model, &vocab, ¶ms](const crow::request& req) { + return serve_response(params, model, vocab, req); + }); + + CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::POST) + ([&model, &vocab, ¶ms](const crow::request& req) { + return serve_response(params, model, vocab, req); + }); + + + CROW_ROUTE(app, "/v1/engines/copilot-codex/completions").methods(crow::HTTPMethod::POST) + ([&model, &vocab, ¶ms](const crow::request& req) { + return serve_response(params, model, vocab, req); }); app.port(18080).multithreaded().run();