Skip to content

Commit

Permalink
fix(security): handling of unsafe characters in outbound header names…
Browse files Browse the repository at this point in the history
… and values
  • Loading branch information
jbr committed Jan 24, 2024
1 parent 71ddfc5 commit 8d468f8
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 47 deletions.
17 changes: 14 additions & 3 deletions client/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,20 @@ impl Conn {

write!(buf, " HTTP/1.1\r\n")?;

for (header, values) in self.request_headers.iter() {
for value in values.iter() {
write!(buf, "{header}: {value}\r\n")?;
for (name, values) in &self.request_headers {
if !name.is_valid() {
return Err(Error::MalformedHeader(name.to_string().into()));
}

for value in values {
if !value.is_valid() {
return Err(Error::MalformedHeader(
format!("value for {name}: {value:?}").into(),
));
}
write!(buf, "{name}: ")?;
buf.extend_from_slice(value.as_ref());
write!(buf, "\r\n")?;
}
}

Expand Down
24 changes: 24 additions & 0 deletions client/tests/unsafe_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use test_harness::test;
use trillium_client::{Client, KnownHeaderName};
use trillium_testing::{connector, harness};

#[test(harness)]
async fn bad_characters_in_header_value() {
assert!(Client::new(connector(()))
.get("http://example.com")
.with_header(
KnownHeaderName::Referer,
"x\r\nConnection: keep-alive\r\n\r\nGET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
)
.await
.is_err());
}

#[test(harness)]
async fn bad_characters_in_header_name() {
assert!(Client::new(connector(()))
.get("http://example.com")
.with_header("dnt: 1\r\nConnection", "keep-alive")
.await
.is_err());
}
20 changes: 14 additions & 6 deletions http/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ where
}
}

fn write_headers(&mut self, output_buffer: &mut Vec<u8>) -> std::io::Result<()> {
fn write_headers(&mut self, output_buffer: &mut Vec<u8>) -> Result<()> {
use std::io::Write;
let status = self.status().unwrap_or(Status::NotFound);

Expand All @@ -801,11 +801,19 @@ where
&self.response_headers
);

for (header, values) in &self.response_headers {
for value in values {
write!(output_buffer, "{header}: ")?;
output_buffer.extend_from_slice(value.as_ref());
write!(output_buffer, "\r\n")?;
for (name, values) in &self.response_headers {
if name.is_valid() {
for value in values {
if value.is_valid() {
write!(output_buffer, "{name}: ")?;
output_buffer.extend_from_slice(value.as_ref());
write!(output_buffer, "\r\n")?;
} else {
log::error!("skipping invalid header value {value:?} for header {name}");
}
}
} else {
log::error!("skipping invalid header with name {name:?}");
}
}

Expand Down
33 changes: 19 additions & 14 deletions http/src/headers/header_name.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use smartcow::SmartCow;
use smartstring::alias::String as SmartString;
use std::{
fmt::{self, Debug, Display, Formatter},
hash::Hash,
str::FromStr,
};

use super::{KnownHeaderName, UnknownHeaderName};
use crate::Error;
use HeaderNameInner::{KnownHeader, UnknownHeader};

/// The name of a http header. This can be either a
/// [`KnownHeaderName`] or a string representation of an unknown
Expand All @@ -30,8 +30,6 @@ pub(super) enum HeaderNameInner<'a> {
KnownHeader(KnownHeaderName),
UnknownHeader(UnknownHeaderName<'a>),
}
use crate::Error;
use HeaderNameInner::{KnownHeader, UnknownHeader};

impl<'a> HeaderName<'a> {
/// Convert a potentially-borrowed headername to a static
Expand All @@ -40,9 +38,7 @@ impl<'a> HeaderName<'a> {
pub fn into_owned(self) -> HeaderName<'static> {
HeaderName(match self.0 {
KnownHeader(known) => KnownHeader(known),
UnknownHeader(UnknownHeaderName(smartcow)) => {
UnknownHeader(UnknownHeaderName(smartcow.into_owned()))
}
UnknownHeader(uhn) => UnknownHeader(uhn.into_owned()),
})
}

Expand All @@ -55,6 +51,14 @@ impl<'a> HeaderName<'a> {
pub fn to_owned(&self) -> HeaderName<'static> {
self.clone().into_owned()
}

/// Determine if this header name contains only the appropriate characters
pub fn is_valid(&self) -> bool {
match &self.0 {
KnownHeader(_) => true,
UnknownHeader(uh) => uh.is_valid(),
}
}
}

impl PartialEq<KnownHeaderName> for HeaderName<'_> {
Expand All @@ -79,7 +83,7 @@ impl From<String> for HeaderName<'static> {
fn from(s: String) -> Self {
Self(match s.parse::<KnownHeaderName>() {
Ok(khn) => KnownHeader(khn),
Err(()) => UnknownHeader(UnknownHeaderName(SmartCow::Owned(s.into()))),
Err(()) => UnknownHeader(UnknownHeaderName::from(s)),
})
}
}
Expand All @@ -88,7 +92,7 @@ impl<'a> From<&'a str> for HeaderName<'a> {
fn from(s: &'a str) -> Self {
Self(match s.parse::<KnownHeaderName>() {
Ok(khn) => KnownHeader(khn),
Err(_e) => UnknownHeader(UnknownHeaderName(SmartCow::Borrowed(s))),
Err(_e) => UnknownHeader(UnknownHeaderName::from(s)),
})
}
}
Expand All @@ -97,11 +101,12 @@ impl FromStr for HeaderName<'static> {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.is_ascii() {
Ok(Self(match s.parse::<KnownHeaderName>() {
Ok(known) => KnownHeader(known),
Err(()) => UnknownHeader(UnknownHeaderName(SmartCow::Owned(SmartString::from(s)))),
}))
if let Ok(known) = s.parse::<KnownHeaderName>() {
return Ok(known.into());
}
let uhn = UnknownHeaderName::from(s.to_string());
if uhn.is_valid() {
Ok(uhn.into())
} else {
Err(Error::MalformedHeader(s.to_string().into()))
}
Expand Down
43 changes: 25 additions & 18 deletions http/src/headers/header_value.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
use smallvec::SmallVec;
use smartcow::SmartCow;

use std::{
borrow::Cow,
fmt::{Debug, Display, Formatter},
};
use HeaderValueInner::{Bytes, Utf8};

/// A `HeaderValue` represents the right hand side of a single `name:
/// value` pair.
#[derive(Eq, PartialEq, Clone)]
pub struct HeaderValue(HeaderValueInner);

impl HeaderValue {
/// determine if this header contains no unsafe characters (\r, \n, \0)
pub fn is_valid(&self) -> bool {
memchr::memchr3(b'\r', b'\n', 0, self.as_ref()).is_none()
}
}

#[derive(Eq, PartialEq, Clone)]
pub(crate) enum HeaderValueInner {
Utf8(SmartCow<'static>),
Expand All @@ -24,17 +31,17 @@ impl serde::Serialize for HeaderValue {
S: serde::Serializer,
{
match &self.0 {
HeaderValueInner::Utf8(s) => serializer.serialize_str(s),
HeaderValueInner::Bytes(bytes) => serializer.serialize_bytes(bytes),
Utf8(s) => serializer.serialize_str(s),
Bytes(bytes) => serializer.serialize_bytes(bytes),
}
}
}

impl Debug for HeaderValue {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self.0 {
HeaderValueInner::Utf8(s) => Debug::fmt(s, f),
HeaderValueInner::Bytes(b) => Debug::fmt(&String::from_utf8_lossy(b), f),
Utf8(s) => Debug::fmt(s, f),
Bytes(b) => Debug::fmt(&String::from_utf8_lossy(b), f),
}
}
}
Expand All @@ -47,8 +54,8 @@ impl HeaderValue {
/// whether it's utf8, use the `AsRef<[u8]>` impl
pub fn as_str(&self) -> Option<&str> {
match &self.0 {
HeaderValueInner::Utf8(utf8) => Some(utf8),
HeaderValueInner::Bytes(_) => None,
Utf8(utf8) => Some(utf8),
Bytes(_) => None,
}
}

Expand All @@ -66,53 +73,53 @@ impl HeaderValue {
impl Display for HeaderValue {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match &self.0 {
HeaderValueInner::Utf8(s) => f.write_str(s),
HeaderValueInner::Bytes(b) => f.write_str(&String::from_utf8_lossy(b)),
Utf8(s) => f.write_str(s),
Bytes(b) => f.write_str(&String::from_utf8_lossy(b)),
}
}
}

impl From<Vec<u8>> for HeaderValue {
fn from(v: Vec<u8>) -> Self {
match String::from_utf8(v) {
Ok(s) => Self(HeaderValueInner::Utf8(SmartCow::Owned(s.into()))),
Err(e) => Self(HeaderValueInner::Bytes(e.into_bytes().into())),
Ok(s) => Self(Utf8(SmartCow::Owned(s.into()))),
Err(e) => Self(Bytes(e.into_bytes().into())),
}
}
}

impl From<Cow<'static, str>> for HeaderValue {
fn from(c: Cow<'static, str>) -> Self {
Self(HeaderValueInner::Utf8(SmartCow::from(c)))
Self(Utf8(SmartCow::from(c)))
}
}

impl From<&'static [u8]> for HeaderValue {
fn from(b: &'static [u8]) -> Self {
match std::str::from_utf8(b) {
Ok(s) => Self(HeaderValueInner::Utf8(SmartCow::Borrowed(s))),
Err(_) => Self(HeaderValueInner::Bytes(b.into())),
Ok(s) => Self(Utf8(SmartCow::Borrowed(s))),
Err(_) => Self(Bytes(b.into())),
}
}
}

impl From<String> for HeaderValue {
fn from(s: String) -> Self {
Self(HeaderValueInner::Utf8(SmartCow::Owned(s.into())))
Self(Utf8(SmartCow::Owned(s.into())))
}
}

impl From<&'static str> for HeaderValue {
fn from(s: &'static str) -> Self {
Self(HeaderValueInner::Utf8(SmartCow::Borrowed(s)))
Self(Utf8(SmartCow::Borrowed(s)))
}
}

impl AsRef<[u8]> for HeaderValue {
fn as_ref(&self) -> &[u8] {
match &self.0 {
HeaderValueInner::Utf8(utf8) => utf8.as_bytes(),
HeaderValueInner::Bytes(b) => b,
Utf8(utf8) => utf8.as_bytes(),
Bytes(b) => b,
}
}
}
Expand Down
34 changes: 28 additions & 6 deletions http/src/headers/unknown_header_name.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use super::{HeaderName, HeaderNameInner::UnknownHeader};
use hashbrown::Equivalent;
use smartcow::SmartCow;
use std::{
fmt::{self, Debug, Display, Formatter},
hash::{Hash, Hasher},
ops::Deref,
};

use hashbrown::Equivalent;
use smartcow::SmartCow;

use super::{HeaderName, HeaderNameInner::UnknownHeader};

#[derive(Clone)]
pub(super) struct UnknownHeaderName<'a>(pub(super) SmartCow<'a>);
pub(super) struct UnknownHeaderName<'a>(SmartCow<'a>);

