diff --git a/src/cli/sql.rs b/src/cli/sql.rs index adf35c48437..44627fa0650 100644 --- a/src/cli/sql.rs +++ b/src/cli/sql.rs @@ -5,6 +5,27 @@ use reqwest::header::ACCEPT; use rustyline::error::ReadlineError; use rustyline::Editor; use serde_json::Value; +use surrealdb::sql::parse; +use surrealdb::sql::statements::UseStatement; +use surrealdb::sql::Statement; + +struct SqlContext { + ns: Option, + db: Option, +} + +impl SqlContext { + fn using(&mut self, statement: &UseStatement) { + match statement.ns.clone() { + Some(ns) => self.ns = Some(ns), + None => {} + }; + match statement.db.clone() { + Some(db) => self.db = Some(db), + None => {} + }; + } +} pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { // Set the default logging level @@ -13,8 +34,10 @@ pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { let user = matches.value_of("user").unwrap(); let pass = matches.value_of("pass").unwrap(); let conn = matches.value_of("conn").unwrap(); - let ns = matches.value_of("ns"); - let db = matches.value_of("db"); + let mut ctx = SqlContext { + ns: matches.value_of("ns").map(|s| s.to_string()), + db: matches.value_of("db").map(|s| s.to_string()), + }; // If we should pretty-print responses let pretty = matches.is_present("pretty"); @@ -44,20 +67,33 @@ pub fn init(matches: &clap::ArgMatches) -> Result<(), Error> { .header(ACCEPT, "application/json") .basic_auth(user, Some(pass)); // Add NS header if specified - let res = match ns { + let res = match ctx.ns.clone() { Some(ns) => res.header("NS", ns), None => res, }; // Add DB header if specified - let res = match db { + let res = match ctx.db.clone() { Some(db) => res.header("DB", db), None => res, }; + // Compile the request for later + let ast = parse(line.as_str()); // Complete request let res = res.body(line).send(); // Get the request response match process(pretty, res) { - Ok(v) => println!("{}", v), + Ok(v) => { + println!("{}", v); + if let Ok(ok_ast) = ast { + for stmt in ok_ast.iter() { + // If there's a use statement, update the context + match stmt { + Statement::Use(using) => ctx.using(using), + _ => {} + } + } + }; + } Err(e) => eprintln!("{}", e), } }