diff --git a/docs/modules/ROOT/partials/rust/errors/ConnectionError.adoc b/docs/modules/ROOT/partials/rust/errors/ConnectionError.adoc index e7a0621b55..829dd3a536 100644 --- a/docs/modules/ROOT/partials/rust/errors/ConnectionError.adoc +++ b/docs/modules/ROOT/partials/rust/errors/ConnectionError.adoc @@ -8,6 +8,7 @@ [options="header"] |=== |Variant +a| `AbsentTlsConfigForTlsConnection` a| `AddressTranslationMismatch` a| `BrokenPipe` a| `ClusterAllNodesFailed` @@ -23,6 +24,7 @@ a| `InvalidResponseField` a| `ListsNotImplemented` a| `MissingPort` a| `MissingResponseField` +a| `NonTlsConnectionWithHttps` a| `QueryStreamNoResponse` a| `RPCMethodUnavailable` a| `SSLCertificateNotValidated` @@ -30,6 +32,7 @@ a| `ServerConnectionFailed` a| `ServerConnectionFailedStatusError` a| `ServerConnectionFailedWithError` a| `ServerConnectionIsClosed` +a| `TlsConnectionWithoutHttps` a| `TokenCredentialInvalid` a| `TransactionIsClosed` a| `TransactionIsClosedWithErrors` diff --git a/python/README.md b/python/README.md index 130c42aaa1..dfa3e86edf 100644 --- a/python/README.md +++ b/python/README.md @@ -40,7 +40,7 @@ class TypeDBExample: def typedb_example(self): # Open a driver connection. Specify your parameters if needed # The connection will be automatically closed on the "with" block exit - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: # Create a database driver.databases.create("typedb") database = driver.databases.get("typedb") diff --git a/python/example.py b/python/example.py index b11cb462b7..3f806395a7 100644 --- a/python/example.py +++ b/python/example.py @@ -8,7 +8,7 @@ class TypeDBExample: def typedb_example(self): # Open a driver connection. Specify your parameters if needed # The connection will be automatically closed on the "with" block exit - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: # Create a database driver.databases.create("typedb") database = driver.databases.get("typedb") diff --git a/python/tests/behaviour/background/cluster/environment.py b/python/tests/behaviour/background/cluster/environment.py index 258caecbae..62df61d775 100644 --- a/python/tests/behaviour/background/cluster/environment.py +++ b/python/tests/behaviour/background/cluster/environment.py @@ -53,7 +53,7 @@ def create_driver(context, host="localhost", port=None, username=None, password= if password is None: password = "password" credentials = Credentials(username, password) - return TypeDB.driver(address=f"{host}:{port}", credentials=credentials, driver_options=DriverOptions()) + return TypeDB.driver(address=f"{host}:{port}", credentials=credentials, driver_options=DriverOptions(is_tls_enabled=False)) def after_scenario(context: Context, scenario): diff --git a/python/tests/behaviour/background/community/environment.py b/python/tests/behaviour/background/community/environment.py index d1c3ff838c..aa55a1c1ec 100644 --- a/python/tests/behaviour/background/community/environment.py +++ b/python/tests/behaviour/background/community/environment.py @@ -50,7 +50,7 @@ def create_driver(context, host="localhost", port=None, username=None, password= if password is None: password = "password" credentials = Credentials(username, password) - return TypeDB.driver(address=f"{host}:{port}", credentials=credentials, driver_options=DriverOptions()) + return TypeDB.driver(address=f"{host}:{port}", credentials=credentials, driver_options=DriverOptions(is_tls_enabled=False)) def after_scenario(context: Context, scenario): diff --git a/python/tests/deployment/test.py b/python/tests/deployment/test.py index d8d5327007..9f90999d6c 100644 --- a/python/tests/deployment/test.py +++ b/python/tests/deployment/test.py @@ -35,7 +35,7 @@ class TestDeployedPythonDriver(TestCase): def setUpClass(cls): super(TestDeployedPythonDriver, cls).setUpClass() global driver - driver = TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) + driver = TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) @classmethod def tearDownClass(cls): diff --git a/python/tests/integration/test_debug.py b/python/tests/integration/test_debug.py index d52c7ce8f3..c6cfa73b67 100644 --- a/python/tests/integration/test_debug.py +++ b/python/tests/integration/test_debug.py @@ -29,7 +29,7 @@ class TestDebug(TestCase): def setUp(self): - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: if TYPEDB not in [db.name for db in driver.databases.all()]: driver.databases.create(TYPEDB) diff --git a/python/tests/integration/test_example.py b/python/tests/integration/test_example.py index 681e63e003..23b62baf91 100644 --- a/python/tests/integration/test_example.py +++ b/python/tests/integration/test_example.py @@ -27,7 +27,7 @@ class TestExample(TestCase): # EXAMPLE END MARKER def setUp(self): - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: if driver.databases.contains("typedb"): driver.databases.get("typedb").delete() @@ -36,7 +36,7 @@ def setUp(self): def test_example(self): # Open a driver connection. Specify your parameters if needed # The connection will be automatically closed on the "with" block exit - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: # Create a database driver.databases.create("typedb") database = driver.databases.get("typedb") diff --git a/python/tests/integration/test_values.py b/python/tests/integration/test_values.py index 0a48f29fc3..f296c4b0cd 100644 --- a/python/tests/integration/test_values.py +++ b/python/tests/integration/test_values.py @@ -34,7 +34,7 @@ class TestValues(TestCase): def setUp(self): - with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver: + with TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver: if driver.databases.contains(TYPEDB): driver.databases.get(TYPEDB).delete() driver.databases.create(TYPEDB) @@ -67,7 +67,7 @@ def test_values(self): "expiration": "P1Y10M7DT15H44M5.00394892S" } - with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver): + with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver): database = driver.databases.get(TYPEDB) with driver.transaction(database.name, SCHEMA) as tx: @@ -184,7 +184,7 @@ def test_datetime(self): Datetime.fromstring("2024-09-21", tz_name="Asia/Calcutta", datetime_fmt="%Y-%m-%d") Datetime.fromstring("21/09/24 18:34", tz_name="Africa/Cairo", datetime_fmt="%d/%m/%y %H:%M") - with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver): + with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver): database = driver.databases.get(TYPEDB) with driver.transaction(database.name, SCHEMA) as tx: @@ -374,7 +374,7 @@ def test_duration(self): Duration.fromstring("P1Y10M7DT15H44M5.00394892S") Duration.fromstring("P55W") - with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions()) as driver): + with (TypeDB.driver(TypeDB.DEFAULT_ADDRESS, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False)) as driver): database = driver.databases.get(TYPEDB) with driver.transaction(database.name, SCHEMA) as tx: diff --git a/rust/src/common/address.rs b/rust/src/common/address.rs index a258f29f35..ea30a0956a 100644 --- a/rust/src/common/address.rs +++ b/rust/src/common/address.rs @@ -26,15 +26,25 @@ use crate::{ error::ConnectionError, }; -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, Default)] pub struct Address { uri: Uri, } impl Address { + const DEFAULT_SCHEME: &'static str = "http"; + pub(crate) fn into_uri(self) -> Uri { self.uri } + + pub(crate) fn uri_scheme(&self) -> Option<&http::uri::Scheme> { + self.uri.scheme() + } + + pub(crate) fn is_https(&self) -> bool { + self.uri_scheme().map_or(false, |scheme| scheme == &http::uri::Scheme::HTTPS) + } } impl FromStr for Address { @@ -44,7 +54,7 @@ impl FromStr for Address { let uri = if address.contains("://") { address.parse::()? } else { - format!("http://{address}").parse::()? + format!("{}://{}", Self::DEFAULT_SCHEME, address).parse::()? }; if uri.port().is_none() { return Err(Error::Connection(ConnectionError::MissingPort { address: address.to_owned() })); @@ -61,6 +71,6 @@ impl fmt::Display for Address { impl fmt::Debug for Address { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(self, f) + write!(f, "{:?}", self.uri) } } diff --git a/rust/src/common/error.rs b/rust/src/common/error.rs index a6493869ee..60242f2af3 100644 --- a/rust/src/common/error.rs +++ b/rust/src/common/error.rs @@ -184,6 +184,12 @@ error_messages! { ConnectionError 32: "The database export channel is closed and no further operation is allowed.", DatabaseExportStreamNoResponse = 33: "Didn't receive any server responses for the database export command.", + AbsentTlsConfigForTlsConnection = + 34: "Could not establish a TLS connection without a TLS config specified. Please verify your driver options.", + TlsConnectionWithoutHttps = + 35: "TLS connections can only be enabled when connecting to HTTPS endpoints, for example using 'https://:port'. Please modify the address, or disable TLS (WARNING: this will send passwords over plaintext).", + NonTlsConnectionWithHttps = + 36: "Connecting to an https endpoint requires enabling TLS in driver options.", } error_messages! { ConceptError diff --git a/rust/src/connection/server_connection.rs b/rust/src/connection/server_connection.rs index 721a13083a..8e805dd34f 100644 --- a/rust/src/connection/server_connection.rs +++ b/rust/src/connection/server_connection.rs @@ -66,6 +66,8 @@ impl ServerConnection { driver_lang: &str, driver_version: &str, ) -> crate::Result<(Self, Vec)> { + Self::validate_tls(&address, &driver_options)?; + let username = credentials.username().to_string(); let request_transmitter = Arc::new(RPCTransmitter::start(address, credentials.clone(), driver_options, &background_runtime)?); @@ -324,6 +326,25 @@ impl ServerConnection { other => Err(InternalError::UnexpectedResponseType { response_type: format!("{other:?}") }.into()), } } + + fn validate_tls(address: &Address, driver_options: &DriverOptions) -> crate::Result { + match driver_options.is_tls_enabled() { + true => { + if driver_options.tls_config().is_none() { + return Err(ConnectionError::AbsentTlsConfigForTlsConnection {}.into()); + } + if !address.is_https() { + return Err(ConnectionError::TlsConnectionWithoutHttps {}.into()); + } + } + false => { + if address.is_https() { + return Err(ConnectionError::NonTlsConnectionWithHttps {}.into()); + } + } + } + Ok(()) + } } impl fmt::Debug for ServerConnection {