Skip to content

Commit

Permalink
feat: added ability to specify upstream url in the request headers
Browse files Browse the repository at this point in the history
  • Loading branch information
talzion12 committed Jan 5, 2024
1 parent bf559a1 commit 6759853
Show file tree
Hide file tree
Showing 11 changed files with 245 additions and 85 deletions.
43 changes: 26 additions & 17 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ edition = "2021"
async-trait = "0.1.77"
clap = { version = "4.4.13", features = ["derive", "env"] }
color-eyre = "0.6"
thiserror = "1.0.56"
serde-error = "0.1.2"
dotenv = "0.15.0"
eyre = "0.6.11"
url = "2.5.0"
Expand All @@ -28,7 +30,7 @@ hex = "0.4.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
phf = { version = "0.11", features = ["macros"] }
opendal = { version = "0.44", features = [
opendal = { version = "0.41", features = [
"services-gcs",
"services-fs",
"rustls",
Expand Down
4 changes: 3 additions & 1 deletion charts/http-cache/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ spec:
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.image.pullPolicy }}
env:
{{ if .Values.upstream.url }}
- name: UPSTREAM_URL
value: {{ .Values.upstream.url | required "upstream.url is required" | quote }}
value: {{ .Values.upstream.url | quote }}
{{ end }}
- name: CACHE_URL
value: {{ .Values.cache.url | required "cache.url is required" | quote }}
- name: RUST_LOG
Expand Down
13 changes: 8 additions & 5 deletions src/cache/create_cache_storage_from_url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ pub async fn create_cache_storage_from_url(url: &Url) -> color_eyre::Result<Box<
"file" => {
let mut builder = Fs::default();
builder.root(url.path());
tracing::info!("Using filesystem cache at root {}", url.path());
Box::new(OpendalStorage::new(Operator::new(builder)?.finish()))
}
"gs" => {
let mut builder = Gcs::default();
builder.bucket(
url.host_str()
.ok_or_else(|| color_eyre::eyre::eyre!("Must set url host as bucket"))?,
);
builder.root(url.path());
let bucket = url
.host_str()
.ok_or_else(|| color_eyre::eyre::eyre!("Must set url host as bucket"))?;
let root = url.path();
builder.bucket(bucket);
builder.root(root);
tracing::info!("Using google cloud cache at bucket {bucket} with root {root}");
Box::new(OpendalStorage::new(Operator::new(builder)?.finish()))
}
other => color_eyre::eyre::bail!("Scheme not supported {other}"),
Expand Down
63 changes: 30 additions & 33 deletions src/cache/layer.rs
Original file line number Diff line number Diff line change
@@ -1,75 +1,71 @@
use std::{
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};

use futures::{channel::mpsc::channel, future::BoxFuture, FutureExt, SinkExt, StreamExt};
use http::{Request, Response, StatusCode};
use hyper::{
body::{Bytes, HttpBody},
Body,
};
use http::{Request, Response, StatusCode, Uri};
use hyper::{body::Bytes, Body};
use phf::phf_set;
use tracing::Instrument;
use url::Url;

use crate::cache::metadata::CacheMetadata;
use crate::{cache::metadata::CacheMetadata, upstream_uri::layer::UpstreamUriExt};

use super::{create_cache_storage_from_url, storage::Cache, GetBody};

pub struct CachingLayer<C: ?Sized, B> {
pub struct CachingLayer<C: ?Sized> {
cache: Arc<C>,
phantom: PhantomData<B>,
}

impl<C: ?Sized, B> CachingLayer<C, B> {
impl<C: ?Sized> CachingLayer<C> {
pub fn new(cache: impl Into<Arc<C>>) -> Self {
Self {
cache: cache.into(),
phantom: PhantomData,
}
}
}

impl<B> CachingLayer<dyn Cache, B> {
impl CachingLayer<dyn Cache> {
pub async fn from_url(url: &Url) -> color_eyre::Result<Self> {
let storage = create_cache_storage_from_url(url).await?;
Ok(Self::new(storage))
}
}

impl<S: Clone, C: Cache + ?Sized, B> tower::Layer<S> for CachingLayer<C, B> {
type Service = CachingService<S, C, B>;
impl<S: Clone, C: Cache + ?Sized> tower::Layer<S> for CachingLayer<C> {
type Service = CachingService<S, C>;

fn layer(&self, inner: S) -> Self::Service {
CachingService {
inner,
cache: self.cache.clone(),
phantom: PhantomData,
}
}
}

pub struct CachingService<S, C: ?Sized, B> {
pub struct CachingService<S, C: ?Sized> {
inner: S,
cache: Arc<C>,
phantom: PhantomData<B>,
}

impl<S, C, B> CachingService<S, C, B>
impl<S, C> CachingService<S, C>
where
S: tower::Service<Request<B>, Response = Response<Body>, Error = hyper::Error> + Send + Sync,
S: tower::Service<Request<Body>, Response = Response<Body>, Error = hyper::Error> + Send + Sync,
C: Cache + 'static + ?Sized,
{
async fn on_request(&mut self, request: Request<B>) -> Result<Response<Body>, hyper::Error> {
let uri = request.uri();
tracing::debug!("Received request for {uri}");
let cache_result = self.cache.get(uri).await;
async fn on_request(&mut self, request: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let UpstreamUriExt(upstream_uri) = request
.extensions()
.get()
.expect("Upstream uri extension is missing");

tracing::debug!("Received request for {upstream_uri}");
let cache_result = self.cache.get(upstream_uri).await;

match cache_result {
Ok(Some((metadata, body))) => self.on_cache_hit(metadata, body).await,
Ok(None) => self.on_cache_miss(request).await,
Ok(None) => self.on_cache_miss(upstream_uri.clone(), request).await,
Err(error) => {
tracing::error!(%error, "Failed to read from cache");
Ok(Response::builder()
Expand Down Expand Up @@ -107,10 +103,13 @@ where
Ok(body)
}

async fn on_cache_miss(&mut self, request: Request<B>) -> Result<Response<Body>, hyper::Error> {
async fn on_cache_miss(
&mut self,
upstream_uri: Uri,
request: Request<Body>,
) -> Result<Response<Body>, hyper::Error> {
tracing::info!("Cache miss");

let uri = request.uri().clone();
let response = self.inner.call(request).await?;

if !response.status().is_success() {
Expand Down Expand Up @@ -138,7 +137,7 @@ where

tokio::spawn(
async move {
match cache_cloned.set(&uri, receiver, metadata).await {
match cache_cloned.set(&upstream_uri, receiver, metadata).await {
Ok(()) => tracing::info!("Wrote to cache"),
Err(err) => tracing::error!("Failed to write to cache {err:?}"),
}
Expand All @@ -155,23 +154,21 @@ where
}
}

impl<S: Clone, C: ?Sized, B> Clone for CachingService<S, C, B> {
impl<S: Clone, C: ?Sized> Clone for CachingService<S, C> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
cache: self.cache.clone(),
phantom: PhantomData,
}
}
}

impl<S, C, B> tower::Service<Request<B>> for CachingService<S, C, B>
impl<S, C> tower::Service<Request<Body>> for CachingService<S, C>
where
S: tower::Service<Request<B>, Response = Response<Body>, Error = hyper::Error> + Send + Sync,
S: tower::Service<Request<Body>, Response = Response<Body>, Error = hyper::Error> + Send + Sync,
S: Clone + 'static,
S::Future: Send,
C: Cache + Send + Sync + 'static + ?Sized,
B: HttpBody + Send + Sync + 'static,
{
type Response = S::Response;
type Error = S::Error;
Expand All @@ -181,7 +178,7 @@ where
self.inner.poll_ready(cx)
}

fn call(&mut self, request: Request<B>) -> Self::Future {
fn call(&mut self, request: Request<Body>) -> Self::Future {
let mut c = self.clone();
async move { c.on_request(request).await }.boxed()
}
Expand Down
21 changes: 11 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ use cache::CachingLayer;
use clap::Parser;
use eyre::Context;
use http::Request;
use hyper::{client::HttpConnector, Body};
use hyper_rustls::HttpsConnector;
use hyper::Body;
use proxy::ProxyService;
use tower::{make::Shared, ServiceBuilder};
use tower::{make::Shared, util::option_layer, ServiceBuilder};
use tracing_error::ErrorLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};

mod cache;
mod options;
mod proxy;
mod upstream_uri;

use options::{LogFormat, Options};
use upstream_uri::layer::ExtractUpstreamUriLayer;

#[tokio::main]
async fn main() -> eyre::Result<()> {
Expand All @@ -37,12 +38,11 @@ async fn main() -> eyre::Result<()> {

let cache_layer = CachingLayer::from_url(&options.cache_url).await?;

let cache_layer_2 =
tower::util::option_layer(if let Some(cache_url_2) = &options.cache_url_2 {
Some(CachingLayer::from_url(cache_url_2).await?)
} else {
None
});
let cache_layer_2 = option_layer(if let Some(cache_url_2) = &options.cache_url_2 {
Some(CachingLayer::from_url(cache_url_2).await?)
} else {
None
});

let client = hyper::Client::builder().build(
hyper_rustls::HttpsConnectorBuilder::new()
Expand All @@ -52,7 +52,7 @@ async fn main() -> eyre::Result<()> {
.enable_http2()
.build(),
);
let proxy = ProxyService::<HttpsConnector<HttpConnector>, Body>::new(options.upstream, client);
let proxy = ProxyService::new(client);

let service = ServiceBuilder::new()
.layer(
Expand All @@ -67,6 +67,7 @@ async fn main() -> eyre::Result<()> {
},
),
)
.layer(ExtractUpstreamUriLayer::new(options.upstream))
.layer(cache_layer_2)
.layer(cache_layer)
.service(proxy);
Expand Down
Loading

0 comments on commit 6759853

Please sign in to comment.