Skip to content

add api route for llm query #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 11, 2023
Merged
19 changes: 18 additions & 1 deletion server/src/handlers/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use self::middleware::{DisAllowRootUser, RouteExt};
mod about;
mod health_check;
mod ingest;
mod llm;
mod logstream;
mod middleware;
mod query;
Expand Down Expand Up @@ -229,6 +230,21 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
.wrap(DisAllowRootUser),
),
);

let llm_query_api = web::scope("/llm")
.service(
web::resource("").route(
web::post()
.to(llm::make_llm_request)
.authorize(Action::Query),
),
)
.service(
// to check if the API key for an LLM has been set up as env var
web::resource("isactive")
.route(web::post().to(llm::is_llm_active).authorize(Action::Query)),
);

// Deny request if username is same as the env variable P_USERNAME.
cfg.service(
// Base path "{url}/api/v1"
Expand Down Expand Up @@ -266,7 +282,8 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) {
logstream_api,
),
)
.service(user_api),
.service(user_api)
.service(llm_query_api),
)
// GET "/" ==> Serve the static frontend directory
.service(ResourceFiles::new("/", generated).resolve_not_found_to_root());
Expand Down
176 changes: 176 additions & 0 deletions server/src/handlers/http/llm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Parseable Server (C) 2022 - 2023 Parseable, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
*/

use actix_web::{http::header::ContentType, web, HttpResponse, Result};
use http::{header, StatusCode};
use itertools::Itertools;
use reqwest;
use serde_json::{json, Value};

use crate::{
metadata::{error::stream_info::MetadataError, STREAM_INFO},
option::CONFIG,
};

const OPEN_AI_URL: &str = "https://api.openai.com/v1/chat/completions";

// Deserialize types for OpenAI Response
#[derive(serde::Deserialize, Debug)]
struct ResponseData {
choices: Vec<Choice>,
}

#[derive(serde::Deserialize, Debug)]
struct Choice {
message: Message,
}

#[derive(serde::Deserialize, Debug)]
struct Message {
content: String,
}

// Request body
#[derive(serde::Deserialize, Debug)]
pub struct AiPrompt {
prompt: String,
stream: String,
}

// Temperory type
#[derive(Debug, serde::Serialize)]
struct Field {
name: String,
data_type: String,
}

impl From<&arrow_schema::Field> for Field {
fn from(field: &arrow_schema::Field) -> Self {
Self {
name: field.name().clone(),
data_type: field.data_type().to_string(),
}
}
}

fn build_prompt(stream: &str, prompt: &str, schema_json: &str) -> String {
format!(
r#"I have a table called {}.
It has the columns:\n{}
Based on this, generate valid SQL for the query: "{}"
Generate only SQL as output. Also add comments in SQL syntax to explain your actions.
Don't output anything else.
If it is not possible to generate valid SQL, output an SQL comment saying so."#,
stream, schema_json, prompt
)
}

fn build_request_body(ai_prompt: String) -> impl serde::Serialize {
json!({
"model": "gpt-3.5-turbo",
"messages": [{ "role": "user", "content": ai_prompt}],
"temperature": 0.6,
})
}

pub async fn make_llm_request(body: web::Json<AiPrompt>) -> Result<HttpResponse, LLMError> {
let api_key = match &CONFIG.parseable.open_ai_key {
Some(api_key) if api_key.len() > 3 => api_key,
_ => return Err(LLMError::InvalidAPIKey),
};

let stream_name = &body.stream;
let schema = STREAM_INFO.schema(stream_name)?;
let filtered_schema = schema
.all_fields()
.into_iter()
.map(Field::from)
.collect_vec();

let schema_json =
serde_json::to_string(&filtered_schema).expect("always converted to valid json");

let prompt = build_prompt(stream_name, &body.prompt, &schema_json);
let body = build_request_body(prompt);

let client = reqwest::Client::new();
let response = client
.post(OPEN_AI_URL)
.header(header::CONTENT_TYPE, "application/json")
.bearer_auth(api_key)
.json(&body)
.send()
.await?;

if response.status().is_success() {
let body: ResponseData = response
.json()
.await
.expect("OpenAI response is always the same");
Ok(HttpResponse::Ok()
.content_type("application/json")
.json(&body.choices[0].message.content))
} else {
let body: Value = response.json().await?;
let message = body
.as_object()
.and_then(|body| body.get("error"))
.and_then(|error| error.as_object())
.and_then(|error| error.get("message"))
.map(|message| message.to_string())
.unwrap_or_else(|| "Error from OpenAI".to_string());

Err(LLMError::APIError(message))
}
}

