Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TCP/TLS server address filter #87

Merged
merged 6 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions ffi/bindings/c/server_example.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ int run_tcp_channel(rodbus_runtime_t* runtime)
rodbus_server_t* server = NULL;
rodbus_device_map_t* map = build_device_map();
rodbus_decode_level_t decode_level = rodbus_decode_level_nothing();
rodbus_param_error_t err = rodbus_server_create_tcp(runtime, "127.0.0.1", 502, 100, map, decode_level, &server);
rodbus_address_filter_t* filter = rodbus_address_filter_any();
rodbus_param_error_t err = rodbus_server_create_tcp(runtime, "127.0.0.1", 502, filter, 100, map, decode_level, &server);
rodbus_address_filter_destroy(filter);
rodbus_device_map_destroy(map);

if (err) {
Expand Down Expand Up @@ -278,7 +280,9 @@ int run_tls_channel(rodbus_runtime_t* runtime, rodbus_tls_server_config_t tls_co
rodbus_device_map_t* map = build_device_map();
rodbus_authorization_handler_t auth_handler = get_auth_handler();
rodbus_decode_level_t decode_level = rodbus_decode_level_nothing();
rodbus_param_error_t err = rodbus_server_create_tls_with_authz(runtime, "127.0.0.1", 802, 100, map, tls_config, auth_handler, decode_level, &server);
rodbus_address_filter_t* filter = rodbus_address_filter_any();
rodbus_param_error_t err = rodbus_server_create_tls_with_authz(runtime, "127.0.0.1", 802, filter, 100, map, tls_config, auth_handler, decode_level, &server);
rodbus_address_filter_destroy(filter);
rodbus_device_map_destroy(map);

if (err) {
Expand Down
6 changes: 4 additions & 2 deletions ffi/bindings/c/server_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ int run_tcp_server(rodbus::Runtime& runtime)
auto device_map = create_device_map();

// ANCHOR: tcp_server_create
auto server = rodbus::Server::create_tcp(runtime, "127.0.0.1", 502, 100, device_map, rodbus::DecodeLevel::nothing());
auto filter = rodbus::AddressFilter::any();
auto server = rodbus::Server::create_tcp(runtime, "127.0.0.1", 502, filter, 100, device_map, rodbus::DecodeLevel::nothing());
// ANCHOR_END: tcp_server_create

return run_server(server);
Expand All @@ -233,7 +234,8 @@ int run_tls_server(rodbus::Runtime& runtime, const rodbus::TlsServerConfig& tls_
auto device_map = create_device_map();

// ANCHOR: tls_server_create
auto server = rodbus::Server::create_tls_with_authz(runtime, "127.0.0.1", 802, 100, device_map, tls_config, std::make_unique<AuthorizationHandler>(), rodbus::DecodeLevel::nothing());
auto filter = rodbus::AddressFilter::any();
auto server = rodbus::Server::create_tls_with_authz(runtime, "127.0.0.1", 802, filter, 100, device_map, tls_config, std::make_unique<AuthorizationHandler>(), rodbus::DecodeLevel::nothing());
// ANCHOR_END: tls_server_create

return run_server(server);
Expand Down
4 changes: 2 additions & 2 deletions ffi/bindings/dotnet/examples/server/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private static Server CreateServer(string type, Runtime runtime, DeviceMap map)
private static Server CreateTcpServer(Runtime runtime, DeviceMap map)
{
// ANCHOR: tcp_server_create
var server = Server.CreateTcp(runtime, "127.0.0.1", 502, 100, map, DecodeLevel.Nothing());
var server = Server.CreateTcp(runtime, "127.0.0.1", 502, AddressFilter.Any(), 100, map, DecodeLevel.Nothing());
// ANCHOR_END: tcp_server_create

return server;
Expand All @@ -205,7 +205,7 @@ private static Server CreateRtuServer(Runtime runtime, DeviceMap map)
private static Server CreateTlsServer(Runtime runtime, DeviceMap map, TlsServerConfig tlsConfig)
{
// ANCHOR: tls_server_create
var server = Server.CreateTlsWithAuthz(runtime, "127.0.0.1", 802, 10, map, tlsConfig, new AuthorizationHandler(), DecodeLevel.Nothing());
var server = Server.CreateTlsWithAuthz(runtime, "127.0.0.1", 802, AddressFilter.Any(), 10, map, tlsConfig, new AuthorizationHandler(), DecodeLevel.Nothing());
// ANCHOR_END: tls_server_create

return server;
Expand Down
2 changes: 1 addition & 1 deletion ffi/bindings/dotnet/rodbus-tests/IntegrationTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public void ClientAndServerCommunication()
}
});

var server = Server.CreateTcp(runtime, ENDPOINT, PORT, 100, map, DecodeLevel.Nothing());
var server = Server.CreateTcp(runtime, ENDPOINT, PORT, AddressFilter.Any(), 100, map, DecodeLevel.Nothing());
var client = ClientChannel.CreateTcp(runtime, ENDPOINT, PORT, 10, new RetryStrategy(), DecodeLevel.Nothing(), new ClientStateListener());

client.Enable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ private static Server createServer(String type, Runtime runtime, DeviceMap map)

private static Server createTcpServer(Runtime runtime, DeviceMap map) {
// ANCHOR: tcp_server_create
Server server = Server.createTcp(runtime, "127.0.0.1", ushort(502), ushort(100), map, DecodeLevel.nothing());
Server server = Server.createTcp(runtime, "127.0.0.1", ushort(502), AddressFilter.any(), ushort(100), map, DecodeLevel.nothing());
// ANCHOR_END: tcp_server_create

return server;
Expand All @@ -177,7 +177,7 @@ private static Server createRtuServer(Runtime runtime, DeviceMap map) {

private static Server createTlsServer(Runtime runtime, DeviceMap map, TlsServerConfig tlsConfig) {
// ANCHOR: tls_server_create
Server server = Server.createTlsWithAuthz(runtime, "127.0.0.1", ushort(802), ushort(10), map, tlsConfig, new TestAuthorizationHandler(), DecodeLevel.nothing());
Server server = Server.createTlsWithAuthz(runtime, "127.0.0.1", ushort(802), AddressFilter.any(), ushort(10), map, tlsConfig, new TestAuthorizationHandler(), DecodeLevel.nothing());
// ANCHOR_END: tls_server_create

return server;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void clientAndServerCommunication() throws ExecutionException, InterruptedExcept
}
});

final Server server = Server.createTcp(runtime, ENDPOINT, PORT, ushort(100), deviceMap, DecodeLevel.nothing());
final Server server = Server.createTcp(runtime, ENDPOINT, PORT, AddressFilter.any(), ushort(100), deviceMap, DecodeLevel.nothing());
final ClientChannel client = ClientChannel.createTcp(runtime, ENDPOINT, PORT, ushort(10), new RetryStrategy(), DecodeLevel.nothing(), new NullClientStateListener());

client.enable();
Expand Down
91 changes: 90 additions & 1 deletion ffi/rodbus-ffi/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rodbus::server::ServerHandle;
use rodbus::AddressRange;
use rodbus::{ExceptionCode, Indexed, UnitId};
use std::collections::HashMap;
use std::ffi::CString;
use std::ffi::{CStr, CString};
use std::net::{IpAddr, SocketAddr};
use std::path::Path;

Expand Down Expand Up @@ -274,11 +274,13 @@ pub(crate) unsafe fn server_create_tcp(
runtime: *mut crate::Runtime,
ip_addr: &std::ffi::CStr,
port: u16,
filter: *mut crate::AddressFilter,
max_sessions: u16,
endpoints: *mut crate::DeviceMap,
decode_level: ffi::DecodeLevel,
) -> Result<*mut crate::Server, ffi::ParamError> {
let runtime = runtime.as_ref().ok_or(ffi::ParamError::NullParameter)?;
let filter = filter.as_ref().ok_or(ffi::ParamError::NullParameter)?;
let address = get_socket_addr(ip_addr, port)?;
let endpoints = endpoints.as_mut().ok_or(ffi::ParamError::NullParameter)?;

Expand All @@ -287,6 +289,7 @@ pub(crate) unsafe fn server_create_tcp(
max_sessions as usize,
address,
handler_map.clone(),
filter.into(),
decode_level.into(),
);

Expand Down Expand Up @@ -339,6 +342,7 @@ pub(crate) unsafe fn server_create_tls(
runtime: *mut crate::Runtime,
ip_addr: &std::ffi::CStr,
port: u16,
filter: *mut crate::AddressFilter,
max_sessions: u16,
endpoints: *mut crate::DeviceMap,
tls_config: ffi::TlsServerConfig,
Expand All @@ -348,6 +352,7 @@ pub(crate) unsafe fn server_create_tls(
runtime,
ip_addr,
port,
filter,
max_sessions,
endpoints,
tls_config,
Expand All @@ -361,6 +366,7 @@ pub(crate) unsafe fn server_create_tls_with_authz(
runtime: *mut crate::Runtime,
ip_addr: &std::ffi::CStr,
port: u16,
filter: *mut crate::AddressFilter,
max_sessions: u16,
endpoints: *mut crate::DeviceMap,
tls_config: ffi::TlsServerConfig,
Expand All @@ -371,6 +377,7 @@ pub(crate) unsafe fn server_create_tls_with_authz(
runtime,
ip_addr,
port,
filter,
max_sessions,
endpoints,
tls_config,
Expand All @@ -384,13 +391,15 @@ pub(crate) unsafe fn server_create_tls_impl(
runtime: *mut crate::Runtime,
ip_addr: &std::ffi::CStr,
port: u16,
filter: *mut crate::AddressFilter,
max_sessions: u16,
endpoints: *mut crate::DeviceMap,
tls_config: ffi::TlsServerConfig,
auth_handler: Option<ffi::AuthorizationHandler>,
decode_level: ffi::DecodeLevel,
) -> Result<*mut crate::Server, ffi::ParamError> {
let runtime = runtime.as_ref().ok_or(ffi::ParamError::NullParameter)?;
let filter = filter.as_ref().ok_or(ffi::ParamError::NullParameter)?;
let address = get_socket_addr(ip_addr, port)?;
let endpoints = endpoints.as_mut().ok_or(ffi::ParamError::NullParameter)?;

Expand Down Expand Up @@ -423,6 +432,7 @@ pub(crate) unsafe fn server_create_tls_impl(
handler_map.clone(),
AuthorizationHandlerWrapper::new(auth).wrap(),
tls_config,
filter.into(),
decode_level.into(),
);

Expand All @@ -437,6 +447,7 @@ pub(crate) unsafe fn server_create_tls_impl(
address,
handler_map.clone(),
tls_config,
rodbus::server::AddressFilter::Any,
decode_level.into(),
);

Expand Down Expand Up @@ -491,3 +502,81 @@ pub(crate) unsafe fn server_set_decode_level(
.block_on(server.inner.set_decode_level(level.into()))??;
Ok(())
}

pub enum AddressFilter {
Any,
WildcardIpv4(WildcardIPv4),
AnyOf(std::collections::HashSet<std::net::IpAddr>),
}

pub fn address_filter_any() -> *mut AddressFilter {
Box::into_raw(Box::new(AddressFilter::Any))
}

fn parse_address_filter(s: &str) -> Result<AddressFilter, ffi::ParamError> {
// first try to parse it as a normal IP
match s.parse::<IpAddr>() {
Ok(x) => {
let mut set = std::collections::HashSet::new();
set.insert(x);
Ok(AddressFilter::AnyOf(set))
}
Err(_) => {
// now try to parse as a wildcard
let wc: WildcardIPv4 = s.parse()?;
Ok(AddressFilter::WildcardIpv4(wc))
}
}
}

impl From<BadIpv4Wildcard> for ffi::ParamError {
fn from(_: BadIpv4Wildcard) -> Self {
ffi::ParamError::InvalidIpAddress
}
}

pub fn address_filter_create(address: &CStr) -> Result<*mut AddressFilter, ffi::ParamError> {
let address = parse_address_filter(address.to_string_lossy().as_ref())?;
Ok(Box::into_raw(Box::new(address)))
}

pub unsafe fn address_filter_add(
address_filter: *mut AddressFilter,
address: &CStr,
) -> Result<(), ffi::ParamError> {
let address_filter = address_filter
.as_mut()
.ok_or(ffi::ParamError::NullParameter)?;
let address = address.to_string_lossy().parse()?;

match address_filter {
AddressFilter::Any => {
// can't add addresses to an "any" specification
return Err(ffi::ParamError::InvalidIpAddress);
}
AddressFilter::AnyOf(set) => {
set.insert(address);
}
AddressFilter::WildcardIpv4(_) => {
// can't add addresses to a wildcard specification
return Err(ffi::ParamError::InvalidIpAddress);
}
}
Ok(())
}

pub unsafe fn address_filter_destroy(address_filter: *mut AddressFilter) {
if !address_filter.is_null() {
Box::from_raw(address_filter);
}
}

impl From<&AddressFilter> for rodbus::server::AddressFilter {
fn from(from: &AddressFilter) -> Self {
match from {
AddressFilter::Any => rodbus::server::AddressFilter::Any,
AddressFilter::AnyOf(set) => rodbus::server::AddressFilter::AnyOf(set.clone()),
AddressFilter::WildcardIpv4(wc) => rodbus::server::AddressFilter::WildcardIpv4(*wc),
}
}
}
48 changes: 48 additions & 0 deletions ffi/rodbus-schema/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) fn build_server(
)?;
let tls_server_config = build_tls_server_config(lib, common)?;
let authorization_handler = build_authorization_handler(lib, common)?;
let address_filter = define_address_filter(lib, common)?;

let server = lib.declare_class("server")?;

Expand All @@ -51,6 +52,7 @@ pub(crate) fn build_server(
)?
.param("address", StringType, address_doc)?
.param("port", Primitive::U16, port_doc)?
.param("filter", address_filter.declaration(), "Filter used to limit which IP address(es) can connect")?
.param("max_sessions", Primitive::U16, "Maximum number of concurrent sessions")?
.param(
"endpoints",
Expand Down Expand Up @@ -106,6 +108,7 @@ pub(crate) fn build_server(
)?
.param("address", StringType, address_doc)?
.param("port", Primitive::U16, port_doc)?
.param("filter", address_filter.declaration(), "Filter used to limit which IP address(es) can connect")?
.param("max_sessions", Primitive::U16, "Maximum number of concurrent sessions")?
.param(
"endpoints",
Expand Down Expand Up @@ -140,6 +143,7 @@ pub(crate) fn build_server(
)?
.param("address", StringType, address_doc)?
.param("port", Primitive::U16, port_doc)?
.param("filter", address_filter.declaration(), "Filter used to limit which IP address(es) can connect")?
.param("max_sessions", Primitive::U16, "Maximum number of concurrent sessions")?
.param(
"endpoints",
Expand Down Expand Up @@ -710,3 +714,47 @@ fn build_write_result_struct(

Ok(write_result)
}

fn define_address_filter(
lib: &mut LibraryBuilder,
common: &CommonDefinitions,
) -> BackTraced<ClassHandle> {
let address_filter = lib.declare_class("address_filter")?;

let address_filter_any_fn = lib
.define_function("address_filter_any")?
.returns(address_filter.clone(), "Address filter")?
.doc("Create an address filter that accepts any IP address")?
.build_static("any")?;

let constructor = lib
.define_constructor(address_filter.clone())?
.param("address", StringType, "IP address to accept")?
.fails_with(common.error_type.clone())?
.doc(
doc("Create an address filter that matches a specific address or wildcards")
.details("Examples: 192.168.1.26, 192.168.0.*, *.*.*.*")
.details("Wildcards are only supported for IPv4 addresses"),
)?
.build()?;

let add = lib
.define_method("add", address_filter.clone())?
.param("address", StringType, "IP address to add")?
.fails_with(common.error_type.clone())?
.doc("Add an accepted IP address to the filter")?
.build()?;

let destructor = lib.define_destructor(address_filter.clone(), "Destroy an address filter")?;

let address_filter = lib
.define_class(&address_filter)?
.constructor(constructor)?
.destructor(destructor)?
.static_method(address_filter_any_fn)?
.method(add)?
.doc("Server address filter")?
.build()?;

Ok(address_filter)
}
1 change: 1 addition & 0 deletions rodbus/examples/perf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
num_sessions,
addr,
ServerHandlerMap::single(UnitId::new(1), handler),
AddressFilter::Any,
DecodeLevel::new(
AppDecodeLevel::Nothing,
FrameDecodeLevel::Nothing,
Expand Down
2 changes: 2 additions & 0 deletions rodbus/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ async fn run_tcp() -> Result<(), Box<dyn std::error::Error>> {
1,
"127.0.0.1:502".parse()?,
map,
AddressFilter::Any,
DecodeLevel::default(),
)
.await?;
Expand Down Expand Up @@ -209,6 +210,7 @@ async fn run_tls(tls_config: TlsServerConfig) -> Result<(), Box<dyn std::error::
map,
ReadOnlyAuthorizationHandler::create(),
tls_config,
AddressFilter::Any,
DecodeLevel::default(),
)
.await?;
Expand Down
1 change: 1 addition & 0 deletions rodbus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
//! 1,
//! SocketAddr::from_str("127.0.0.1:502")?,
//! map,
//! AddressFilter::Any,
//! DecodeLevel::default(),
//! ).await?;
//!
Expand Down