Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 219 additions & 12 deletions src/proxy_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,6 @@ where
None => return Ok(false),
};

// Read body if content-length is set.
let body = read_body(stream, &leftover, &headers).await?;

let default_port = if scheme == "https" { 443 } else { 80 };
Expand Down Expand Up @@ -1006,36 +1005,178 @@ where
}
}

fn header_value<'a>(headers: &'a [(String, String)], name: &str) -> Option<&'a str> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}

fn expects_100_continue(headers: &[(String, String)]) -> bool {
header_value(headers, "expect")
.map(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("100-continue"))
})
.unwrap_or(false)
}

fn invalid_body(msg: impl Into<String>) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::InvalidData, msg.into())
}

async fn read_body<S>(
stream: &mut S,
leftover: &[u8],
headers: &[(String, String)],
) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let cl: Option<usize> = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.parse().ok());
let transfer_encoding = header_value(headers, "transfer-encoding");
let is_chunked = transfer_encoding
.map(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("chunked"))
})
.unwrap_or(false);

let content_length = match header_value(headers, "content-length") {
Some(v) => Some(
v.parse::<usize>()
.map_err(|_| invalid_body(format!("invalid Content-Length: {}", v)))?,
),
None => None,
};

if transfer_encoding.is_some() && !is_chunked {
return Err(invalid_body(format!(
"unsupported Transfer-Encoding: {}",
transfer_encoding.unwrap_or_default()
)));
}

if is_chunked && content_length.is_some() {
return Err(invalid_body(
"both Transfer-Encoding: chunked and Content-Length are present",
));
}

let Some(cl) = cl else {
if expects_100_continue(headers) && (is_chunked || content_length.is_some()) {
stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
stream.flush().await?;
}

if is_chunked {
return read_chunked_request_body(stream, leftover.to_vec()).await;
}

let Some(content_length) = content_length else {
return Ok(Vec::new());
};
let mut body = Vec::with_capacity(cl);
body.extend_from_slice(&leftover[..leftover.len().min(cl)]);

let mut body = Vec::with_capacity(content_length);
body.extend_from_slice(&leftover[..leftover.len().min(content_length)]);
let mut tmp = [0u8; 8192];
while body.len() < cl {
while body.len() < content_length {
let n = stream.read(&mut tmp).await?;
if n == 0 {
break;
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF mid-body",
));
}
let need = cl - body.len();
let need = content_length - body.len();
body.extend_from_slice(&tmp[..n.min(need)]);
}
Ok(body)
}

async fn read_chunked_request_body<S>(stream: &mut S, mut buf: Vec<u8>) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
{
let mut out = Vec::new();
let mut tmp = [0u8; 8192];

loop {
let line = read_crlf_line(stream, &mut buf, &mut tmp).await?;
if line.is_empty() {
continue;
}

let line_str = std::str::from_utf8(&line)
.map_err(|_| invalid_body("non-utf8 chunk size line"))?
.trim();
let size_hex = line_str.split(';').next().unwrap_or("");
let size = usize::from_str_radix(size_hex, 16)
.map_err(|_| invalid_body(format!("bad chunk size '{}'", line_str)))?;

if size == 0 {
loop {
let trailer = read_crlf_line(stream, &mut buf, &mut tmp).await?;
if trailer.is_empty() {
return Ok(out);
}
}
}

fill_buffer(stream, &mut buf, &mut tmp, size + 2).await?;
if &buf[size..size + 2] != b"\r\n" {
return Err(invalid_body("chunk missing trailing CRLF"));
}
out.extend_from_slice(&buf[..size]);
buf.drain(..size + 2);
}
}

async fn read_crlf_line<S>(
stream: &mut S,
buf: &mut Vec<u8>,
tmp: &mut [u8],
) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
{
loop {
if let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
let line = buf[..idx].to_vec();
buf.drain(..idx + 2);
return Ok(line);
}
let n = stream.read(tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF in chunked body",
));
}
buf.extend_from_slice(&tmp[..n]);
}
}

async fn fill_buffer<S>(
stream: &mut S,
buf: &mut Vec<u8>,
tmp: &mut [u8],
want: usize,
) -> std::io::Result<()>
where
S: tokio::io::AsyncRead + Unpin,
{
while buf.len() < want {
let n = stream.read(tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF in chunked body",
));
}
buf.extend_from_slice(&tmp[..n]);
}
Ok(())
}

// ---------- Plain HTTP proxy ----------

async fn do_plain_http(
Expand Down Expand Up @@ -1068,3 +1209,69 @@ async fn do_plain_http(
sock.flush().await?;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};

fn headers(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect()
}

#[tokio::test(flavor = "current_thread")]
async fn read_body_decodes_chunked_request() {
let (mut client, mut server) = duplex(1024);
let writer = tokio::spawn(async move {
client
.write_all(b"llo\r\n6\r\n world\r\n0\r\nFoo: bar\r\n\r\n")
.await
.unwrap();
});

let body = read_body(
&mut server,
b"5\r\nhe",
&headers(&[("Transfer-Encoding", "chunked")]),
)
.await
.unwrap();

writer.await.unwrap();
assert_eq!(body, b"hello world");
}

#[tokio::test(flavor = "current_thread")]
async fn read_body_sends_100_continue_before_waiting_for_body() {
let (mut client, mut server) = duplex(1024);
let client_task = tokio::spawn(async move {
let mut got = Vec::new();
let mut tmp = [0u8; 64];
loop {
let n = client.read(&mut tmp).await.unwrap();
assert!(n > 0, "proxy closed before sending 100 Continue");
got.extend_from_slice(&tmp[..n]);
if got.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
assert_eq!(got, b"HTTP/1.1 100 Continue\r\n\r\n");
client.write_all(b"hello").await.unwrap();
});

let body = read_body(
&mut server,
&[],
&headers(&[("Content-Length", "5"), ("Expect", "100-continue")]),
)
.await
.unwrap();

client_task.await.unwrap();
assert_eq!(body, b"hello");
}

}