Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented poll_* family functions on UDP ProxySocket #987

Merged
merged 5 commits into from
Oct 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 42 additions & 1 deletion crates/shadowsocks/src/relay/socks5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{
vec,
};

use bytes::{BufMut, BytesMut};
use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

pub use self::consts::{
Expand Down Expand Up @@ -213,6 +213,47 @@ pub enum Address {
}

impl Address {
/// read from a cursor
pub fn read_cursor<T: AsRef<[u8]>>(cur: &mut io::Cursor<T>) -> Result<Address, Error> {
zonyitoo marked this conversation as resolved.
Show resolved Hide resolved
if cur.remaining() < 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}

let atyp = cur.get_u8();
match atyp {
consts::SOCKS5_ADDR_TYPE_IPV4 => {
if cur.remaining() < 4 + 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}
let addr = Ipv4Addr::from(cur.get_u32());
let port = cur.get_u16();
Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(addr, port))))
}
consts::SOCKS5_ADDR_TYPE_IPV6 => {
if cur.remaining() < 16 + 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}
let addr = Ipv6Addr::from(cur.get_u128());
let port = cur.get_u16();
Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(
addr, port, 0, 0,
))))
}
consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
let domain_len = cur.get_u8() as usize;
if cur.remaining() < domain_len {
return Err(Error::AddressDomainInvalidEncoding);
}
let mut buf = vec![0u8; domain_len];
cur.copy_to_slice(&mut buf);
let port = cur.get_u16();
let addr = String::from_utf8(buf).map_err(|_| Error::AddressDomainInvalidEncoding)?;
Ok(Address::DomainNameAddress(addr, port))
}
_ => Err(Error::AddressTypeNotSupported(atyp)),
}
}

/// Parse from a `AsyncRead`
pub async fn read_from<R>(stream: &mut R) -> Result<Address, Error>
where
Expand Down
9 changes: 5 additions & 4 deletions crates/shadowsocks/src/relay/udprelay/aead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub fn encrypt_payload_aead(
}

/// Decrypt UDP AEAD protocol packet
pub async fn decrypt_payload_aead(
pub fn decrypt_payload_aead(
_context: &Context,
method: CipherKind,
key: &[u8],
Expand Down Expand Up @@ -109,7 +109,7 @@ pub async fn decrypt_payload_aead(
let data_len = data.len() - tag_len;
let data = &mut data[..data_len];

let (dn, addr) = parse_packet(data).await?;
let (dn, addr) = parse_packet(data)?;

let data_length = data_len - dn;
let data_start_idx = salt_len + dn;
Expand All @@ -120,9 +120,10 @@ pub async fn decrypt_payload_aead(
Ok((data_length, addr))
}

async fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> {
#[inline]
fn parse_packet(buf: &[u8]) -> ProtocolResult<(usize, Address)> {
let mut cur = Cursor::new(buf);
match Address::read_from(&mut cur).await {
match Address::read_cursor(&mut cur) {
Ok(address) => {
let pos = cur.position() as usize;
Ok((pos, address))
Expand Down
8 changes: 4 additions & 4 deletions crates/shadowsocks/src/relay/udprelay/aead_2022.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ pub fn encrypt_client_payload_aead_2022(
}

/// Decrypt `Client -> Server` UDP AEAD protocol packet
pub async fn decrypt_client_payload_aead_2022(
pub fn decrypt_client_payload_aead_2022(
context: &Context,
method: CipherKind,
key: &[u8],
Expand Down Expand Up @@ -581,7 +581,7 @@ pub async fn decrypt_client_payload_aead_2022(
user,
};

let addr = match Address::read_from(&mut cursor).await {
let addr = match Address::read_cursor(&mut cursor) {
Ok(a) => a,
Err(err) => return Err(ProtocolError::InvalidAddress(err)),
};
Expand Down Expand Up @@ -641,7 +641,7 @@ pub fn encrypt_server_payload_aead_2022(
}

/// Decrypt `Server -> Client` UDP AEAD protocol packet
pub async fn decrypt_server_payload_aead_2022(
pub fn decrypt_server_payload_aead_2022(
context: &Context,
method: CipherKind,
key: &[u8],
Expand Down Expand Up @@ -687,7 +687,7 @@ pub async fn decrypt_server_payload_aead_2022(
user: None,
};

let addr = match Address::read_from(&mut cursor).await {
let addr = match Address::read_cursor(&mut cursor) {
Ok(a) => a,
Err(err) => return Err(ProtocolError::InvalidAddress(err)),
};
Expand Down
14 changes: 4 additions & 10 deletions crates/shadowsocks/src/relay/udprelay/crypto_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub fn encrypt_server_payload(
}

/// Decrypt `Client -> Server` payload from ShadowSocks UDP encrypted packet
pub async fn decrypt_client_payload(
pub fn decrypt_client_payload(
context: &Context,
method: CipherKind,
key: &[u8],
Expand All @@ -143,7 +143,7 @@ pub async fn decrypt_client_payload(
CipherCategory::None => {
let _ = user_manager;
let mut cur = Cursor::new(payload);
match Address::read_from(&mut cur).await {
match Address::read_cursor(&mut cur) {
Ok(address) => {
let pos = cur.position() as usize;
let payload = cur.into_inner();
Expand All @@ -157,27 +157,24 @@ pub async fn decrypt_client_payload(
CipherCategory::Stream => {
let _ = user_manager;
decrypt_payload_stream(context, method, key, payload)
.await
.map(|(n, a)| (n, a, None))
.map_err(Into::into)
}
CipherCategory::Aead => {
let _ = user_manager;
decrypt_payload_aead(context, method, key, payload)
.await
.map(|(n, a)| (n, a, None))
.map_err(Into::into)
}
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => decrypt_client_payload_aead_2022(context, method, key, payload, user_manager)
.await
.map(|(n, a, c)| (n, a, Some(c)))
.map_err(Into::into),
}
}

/// Decrypt `Server -> Client` payload from ShadowSocks UDP encrypted packet
pub async fn decrypt_server_payload(
pub fn decrypt_server_payload(
context: &Context,
method: CipherKind,
key: &[u8],
Expand All @@ -186,7 +183,7 @@ pub async fn decrypt_server_payload(
match method.category() {
CipherCategory::None => {
let mut cur = Cursor::new(payload);
match Address::read_from(&mut cur).await {
match Address::read_cursor(&mut cur) {
Ok(address) => {
let pos = cur.position() as usize;
let payload = cur.into_inner();
Expand All @@ -198,16 +195,13 @@ pub async fn decrypt_server_payload(
}
#[cfg(feature = "stream-cipher")]
CipherCategory::Stream => decrypt_payload_stream(context, method, key, payload)
.await
.map(|(n, a)| (n, a, None))
.map_err(Into::into),
CipherCategory::Aead => decrypt_payload_aead(context, method, key, payload)
.await
.map(|(n, a)| (n, a, None))
.map_err(Into::into),
#[cfg(feature = "aead-cipher-2022")]
CipherCategory::Aead2022 => decrypt_server_payload_aead_2022(context, method, key, payload)
.await
.map(|(n, a, c)| (n, a, Some(c)))
.map_err(Into::into),
}
Expand Down