impl PartialEq for UnknownHeaderName<'_> {
fn eq(&self, other: &Self) -> bool {
Expand Down Expand Up @@ -46,6 +44,30 @@ impl<'a> From<UnknownHeaderName<'a>> for HeaderName<'a> {
}
}

impl UnknownHeaderName<'_> {
pub(crate) fn is_valid(&self) -> bool {
self.0
.chars()
.all(|c| matches!(c, 'a'..='z'|'A'..='Z'|'0'..='9'|'-'|'_'))
}

pub(crate) fn into_owned(self) -> UnknownHeaderName<'static> {
UnknownHeaderName(self.0.into_owned())
}
}

impl From<String> for UnknownHeaderName<'static> {
fn from(value: String) -> Self {
Self(value.into())
}
}

impl<'a> From<&'a str> for UnknownHeaderName<'a> {
fn from(value: &'a str) -> Self {
Self(value.into())
}
}

impl<'a> From<SmartCow<'a>> for UnknownHeaderName<'a> {
fn from(value: SmartCow<'a>) -> Self {
Self(value)
Expand Down
49 changes: 49 additions & 0 deletions http/tests/unsafe_headers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use indoc::{formatdoc, indoc};
use pretty_assertions::assert_eq;
use stopper::Stopper;
use test_harness::test;
use trillium_http::{Conn, KnownHeaderName, SERVER};
use trillium_testing::{harness, TestResult, TestTransport};

const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT";

async fn handler(mut conn: Conn<TestTransport>) -> Conn<TestTransport> {
conn.set_status(200);
conn.set_response_body("response: 0123456789");
conn.response_headers_mut()
.insert(KnownHeaderName::Date, TEST_DATE);
conn.response_headers_mut().insert(
KnownHeaderName::Connection,
"close\r\nGET / HTTP/1.1\r\nHost: example.com\r\n\r\n",
);
conn.response_headers_mut().insert("Bad\r\nHeader", "true");
conn
}

#[test(harness)]
async fn bad_headers() -> TestResult {
let (client, server) = TestTransport::new();

trillium_testing::spawn(async move {
Conn::map(server, Stopper::new(), handler).await.unwrap();
});

client.write_all(indoc! {"
GET / HTTP/1.1\r
Host: example.com\r
\r
"});

let expected_response = formatdoc! {"
HTTP/1.1 200 OK\r
Server: {SERVER}\r
Date: {TEST_DATE}\r
Content-Length: 20\r
\r
response: 0123456789\
"};

assert_eq!(client.read_available_string().await, expected_response);

Ok(())
}

0 comments on commit 8d468f8

Please sign in to comment.