diff --git a/cassandra/cluster.py b/cassandra/cluster.py index fcf4a0e440..66bf7c7049 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -685,7 +685,7 @@ class Cluster(object): Used for testing new protocol features incrementally before the new version is complete. """ - compression: Union[bool, str] = True + compression: Union[bool, str, None] = True """ Controls compression for communications between the driver and Cassandra. If left as the default of :const:`True`, either lz4 or snappy compression @@ -695,7 +695,7 @@ class Cluster(object): You may also set this to 'snappy' or 'lz4' to request that specific compression type. - Setting this to :const:`False` disables compression. + Setting this to :const:`False` or :const:`None` disables compression. """ _application_info: Optional[ApplicationInfoBase] = None @@ -1172,7 +1172,7 @@ def token_metadata_enabled(self, enabled): def __init__(self, contact_points=_NOT_SET, port=9042, - compression: Union[bool, str] = True, + compression: Union[bool, str, None] = True, auth_provider=None, load_balancing_policy=None, reconnection_policy=None, @@ -1285,7 +1285,8 @@ def __init__(self, self._resolve_hostnames() - if isinstance(compression, bool): + if isinstance(compression, bool) or compression is None: + compression = bool(compression) if compression and not locally_supported_compressions: log.error( "Compression is enabled, but no compression libraries are available. " diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 383a4de7c8..f3efed9f54 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -122,6 +122,51 @@ def test_port_range(self): with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port=invalid_port) + def test_compression_autodisabled_without_libraries(self): + with patch.dict('cassandra.cluster.locally_supported_compressions', {}, clear=True): + with patch('cassandra.cluster.log') as patched_logger: + cluster = Cluster(compression=True) + + patched_logger.error.assert_called_once() + assert cluster.compression is False + + def test_compression_validates_requested_algorithm(self): + with patch.dict('cassandra.cluster.locally_supported_compressions', {}, clear=True): + with pytest.raises(ValueError): + Cluster(compression='lz4') + + with patch.dict('cassandra.cluster.locally_supported_compressions', {'lz4': ('c', 'd')}, clear=True): + with patch('cassandra.cluster.log') as patched_logger: + cluster = Cluster(compression='lz4') + + patched_logger.error.assert_not_called() + assert cluster.compression == 'lz4' + + def test_compression_type_validation(self): + with pytest.raises(TypeError): + Cluster(compression=123) + + def test_connection_factory_passes_compression_kwarg(self): + endpoint = Mock(address='127.0.0.1') + scenarios = [ + ({}, True, False), + ({'snappy': ('c', 'd')}, True, True), + ({'lz4': ('c', 'd')}, 'lz4', 'lz4'), + ({'lz4': ('c', 'd'), 'snappy': ('c', 'd')}, False, False), + ({'lz4': ('c', 'd'), 'snappy': ('c', 'd')}, None, False), + ] + + for supported, configured, expected in scenarios: + with patch.dict('cassandra.cluster.locally_supported_compressions', supported, clear=True): + with patch.object(Cluster.connection_class, 'factory', autospec=True, return_value='connection') as factory: + cluster = Cluster(compression=configured) + conn = cluster.connection_factory(endpoint) + + assert conn == 'connection' + assert factory.call_count == 1 + assert factory.call_args.kwargs['compression'] == expected + assert cluster.compression == expected + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket