Skip to content

Commit

Permalink
CLI: use new CLI arg parser
Browse files Browse the repository at this point in the history
  • Loading branch information
ohsayan committed May 19, 2024
1 parent c74368f commit 2a61b46
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 76 deletions.
46 changes: 24 additions & 22 deletions cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ use {
event::{self, Event, KeyCode, KeyEvent},
terminal,
},
libsky::{env_vars, CliAction},
libsky::{
cli_utils::{CliCommand, CliCommandData, CommandLineArgs, SingleOption},
env_vars,
},
std::{
collections::HashMap,
env, fs,
io::{self, Write},
process::exit,
Expand All @@ -43,13 +45,13 @@ const TXT_HELP: &str = include_str!(concat!(env!("OUT_DIR"), "/skysh"));

#[derive(Debug)]
pub struct ClientConfig {
pub kind: ClientConfigKind,
pub kind: EndpointConfig,
pub username: String,
pub password: String,
}

impl ClientConfig {
pub fn new(kind: ClientConfigKind, username: String, password: String) -> Self {
pub fn new(kind: EndpointConfig, username: String, password: String) -> Self {
Self {
kind,
username,
Expand All @@ -59,7 +61,7 @@ impl ClientConfig {
}

#[derive(Debug)]
pub enum ClientConfigKind {
pub enum EndpointConfig {
Tcp(String, u16),
Tls(String, u16, String),
}
Expand All @@ -73,15 +75,15 @@ pub enum Task {

enum TaskInner {
HelpMsg(String),
OpenShell(HashMap<String, String>),
OpenShell(CliCommandData<SingleOption>),
}

fn load_env() -> CliResult<TaskInner> {
let action = libsky::parse_cli_args_disallow_duplicate()?;
let action = CliCommand::<SingleOption>::from_cli()?;
match action {
CliAction::Help => Ok(TaskInner::HelpMsg(TXT_HELP.into())),
CliAction::Version => Ok(TaskInner::HelpMsg(libsky::version_msg("skysh"))),
CliAction::Action(a) => Ok(TaskInner::OpenShell(a)),
CliCommand::Help(_) => Ok(TaskInner::HelpMsg(TXT_HELP.into())),
CliCommand::Version(_) => Ok(TaskInner::HelpMsg(libsky::version_msg("skysh"))),
CliCommand::Run(a) => Ok(TaskInner::OpenShell(a)),
}
}

Expand All @@ -90,8 +92,8 @@ pub fn parse() -> CliResult<Task> {
TaskInner::HelpMsg(msg) => return Ok(Task::HelpMessage(msg)),
TaskInner::OpenShell(args) => args,
};
let endpoint = match args.remove("--endpoint") {
None => ClientConfigKind::Tcp("127.0.0.1".into(), 2003),
let endpoint = match args.take_option("endpoint") {
None => EndpointConfig::Tcp("127.0.0.1".into(), 2003),
Some(ep) => {
// should be in the format protocol@host:port
let proto_host_port: Vec<&str> = ep.split("@").collect();
Expand All @@ -112,18 +114,18 @@ pub fn parse() -> CliResult<Task> {
)))
}
};
let tls_cert = args.remove("--tls-cert");
let tls_cert = args.take_option("tls-cert");
match protocol {
"tcp" => {
// TODO(@ohsayan): warn!
ClientConfigKind::Tcp(host.into(), port)
EndpointConfig::Tcp(host.into(), port)
}
"tls" => {
// we need a TLS cert
match tls_cert {
Some(path) => {
let cert = fs::read_to_string(path)?;
ClientConfigKind::Tls(host.into(), port, cert)
let cert = fs::read_to_string(path.as_ref())?;
EndpointConfig::Tls(host.into(), port, cert)
}
None => {
return Err(CliError::ArgsErr(format!(
Expand All @@ -140,15 +142,15 @@ pub fn parse() -> CliResult<Task> {
}
}
};
let username = match args.remove("--user") {
let username = match args.take_option("user") {
Some(u) => u,
None => {
// default
"root".into()
}
};
let password = match args.remove("--password") {
Some(p) => check_password(p, "cli arguments")?,
let password = match args.take_option("password") {
Some(p) => check_password(p.into(), "cli arguments")?,
None => {
// let us check the environment variable to see if anything was set
match env::var(env_vars::SKYDB_PASSWORD) {
Expand All @@ -157,11 +159,11 @@ pub fn parse() -> CliResult<Task> {
}
}
};
let eval = args.remove("--eval").or_else(|| args.remove("-e"));
let eval = args.take_option("eval").or_else(|| args.take_option("e"));
if args.is_empty() {
let client = ClientConfig::new(endpoint, username, password);
let client = ClientConfig::new(endpoint, username.into(), password);
match eval {
Some(query) => Ok(Task::ExecOnce(client, query)),
Some(query) => Ok(Task::ExecOnce(client, query.into())),
None => Ok(Task::OpenShell(client)),
}
} else {
Expand Down
13 changes: 3 additions & 10 deletions cli/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,9 @@ pub enum CliError {
OtherError(&'static str),
}

impl From<libsky::ArgParseError> for CliError {
fn from(e: libsky::ArgParseError) -> Self {
match e {
libsky::ArgParseError::Duplicate(d) => {
Self::ArgsErr(format!("duplicate value for `{d}`"))
}
libsky::ArgParseError::MissingValue(m) => {
Self::ArgsErr(format!("missing value for `{m}`"))
}
}
impl From<libsky::cli_utils::CliArgsError> for CliError {
fn from(value: libsky::cli_utils::CliArgsError) -> Self {
Self::ArgsErr(value.to_string())
}
}

Expand Down
6 changes: 3 additions & 3 deletions cli/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

use {
crate::{
args::{ClientConfig, ClientConfigKind},
args::{ClientConfig, EndpointConfig},
error::{CliError, CliResult},
},
skytable::{
Expand All @@ -42,7 +42,7 @@ pub fn connect<T>(
tls_f: impl Fn(ConnectionTls) -> CliResult<T>,
) -> CliResult<T> {
match cfg.kind {
ClientConfigKind::Tcp(host, port) => {
EndpointConfig::Tcp(host, port) => {
let c = Config::new(&host, port, &cfg.username, &cfg.password).connect()?;
if print_con_info {
println!(
Expand All @@ -52,7 +52,7 @@ pub fn connect<T>(
}
tcp_f(c)
}
ClientConfigKind::Tls(host, port, cert) => {
EndpointConfig::Tls(host, port, cert) => {
let c = Config::new(&host, port, &cfg.username, &cfg.password).connect_tls(&cert)?;
if print_con_info {
println!(
Expand Down
129 changes: 88 additions & 41 deletions libsky/src/cli_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::{
collections::{hash_map::Entry, HashMap, HashSet},
error::Error,
fmt,
str::FromStr,
};

/*
Expand All @@ -43,6 +44,7 @@ pub enum CliArgsError {
DuplicateOption(String),
SubcommandDisallowed,
ArgParseError(String),
Other(String),
}

impl fmt::Display for CliArgsError {
Expand All @@ -53,6 +55,7 @@ impl fmt::Display for CliArgsError {
Self::DuplicateOption(opt) => write!(f, "found duplicate option `--{opt}`"),
Self::SubcommandDisallowed => write!(f, "subcommands are disallowed in this context"),
Self::ArgParseError(arg) => write!(f, "failed to parse value assigned to `--{arg}`"),
Self::Other(e) => write!(f, "{e}"),
}
}
}
Expand All @@ -75,6 +78,7 @@ pub trait CliArgsDecode: Sized {
) -> CliResult<Self>;
fn yield_command(data: Self::Data) -> CliResult<Self>;
fn yield_help(data: Self::Data) -> CliResult<Self>;
fn yield_version(data: Self::Data) -> CliResult<Self>;
}

pub trait CommandLineArgs: Sized + CliArgsDecode {
Expand Down Expand Up @@ -162,53 +166,65 @@ fn decode_args<C: CliArgsDecode, const HAS_BINARY_NAME: bool>(
let mut cli_data = C::initialize::<HAS_BINARY_NAME>(&mut args);
while let Some(arg) = args.next() {
let arg = arg.as_str();
if arg == "-h" || arg == "--help" {
let arg = if arg == "-h" || arg == "--help" {
return C::yield_help(cli_data);
}
if arg.starts_with("--") {
// option or flag
let arg = &arg[2..];
if arg.is_empty() {
return Err(CliArgsError::ArgFmtError(format!("invalid argument")));
}
// is this arg in the --x=y format?
let mut arg_split = arg.split("=");
let (arg_split_name_, arg_split_value_) = (arg_split.next(), arg_split.next());
match (arg_split_name_, arg_split_value_) {
(Some(name_), Some(value_)) => {
if name_.is_empty() || value_.is_empty() {
return Err(CliArgsError::ArgFmtError(arg.to_string()));
}
// yes, it was formatted this way
C::push_option(&mut cli_data, name_.boxed_str(), value_.boxed_str())?;
continue;
} else if arg == "-v" || arg == "--version" {
return C::yield_version(cli_data);
} else {
if arg.starts_with("--") {
// option or flag
&arg[2..]
} else if arg.starts_with("-") {
if arg.len() != 2 {
// invalid shorthand
return Err(CliArgsError::Other(format!(
"the argument `{arg}` is formatted incorrectly"
)));
}
(Some(_), None) => {}
_ => unreachable!(),
// option or flag
&arg[1..]
} else {
// this is subcommand
return C::yield_subcommand(cli_data, arg.boxed_str(), args);
}
// no, probably in the --x y format
match args.peek() {
Some(arg_) => {
if arg_.as_str().starts_with("--") || arg_.as_str().starts_with("-") {
// flag
C::push_flag(&mut cli_data, arg.boxed_str())?;
} else {
// option
C::push_option(
&mut cli_data,
arg.boxed_str(),
args.next().unwrap().boxed_str(),
)?;
}
};
if arg.is_empty() {
return Err(CliArgsError::ArgFmtError(format!("invalid argument")));
}
// is this arg in the --x=y format?
let mut arg_split = arg.split("=");
let (arg_split_name_, arg_split_value_) = (arg_split.next(), arg_split.next());
match (arg_split_name_, arg_split_value_) {
(Some(name_), Some(value_)) => {
if name_.is_empty() || value_.is_empty() {
return Err(CliArgsError::ArgFmtError(arg.to_string()));
}
None => {
// yes, it was formatted this way
C::push_option(&mut cli_data, name_.boxed_str(), value_.boxed_str())?;
continue;
}
(Some(_), None) => {}
_ => unreachable!(),
}
// no, probably in the --x y format
match args.peek() {
Some(arg_) => {
if arg_.as_str().starts_with("--") || arg_.as_str().starts_with("-") {
// flag
C::push_flag(&mut cli_data, arg.boxed_str())?;
} else {
// option
C::push_option(
&mut cli_data,
arg.boxed_str(),
args.next().unwrap().boxed_str(),
)?;
}
}
} else {
// this is subcommand
return C::yield_subcommand(cli_data, arg.boxed_str(), args);
None => {
// flag
C::push_flag(&mut cli_data, arg.boxed_str())?;
}
}
}
C::yield_command(cli_data)
Expand All @@ -222,12 +238,31 @@ fn decode_args<C: CliArgsDecode, const HAS_BINARY_NAME: bool>(
pub enum CliCommand<Opt: CliArgsOptions> {
Help(CliCommandData<Opt>),
Run(CliCommandData<Opt>),
Version(CliCommandData<Opt>),
}

#[derive(Debug, PartialEq, Clone)]
pub struct CliCommandData<Opt: CliArgsOptions> {
pub options: Opt,
pub flags: HashSet<Box<str>>,
options: Opt,
flags: HashSet<Box<str>>,
}

impl CliCommandData<SingleOption> {
pub fn take_option(&mut self, option: &str) -> Option<Box<str>> {
self.options.remove(option)
}
pub fn parse_take_option<T: FromStr>(&mut self, option: &str) -> CliResult<Option<T>> {
match self.options.remove(option) {
Some(opt) => match opt.parse() {
Ok(opt) => Ok(Some(opt)),
Err(_) => Err(CliArgsError::ArgParseError(option.to_owned())),
},
None => Ok(None),
}
}
pub fn is_empty(&self) -> bool {
self.options.is_empty() && self.flags.is_empty()
}
}

impl<Opt: CliArgsOptions> CliArgsDecode for CliCommand<Opt> {
Expand Down Expand Up @@ -269,6 +304,9 @@ impl<Opt: CliArgsOptions> CliArgsDecode for CliCommand<Opt> {
fn yield_help(data: Self::Data) -> CliResult<Self> {
Ok(CliCommand::Help(data))
}
fn yield_version(data: Self::Data) -> CliResult<Self> {
Ok(CliCommand::Version(data))
}
}

/*
Expand All @@ -279,8 +317,10 @@ impl<Opt: CliArgsOptions> CliArgsDecode for CliCommand<Opt> {
pub enum CliMultiCommand<OptR: CliArgsOptions, OptS: CliArgsOptions> {
Run(CliCommandData<OptR>),
Help(CliCommandData<OptR>),
Version(CliCommandData<OptR>),
Subcommand(CliCommandData<OptR>, Subcommand<OptS>),
SubcommandHelp(CliCommandData<OptR>, Subcommand<OptS>),
SubcommandVersion(CliCommandData<OptR>, Subcommand<OptS>),
}

#[derive(Debug, PartialEq, Clone)]
Expand Down Expand Up @@ -340,8 +380,15 @@ impl<OptR: CliArgsOptions, OptS: CliArgsOptions> CliArgsDecode for CliMultiComma
data,
Subcommand::new(subcommand, subcommand_data),
)),
CliCommand::Version(subcommand_data) => Ok(CliMultiCommand::SubcommandVersion(
data,
Subcommand::new(subcommand, subcommand_data),
)),
}
}
fn yield_version(data: Self::Data) -> CliResult<Self> {
Ok(Self::Version(data))
}
}

/*
Expand Down

0 comments on commit 2a61b46

Please sign in to comment.