diff --git a/cassandra/protocol.py b/cassandra/protocol.py index d8716f4eeb..b5cdfb764b 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -718,24 +718,41 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata, rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)] self.column_names = [c[2] for c in column_metadata] self.column_types = [c[3] for c in column_metadata] - col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] - def decode_val(val, col_md, col_desc): - uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc) - col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3] - raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val - return col_type.from_binary(raw_bytes, protocol_version) - - def decode_row(row): - return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + # Optimize by checking column_encryption_policy once per result message. + # This avoids checking if the policy exists for every single value decoded. + if column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] + + def decode_val(val, col_md, col_desc): + uses_ce = column_encryption_policy.contains_column(col_desc) + if uses_ce: + col_type = column_encryption_policy.column_type(col_desc) + raw_bytes = column_encryption_policy.decrypt(col_desc, val) + return col_type.from_binary(raw_bytes, protocol_version) + else: + return col_md[3].from_binary(val, protocol_version) + + def decode_row(row): + return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs)) + else: + # Simple path without encryption - just decode raw bytes directly + def decode_row(row): + return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata)) try: self.parsed_rows = [decode_row(row) for row in rows] except Exception: + # Create col_descs only if needed for error reporting + if not column_encryption_policy: + col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata] for row in rows: for val, col_md, col_desc in zip(row, column_metadata, col_descs): try: - decode_val(val, col_md, col_desc) + if column_encryption_policy: + decode_val(val, col_md, col_desc) + else: + col_md[3].from_binary(val, protocol_version) except Exception as e: raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2], col_md[3].cql_parameterized_type(), diff --git a/tests/unit/test_protocol_decode_optimization.py b/tests/unit/test_protocol_decode_optimization.py new file mode 100644 index 0000000000..e0fd81fe3e --- /dev/null +++ b/tests/unit/test_protocol_decode_optimization.py @@ -0,0 +1,155 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest +from unittest.mock import Mock + +from cassandra import ProtocolVersion +from cassandra.cqltypes import Int32Type, UTF8Type +from cassandra.marshal import int32_pack +from cassandra.policies import ColDesc +from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS + + +class DecodeOptimizationTest(unittest.TestCase): + """ + Tests to verify the optimization of column_encryption_policy checks + in recv_results_rows. The optimization checks if the policy exists once + per result message, avoiding the redundant 'column_encryption_policy and ...' + check for every value. + """ + + def _create_mock_result_metadata(self): + """Create mock result metadata for testing""" + return [ + ('keyspace1', 'table1', 'col1', Int32Type), + ('keyspace1', 'table1', 'col2', UTF8Type), + ] + + def _create_mock_result_message(self): + """Create a mock result message with data""" + msg = ResultMessage(kind=RESULT_KIND_ROWS) + msg.column_metadata = self._create_mock_result_metadata() + msg.recv_results_metadata = Mock() + msg.recv_row = Mock(side_effect=[ + [int32_pack(42), b'hello'], + [int32_pack(100), b'world'], + ]) + return msg + + def _create_mock_stream(self): + """Create a mock stream for reading rows""" + # Pack rowcount (2 rows) + data = int32_pack(2) + return io.BytesIO(data) + + def test_decode_without_encryption_policy(self): + """ + Test that decoding works correctly without column encryption policy. + This should use the optimized simple path. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + self.assertEqual(msg.parsed_rows[1][0], 100) + self.assertEqual(msg.parsed_rows[1][1], 'world') + + def test_decode_with_encryption_policy_no_encrypted_columns(self): + """ + Test that decoding works with encryption policy when no columns are encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy that has no encrypted columns + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + def test_decode_with_encryption_policy_with_encrypted_column(self): + """ + Test that decoding works with encryption policy when one column is encrypted. + """ + msg = self._create_mock_result_message() + f = self._create_mock_stream() + + # Create mock encryption policy where first column is encrypted + mock_policy = Mock() + def contains_column_side_effect(col_desc): + return col_desc.col == 'col1' + mock_policy.contains_column = Mock(side_effect=contains_column_side_effect) + mock_policy.column_type = Mock(return_value=Int32Type) + mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # Verify results + self.assertEqual(len(msg.parsed_rows), 2) + self.assertEqual(msg.parsed_rows[0][0], 42) + self.assertEqual(msg.parsed_rows[0][1], 'hello') + + # Verify contains_column was called for each value (but policy existence check happens once) + # Should be called 4 times (2 rows × 2 columns) + self.assertEqual(mock_policy.contains_column.call_count, 4) + + # Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column) + self.assertEqual(mock_policy.decrypt.call_count, 2) + + def test_optimization_efficiency(self): + """ + Verify that the optimization checks policy existence once per result message. + The key optimization is checking 'if column_encryption_policy:' once, + rather than 'column_encryption_policy and ...' for every value. + """ + msg = self._create_mock_result_message() + + # Create more rows to make the check pattern clear + msg.recv_row = Mock(side_effect=[ + [int32_pack(i), f'text{i}'.encode()] for i in range(100) + ]) + + # Create mock stream with 100 rows + f = io.BytesIO(int32_pack(100)) + + mock_policy = Mock() + mock_policy.contains_column = Mock(return_value=False) + + msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy) + + # With optimization: policy existence checked once, contains_column called per value + # = 100 rows * 2 columns = 200 calls to contains_column + # The key is we avoid checking 'column_encryption_policy and ...' 200 times + self.assertEqual(mock_policy.contains_column.call_count, 200, + "contains_column should be called for each value when policy exists") + + +if __name__ == '__main__': + unittest.main()