Skip to content

Commit

Permalink
feat: allow setting a base path for the database
Browse files Browse the repository at this point in the history
Also, don't double-arc the AppState.
  • Loading branch information
ctron committed Sep 14, 2023
1 parent cba2188 commit 832002b
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 94 deletions.
30 changes: 13 additions & 17 deletions collectorist/api/src/coordinator/collector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;

use chrono::Utc;
Expand All @@ -13,7 +14,7 @@ use collector_client::{
};
use collectorist_client::{CollectorConfig, Interest};

use crate::SharedState;
use crate::state::AppState;

#[derive(Debug, thiserror::Error)]
#[error("No configuration for collector")]
Expand All @@ -32,8 +33,8 @@ pub struct Collector {
}

impl Collector {
pub fn new(state: SharedState, id: String, config: CollectorConfig) -> Self {
let update = tokio::spawn(Collector::update(state.clone(), id.clone()));
pub fn new(state: Arc<AppState>, id: String, config: CollectorConfig) -> Self {
let update = tokio::spawn(Collector::update(state, id.clone()));
Self {
id,
config: config.clone(),
Expand All @@ -43,7 +44,7 @@ impl Collector {

pub async fn collect_packages(
&self,
state: SharedState,
state: &AppState,
purls: Vec<String>,
) -> Result<CollectPackagesResponse, anyhow::Error> {
Self::collect_packages_internal(
Expand All @@ -57,7 +58,7 @@ impl Collector {
}

async fn collect_packages_internal(
state: SharedState,
state: &AppState,
id: String,
config: &CollectorConfig,
purls: Vec<String>,
Expand Down Expand Up @@ -87,14 +88,14 @@ impl Collector {

pub async fn collect_vulnerabilities(
&self,
state: SharedState,
state: &AppState,
vulnerability_ids: HashSet<String>,
) -> Result<CollectVulnerabilitiesResponse, anyhow::Error> {
Self::collect_vulnerabilities_internal(state, self.id.clone(), &self.config, vulnerability_ids).await
}

async fn collect_vulnerabilities_internal(
state: SharedState,
state: &AppState,
id: String,
config: &CollectorConfig,
vulnerability_ids: HashSet<String>,
Expand All @@ -116,7 +117,7 @@ impl Collector {
Ok(response)
}

pub async fn update(state: SharedState, id: String) {
pub async fn update(state: Arc<AppState>, id: String) {
loop {
if let Some(config) = state.collectors.read().await.collector_config(id.clone()) {
let collector_url = config.url.clone();
Expand All @@ -130,14 +131,9 @@ impl Collector {

if !purls.is_empty() {
log::debug!("polling packages for {} -> {}", id, collector_url);
if let Ok(response) = Self::collect_packages_internal(
state.clone(),
id.clone(),
&config,
purls,
RetentionMode::All,
)
.await
if let Ok(response) =
Self::collect_packages_internal(&state, id.clone(), &config, purls, RetentionMode::All)
.await
{
// during normal re-scan, we did indeed discover some vulns, make sure they are in the DB.
let vuln_ids: HashSet<_> = response.purls.values().flatten().collect();
Expand All @@ -159,7 +155,7 @@ impl Collector {

if !vuln_ids.is_empty() {
log::debug!("polling vulnerabilities for {} -> {}", id, collector_url);
Self::collect_vulnerabilities_internal(state.clone(), id.clone(), &config, vuln_ids)
Self::collect_vulnerabilities_internal(&state, id.clone(), &config, vuln_ids)
.await
.ok();
}
Expand Down
12 changes: 6 additions & 6 deletions collectorist/api/src/coordinator/collectors.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::coordinator::collector::Collector;
use collector_client::{CollectPackagesResponse, CollectVulnerabilitiesResponse};
use collectorist_client::{CollectPackagesRequest, CollectorConfig, Interest};
use futures::future::join_all;

use crate::SharedState;
use crate::state::AppState;

#[derive(Default)]
pub struct Collectors {
collectors: HashMap<String, Collector>,
}

impl Collectors {
pub async fn register(&mut self, state: SharedState, id: String, config: CollectorConfig) -> Result<(), ()> {
self.collectors
.insert(id.clone(), Collector::new(state.clone(), id, config));
pub async fn register(&mut self, state: Arc<AppState>, id: String, config: CollectorConfig) -> Result<(), ()> {
self.collectors.insert(id.clone(), Collector::new(state, id, config));
Ok(())
}

Expand All @@ -34,7 +34,7 @@ impl Collectors {

pub async fn collect_packages(
&self,
state: SharedState,
state: &AppState,
request: CollectPackagesRequest,
) -> Vec<CollectPackagesResponse> {
let mut futures = Vec::new();
Expand All @@ -50,7 +50,7 @@ impl Collectors {

pub async fn collect_vulnerabilities(
&self,
state: SharedState,
state: &AppState,
vuln_ids: HashSet<String>,
) -> Vec<CollectVulnerabilitiesResponse> {
let mut futures = Vec::new();
Expand Down
10 changes: 5 additions & 5 deletions collectorist/api/src/coordinator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use log::{info, warn};
use reqwest::Url;
use tokio::time::{interval, sleep};

use crate::SharedState;
use crate::state::AppState;

pub struct Coordinator {
csub_url: Url,
Expand All @@ -33,7 +33,7 @@ impl Coordinator {
Self { csub_url }
}

pub async fn listen(&self, state: SharedState) {
pub async fn listen(&self, state: &AppState) {
let listener = async move {
loop {
if let Ok(mut csub) = CollectSubClient::new(self.csub_url.to_string()).await {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl Coordinator {

pub async fn collect_packages(
&self,
state: SharedState,
state: &AppState,
request: CollectPackagesRequest,
) -> Vec<CollectPackagesResponse> {
let collectors = state.collectors.read().await;
Expand All @@ -92,14 +92,14 @@ impl Coordinator {
result
}

pub async fn collect_vulnerabilities(&self, state: SharedState, request: CollectVulnerabilitiesRequest) {
pub async fn collect_vulnerabilities(&self, state: &AppState, request: CollectVulnerabilitiesRequest) {
let collectors = state.collectors.read().await;
collectors
.collect_vulnerabilities(state.clone(), request.vuln_ids.iter().cloned().collect::<HashSet<_>>())
.await;
}

pub async fn add_purl(&self, state: SharedState, purl: &str) -> Result<(), anyhow::Error> {
pub async fn add_purl(&self, state: &AppState, purl: &str) -> Result<(), anyhow::Error> {
state.db.insert_purl(purl).await
}
}
9 changes: 5 additions & 4 deletions collectorist/api/src/db/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::path::Path;
use std::str::FromStr;

use chrono::{DateTime, Utc};
Expand All @@ -12,13 +13,13 @@ pub struct Db {
}

impl Db {
pub async fn new() -> Result<Self, anyhow::Error> {
pub async fn new(base: impl AsRef<Path>) -> Result<Self, anyhow::Error> {
let db = Self {
pool: SqlitePool::connect_with(if cfg!(test) {
SqliteConnectOptions::from_str(":memory:")?
} else {
SqliteConnectOptions::default()
.filename(DB_FILE_NAME)
.filename(base.as_ref().join(DB_FILE_NAME))
.create_if_missing(true)
})
.await?,
Expand Down Expand Up @@ -286,7 +287,7 @@ mod test {

#[actix_web::test]
async fn insert_purl() -> Result<(), anyhow::Error> {
let db = Db::new().await?;
let db = Db::new(".").await?;

db.insert_purl("bob").await?;
db.insert_purl("bob").await?;
Expand All @@ -309,7 +310,7 @@ mod test {

#[actix_web::test]
async fn update_purl_scan_time() -> Result<(), anyhow::Error> {
let db = Db::new().await?;
let db = Db::new(".").await?;

db.insert_purl("not-scanned").await?;
db.insert_purl("is-scanned").await?;
Expand Down
16 changes: 10 additions & 6 deletions collectorist/api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::state::AppState;
use reqwest::Url;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::Arc;
use trustification_auth::client::OpenIdTokenProviderConfigArguments;
Expand Down Expand Up @@ -48,6 +49,10 @@ pub struct Run {
)]
pub(crate) guac_url: Url,

/// Base path to the database store. Defaults to the local directory.
#[arg(env, short = 'b', long = "storage-base")]
pub(crate) storage_base: Option<PathBuf>,

#[command(flatten)]
pub auth: AuthConfigArguments,

Expand Down Expand Up @@ -80,7 +85,7 @@ impl Run {

Infrastructure::from(self.infra)
.run("collectorist-api", |metrics| async move {
let state = Self::configure(self.csub_url, self.guac_url).await?;
let state = Self::configure(self.storage_base, self.csub_url, self.guac_url).await?;
let server = server::run(
state.clone(),
self.api.socket_addr()?,
Expand All @@ -89,7 +94,7 @@ impl Run {
authorizer,
swagger_oidc,
);
let listener = state.coordinator.listen(state.clone());
let listener = state.coordinator.listen(&state);
tokio::select! {
_ = listener => { }
_ = server => { }
Expand All @@ -101,10 +106,9 @@ impl Run {
Ok(ExitCode::SUCCESS)
}

async fn configure(csub_url: Url, guac_url: Url) -> anyhow::Result<Arc<AppState>> {
let state = Arc::new(AppState::new(csub_url, guac_url).await?);
async fn configure(base: Option<PathBuf>, csub_url: Url, guac_url: Url) -> anyhow::Result<Arc<AppState>> {
let base = base.unwrap_or_else(|| ".".into());
let state = Arc::new(AppState::new(base, csub_url, guac_url).await?);
Ok(state)
}
}

pub(crate) type SharedState = Arc<AppState>;
13 changes: 5 additions & 8 deletions collectorist/api/src/server/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use collector_client::CollectPackagesResponse;
use collectorist_client::CollectPackagesRequest;
use collectorist_client::CollectVulnerabilitiesRequest;

use crate::SharedState;
use crate::state::AppState;

/// Post a list of purls to be "gathered"
#[utoipa::path(
Expand All @@ -20,11 +20,11 @@ use crate::SharedState;
)]
#[post("/packages")]
pub(crate) async fn collect_packages(
state: web::Data<SharedState>,
state: web::Data<AppState>,
input: web::Json<CollectPackagesRequest>,
) -> actix_web::Result<impl Responder> {
let purls = input.into_inner();
let results = state.coordinator.collect_packages(state.get_ref().clone(), purls).await;
let results = state.coordinator.collect_packages(&state, purls).await;
let mut purls = HashMap::<String, Vec<String>>::new();
for gr in results {
for k in gr.purls.keys() {
Expand All @@ -47,13 +47,10 @@ pub(crate) async fn collect_packages(
)]
#[post("/vulnerabilities")]
pub(crate) async fn collect_vulnerabilities(
state: web::Data<SharedState>,
state: web::Data<AppState>,
input: web::Json<CollectVulnerabilitiesRequest>,
) -> actix_web::Result<impl Responder> {
let request = input.into_inner();
state
.coordinator
.collect_vulnerabilities(state.get_ref().clone(), request)
.await;
state.coordinator.collect_vulnerabilities(&state, request).await;
Ok(HttpResponse::Ok().finish())
}

0 comments on commit 832002b

Please sign in to comment.