Skip to content

Commit

Permalink
make get_cluster_ids_by_domain a method of ConfigState
Browse files Browse the repository at this point in the history
  • Loading branch information
Keksoj committed May 4, 2023
1 parent 8038181 commit 146dd32
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 50 deletions.
9 changes: 3 additions & 6 deletions bin/src/command/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use sozu_command_lib::{
},
request::WorkerRequest,
scm_socket::Listeners,
state::get_cluster_ids_by_domain,
};

use sozu::metrics::METRICS;
Expand Down Expand Up @@ -1079,11 +1078,9 @@ impl CommandServer {
})),
}),
Some(RequestType::QueryClustersByDomain(domain)) => {
let cluster_ids = get_cluster_ids_by_domain(
&self.state,
domain.hostname.clone(),
domain.path.clone(),
);
let cluster_ids = self
.state
.get_cluster_ids_by_domain(domain.hostname.clone(), domain.path.clone());
let vec = cluster_ids
.iter()
.map(|cluster_id| self.state.cluster_state(cluster_id))
Expand Down
70 changes: 32 additions & 38 deletions command/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,32 @@ impl ConfigState {
+ self.https_fronts.values().count()
+ self.tcp_fronts.values().fold(0, |acc, v| acc + v.len())
}

pub fn get_cluster_ids_by_domain(
&self,
hostname: String,
path: Option<String>,
) -> HashSet<ClusterId> {
let mut cluster_ids: HashSet<ClusterId> = HashSet::new();

self.http_fronts.values().for_each(|front| {
if domain_check(&front.hostname, &front.path, &hostname, &path) {
if let Some(id) = &front.cluster_id {
cluster_ids.insert(id.to_string());
}
}
});

self.https_fronts.values().for_each(|front| {
if domain_check(&front.hostname, &front.path, &hostname, &path) {
if let Some(id) = &front.cluster_id {
cluster_ids.insert(id.to_string());
}
}
});

cluster_ids
}
}

fn domain_check(
Expand All @@ -1222,32 +1248,6 @@ fn domain_check(
true
}

pub fn get_cluster_ids_by_domain(
state: &ConfigState,
hostname: String,
path: Option<String>,
) -> HashSet<ClusterId> {
let mut cluster_ids: HashSet<ClusterId> = HashSet::new();

state.http_fronts.values().for_each(|front| {
if domain_check(&front.hostname, &front.path, &hostname, &path) {
if let Some(id) = &front.cluster_id {
cluster_ids.insert(id.to_string());
}
}
});

state.https_fronts.values().for_each(|front| {
if domain_check(&front.hostname, &front.path, &hostname, &path) {
if let Some(id) = &front.cluster_id {
cluster_ids.insert(id.to_string());
}
}
});

cluster_ids
}

pub fn get_certificate(state: &ConfigState, fingerprint: &[u8]) -> Option<CertificateWithNames> {
state
.certificates
Expand Down Expand Up @@ -1698,27 +1698,21 @@ mod tests {

let empty: HashSet<ClusterId> = HashSet::new();
assert_eq!(
get_cluster_ids_by_domain(&config, String::from("lolcatho.st"), None),
config.get_cluster_ids_by_domain(String::from("lolcatho.st"), None),
cluster1_cluster2
);
assert_eq!(
get_cluster_ids_by_domain(
&config,
String::from("lolcatho.st"),
Some(String::from("/api"))
),
config
.get_cluster_ids_by_domain(String::from("lolcatho.st"), Some(String::from("/api"))),
cluster2
);
assert_eq!(
get_cluster_ids_by_domain(&config, String::from("lolcathost"), None),
config.get_cluster_ids_by_domain(String::from("lolcathost"), None),
empty
);
assert_eq!(
get_cluster_ids_by_domain(
&config,
String::from("lolcathost"),
Some(String::from("/sozu"))
),
config
.get_cluster_ids_by_domain(String::from("lolcathost"), Some(String::from("/sozu"))),
empty
);
}
Expand Down
10 changes: 4 additions & 6 deletions lib/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use sozu_command::{
request::WorkerRequest,
response::{MessageId, WorkerResponse},
scm_socket::{Listeners, ScmSocket},
state::{get_certificate, get_cluster_ids_by_domain, ConfigState},
state::{get_certificate, ConfigState},
};
use time::{Duration, Instant};

Expand Down Expand Up @@ -907,11 +907,9 @@ impl Server {
));
}
Some(RequestType::QueryClustersByDomain(domain)) => {
let cluster_ids = get_cluster_ids_by_domain(
&self.config_state,
domain.hostname.clone(),
domain.path.clone(),
);
let cluster_ids = self
.config_state
.get_cluster_ids_by_domain(domain.hostname.clone(), domain.path.clone());
let vec = cluster_ids
.iter()
.map(|cluster_id| self.config_state.cluster_state(cluster_id))
Expand Down

0 comments on commit 146dd32

Please sign in to comment.