Skip to content
Draft
37 changes: 27 additions & 10 deletions cassandra/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,24 +718,41 @@ def recv_results_rows(self, f, protocol_version, user_type_map, result_metadata,
rows = [self.recv_row(f, len(column_metadata)) for _ in range(rowcount)]
self.column_names = [c[2] for c in column_metadata]
self.column_types = [c[3] for c in column_metadata]
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]

def decode_val(val, col_md, col_desc):
uses_ce = column_encryption_policy and column_encryption_policy.contains_column(col_desc)
col_type = column_encryption_policy.column_type(col_desc) if uses_ce else col_md[3]
raw_bytes = column_encryption_policy.decrypt(col_desc, val) if uses_ce else val
return col_type.from_binary(raw_bytes, protocol_version)

def decode_row(row):
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
# Optimize by checking column_encryption_policy once per result message.
# This avoids checking if the policy exists for every single value decoded.
if column_encryption_policy:
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]

def decode_val(val, col_md, col_desc):
uses_ce = column_encryption_policy.contains_column(col_desc)
if uses_ce:
col_type = column_encryption_policy.column_type(col_desc)
raw_bytes = column_encryption_policy.decrypt(col_desc, val)
return col_type.from_binary(raw_bytes, protocol_version)
else:
return col_md[3].from_binary(val, protocol_version)

def decode_row(row):
return tuple(decode_val(val, col_md, col_desc) for val, col_md, col_desc in zip(row, column_metadata, col_descs))
else:
# Simple path without encryption - just decode raw bytes directly
def decode_row(row):
return tuple(col_md[3].from_binary(val, protocol_version) for val, col_md in zip(row, column_metadata))

try:
self.parsed_rows = [decode_row(row) for row in rows]
except Exception:
# Create col_descs only if needed for error reporting
if not column_encryption_policy:
col_descs = [ColDesc(md[0], md[1], md[2]) for md in column_metadata]
for row in rows:
for val, col_md, col_desc in zip(row, column_metadata, col_descs):
try:
decode_val(val, col_md, col_desc)
if column_encryption_policy:
decode_val(val, col_md, col_desc)
else:
col_md[3].from_binary(val, protocol_version)
except Exception as e:
raise DriverException('Failed decoding result column "%s" of type %s: %s' % (col_md[2],
col_md[3].cql_parameterized_type(),
Expand Down
155 changes: 155 additions & 0 deletions tests/unit/test_protocol_decode_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import unittest
from unittest.mock import Mock

from cassandra import ProtocolVersion
from cassandra.cqltypes import Int32Type, UTF8Type
from cassandra.marshal import int32_pack
from cassandra.policies import ColDesc
from cassandra.protocol import ResultMessage, RESULT_KIND_ROWS


class DecodeOptimizationTest(unittest.TestCase):
"""
Tests to verify the optimization of column_encryption_policy checks
in recv_results_rows. The optimization checks if the policy exists once
per result message, avoiding the redundant 'column_encryption_policy and ...'
check for every value.
"""

def _create_mock_result_metadata(self):
"""Create mock result metadata for testing"""
return [
('keyspace1', 'table1', 'col1', Int32Type),
('keyspace1', 'table1', 'col2', UTF8Type),
]

def _create_mock_result_message(self):
"""Create a mock result message with data"""
msg = ResultMessage(kind=RESULT_KIND_ROWS)
msg.column_metadata = self._create_mock_result_metadata()
msg.recv_results_metadata = Mock()
msg.recv_row = Mock(side_effect=[
[int32_pack(42), b'hello'],
[int32_pack(100), b'world'],
])
return msg

def _create_mock_stream(self):
"""Create a mock stream for reading rows"""
# Pack rowcount (2 rows)
data = int32_pack(2)
return io.BytesIO(data)

def test_decode_without_encryption_policy(self):
"""
Test that decoding works correctly without column encryption policy.
This should use the optimized simple path.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, None)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')
self.assertEqual(msg.parsed_rows[1][0], 100)
self.assertEqual(msg.parsed_rows[1][1], 'world')

def test_decode_with_encryption_policy_no_encrypted_columns(self):
"""
Test that decoding works with encryption policy when no columns are encrypted.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

# Create mock encryption policy that has no encrypted columns
mock_policy = Mock()
mock_policy.contains_column = Mock(return_value=False)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')

# Verify contains_column was called for each value (but policy existence check happens once)
# Should be called 4 times (2 rows × 2 columns)
self.assertEqual(mock_policy.contains_column.call_count, 4)

def test_decode_with_encryption_policy_with_encrypted_column(self):
"""
Test that decoding works with encryption policy when one column is encrypted.
"""
msg = self._create_mock_result_message()
f = self._create_mock_stream()

# Create mock encryption policy where first column is encrypted
mock_policy = Mock()
def contains_column_side_effect(col_desc):
return col_desc.col == 'col1'
mock_policy.contains_column = Mock(side_effect=contains_column_side_effect)
mock_policy.column_type = Mock(return_value=Int32Type)
mock_policy.decrypt = Mock(side_effect=lambda col_desc, val: val)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# Verify results
self.assertEqual(len(msg.parsed_rows), 2)
self.assertEqual(msg.parsed_rows[0][0], 42)
self.assertEqual(msg.parsed_rows[0][1], 'hello')

# Verify contains_column was called for each value (but policy existence check happens once)
# Should be called 4 times (2 rows × 2 columns)
self.assertEqual(mock_policy.contains_column.call_count, 4)

# Verify decrypt was called for each encrypted value (2 rows * 1 encrypted column)
self.assertEqual(mock_policy.decrypt.call_count, 2)

def test_optimization_efficiency(self):
"""
Verify that the optimization checks policy existence once per result message.
The key optimization is checking 'if column_encryption_policy:' once,
rather than 'column_encryption_policy and ...' for every value.
"""
msg = self._create_mock_result_message()

# Create more rows to make the check pattern clear
msg.recv_row = Mock(side_effect=[
[int32_pack(i), f'text{i}'.encode()] for i in range(100)
])

# Create mock stream with 100 rows
f = io.BytesIO(int32_pack(100))

mock_policy = Mock()
mock_policy.contains_column = Mock(return_value=False)

msg.recv_results_rows(f, ProtocolVersion.V4, {}, None, mock_policy)

# With optimization: policy existence checked once, contains_column called per value
# = 100 rows * 2 columns = 200 calls to contains_column
# The key is we avoid checking 'column_encryption_policy and ...' 200 times
self.assertEqual(mock_policy.contains_column.call_count, 200,
"contains_column should be called for each value when policy exists")


if __name__ == '__main__':
unittest.main()
Loading