pub async fn is_llm_active(_body: web::Json<AiPrompt>) -> HttpResponse {
let is_active = matches!(&CONFIG.parseable.open_ai_key, Some(api_key) if api_key.len() > 3);
HttpResponse::Ok()
.content_type("application/json")
.json(json!({"is_active": is_active}))
}

#[derive(Debug, thiserror::Error)]
pub enum LLMError {
#[error("Either OpenAI key was not provided or was invalid")]
InvalidAPIKey,
#[error("Failed to call OpenAI endpoint: {0}")]
FailedRequest(#[from] reqwest::Error),
#[error("{0}")]
APIError(String),
#[error("{0}")]
StreamDoesNotExist(#[from] MetadataError),
}

impl actix_web::ResponseError for LLMError {
fn status_code(&self) -> http::StatusCode {
match self {
Self::InvalidAPIKey => StatusCode::INTERNAL_SERVER_ERROR,
Self::FailedRequest(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::APIError(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::StreamDoesNotExist(_) => StatusCode::INTERNAL_SERVER_ERROR,
}
}

fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
actix_web::HttpResponse::build(self.status_code())
.insert_header(ContentType::plaintext())
.body(self.to_string())
}
}
33 changes: 23 additions & 10 deletions server/src/option.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ pub struct Server {
/// Server should send anonymous analytics or not
pub send_analytics: bool,

/// Open AI access key
pub open_ai_key: Option<String>,

/// Rows in Parquet Rowgroup
pub row_group_size: usize,

Expand Down Expand Up @@ -232,6 +235,7 @@ impl FromArgMatches for Server {
.get_one::<bool>(Self::SEND_ANALYTICS)
.cloned()
.expect("default for send analytics");
self.open_ai_key = m.get_one::<String>(Self::OPEN_AI_KEY).cloned();
// converts Gib to bytes before assigning
self.query_memory_pool_size = m
.get_one::<u8>(Self::QUERY_MEM_POOL_SIZE)
Expand Down Expand Up @@ -271,6 +275,7 @@ impl Server {
pub const PASSWORD: &str = "password";
pub const CHECK_UPDATE: &str = "check-update";
pub const SEND_ANALYTICS: &str = "send-analytics";
pub const OPEN_AI_KEY: &str = "open-ai-key";
pub const QUERY_MEM_POOL_SIZE: &str = "query-mempool-size";
pub const ROW_GROUP_SIZE: &str = "row-group-size";
pub const PARQUET_COMPRESSION_ALGO: &str = "compression-algo";
Expand Down Expand Up @@ -351,6 +356,24 @@ impl Server {
.required(true)
.help("Password for the basic authentication on the server"),
)
.arg(
Arg::new(Self::SEND_ANALYTICS)
.long(Self::SEND_ANALYTICS)
.env("P_SEND_ANONYMOUS_USAGE_DATA")
.value_name("BOOL")
.required(false)
.default_value("true")
.value_parser(value_parser!(bool))
.help("Disable/Enable sending anonymous user data"),
)
.arg(
Arg::new(Self::OPEN_AI_KEY)
.long(Self::OPEN_AI_KEY)
.env("OPENAI_API_KEY")
.value_name("STRING")
.required(false)
.help("Set OpenAI key to enable llm feature"),
)
.arg(
Arg::new(Self::CHECK_UPDATE)
.long(Self::CHECK_UPDATE)
Expand Down Expand Up @@ -380,16 +403,6 @@ impl Server {
.value_parser(value_parser!(usize))
.help("Number of rows in a row groups"),
)
.arg(
Arg::new(Self::SEND_ANALYTICS)
.long(Self::SEND_ANALYTICS)
.env("P_SEND_ANONYMOUS_USAGE_DATA")
.value_name("BOOL")
.required(false)
.default_value("true")
.value_parser(value_parser!(bool))
.help("Disable/Enable sending anonymous user data"),
)
.arg(
Arg::new(Self::PARQUET_COMPRESSION_ALGO)
.long(Self::PARQUET_COMPRESSION_ALGO)
Expand Down