Skip to content

Commit

Permalink
use TryFrom in prost::Decode
Browse files Browse the repository at this point in the history
  • Loading branch information
Keksoj committed Sep 13, 2023
1 parent bf7b492 commit 8c164f9
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 73 deletions.
6 changes: 3 additions & 3 deletions bin/src/command/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl CommandServer {
self.upgrade_worker(client_id, worker_id).await
}
Some(RequestType::ConfigureMetrics(config)) => {
match MetricsConfiguration::from_i32(config) {
Some(config) => self.configure_metrics(client_id, config).await,
None => Err(anyhow::Error::msg("wrong i32 for metrics configuration")),
match MetricsConfiguration::try_from(config) {
Ok(config) => self.configure_metrics(client_id, config).await,
Err(_) => Err(anyhow::Error::msg("wrong i32 for metrics configuration")),
}
}
Some(RequestType::Logging(logging_filter)) => self.set_logging_level(logging_filter),
Expand Down
6 changes: 3 additions & 3 deletions command/src/proto/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ impl Display for CertificateAndKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let versions = self.versions.iter().fold(String::new(), |acc, tls_v| {
acc + " "
+ match TlsVersion::from_i32(*tls_v) {
Some(v) => v.as_str_name(),
None => "",
+ match TlsVersion::try_from(*tls_v) {
Ok(v) => v.as_str_name(),
Err(_) => "",
}
});
write!(
Expand Down
18 changes: 10 additions & 8 deletions command/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ impl RequestHttpFrontend {
hostname: self.hostname,
path: self.path,
method: self.method,
position: RulePosition::from_i32(self.position).ok_or(RequestError::InvalidValue {
name: "position".to_string(),
value: self.position,
position: RulePosition::try_from(self.position).map_err(|_| {
RequestError::InvalidValue {
name: "position".to_string(),
value: self.position,
}
})?,
tags: Some(self.tags),
})
Expand All @@ -156,17 +158,17 @@ impl RequestHttpFrontend {
impl Display for RequestHttpFrontend {
/// Used to create a unique summary of the frontend, used as a key in maps
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match &PathRuleKind::from_i32(self.path.kind) {
Some(PathRuleKind::Prefix) => {
let s = match &PathRuleKind::try_from(self.path.kind) {
Ok(PathRuleKind::Prefix) => {
format!("{};{};P{}", self.address, self.hostname, self.path.value)
}
Some(PathRuleKind::Regex) => {
Ok(PathRuleKind::Regex) => {
format!("{};{};R{}", self.address, self.hostname, self.path.value)
}
Some(PathRuleKind::Equals) => {
Ok(PathRuleKind::Equals) => {
format!("{};{};={}", self.address, self.hostname, self.path.value)
}
None => String::from("Wrong variant of PathRuleKind"),
Err(e) => format!("Wrong variant of PathRuleKind: {e}"),
};

match &self.method {
Expand Down
12 changes: 6 additions & 6 deletions command/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ impl PathRule {
}

pub fn is_default_path_rule(p: &PathRule) -> bool {
PathRuleKind::from_i32(p.kind) == Some(PathRuleKind::Prefix) && p.value.is_empty()
PathRuleKind::try_from(p.kind) == Ok(PathRuleKind::Prefix) && p.value.is_empty()
}

impl std::fmt::Display for PathRule {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match PathRuleKind::from_i32(self.kind) {
Some(PathRuleKind::Prefix) => write!(f, "prefix '{}'", self.value),
Some(PathRuleKind::Regex) => write!(f, "regexp '{}'", self.value),
Some(PathRuleKind::Equals) => write!(f, "equals '{}'", self.value),
None => write!(f, ""),
match PathRuleKind::try_from(self.kind) {
Ok(PathRuleKind::Prefix) => write!(f, "prefix '{}'", self.value),
Ok(PathRuleKind::Regex) => write!(f, "regexp '{}'", self.value),
Ok(PathRuleKind::Equals) => write!(f, "equals '{}'", self.value),
Err(_) => write!(f, ""),
}
}
}
Expand Down
41 changes: 20 additions & 21 deletions command/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use std::{
net::SocketAddr,
};

use prost::DecodeError;

use crate::{
certificate::{self, calculate_fingerprint, Fingerprint},
proto::{
Expand Down Expand Up @@ -56,6 +58,12 @@ pub enum StateError {
FrontendConversion { frontend: String, error: String },
}

impl From<DecodeError> for StateError {
fn from(decode_error: DecodeError) -> Self {
Self::WrongRequest(format!("Wrong field value: {decode_error}"))
}
}

/// The `ConfigState` represents the state of Sōzu's business, which is to forward traffic
/// from frontends to backends. Hence, it contains all details about:
///
Expand Down Expand Up @@ -213,13 +221,10 @@ impl ConfigState {
}

fn remove_listener(&mut self, remove: &RemoveListener) -> Result<(), StateError> {
match ListenerType::from_i32(remove.proxy) {
Some(ListenerType::Http) => self.remove_http_listener(&remove.address),
Some(ListenerType::Https) => self.remove_https_listener(&remove.address),
Some(ListenerType::Tcp) => self.remove_tcp_listener(&remove.address),
None => Err(StateError::WrongRequest(
"Wrong ListenerType on RemoveListener request".to_string(),
)),
match ListenerType::try_from(remove.proxy)? {
ListenerType::Http => self.remove_http_listener(&remove.address),
ListenerType::Https => self.remove_https_listener(&remove.address),
ListenerType::Tcp => self.remove_tcp_listener(&remove.address),
}
}

Expand All @@ -245,66 +250,60 @@ impl ConfigState {
}

fn activate_listener(&mut self, activate: &ActivateListener) -> Result<(), StateError> {
match ListenerType::from_i32(activate.proxy) {
Some(ListenerType::Http) => self
match ListenerType::try_from(activate.proxy)? {
ListenerType::Http => self
.http_listeners
.get_mut(&activate.address)
.map(|listener| listener.active = true)
.ok_or(StateError::NotFound {
kind: ObjectKind::HttpListener,
id: activate.address.to_owned(),
}),
Some(ListenerType::Https) => self
ListenerType::Https => self
.https_listeners
.get_mut(&activate.address)
.map(|listener| listener.active = true)
.ok_or(StateError::NotFound {
kind: ObjectKind::HttpsListener,
id: activate.address.to_owned(),
}),
Some(ListenerType::Tcp) => self
ListenerType::Tcp => self
.tcp_listeners
.get_mut(&activate.address)
.map(|listener| listener.active = true)
.ok_or(StateError::NotFound {
kind: ObjectKind::TcpListener,
id: activate.address.to_owned(),
}),
None => Err(StateError::WrongRequest(
"Wrong variant for ListenerType on request".to_string(),
)),
}
}

fn deactivate_listener(&mut self, deactivate: &DeactivateListener) -> Result<(), StateError> {
match ListenerType::from_i32(deactivate.proxy) {
Some(ListenerType::Http) => self
match ListenerType::try_from(deactivate.proxy)? {
ListenerType::Http => self
.http_listeners
.get_mut(&deactivate.address)
.map(|listener| listener.active = false)
.ok_or(StateError::NotFound {
kind: ObjectKind::HttpListener,
id: deactivate.address.to_owned(),
}),
Some(ListenerType::Https) => self
ListenerType::Https => self
.https_listeners
.get_mut(&deactivate.address)
.map(|listener| listener.active = false)
.ok_or(StateError::NotFound {
kind: ObjectKind::HttpsListener,
id: deactivate.address.to_owned(),
}),
Some(ListenerType::Tcp) => self
ListenerType::Tcp => self
.tcp_listeners
.get_mut(&deactivate.address)
.map(|listener| listener.active = false)
.ok_or(StateError::NotFound {
kind: ObjectKind::TcpListener,
id: deactivate.address.to_owned(),
}),
None => Err(StateError::WrongRequest(
"Wrong variant for ListenerType on request".to_string(),
)),
}
}

Expand Down
10 changes: 5 additions & 5 deletions lib/src/https.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,11 @@ impl HttpsListener {
let versions = config
.versions
.iter()
.filter_map(|version| match TlsVersion::from_i32(*version) {
Some(TlsVersion::TlsV12) => Some(&rustls::version::TLS12),
Some(TlsVersion::TlsV13) => Some(&rustls::version::TLS13),
_other_version => {
error!("unsupported TLS version: {:?}", _other_version);
.filter_map(|version| match TlsVersion::try_from(*version) {
Ok(TlsVersion::TlsV12) => Some(&rustls::version::TLS12),
Ok(TlsVersion::TlsV13) => Some(&rustls::version::TLS13),
Ok(_) | Err(_) => {
error!("unsupported TLS version");
None
}
})
Expand Down
10 changes: 5 additions & 5 deletions lib/src/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,11 @@ impl PathRule {
}

pub fn from_config(rule: CommandPathRule) -> Option<Self> {
match PathRuleKind::from_i32(rule.kind) {
Some(PathRuleKind::Prefix) => Some(PathRule::Prefix(rule.value)),
Some(PathRuleKind::Regex) => Regex::new(&rule.value).ok().map(PathRule::Regex),
Some(PathRuleKind::Equals) => Some(PathRule::Equals(rule.value)),
_ => None,
match PathRuleKind::try_from(rule.kind) {
Ok(PathRuleKind::Prefix) => Some(PathRule::Prefix(rule.value)),
Ok(PathRuleKind::Regex) => Regex::new(&rule.value).ok().map(PathRule::Regex),
Ok(PathRuleKind::Equals) => Some(PathRule::Equals(rule.value)),
Err(_) => None,
}
}
}
Expand Down
38 changes: 20 additions & 18 deletions lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ impl Server {

fn notify(&mut self, message: WorkerRequest) {
if let Some(RequestType::ConfigureMetrics(configuration)) = &message.content.request_type {
if let Some(metrics_config) = MetricsConfiguration::from_i32(*configuration) {
if let Ok(metrics_config) = MetricsConfiguration::try_from(*configuration) {
METRICS.with(|metrics| {
(*metrics.borrow_mut()).configure(&metrics_config);

Expand Down Expand Up @@ -1008,11 +1008,11 @@ impl Server {
Some(RequestType::RemoveListener(ref remove)) => {
debug!("{} remove {:?} listener {:?}", req_id, remove.proxy, remove);
self.base_sessions_count -= 1;
let response = match ListenerType::from_i32(remove.proxy) {
Some(ListenerType::Http) => self.http.borrow_mut().notify(request.clone()),
Some(ListenerType::Https) => self.https.borrow_mut().notify(request.clone()),
Some(ListenerType::Tcp) => self.tcp.borrow_mut().notify(request.clone()),
None => WorkerResponse::error(req_id, "Wrong variant ListenerType"),
let response = match ListenerType::try_from(remove.proxy) {
Ok(ListenerType::Http) => self.http.borrow_mut().notify(request.clone()),
Ok(ListenerType::Https) => self.https.borrow_mut().notify(request.clone()),
Ok(ListenerType::Tcp) => self.tcp.borrow_mut().notify(request.clone()),
Err(_) => WorkerResponse::error(req_id, "Wrong variant ListenerType"),
};
push_queue(response);
}
Expand All @@ -1031,8 +1031,10 @@ impl Server {
.borrow_mut()
.set_load_balancing_policy_for_cluster(
&cluster.cluster_id,
LoadBalancingAlgorithms::from_i32(cluster.load_balancing).unwrap_or_default(),
cluster.load_metric.and_then(LoadMetric::from_i32),
LoadBalancingAlgorithms::try_from(cluster.load_balancing).unwrap_or_default(),
cluster
.load_metric
.and_then(|n| LoadMetric::try_from(n).ok()),
);
}

Expand Down Expand Up @@ -1178,8 +1180,8 @@ impl Server {
Err(e) => return WorkerResponse::error(req_id, format!("Wrong socket address: {e}")),
};

match ListenerType::from_i32(activate.proxy) {
Some(ListenerType::Http) => {
match ListenerType::try_from(activate.proxy) {
Ok(ListenerType::Http) => {
let listener = self
.scm_listeners
.as_mut()
Expand All @@ -1198,7 +1200,7 @@ impl Server {
}
}
}
Some(ListenerType::Https) => {
Ok(ListenerType::Https) => {
let listener = self
.scm_listeners
.as_mut()
Expand All @@ -1220,7 +1222,7 @@ impl Server {
}
}
}
Some(ListenerType::Tcp) => {
Ok(ListenerType::Tcp) => {
let listener = self
.scm_listeners
.as_mut()
Expand All @@ -1239,7 +1241,7 @@ impl Server {
}
}
}
None => WorkerResponse::error(req_id, "Wrong variant for ListenerType on request"),
Err(_) => WorkerResponse::error(req_id, "Wrong variant for ListenerType on request"),
}
}

Expand All @@ -1258,8 +1260,8 @@ impl Server {
Err(e) => return WorkerResponse::error(req_id, format!("Wrong socket address: {e}")),
};

match ListenerType::from_i32(deactivate.proxy) {
Some(ListenerType::Http) => {
match ListenerType::try_from(deactivate.proxy) {
Ok(ListenerType::Http) => {
let (token, mut listener) = match self.http.borrow_mut().give_back_listener(address)
{
Some((token, listener)) => (token, listener),
Expand Down Expand Up @@ -1302,7 +1304,7 @@ impl Server {
}
WorkerResponse::ok(req_id)
}
Some(ListenerType::Https) => {
Ok(ListenerType::Https) => {
let (token, mut listener) =
match self.https.borrow_mut().give_back_listener(address) {
Some((token, listener)) => (token, listener),
Expand Down Expand Up @@ -1343,7 +1345,7 @@ impl Server {
}
WorkerResponse::ok(req_id)
}
Some(ListenerType::Tcp) => {
Ok(ListenerType::Tcp) => {
let (token, mut listener) = match self.tcp.borrow_mut().give_back_listener(address)
{
Some((token, listener)) => (token, listener),
Expand Down Expand Up @@ -1382,7 +1384,7 @@ impl Server {
}
WorkerResponse::ok(req_id)
}
None => WorkerResponse::error(req_id, "Wrong variant for ListenerType on request"),
Err(_) => WorkerResponse::error(req_id, "Wrong variant for ListenerType on request"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1372,7 +1372,7 @@ impl ProxyConfiguration for TcpProxy {
let config = ClusterConfiguration {
proxy_protocol: cluster
.proxy_protocol
.and_then(ProxyProtocolConfig::from_i32),
.and_then(|n| ProxyProtocolConfig::try_from(n).ok()),
//load_balancing: cluster.load_balancing,
};
self.configs.insert(cluster.cluster_id, config);
Expand Down
6 changes: 3 additions & 3 deletions lib/src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl CertificateResolverHelper for GenericCertificateResolver {
let versions = certificate_and_key
.versions
.iter()
.filter_map(|v| TlsVersion::from_i32(*v))
.filter_map(|v| TlsVersion::try_from(*v).ok())
.collect();
return Ok(ParsedCertificateAndKey {
certificate,
Expand All @@ -339,7 +339,7 @@ impl CertificateResolverHelper for GenericCertificateResolver {
let versions = certificate_and_key
.versions
.iter()
.filter_map(|v| TlsVersion::from_i32(*v))
.filter_map(|v| TlsVersion::try_from(*v).ok())
.collect();
return Ok(ParsedCertificateAndKey {
certificate,
Expand All @@ -354,7 +354,7 @@ impl CertificateResolverHelper for GenericCertificateResolver {
let versions = certificate_and_key
.versions
.iter()
.filter_map(|v| TlsVersion::from_i32(*v))
.filter_map(|v| TlsVersion::try_from(*v).ok())
.collect();

return Ok(ParsedCertificateAndKey {
Expand Down

0 comments on commit 8c164f9

Please sign in to comment.