diff --git a/tests/unit/v1/test__helpers.py b/tests/unit/v1/test__helpers.py index 710e9e8bc24e4..200f66d81e30d 100644 --- a/tests/unit/v1/test__helpers.py +++ b/tests/unit/v1/test__helpers.py @@ -13,2492 +13,2540 @@ # See the License for the specific language governing permissions and # limitations under the License. -import aiounittest import datetime -import unittest import mock import pytest -from typing import List -class AsyncMock(mock.MagicMock): - async def __call__(self, *args, **kwargs): - return super(AsyncMock, self).__call__(*args, **kwargs) +def _make_geo_point(lat, lng): + from google.cloud.firestore_v1._helpers import GeoPoint + return GeoPoint(lat, lng) -class AsyncIter: - """Utility to help recreate the effect of an async generator. Useful when - you need to mock a system that requires `async for`. - """ - def __init__(self, items): - self.items = items +def test_geopoint_constructor(): + lat = 81.25 + lng = 359.984375 + geo_pt = _make_geo_point(lat, lng) + assert geo_pt.latitude == lat + assert geo_pt.longitude == lng - async def __aiter__(self): - for i in self.items: - yield i +def test_geopoint_to_protobuf(): + from google.type import latlng_pb2 -class TestGeoPoint(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1._helpers import GeoPoint - - return GeoPoint - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - lat = 81.25 - lng = 359.984375 - geo_pt = self._make_one(lat, lng) - self.assertEqual(geo_pt.latitude, lat) - self.assertEqual(geo_pt.longitude, lng) - - def test_to_protobuf(self): - from google.type import latlng_pb2 - - lat = 0.015625 - lng = 20.03125 - geo_pt = self._make_one(lat, lng) - result = geo_pt.to_protobuf() - geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) - self.assertEqual(result, geo_pt_pb) - - def test___eq__(self): - lat = 0.015625 - lng = 20.03125 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = self._make_one(lat, lng) - self.assertEqual(geo_pt1, geo_pt2) - - def test___eq__type_differ(self): - lat = 0.015625 - lng = 20.03125 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = object() - self.assertNotEqual(geo_pt1, geo_pt2) - self.assertIs(geo_pt1.__eq__(geo_pt2), NotImplemented) - - def test___ne__same_value(self): - lat = 0.015625 - lng = 20.03125 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = self._make_one(lat, lng) - comparison_val = geo_pt1 != geo_pt2 - self.assertFalse(comparison_val) - - def test___ne__(self): - geo_pt1 = self._make_one(0.0, 1.0) - geo_pt2 = self._make_one(2.0, 3.0) - self.assertNotEqual(geo_pt1, geo_pt2) - - def test___ne__type_differ(self): - lat = 0.015625 - lng = 20.03125 - geo_pt1 = self._make_one(lat, lng) - geo_pt2 = object() - self.assertNotEqual(geo_pt1, geo_pt2) - self.assertIs(geo_pt1.__ne__(geo_pt2), NotImplemented) - - -class Test_verify_path(unittest.TestCase): - @staticmethod - def _call_fut(path, is_collection): - from google.cloud.firestore_v1._helpers import verify_path - - return verify_path(path, is_collection) - - def test_empty(self): - path = () - with self.assertRaises(ValueError): - self._call_fut(path, True) - with self.assertRaises(ValueError): - self._call_fut(path, False) - - def test_wrong_length_collection(self): - path = ("foo", "bar") - with self.assertRaises(ValueError): - self._call_fut(path, True) - - def test_wrong_length_document(self): - path = ("Kind",) - with self.assertRaises(ValueError): - self._call_fut(path, False) - - def test_wrong_type_collection(self): - path = (99, "ninety-nine", "zap") - with self.assertRaises(ValueError): - self._call_fut(path, True) - - def test_wrong_type_document(self): - path = ("Users", "Ada", "Candy", {}) - with self.assertRaises(ValueError): - self._call_fut(path, False) - - def test_success_collection(self): - path = ("Computer", "Magic", "Win") - ret_val = self._call_fut(path, True) - # NOTE: We are just checking that it didn't fail. - self.assertIsNone(ret_val) - - def test_success_document(self): - path = ("Tokenizer", "Seventeen", "Cheese", "Burger") - ret_val = self._call_fut(path, False) - # NOTE: We are just checking that it didn't fail. - self.assertIsNone(ret_val) - - -class Test_encode_value(unittest.TestCase): - @staticmethod - def _call_fut(value): - from google.cloud.firestore_v1._helpers import encode_value - - return encode_value(value) - - def test_none(self): - from google.protobuf import struct_pb2 - - result = self._call_fut(None) - expected = _value_pb(null_value=struct_pb2.NULL_VALUE) - self.assertEqual(result, expected) - - def test_boolean(self): - result = self._call_fut(True) - expected = _value_pb(boolean_value=True) - self.assertEqual(result, expected) - - def test_integer(self): - value = 425178 - result = self._call_fut(value) - expected = _value_pb(integer_value=value) - self.assertEqual(result, expected) - - def test_float(self): - value = 123.4453125 - result = self._call_fut(value) - expected = _value_pb(double_value=value) - self.assertEqual(result, expected) - - def test_datetime_with_nanos(self): - from google.api_core.datetime_helpers import DatetimeWithNanoseconds - from google.protobuf import timestamp_pb2 - - dt_seconds = 1488768504 - dt_nanos = 458816991 - timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) - dt_val = DatetimeWithNanoseconds.from_timestamp_pb(timestamp_pb) - - result = self._call_fut(dt_val) - expected = _value_pb(timestamp_value=timestamp_pb) - self.assertEqual(result, expected) - - def test_datetime_wo_nanos(self): - from google.protobuf import timestamp_pb2 - - dt_seconds = 1488768504 - dt_nanos = 458816000 - # Make sure precision is valid in microseconds too. - self.assertEqual(dt_nanos % 1000, 0) - dt_val = datetime.datetime.utcfromtimestamp(dt_seconds + 1e-9 * dt_nanos) - - result = self._call_fut(dt_val) - timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) - expected = _value_pb(timestamp_value=timestamp_pb) - self.assertEqual(result, expected) - - def test_string(self): - value = u"\u2018left quote, right quote\u2019" - result = self._call_fut(value) - expected = _value_pb(string_value=value) - self.assertEqual(result, expected) - - def test_bytes(self): - value = b"\xe3\xf2\xff\x00" - result = self._call_fut(value) - expected = _value_pb(bytes_value=value) - self.assertEqual(result, expected) - - def test_reference_value(self): - client = _make_client() - - value = client.document("my", "friend") - result = self._call_fut(value) - expected = _value_pb(reference_value=value._document_path) - self.assertEqual(result, expected) - - def test_geo_point(self): - from google.cloud.firestore_v1._helpers import GeoPoint - - value = GeoPoint(50.5, 88.75) - result = self._call_fut(value) - expected = _value_pb(geo_point_value=value.to_protobuf()) - self.assertEqual(result, expected) - - def test_array(self): - from google.cloud.firestore_v1.types.document import ArrayValue - - result = self._call_fut([99, True, 118.5]) - - array_pb = ArrayValue( - values=[ - _value_pb(integer_value=99), - _value_pb(boolean_value=True), - _value_pb(double_value=118.5), - ] - ) - expected = _value_pb(array_value=array_pb) - self.assertEqual(result, expected) + lat = 0.015625 + lng = 20.03125 + geo_pt = _make_geo_point(lat, lng) + result = geo_pt.to_protobuf() + geo_pt_pb = latlng_pb2.LatLng(latitude=lat, longitude=lng) + assert result == geo_pt_pb - def test_map(self): - from google.cloud.firestore_v1.types.document import MapValue - result = self._call_fut({"abc": 285, "def": b"piglatin"}) +def test_geopoint___eq__w_same_value(): + lat = 0.015625 + lng = 20.03125 + geo_pt1 = _make_geo_point(lat, lng) + geo_pt2 = _make_geo_point(lat, lng) + assert geo_pt1 == geo_pt2 - map_pb = MapValue( - fields={ - "abc": _value_pb(integer_value=285), - "def": _value_pb(bytes_value=b"piglatin"), - } - ) - expected = _value_pb(map_value=map_pb) - self.assertEqual(result, expected) - - def test_bad_type(self): - value = object() - with self.assertRaises(TypeError): - self._call_fut(value) - - -class Test_encode_dict(unittest.TestCase): - @staticmethod - def _call_fut(values_dict): - from google.cloud.firestore_v1._helpers import encode_dict - - return encode_dict(values_dict) - - def test_many_types(self): - from google.protobuf import struct_pb2 - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types.document import ArrayValue - from google.cloud.firestore_v1.types.document import MapValue - - dt_seconds = 1497397225 - dt_nanos = 465964000 - # Make sure precision is valid in microseconds too. - self.assertEqual(dt_nanos % 1000, 0) - dt_val = datetime.datetime.utcfromtimestamp(dt_seconds + 1e-9 * dt_nanos) - - client = _make_client() - document = client.document("most", "adjective", "thing", "here") - - values_dict = { - "foo": None, - "bar": True, - "baz": 981, - "quux": 2.875, - "quuz": dt_val, - "corge": u"\N{snowman}", - "grault": b"\xe2\x98\x83", - "wibble": document, - "garply": [u"fork", 4.0], - "waldo": {"fred": u"zap", "thud": False}, - } - encoded_dict = self._call_fut(values_dict) - expected_dict = { - "foo": _value_pb(null_value=struct_pb2.NULL_VALUE), - "bar": _value_pb(boolean_value=True), - "baz": _value_pb(integer_value=981), - "quux": _value_pb(double_value=2.875), - "quuz": _value_pb( - timestamp_value=timestamp_pb2.Timestamp( - seconds=dt_seconds, nanos=dt_nanos - ) - ), - "corge": _value_pb(string_value=u"\N{snowman}"), - "grault": _value_pb(bytes_value=b"\xe2\x98\x83"), - "wibble": _value_pb(reference_value=document._document_path), - "garply": _value_pb( - array_value=ArrayValue( - values=[ - _value_pb(string_value=u"fork"), - _value_pb(double_value=4.0), - ] - ) - ), - "waldo": _value_pb( - map_value=MapValue( - fields={ - "fred": _value_pb(string_value=u"zap"), - "thud": _value_pb(boolean_value=False), - } - ) - ), - } - self.assertEqual(encoded_dict, expected_dict) +def test_geopoint___eq__w_type_differ(): + lat = 0.015625 + lng = 20.03125 + geo_pt1 = _make_geo_point(lat, lng) + geo_pt2 = object() + assert geo_pt1 != geo_pt2 + assert geo_pt1.__eq__(geo_pt2) is NotImplemented -class Test_reference_value_to_document(unittest.TestCase): - @staticmethod - def _call_fut(reference_value, client): - from google.cloud.firestore_v1._helpers import reference_value_to_document - return reference_value_to_document(reference_value, client) +def test_geopoint___ne__w_same_value(): + lat = 0.015625 + lng = 20.03125 + geo_pt1 = _make_geo_point(lat, lng) + geo_pt2 = _make_geo_point(lat, lng) + assert not geo_pt1 != geo_pt2 - def test_bad_format(self): - from google.cloud.firestore_v1._helpers import BAD_REFERENCE_ERROR - reference_value = "not/the/right/format" - with self.assertRaises(ValueError) as exc_info: - self._call_fut(reference_value, None) +def test_geopoint___ne__w_other_value(): + geo_pt1 = _make_geo_point(0.0, 1.0) + geo_pt2 = _make_geo_point(2.0, 3.0) + assert geo_pt1 != geo_pt2 - err_msg = BAD_REFERENCE_ERROR.format(reference_value) - self.assertEqual(exc_info.exception.args, (err_msg,)) - def test_same_client(self): - from google.cloud.firestore_v1.document import DocumentReference +def test_geopoint___ne__w_type_differ(): + lat = 0.015625 + lng = 20.03125 + geo_pt1 = _make_geo_point(lat, lng) + geo_pt2 = object() + assert geo_pt1 != geo_pt2 + assert geo_pt1.__ne__(geo_pt2) is NotImplemented - client = _make_client() - document = client.document("that", "this") - reference_value = document._document_path - new_document = self._call_fut(reference_value, client) - self.assertIsNot(new_document, document) +def test_verify_path_w_empty(): + from google.cloud.firestore_v1._helpers import verify_path - self.assertIsInstance(new_document, DocumentReference) - self.assertIs(new_document._client, client) - self.assertEqual(new_document._path, document._path) + path = () + with pytest.raises(ValueError): + verify_path(path, True) + with pytest.raises(ValueError): + verify_path(path, False) - def test_different_client(self): - from google.cloud.firestore_v1._helpers import WRONG_APP_REFERENCE - client1 = _make_client(project="kirk") - document = client1.document("tin", "foil") - reference_value = document._document_path +def test_verify_path_w_wrong_length_collection(): + from google.cloud.firestore_v1._helpers import verify_path - client2 = _make_client(project="spock") - with self.assertRaises(ValueError) as exc_info: - self._call_fut(reference_value, client2) + path = ("foo", "bar") + with pytest.raises(ValueError): + verify_path(path, True) - err_msg = WRONG_APP_REFERENCE.format(reference_value, client2._database_string) - self.assertEqual(exc_info.exception.args, (err_msg,)) +def test_verify_path_w_wrong_length_document(): + from google.cloud.firestore_v1._helpers import verify_path -class TestDocumentReferenceValue(unittest.TestCase): - @staticmethod - def _call(ref_value: str): - from google.cloud.firestore_v1._helpers import DocumentReferenceValue + path = ("Kind",) + with pytest.raises(ValueError): + verify_path(path, False) - return DocumentReferenceValue(ref_value) - def test_normal(self): - orig = "projects/name/databases/(default)/documents/col/doc" - parsed = self._call(orig) - self.assertEqual(parsed.collection_name, "col") - self.assertEqual(parsed.database_name, "(default)") - self.assertEqual(parsed.document_id, "doc") +def test_verify_path_w_wrong_type_collection(): + from google.cloud.firestore_v1._helpers import verify_path - self.assertEqual(parsed.full_path, orig) - parsed._reference_value = None # type: ignore - self.assertEqual(parsed.full_path, orig) + path = (99, "ninety-nine", "zap") + with pytest.raises(ValueError): + verify_path(path, True) - def test_nested(self): - parsed = self._call( - "projects/name/databases/(default)/documents/col/doc/nested" - ) - self.assertEqual(parsed.collection_name, "col") - self.assertEqual(parsed.database_name, "(default)") - self.assertEqual(parsed.document_id, "doc/nested") - def test_broken(self): - self.assertRaises( - ValueError, self._call, "projects/name/databases/(default)/documents/col", - ) +def test_verify_path_w_wrong_type_document(): + from google.cloud.firestore_v1._helpers import verify_path + path = ("Users", "Ada", "Candy", {}) + with pytest.raises(ValueError): + verify_path(path, False) -class Test_document_snapshot_to_protobuf(unittest.TestCase): - def test_real_snapshot(self): - from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf - from google.cloud.firestore_v1.types import Document - from google.cloud.firestore_v1.base_document import DocumentSnapshot - from google.cloud.firestore_v1.document import DocumentReference - from google.protobuf import timestamp_pb2 # type: ignore - - client = _make_client() - snapshot = DocumentSnapshot( - data={"hello": "world"}, - reference=DocumentReference("col", "doc", client=client), - exists=True, - read_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), - update_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), - create_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), - ) - self.assertIsInstance(document_snapshot_to_protobuf(snapshot), Document) - - def test_non_existant_snapshot(self): - from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf - from google.cloud.firestore_v1.base_document import DocumentSnapshot - from google.cloud.firestore_v1.document import DocumentReference - - client = _make_client() - snapshot = DocumentSnapshot( - data=None, - reference=DocumentReference("col", "doc", client=client), - exists=False, - read_time=None, - update_time=None, - create_time=None, - ) - self.assertIsNone(document_snapshot_to_protobuf(snapshot)) +def test_verify_path_w_success_collection(): + from google.cloud.firestore_v1._helpers import verify_path + + path = ("Computer", "Magic", "Win") + ret_val = verify_path(path, True) + # NOTE: We are just checking that it didn't fail. + assert ret_val is None + + +def test_verify_path_w_success_document(): + from google.cloud.firestore_v1._helpers import verify_path + + path = ("Tokenizer", "Seventeen", "Cheese", "Burger") + ret_val = verify_path(path, False) + # NOTE: We are just checking that it didn't fail. + assert ret_val is None + + +def test_encode_value_w_none(): + from google.protobuf import struct_pb2 + from google.cloud.firestore_v1._helpers import encode_value + + result = encode_value(None) + expected = _value_pb(null_value=struct_pb2.NULL_VALUE) + assert result == expected + + +def test_encode_value_w_boolean(): + from google.cloud.firestore_v1._helpers import encode_value + + result = encode_value(True) + expected = _value_pb(boolean_value=True) + assert result == expected + + +def test_encode_value_w_integer(): + from google.cloud.firestore_v1._helpers import encode_value + + value = 425178 + result = encode_value(value) + expected = _value_pb(integer_value=value) + assert result == expected + + +def test_encode_value_w_float(): + from google.cloud.firestore_v1._helpers import encode_value + + value = 123.4453125 + result = encode_value(value) + expected = _value_pb(double_value=value) + assert result == expected + + +def test_encode_value_w_datetime_with_nanos(): + from google.api_core.datetime_helpers import DatetimeWithNanoseconds + from google.cloud.firestore_v1._helpers import encode_value + from google.protobuf import timestamp_pb2 + + dt_seconds = 1488768504 + dt_nanos = 458816991 + timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) + dt_val = DatetimeWithNanoseconds.from_timestamp_pb(timestamp_pb) -class Test_decode_value(unittest.TestCase): - @staticmethod - def _call_fut(value, client=mock.sentinel.client): - from google.cloud.firestore_v1._helpers import decode_value + result = encode_value(dt_val) + expected = _value_pb(timestamp_value=timestamp_pb) + assert result == expected - return decode_value(value, client) - def test_none(self): - from google.protobuf import struct_pb2 +def test_encode_value_w_datetime_wo_nanos(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import encode_value - value = _value_pb(null_value=struct_pb2.NULL_VALUE) - self.assertIsNone(self._call_fut(value)) + dt_seconds = 1488768504 + dt_nanos = 458816000 + # Make sure precision is valid in microseconds too. + assert dt_nanos % 1000 == 0 + dt_val = datetime.datetime.utcfromtimestamp(dt_seconds + 1e-9 * dt_nanos) - def test_bool(self): - value1 = _value_pb(boolean_value=True) - self.assertTrue(self._call_fut(value1)) - value2 = _value_pb(boolean_value=False) - self.assertFalse(self._call_fut(value2)) + result = encode_value(dt_val) + timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) + expected = _value_pb(timestamp_value=timestamp_pb) + assert result == expected - def test_int(self): - int_val = 29871 - value = _value_pb(integer_value=int_val) - self.assertEqual(self._call_fut(value), int_val) - def test_float(self): - float_val = 85.9296875 - value = _value_pb(double_value=float_val) - self.assertEqual(self._call_fut(value), float_val) +def test_encode_value_w_string(): + from google.cloud.firestore_v1._helpers import encode_value - def test_datetime(self): - from google.api_core.datetime_helpers import DatetimeWithNanoseconds - from google.protobuf import timestamp_pb2 + value = u"\u2018left quote, right quote\u2019" + result = encode_value(value) + expected = _value_pb(string_value=value) + assert result == expected - dt_seconds = 552855006 - dt_nanos = 766961828 - timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) - value = _value_pb(timestamp_value=timestamp_pb) +def test_encode_value_w_bytes(): + from google.cloud.firestore_v1._helpers import encode_value - expected_dt_val = DatetimeWithNanoseconds.from_timestamp_pb(timestamp_pb) - self.assertEqual(self._call_fut(value), expected_dt_val) + value = b"\xe3\xf2\xff\x00" + result = encode_value(value) + expected = _value_pb(bytes_value=value) + assert result == expected - def test_unicode(self): - unicode_val = u"zorgon" - value = _value_pb(string_value=unicode_val) - self.assertEqual(self._call_fut(value), unicode_val) - def test_bytes(self): - bytes_val = b"abc\x80" - value = _value_pb(bytes_value=bytes_val) - self.assertEqual(self._call_fut(value), bytes_val) +def test_encode_value_w_reference_value(): + from google.cloud.firestore_v1._helpers import encode_value - def test_reference(self): - from google.cloud.firestore_v1.document import DocumentReference + client = _make_client() - client = _make_client() - path = (u"then", u"there-was-one") - document = client.document(*path) - ref_string = document._document_path - value = _value_pb(reference_value=ref_string) + value = client.document("my", "friend") + result = encode_value(value) + expected = _value_pb(reference_value=value._document_path) + assert result == expected - result = self._call_fut(value, client) - self.assertIsInstance(result, DocumentReference) - self.assertIs(result._client, client) - self.assertEqual(result._path, path) - def test_geo_point(self): - from google.cloud.firestore_v1._helpers import GeoPoint +def test_encode_value_w_geo_point(): + from google.cloud.firestore_v1._helpers import encode_value + from google.cloud.firestore_v1._helpers import GeoPoint - geo_pt = GeoPoint(latitude=42.5, longitude=99.0625) - value = _value_pb(geo_point_value=geo_pt.to_protobuf()) - self.assertEqual(self._call_fut(value), geo_pt) + value = GeoPoint(50.5, 88.75) + result = encode_value(value) + expected = _value_pb(geo_point_value=value.to_protobuf()) + assert result == expected - def test_array(self): - from google.cloud.firestore_v1.types import document - sub_value1 = _value_pb(boolean_value=True) - sub_value2 = _value_pb(double_value=14.1396484375) - sub_value3 = _value_pb(bytes_value=b"\xde\xad\xbe\xef") - array_pb = document.ArrayValue(values=[sub_value1, sub_value2, sub_value3]) - value = _value_pb(array_value=array_pb) +def test_encode_value_w_array(): + from google.cloud.firestore_v1._helpers import encode_value + from google.cloud.firestore_v1.types.document import ArrayValue - expected = [ - sub_value1.boolean_value, - sub_value2.double_value, - sub_value3.bytes_value, + result = encode_value([99, True, 118.5]) + + array_pb = ArrayValue( + values=[ + _value_pb(integer_value=99), + _value_pb(boolean_value=True), + _value_pb(double_value=118.5), ] - self.assertEqual(self._call_fut(value), expected) + ) + expected = _value_pb(array_value=array_pb) + assert result == expected + - def test_map(self): - from google.cloud.firestore_v1.types import document +def test_encode_value_w_map(): + from google.cloud.firestore_v1._helpers import encode_value + from google.cloud.firestore_v1.types.document import MapValue - sub_value1 = _value_pb(integer_value=187680) - sub_value2 = _value_pb(string_value=u"how low can you go?") - map_pb = document.MapValue(fields={"first": sub_value1, "second": sub_value2}) - value = _value_pb(map_value=map_pb) + result = encode_value({"abc": 285, "def": b"piglatin"}) - expected = { - "first": sub_value1.integer_value, - "second": sub_value2.string_value, + map_pb = MapValue( + fields={ + "abc": _value_pb(integer_value=285), + "def": _value_pb(bytes_value=b"piglatin"), } - self.assertEqual(self._call_fut(value), expected) - - def test_nested_map(self): - from google.cloud.firestore_v1.types import document - - actual_value1 = 1009876 - actual_value2 = u"hey you guys" - actual_value3 = 90.875 - map_pb1 = document.MapValue( - fields={ - "lowest": _value_pb(integer_value=actual_value1), - "aside": _value_pb(string_value=actual_value2), - } - ) - map_pb2 = document.MapValue( - fields={ - "middle": _value_pb(map_value=map_pb1), - "aside": _value_pb(boolean_value=True), - } - ) - map_pb3 = document.MapValue( - fields={ - "highest": _value_pb(map_value=map_pb2), - "aside": _value_pb(double_value=actual_value3), - } - ) - value = _value_pb(map_value=map_pb3) - - expected = { - "highest": { - "middle": {"lowest": actual_value1, "aside": actual_value2}, - "aside": True, - }, - "aside": actual_value3, + ) + expected = _value_pb(map_value=map_pb) + assert result == expected + + +def test_encode_value_w_bad_type(): + from google.cloud.firestore_v1._helpers import encode_value + + value = object() + with pytest.raises(TypeError): + encode_value(value) + + +def test_encode_dict_w_many_types(): + from google.protobuf import struct_pb2 + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import encode_dict + from google.cloud.firestore_v1.types.document import ArrayValue + from google.cloud.firestore_v1.types.document import MapValue + + dt_seconds = 1497397225 + dt_nanos = 465964000 + # Make sure precision is valid in microseconds too. + assert dt_nanos % 1000 == 0 + dt_val = datetime.datetime.utcfromtimestamp(dt_seconds + 1e-9 * dt_nanos) + + client = _make_client() + document = client.document("most", "adjective", "thing", "here") + + values_dict = { + "foo": None, + "bar": True, + "baz": 981, + "quux": 2.875, + "quuz": dt_val, + "corge": u"\N{snowman}", + "grault": b"\xe2\x98\x83", + "wibble": document, + "garply": [u"fork", 4.0], + "waldo": {"fred": u"zap", "thud": False}, + } + encoded_dict = encode_dict(values_dict) + expected_dict = { + "foo": _value_pb(null_value=struct_pb2.NULL_VALUE), + "bar": _value_pb(boolean_value=True), + "baz": _value_pb(integer_value=981), + "quux": _value_pb(double_value=2.875), + "quuz": _value_pb( + timestamp_value=timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) + ), + "corge": _value_pb(string_value=u"\N{snowman}"), + "grault": _value_pb(bytes_value=b"\xe2\x98\x83"), + "wibble": _value_pb(reference_value=document._document_path), + "garply": _value_pb( + array_value=ArrayValue( + values=[_value_pb(string_value=u"fork"), _value_pb(double_value=4.0)] + ) + ), + "waldo": _value_pb( + map_value=MapValue( + fields={ + "fred": _value_pb(string_value=u"zap"), + "thud": _value_pb(boolean_value=False), + } + ) + ), + } + assert encoded_dict == expected_dict + + +def test_reference_value_to_document_w_bad_format(): + from google.cloud.firestore_v1._helpers import BAD_REFERENCE_ERROR + from google.cloud.firestore_v1._helpers import reference_value_to_document + + reference_value = "not/the/right/format" + with pytest.raises(ValueError) as exc_info: + reference_value_to_document(reference_value, None) + + err_msg = BAD_REFERENCE_ERROR.format(reference_value) + assert exc_info.value.args == (err_msg,) + + +def test_reference_value_to_document_w_same_client(): + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1._helpers import reference_value_to_document + + client = _make_client() + document = client.document("that", "this") + reference_value = document._document_path + + new_document = reference_value_to_document(reference_value, client) + + assert new_document is not document + assert isinstance(new_document, DocumentReference) + assert new_document._client is client + assert new_document._path == document._path + + +def test_reference_value_to_document_w_different_client(): + from google.cloud.firestore_v1._helpers import WRONG_APP_REFERENCE + from google.cloud.firestore_v1._helpers import reference_value_to_document + + client1 = _make_client(project="kirk") + document = client1.document("tin", "foil") + reference_value = document._document_path + client2 = _make_client(project="spock") + + with pytest.raises(ValueError) as exc_info: + reference_value_to_document(reference_value, client2) + + err_msg = WRONG_APP_REFERENCE.format(reference_value, client2._database_string) + assert exc_info.value.args == (err_msg,) + + +def test_documentreferencevalue_w_normal(): + from google.cloud.firestore_v1._helpers import DocumentReferenceValue + + orig = "projects/name/databases/(default)/documents/col/doc" + parsed = DocumentReferenceValue(orig) + assert parsed.collection_name == "col" + assert parsed.database_name == "(default)" + assert parsed.document_id == "doc" + + assert parsed.full_path == orig + parsed._reference_value = None # type: ignore + assert parsed.full_path == orig + + +def test_documentreferencevalue_w_nested(): + from google.cloud.firestore_v1._helpers import DocumentReferenceValue + + parsed = DocumentReferenceValue( + "projects/name/databases/(default)/documents/col/doc/nested" + ) + assert parsed.collection_name == "col" + assert parsed.database_name == "(default)" + assert parsed.document_id == "doc/nested" + + +def test_documentreferencevalue_w_broken(): + from google.cloud.firestore_v1._helpers import DocumentReferenceValue + + with pytest.raises(ValueError): + DocumentReferenceValue("projects/name/databases/(default)/documents/col") + + +def test_document_snapshot_to_protobuf_w_real_snapshot(): + from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf + from google.cloud.firestore_v1.types import Document + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.document import DocumentReference + from google.protobuf import timestamp_pb2 # type: ignore + + client = _make_client() + snapshot = DocumentSnapshot( + data={"hello": "world"}, + reference=DocumentReference("col", "doc", client=client), + exists=True, + read_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + update_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + create_time=timestamp_pb2.Timestamp(seconds=0, nanos=1), + ) + assert isinstance(document_snapshot_to_protobuf(snapshot), Document) + + +def test_document_snapshot_to_protobuf_w_non_existant_snapshot(): + from google.cloud.firestore_v1._helpers import document_snapshot_to_protobuf + from google.cloud.firestore_v1.base_document import DocumentSnapshot + from google.cloud.firestore_v1.document import DocumentReference + + client = _make_client() + snapshot = DocumentSnapshot( + data=None, + reference=DocumentReference("col", "doc", client=client), + exists=False, + read_time=None, + update_time=None, + create_time=None, + ) + assert document_snapshot_to_protobuf(snapshot) is None + + +def test_decode_value_w_none(): + from google.protobuf import struct_pb2 + from google.cloud.firestore_v1._helpers import decode_value + + value = _value_pb(null_value=struct_pb2.NULL_VALUE) + assert decode_value(value, mock.sentinel.client) is None + + +def test_decode_value_w_bool(): + from google.cloud.firestore_v1._helpers import decode_value + + value1 = _value_pb(boolean_value=True) + assert decode_value(value1, mock.sentinel.client) + value2 = _value_pb(boolean_value=False) + assert not decode_value(value2, mock.sentinel.client) + + +def test_decode_value_w_int(): + from google.cloud.firestore_v1._helpers import decode_value + + int_val = 29871 + value = _value_pb(integer_value=int_val) + assert decode_value(value, mock.sentinel.client) == int_val + + +def test_decode_value_w_float(): + from google.cloud.firestore_v1._helpers import decode_value + + float_val = 85.9296875 + value = _value_pb(double_value=float_val) + assert decode_value(value, mock.sentinel.client) == float_val + + +def test_decode_value_w_datetime(): + from google.cloud.firestore_v1._helpers import decode_value + from google.api_core.datetime_helpers import DatetimeWithNanoseconds + from google.protobuf import timestamp_pb2 + + dt_seconds = 552855006 + dt_nanos = 766961828 + + timestamp_pb = timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) + value = _value_pb(timestamp_value=timestamp_pb) + + expected_dt_val = DatetimeWithNanoseconds.from_timestamp_pb(timestamp_pb) + assert decode_value(value, mock.sentinel.client) == expected_dt_val + + +def test_decode_value_w_unicode(): + from google.cloud.firestore_v1._helpers import decode_value + + unicode_val = u"zorgon" + value = _value_pb(string_value=unicode_val) + assert decode_value(value, mock.sentinel.client) == unicode_val + + +def test_decode_value_w_bytes(): + from google.cloud.firestore_v1._helpers import decode_value + + bytes_val = b"abc\x80" + value = _value_pb(bytes_value=bytes_val) + assert decode_value(value, mock.sentinel.client) == bytes_val + + +def test_decode_value_w_reference(): + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1._helpers import decode_value + + client = _make_client() + path = (u"then", u"there-was-one") + document = client.document(*path) + ref_string = document._document_path + value = _value_pb(reference_value=ref_string) + + result = decode_value(value, client) + assert isinstance(result, DocumentReference) + assert result._client is client + assert result._path == path + + +def test_decode_value_w_geo_point(): + from google.cloud.firestore_v1._helpers import GeoPoint + from google.cloud.firestore_v1._helpers import decode_value + + geo_pt = GeoPoint(latitude=42.5, longitude=99.0625) + value = _value_pb(geo_point_value=geo_pt.to_protobuf()) + assert decode_value(value, mock.sentinel.client) == geo_pt + + +def test_decode_value_w_array(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1._helpers import decode_value + + sub_value1 = _value_pb(boolean_value=True) + sub_value2 = _value_pb(double_value=14.1396484375) + sub_value3 = _value_pb(bytes_value=b"\xde\xad\xbe\xef") + array_pb = document.ArrayValue(values=[sub_value1, sub_value2, sub_value3]) + value = _value_pb(array_value=array_pb) + + expected = [ + sub_value1.boolean_value, + sub_value2.double_value, + sub_value3.bytes_value, + ] + assert decode_value(value, mock.sentinel.client) == expected + + +def test_decode_value_w_map(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1._helpers import decode_value + + sub_value1 = _value_pb(integer_value=187680) + sub_value2 = _value_pb(string_value=u"how low can you go?") + map_pb = document.MapValue(fields={"first": sub_value1, "second": sub_value2}) + value = _value_pb(map_value=map_pb) + + expected = { + "first": sub_value1.integer_value, + "second": sub_value2.string_value, + } + assert decode_value(value, mock.sentinel.client) == expected + + +def test_decode_value_w_nested_map(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1._helpers import decode_value + + actual_value1 = 1009876 + actual_value2 = u"hey you guys" + actual_value3 = 90.875 + map_pb1 = document.MapValue( + fields={ + "lowest": _value_pb(integer_value=actual_value1), + "aside": _value_pb(string_value=actual_value2), } - self.assertEqual(self._call_fut(value), expected) - - def test_unset_value_type(self): - with self.assertRaises(ValueError): - self._call_fut(_value_pb()) - - def test_unknown_value_type(self): - value_pb = mock.Mock() - value_pb._pb.WhichOneof.return_value = "zoob_value" - - with self.assertRaises(ValueError): - self._call_fut(value_pb) - - value_pb._pb.WhichOneof.assert_called_once_with("value_type") - - -class Test_decode_dict(unittest.TestCase): - @staticmethod - def _call_fut(value_fields, client=mock.sentinel.client): - from google.cloud.firestore_v1._helpers import decode_dict - - return decode_dict(value_fields, client) - - def test_many_types(self): - from google.protobuf import struct_pb2 - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types.document import ArrayValue - from google.cloud.firestore_v1.types.document import MapValue - from google.cloud._helpers import UTC - from google.cloud.firestore_v1.field_path import FieldPath - - dt_seconds = 1394037350 - dt_nanos = 667285000 - # Make sure precision is valid in microseconds too. - self.assertEqual(dt_nanos % 1000, 0) - dt_val = datetime.datetime.utcfromtimestamp( - dt_seconds + 1e-9 * dt_nanos - ).replace(tzinfo=UTC) - - value_fields = { - "foo": _value_pb(null_value=struct_pb2.NULL_VALUE), - "bar": _value_pb(boolean_value=True), - "baz": _value_pb(integer_value=981), - "quux": _value_pb(double_value=2.875), - "quuz": _value_pb( - timestamp_value=timestamp_pb2.Timestamp( - seconds=dt_seconds, nanos=dt_nanos - ) - ), - "corge": _value_pb(string_value=u"\N{snowman}"), - "grault": _value_pb(bytes_value=b"\xe2\x98\x83"), - "garply": _value_pb( - array_value=ArrayValue( - values=[ - _value_pb(string_value=u"fork"), - _value_pb(double_value=4.0), - ] - ) - ), - "waldo": _value_pb( - map_value=MapValue( - fields={ - "fred": _value_pb(string_value=u"zap"), - "thud": _value_pb(boolean_value=False), - } - ) - ), - FieldPath("a", "b", "c").to_api_repr(): _value_pb(boolean_value=False), + ) + map_pb2 = document.MapValue( + fields={ + "middle": _value_pb(map_value=map_pb1), + "aside": _value_pb(boolean_value=True), } - expected = { - "foo": None, - "bar": True, - "baz": 981, - "quux": 2.875, - "quuz": dt_val, - "corge": u"\N{snowman}", - "grault": b"\xe2\x98\x83", - "garply": [u"fork", 4.0], - "waldo": {"fred": u"zap", "thud": False}, - "a.b.c": False, + ) + map_pb3 = document.MapValue( + fields={ + "highest": _value_pb(map_value=map_pb2), + "aside": _value_pb(double_value=actual_value3), } - self.assertEqual(self._call_fut(value_fields), expected) + ) + value = _value_pb(map_value=map_pb3) + expected = { + "highest": { + "middle": {"lowest": actual_value1, "aside": actual_value2}, + "aside": True, + }, + "aside": actual_value3, + } + assert decode_value(value, mock.sentinel.client) == expected -class Test_get_doc_id(unittest.TestCase): - @staticmethod - def _call_fut(document_pb, expected_prefix): - from google.cloud.firestore_v1._helpers import get_doc_id - return get_doc_id(document_pb, expected_prefix) +def test_decode_value_w_unset_value_type(): + from google.cloud.firestore_v1._helpers import decode_value - @staticmethod - def _dummy_ref_string(collection_id): - from google.cloud.firestore_v1.client import DEFAULT_DATABASE + with pytest.raises(ValueError): + decode_value(_value_pb(), mock.sentinel.client) - project = u"bazzzz" - return u"projects/{}/databases/{}/documents/{}".format( - project, DEFAULT_DATABASE, collection_id - ) - def test_success(self): - from google.cloud.firestore_v1.types import document +def test_decode_value_w_unknown_value_type(): + from google.cloud.firestore_v1._helpers import decode_value - prefix = self._dummy_ref_string("sub-collection") - actual_id = "this-is-the-one" - name = "{}/{}".format(prefix, actual_id) + value_pb = mock.Mock() + value_pb._pb.WhichOneof.return_value = "zoob_value" - document_pb = document.Document(name=name) - document_id = self._call_fut(document_pb, prefix) - self.assertEqual(document_id, actual_id) + with pytest.raises(ValueError): + decode_value(value_pb, mock.sentinel.client) - def test_failure(self): - from google.cloud.firestore_v1.types import document + value_pb._pb.WhichOneof.assert_called_once_with("value_type") - actual_prefix = self._dummy_ref_string("the-right-one") - wrong_prefix = self._dummy_ref_string("the-wrong-one") - name = "{}/{}".format(actual_prefix, "sorry-wont-works") - document_pb = document.Document(name=name) - with self.assertRaises(ValueError) as exc_info: - self._call_fut(document_pb, wrong_prefix) +def test_decode_dict_w_many_types(): + from google.protobuf import struct_pb2 + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types.document import ArrayValue + from google.cloud.firestore_v1.types.document import MapValue + from google.cloud._helpers import UTC + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1._helpers import decode_dict - exc_args = exc_info.exception.args - self.assertEqual(len(exc_args), 4) - self.assertEqual(exc_args[1], name) - self.assertEqual(exc_args[3], wrong_prefix) + dt_seconds = 1394037350 + dt_nanos = 667285000 + # Make sure precision is valid in microseconds too. + assert dt_nanos % 1000 == 0 + dt_val = datetime.datetime.utcfromtimestamp(dt_seconds + 1e-9 * dt_nanos).replace( + tzinfo=UTC + ) + value_fields = { + "foo": _value_pb(null_value=struct_pb2.NULL_VALUE), + "bar": _value_pb(boolean_value=True), + "baz": _value_pb(integer_value=981), + "quux": _value_pb(double_value=2.875), + "quuz": _value_pb( + timestamp_value=timestamp_pb2.Timestamp(seconds=dt_seconds, nanos=dt_nanos) + ), + "corge": _value_pb(string_value=u"\N{snowman}"), + "grault": _value_pb(bytes_value=b"\xe2\x98\x83"), + "garply": _value_pb( + array_value=ArrayValue( + values=[_value_pb(string_value=u"fork"), _value_pb(double_value=4.0)] + ) + ), + "waldo": _value_pb( + map_value=MapValue( + fields={ + "fred": _value_pb(string_value=u"zap"), + "thud": _value_pb(boolean_value=False), + } + ) + ), + FieldPath("a", "b", "c").to_api_repr(): _value_pb(boolean_value=False), + } + expected = { + "foo": None, + "bar": True, + "baz": 981, + "quux": 2.875, + "quuz": dt_val, + "corge": u"\N{snowman}", + "grault": b"\xe2\x98\x83", + "garply": [u"fork", 4.0], + "waldo": {"fred": u"zap", "thud": False}, + "a.b.c": False, + } + assert decode_dict(value_fields, mock.sentinel.client) == expected + + +def _dummy_ref_string(collection_id): + from google.cloud.firestore_v1.client import DEFAULT_DATABASE + + project = u"bazzzz" + return u"projects/{}/databases/{}/documents/{}".format( + project, DEFAULT_DATABASE, collection_id + ) -class Test_extract_fields(unittest.TestCase): - @staticmethod - def _call_fut(document_data, prefix_path, expand_dots=False): - from google.cloud.firestore_v1 import _helpers - return _helpers.extract_fields( - document_data, prefix_path, expand_dots=expand_dots - ) +def test_get_doc_id_w_success(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1._helpers import get_doc_id - def test_w_empty_document(self): - from google.cloud.firestore_v1._helpers import _EmptyDict + prefix = _dummy_ref_string("sub-collection") + actual_id = "this-is-the-one" + name = "{}/{}".format(prefix, actual_id) - document_data = {} - prefix_path = _make_field_path() - expected = [(_make_field_path(), _EmptyDict)] + document_pb = document.Document(name=name) + document_id = get_doc_id(document_pb, prefix) + assert document_id == actual_id - iterator = self._call_fut(document_data, prefix_path) - self.assertEqual(list(iterator), expected) - def test_w_invalid_key_and_expand_dots(self): - document_data = {"b": 1, "a~d": 2, "c": 3} - prefix_path = _make_field_path() +def test_get_doc_id_w_failure(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1._helpers import get_doc_id - with self.assertRaises(ValueError): - list(self._call_fut(document_data, prefix_path, expand_dots=True)) + actual_prefix = _dummy_ref_string("the-right-one") + wrong_prefix = _dummy_ref_string("the-wrong-one") + name = "{}/{}".format(actual_prefix, "sorry-wont-works") - def test_w_shallow_keys(self): - document_data = {"b": 1, "a": 2, "c": 3} - prefix_path = _make_field_path() - expected = [ - (_make_field_path("a"), 2), - (_make_field_path("b"), 1), - (_make_field_path("c"), 3), - ] + document_pb = document.Document(name=name) + with pytest.raises(ValueError) as exc_info: + get_doc_id(document_pb, wrong_prefix) - iterator = self._call_fut(document_data, prefix_path) - self.assertEqual(list(iterator), expected) + exc_args = exc_info.value.args + assert len(exc_args) == 4 + assert exc_args[1] == name + assert exc_args[3] == wrong_prefix - def test_w_nested(self): - from google.cloud.firestore_v1._helpers import _EmptyDict - document_data = {"b": {"a": {"d": 4, "c": 3, "g": {}}, "e": 7}, "f": 5} - prefix_path = _make_field_path() - expected = [ - (_make_field_path("b", "a", "c"), 3), - (_make_field_path("b", "a", "d"), 4), - (_make_field_path("b", "a", "g"), _EmptyDict), - (_make_field_path("b", "e"), 7), - (_make_field_path("f"), 5), - ] +def test_extract_fields_w_empty_document(): + from google.cloud.firestore_v1._helpers import extract_fields + from google.cloud.firestore_v1._helpers import _EmptyDict - iterator = self._call_fut(document_data, prefix_path) - self.assertEqual(list(iterator), expected) + document_data = {} + prefix_path = _make_field_path() + expected = [(_make_field_path(), _EmptyDict)] - def test_w_expand_dotted(self): - from google.cloud.firestore_v1._helpers import _EmptyDict + iterator = extract_fields(document_data, prefix_path) + assert list(iterator) == expected - document_data = { - "b": {"a": {"d": 4, "c": 3, "g": {}, "k.l.m": 17}, "e": 7}, - "f": 5, - "h.i.j": 9, - } - prefix_path = _make_field_path() - expected = [ - (_make_field_path("b", "a", "c"), 3), - (_make_field_path("b", "a", "d"), 4), - (_make_field_path("b", "a", "g"), _EmptyDict), - (_make_field_path("b", "a", "k.l.m"), 17), - (_make_field_path("b", "e"), 7), - (_make_field_path("f"), 5), - (_make_field_path("h", "i", "j"), 9), - ] - iterator = self._call_fut(document_data, prefix_path, expand_dots=True) - self.assertEqual(list(iterator), expected) +def test_extract_fields_w_invalid_key_and_expand_dots(): + from google.cloud.firestore_v1._helpers import extract_fields + document_data = {"b": 1, "a~d": 2, "c": 3} + prefix_path = _make_field_path() -class Test_set_field_value(unittest.TestCase): - @staticmethod - def _call_fut(document_data, field_path, value): - from google.cloud.firestore_v1 import _helpers + with pytest.raises(ValueError): + list(extract_fields(document_data, prefix_path, expand_dots=True)) - return _helpers.set_field_value(document_data, field_path, value) - def test_normal_value_w_shallow(self): - document = {} - field_path = _make_field_path("a") - value = 3 +def test_extract_fields_w_shallow_keys(): + from google.cloud.firestore_v1._helpers import extract_fields - self._call_fut(document, field_path, value) + document_data = {"b": 1, "a": 2, "c": 3} + prefix_path = _make_field_path() + expected = [ + (_make_field_path("a"), 2), + (_make_field_path("b"), 1), + (_make_field_path("c"), 3), + ] - self.assertEqual(document, {"a": 3}) + iterator = extract_fields(document_data, prefix_path) + assert list(iterator) == expected - def test_normal_value_w_nested(self): - document = {} - field_path = _make_field_path("a", "b", "c") - value = 3 - self._call_fut(document, field_path, value) +def test_extract_fields_w_nested(): + from google.cloud.firestore_v1._helpers import _EmptyDict + from google.cloud.firestore_v1._helpers import extract_fields - self.assertEqual(document, {"a": {"b": {"c": 3}}}) + document_data = {"b": {"a": {"d": 4, "c": 3, "g": {}}, "e": 7}, "f": 5} + prefix_path = _make_field_path() + expected = [ + (_make_field_path("b", "a", "c"), 3), + (_make_field_path("b", "a", "d"), 4), + (_make_field_path("b", "a", "g"), _EmptyDict), + (_make_field_path("b", "e"), 7), + (_make_field_path("f"), 5), + ] - def test_empty_dict_w_shallow(self): - from google.cloud.firestore_v1._helpers import _EmptyDict + iterator = extract_fields(document_data, prefix_path) + assert list(iterator) == expected - document = {} - field_path = _make_field_path("a") - value = _EmptyDict - self._call_fut(document, field_path, value) +def test_extract_fields_w_expand_dotted(): + from google.cloud.firestore_v1._helpers import _EmptyDict + from google.cloud.firestore_v1._helpers import extract_fields - self.assertEqual(document, {"a": {}}) + document_data = { + "b": {"a": {"d": 4, "c": 3, "g": {}, "k.l.m": 17}, "e": 7}, + "f": 5, + "h.i.j": 9, + } + prefix_path = _make_field_path() + expected = [ + (_make_field_path("b", "a", "c"), 3), + (_make_field_path("b", "a", "d"), 4), + (_make_field_path("b", "a", "g"), _EmptyDict), + (_make_field_path("b", "a", "k.l.m"), 17), + (_make_field_path("b", "e"), 7), + (_make_field_path("f"), 5), + (_make_field_path("h", "i", "j"), 9), + ] - def test_empty_dict_w_nested(self): - from google.cloud.firestore_v1._helpers import _EmptyDict + iterator = extract_fields(document_data, prefix_path, expand_dots=True) + assert list(iterator) == expected - document = {} - field_path = _make_field_path("a", "b", "c") - value = _EmptyDict - self._call_fut(document, field_path, value) +def test_set_field_value_normal_value_w_shallow(): + from google.cloud.firestore_v1._helpers import set_field_value - self.assertEqual(document, {"a": {"b": {"c": {}}}}) + document = {} + field_path = _make_field_path("a") + value = 3 + set_field_value(document, field_path, value) -class Test_get_field_value(unittest.TestCase): - @staticmethod - def _call_fut(document_data, field_path): - from google.cloud.firestore_v1 import _helpers + assert document == {"a": 3} - return _helpers.get_field_value(document_data, field_path) - def test_w_empty_path(self): - document = {} +def test_set_field_value_normal_value_w_nested(): + from google.cloud.firestore_v1._helpers import set_field_value - with self.assertRaises(ValueError): - self._call_fut(document, _make_field_path()) + document = {} + field_path = _make_field_path("a", "b", "c") + value = 3 - def test_miss_shallow(self): - document = {} + set_field_value(document, field_path, value) - with self.assertRaises(KeyError): - self._call_fut(document, _make_field_path("nonesuch")) + assert document == {"a": {"b": {"c": 3}}} - def test_miss_nested(self): - document = {"a": {"b": {}}} - with self.assertRaises(KeyError): - self._call_fut(document, _make_field_path("a", "b", "c")) +def test_set_field_value_empty_dict_w_shallow(): + from google.cloud.firestore_v1._helpers import _EmptyDict + from google.cloud.firestore_v1._helpers import set_field_value - def test_hit_shallow(self): - document = {"a": 1} + document = {} + field_path = _make_field_path("a") + value = _EmptyDict - self.assertEqual(self._call_fut(document, _make_field_path("a")), 1) + set_field_value(document, field_path, value) - def test_hit_nested(self): - document = {"a": {"b": {"c": 1}}} + assert document == {"a": {}} - self.assertEqual(self._call_fut(document, _make_field_path("a", "b", "c")), 1) +def test_set_field_value_empty_dict_w_nested(): + from google.cloud.firestore_v1._helpers import _EmptyDict + from google.cloud.firestore_v1._helpers import set_field_value -class TestDocumentExtractor(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1 import _helpers + document = {} + field_path = _make_field_path("a", "b", "c") + value = _EmptyDict - return _helpers.DocumentExtractor + set_field_value(document, field_path, value) - def _make_one(self, document_data): - return self._get_target_class()(document_data) + assert document == {"a": {"b": {"c": {}}}} - def test_ctor_w_empty_document(self): - document_data = {} - inst = self._make_one(document_data) +def test__get_field_value_w_empty_path(): + from google.cloud.firestore_v1._helpers import get_field_value - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertTrue(inst.empty_document) - self.assertFalse(inst.has_transforms) - self.assertEqual(inst.transform_paths, []) + document = {} - def test_ctor_w_delete_field_shallow(self): - from google.cloud.firestore_v1.transforms import DELETE_FIELD + with pytest.raises(ValueError): + get_field_value(document, _make_field_path()) - document_data = {"a": DELETE_FIELD} - inst = self._make_one(document_data) +def test__get_field_value_miss_shallow(): + from google.cloud.firestore_v1._helpers import get_field_value - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, [_make_field_path("a")]) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - self.assertEqual(inst.transform_paths, []) - - def test_ctor_w_delete_field_nested(self): - from google.cloud.firestore_v1.transforms import DELETE_FIELD - - document_data = {"a": {"b": {"c": DELETE_FIELD}}} - - inst = self._make_one(document_data) - - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, [_make_field_path("a", "b", "c")]) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - self.assertEqual(inst.transform_paths, []) - - def test_ctor_w_server_timestamp_shallow(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - - document_data = {"a": SERVER_TIMESTAMP} - - inst = self._make_one(document_data) - - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, [_make_field_path("a")]) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_server_timestamp_nested(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - - document_data = {"a": {"b": {"c": SERVER_TIMESTAMP}}} - - inst = self._make_one(document_data) - - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, [_make_field_path("a", "b", "c")]) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_array_remove_shallow(self): - from google.cloud.firestore_v1.transforms import ArrayRemove - - values = [1, 3, 5] - document_data = {"a": ArrayRemove(values)} - - inst = self._make_one(document_data) - - expected_array_removes = {_make_field_path("a"): values} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, expected_array_removes) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_array_remove_nested(self): - from google.cloud.firestore_v1.transforms import ArrayRemove - - values = [2, 4, 8] - document_data = {"a": {"b": {"c": ArrayRemove(values)}}} - - inst = self._make_one(document_data) - - expected_array_removes = {_make_field_path("a", "b", "c"): values} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, expected_array_removes) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_array_union_shallow(self): - from google.cloud.firestore_v1.transforms import ArrayUnion - - values = [1, 3, 5] - document_data = {"a": ArrayUnion(values)} - - inst = self._make_one(document_data) - - expected_array_unions = {_make_field_path("a"): values} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, expected_array_unions) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_array_union_nested(self): - from google.cloud.firestore_v1.transforms import ArrayUnion - - values = [2, 4, 8] - document_data = {"a": {"b": {"c": ArrayUnion(values)}}} - - inst = self._make_one(document_data) - - expected_array_unions = {_make_field_path("a", "b", "c"): values} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, expected_array_unions) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_increment_shallow(self): - from google.cloud.firestore_v1.transforms import Increment - - value = 1 - document_data = {"a": Increment(value)} - - inst = self._make_one(document_data) - - expected_increments = {_make_field_path("a"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, expected_increments) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_increment_nested(self): - from google.cloud.firestore_v1.transforms import Increment - - value = 2 - document_data = {"a": {"b": {"c": Increment(value)}}} - - inst = self._make_one(document_data) - - expected_increments = {_make_field_path("a", "b", "c"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, expected_increments) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_maximum_shallow(self): - from google.cloud.firestore_v1.transforms import Maximum - - value = 1 - document_data = {"a": Maximum(value)} - - inst = self._make_one(document_data) - - expected_maximums = {_make_field_path("a"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, expected_maximums) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_maximum_nested(self): - from google.cloud.firestore_v1.transforms import Maximum - - value = 2 - document_data = {"a": {"b": {"c": Maximum(value)}}} - - inst = self._make_one(document_data) - - expected_maximums = {_make_field_path("a", "b", "c"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, expected_maximums) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_minimum_shallow(self): - from google.cloud.firestore_v1.transforms import Minimum - - value = 1 - document_data = {"a": Minimum(value)} - - inst = self._make_one(document_data) - - expected_minimums = {_make_field_path("a"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, expected_minimums) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a")]) - - def test_ctor_w_minimum_nested(self): - from google.cloud.firestore_v1.transforms import Minimum - - value = 2 - document_data = {"a": {"b": {"c": Minimum(value)}}} - - inst = self._make_one(document_data) - - expected_minimums = {_make_field_path("a", "b", "c"): value} - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, []) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, expected_minimums) - self.assertEqual(inst.set_fields, {}) - self.assertFalse(inst.empty_document) - self.assertTrue(inst.has_transforms) - self.assertEqual(inst.transform_paths, [_make_field_path("a", "b", "c")]) - - def test_ctor_w_empty_dict_shallow(self): - document_data = {"a": {}} - - inst = self._make_one(document_data) - - expected_field_paths = [_make_field_path("a")] - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, expected_field_paths) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, document_data) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - self.assertEqual(inst.transform_paths, []) - - def test_ctor_w_empty_dict_nested(self): - document_data = {"a": {"b": {"c": {}}}} - - inst = self._make_one(document_data) - - expected_field_paths = [_make_field_path("a", "b", "c")] - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, expected_field_paths) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, document_data) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - self.assertEqual(inst.transform_paths, []) - - def test_ctor_w_normal_value_shallow(self): - document_data = {"b": 1, "a": 2, "c": 3} - - inst = self._make_one(document_data) - - expected_field_paths = [ - _make_field_path("a"), - _make_field_path("b"), - _make_field_path("c"), - ] - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, expected_field_paths) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.set_fields, document_data) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - - def test_ctor_w_normal_value_nested(self): - document_data = {"b": {"a": {"d": 4, "c": 3}, "e": 7}, "f": 5} - - inst = self._make_one(document_data) - - expected_field_paths = [ - _make_field_path("b", "a", "c"), - _make_field_path("b", "a", "d"), - _make_field_path("b", "e"), - _make_field_path("f"), - ] - self.assertEqual(inst.document_data, document_data) - self.assertEqual(inst.field_paths, expected_field_paths) - self.assertEqual(inst.deleted_fields, []) - self.assertEqual(inst.server_timestamps, []) - self.assertEqual(inst.array_removes, {}) - self.assertEqual(inst.array_unions, {}) - self.assertEqual(inst.increments, {}) - self.assertEqual(inst.maximums, {}) - self.assertEqual(inst.minimums, {}) - self.assertEqual(inst.set_fields, document_data) - self.assertFalse(inst.empty_document) - self.assertFalse(inst.has_transforms) - - def test_get_update_pb_w_exists_precondition(self): - from google.cloud.firestore_v1.types import write - - document_data = {} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + document = {} - update_pb = inst.get_update_pb(document_path, exists=False) + with pytest.raises(KeyError): + get_field_value(document, _make_field_path("nonesuch")) - self.assertIsInstance(update_pb, write.Write) - self.assertEqual(update_pb.update.name, document_path) - self.assertEqual(update_pb.update.fields, document_data) - self.assertTrue(update_pb._pb.HasField("current_document")) - self.assertFalse(update_pb.current_document.exists) - def test_get_update_pb_wo_exists_precondition(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1._helpers import encode_dict +def test__get_field_value_miss_nested(): + from google.cloud.firestore_v1._helpers import get_field_value - document_data = {"a": 1} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + document = {"a": {"b": {}}} - update_pb = inst.get_update_pb(document_path) + with pytest.raises(KeyError): + get_field_value(document, _make_field_path("a", "b", "c")) - self.assertIsInstance(update_pb, write.Write) - self.assertEqual(update_pb.update.name, document_path) - self.assertEqual(update_pb.update.fields, encode_dict(document_data)) - self.assertFalse(update_pb._pb.HasField("current_document")) - def test_get_field_transform_pbs_miss(self): - document_data = {"a": 1} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) +def test__get_field_value_hit_shallow(): + from google.cloud.firestore_v1._helpers import get_field_value - field_transform_pbs = inst.get_field_transform_pbs(document_path) + document = {"a": 1} - self.assertEqual(field_transform_pbs, []) + assert get_field_value(document, _make_field_path("a")) == 1 - def test_get_field_transform_pbs_w_server_timestamp(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM - document_data = {"a": SERVER_TIMESTAMP} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) +def test__get_field_value_hit_nested(): + from google.cloud.firestore_v1._helpers import get_field_value - field_transform_pbs = inst.get_field_transform_pbs(document_path) + document = {"a": {"b": {"c": 1}}} - self.assertEqual(len(field_transform_pbs), 1) - field_transform_pb = field_transform_pbs[0] - self.assertIsInstance( - field_transform_pb, write.DocumentTransform.FieldTransform - ) - self.assertEqual(field_transform_pb.field_path, "a") - self.assertEqual(field_transform_pb.set_to_server_value, REQUEST_TIME_ENUM) - - def test_get_transform_pb_w_server_timestamp_w_exists_precondition(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM - - document_data = {"a": SERVER_TIMESTAMP} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + assert get_field_value(document, _make_field_path("a", "b", "c")) == 1 - transform_pb = inst.get_transform_pb(document_path, exists=False) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a") - self.assertEqual(transform.set_to_server_value, REQUEST_TIME_ENUM) - self.assertTrue(transform_pb._pb.HasField("current_document")) - self.assertFalse(transform_pb.current_document.exists) - - def test_get_transform_pb_w_server_timestamp_wo_exists_precondition(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM - - document_data = {"a": {"b": {"c": SERVER_TIMESTAMP}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) - transform_pb = inst.get_transform_pb(document_path) +def _make_document_extractor(document_data): + from google.cloud.firestore_v1._helpers import DocumentExtractor - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - self.assertEqual(transform.set_to_server_value, REQUEST_TIME_ENUM) - self.assertFalse(transform_pb._pb.HasField("current_document")) + return DocumentExtractor(document_data) - @staticmethod - def _array_value_to_list(array_value): - from google.cloud.firestore_v1._helpers import decode_value - return [decode_value(element, client=None) for element in array_value.values] +def test_documentextractor_ctor_w_empty_document(): + document_data = {} - def test_get_transform_pb_w_array_remove(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import ArrayRemove + inst = _make_document_extractor(document_data) - values = [2, 4, 8] - document_data = {"a": {"b": {"c": ArrayRemove(values)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert inst.empty_document + assert not inst.has_transforms + assert inst.transform_paths == [] - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - removed = self._array_value_to_list(transform.remove_all_from_array) - self.assertEqual(removed, values) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_array_union(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import ArrayUnion - - values = [1, 3, 5] - document_data = {"a": {"b": {"c": ArrayUnion(values)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = self._array_value_to_list(transform.append_missing_elements) - self.assertEqual(added, values) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_increment_int(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Increment - - value = 1 - document_data = {"a": {"b": {"c": Increment(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) +def test_documentextractor_ctor_w_delete_field_shallow(): + from google.cloud.firestore_v1.transforms import DELETE_FIELD - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.increment.integer_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_increment_float(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Increment - - value = 3.1415926 - document_data = {"a": {"b": {"c": Increment(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + document_data = {"a": DELETE_FIELD} - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.increment.double_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_maximum_int(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Maximum - - value = 1 - document_data = {"a": {"b": {"c": Maximum(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + inst = _make_document_extractor(document_data) - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.maximum.integer_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_maximum_float(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Maximum - - value = 3.1415926 - document_data = {"a": {"b": {"c": Maximum(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [_make_field_path("a")] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert not inst.has_transforms + assert inst.transform_paths == [] - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.maximum.double_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_minimum_int(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Minimum - - value = 1 - document_data = {"a": {"b": {"c": Minimum(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" - ) - transform_pb = inst.get_transform_pb(document_path) - - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.minimum.integer_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) - - def test_get_transform_pb_w_minimum_float(self): - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1.transforms import Minimum - - value = 3.1415926 - document_data = {"a": {"b": {"c": Minimum(value)}}} - inst = self._make_one(document_data) - document_path = ( - "projects/project-id/databases/(default)/" "documents/document-id" +def test_documentextractor_ctor_w_delete_field_nested(): + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + document_data = {"a": {"b": {"c": DELETE_FIELD}}} + + inst = _make_document_extractor(document_data) + + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [_make_field_path("a", "b", "c")] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert not inst.has_transforms + assert inst.transform_paths == [] + + +def test_documentextractor_ctor_w_server_timestamp_shallow(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + + document_data = {"a": SERVER_TIMESTAMP} + + inst = _make_document_extractor(document_data) + + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [_make_field_path("a")] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor_ctor_w_server_timestamp_nested(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + + document_data = {"a": {"b": {"c": SERVER_TIMESTAMP}}} + + inst = _make_document_extractor(document_data) + + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [_make_field_path("a", "b", "c")] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_array_remove_shallow(): + from google.cloud.firestore_v1.transforms import ArrayRemove + + values = [1, 3, 5] + document_data = {"a": ArrayRemove(values)} + + inst = _make_document_extractor(document_data) + + expected_array_removes = {_make_field_path("a"): values} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == expected_array_removes + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor_ctor_w_array_remove_nested(): + from google.cloud.firestore_v1.transforms import ArrayRemove + + values = [2, 4, 8] + document_data = {"a": {"b": {"c": ArrayRemove(values)}}} + + inst = _make_document_extractor(document_data) + + expected_array_removes = {_make_field_path("a", "b", "c"): values} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == expected_array_removes + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_array_union_shallow(): + from google.cloud.firestore_v1.transforms import ArrayUnion + + values = [1, 3, 5] + document_data = {"a": ArrayUnion(values)} + + inst = _make_document_extractor(document_data) + + expected_array_unions = {_make_field_path("a"): values} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == expected_array_unions + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor__documentextractor_ctor_w_array_union_nested(): + from google.cloud.firestore_v1.transforms import ArrayUnion + + values = [2, 4, 8] + document_data = {"a": {"b": {"c": ArrayUnion(values)}}} + + inst = _make_document_extractor(document_data) + + expected_array_unions = {_make_field_path("a", "b", "c"): values} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == expected_array_unions + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_increment_shallow(): + from google.cloud.firestore_v1.transforms import Increment + + value = 1 + document_data = {"a": Increment(value)} + + inst = _make_document_extractor(document_data) + + expected_increments = {_make_field_path("a"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == expected_increments + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor_ctor_w_increment_nested(): + from google.cloud.firestore_v1.transforms import Increment + + value = 2 + document_data = {"a": {"b": {"c": Increment(value)}}} + + inst = _make_document_extractor(document_data) + + expected_increments = {_make_field_path("a", "b", "c"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == expected_increments + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_maximum_shallow(): + from google.cloud.firestore_v1.transforms import Maximum + + value = 1 + document_data = {"a": Maximum(value)} + + inst = _make_document_extractor(document_data) + + expected_maximums = {_make_field_path("a"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == expected_maximums + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor_ctor_w_maximum_nested(): + from google.cloud.firestore_v1.transforms import Maximum + + value = 2 + document_data = {"a": {"b": {"c": Maximum(value)}}} + + inst = _make_document_extractor(document_data) + + expected_maximums = {_make_field_path("a", "b", "c"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == expected_maximums + assert inst.minimums == {} + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_minimum_shallow(): + from google.cloud.firestore_v1.transforms import Minimum + + value = 1 + document_data = {"a": Minimum(value)} + + inst = _make_document_extractor(document_data) + + expected_minimums = {_make_field_path("a"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == expected_minimums + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a")] + + +def test_documentextractor_ctor_w_minimum_nested(): + from google.cloud.firestore_v1.transforms import Minimum + + value = 2 + document_data = {"a": {"b": {"c": Minimum(value)}}} + + inst = _make_document_extractor(document_data) + + expected_minimums = {_make_field_path("a", "b", "c"): value} + assert inst.document_data == document_data + assert inst.field_paths == [] + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == expected_minimums + assert inst.set_fields == {} + assert not inst.empty_document + assert inst.has_transforms + assert inst.transform_paths == [_make_field_path("a", "b", "c")] + + +def test_documentextractor_ctor_w_empty_dict_shallow(): + document_data = {"a": {}} + + inst = _make_document_extractor(document_data) + + expected_field_paths = [_make_field_path("a")] + assert inst.document_data == document_data + assert inst.field_paths == expected_field_paths + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == document_data + assert not inst.empty_document + assert not inst.has_transforms + assert inst.transform_paths == [] + + +def test_documentextractor_ctor_w_empty_dict_nested(): + document_data = {"a": {"b": {"c": {}}}} + + inst = _make_document_extractor(document_data) + + expected_field_paths = [_make_field_path("a", "b", "c")] + assert inst.document_data == document_data + assert inst.field_paths == expected_field_paths + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == document_data + assert not inst.empty_document + assert not inst.has_transforms + assert inst.transform_paths == [] + + +def test_documentextractor_ctor_w_normal_value_shallow(): + document_data = {"b": 1, "a": 2, "c": 3} + + inst = _make_document_extractor(document_data) + + expected_field_paths = [ + _make_field_path("a"), + _make_field_path("b"), + _make_field_path("c"), + ] + assert inst.document_data == document_data + assert inst.field_paths == expected_field_paths + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.set_fields == document_data + assert not inst.empty_document + assert not inst.has_transforms + + +def test_documentextractor_ctor_w_normal_value_nested(): + document_data = {"b": {"a": {"d": 4, "c": 3}, "e": 7}, "f": 5} + + inst = _make_document_extractor(document_data) + + expected_field_paths = [ + _make_field_path("b", "a", "c"), + _make_field_path("b", "a", "d"), + _make_field_path("b", "e"), + _make_field_path("f"), + ] + assert inst.document_data == document_data + assert inst.field_paths == expected_field_paths + assert inst.deleted_fields == [] + assert inst.server_timestamps == [] + assert inst.array_removes == {} + assert inst.array_unions == {} + assert inst.increments == {} + assert inst.maximums == {} + assert inst.minimums == {} + assert inst.set_fields == document_data + assert not inst.empty_document + assert not inst.has_transforms + + +def test_documentextractor_get_update_pb_w_exists_precondition(): + from google.cloud.firestore_v1.types import write + + document_data = {} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + update_pb = inst.get_update_pb(document_path, exists=False) + + assert isinstance(update_pb, write.Write) + assert update_pb.update.name == document_path + assert update_pb.update.fields == document_data + assert update_pb._pb.HasField("current_document") + assert not update_pb.current_document.exists + + +def test_documentextractor_get_update_pb_wo_exists_precondition(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import encode_dict + + document_data = {"a": 1} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + update_pb = inst.get_update_pb(document_path) + + assert isinstance(update_pb, write.Write) + assert update_pb.update.name == document_path + assert update_pb.update.fields == encode_dict(document_data) + assert not update_pb._pb.HasField("current_document") + + +def test_documentextractor_get_field_transform_pbs_miss(): + document_data = {"a": 1} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + field_transform_pbs = inst.get_field_transform_pbs(document_path) + + assert field_transform_pbs == [] + + +def test_documentextractor_get_field_transform_pbs_w_server_timestamp(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + + document_data = {"a": SERVER_TIMESTAMP} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + field_transform_pbs = inst.get_field_transform_pbs(document_path) + + assert len(field_transform_pbs) == 1 + field_transform_pb = field_transform_pbs[0] + assert isinstance(field_transform_pb, write.DocumentTransform.FieldTransform) + assert field_transform_pb.field_path == "a" + assert field_transform_pb.set_to_server_value == REQUEST_TIME_ENUM + + +def test_documentextractor_get_transform_pb_w_server_timestamp_w_exists_precondition(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + + document_data = {"a": SERVER_TIMESTAMP} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path, exists=False) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a" + assert transform.set_to_server_value == REQUEST_TIME_ENUM + assert transform_pb._pb.HasField("current_document") + assert not transform_pb.current_document.exists + + +def test_documentextractor_get_transform_pb_w_server_timestamp_wo_exists_precondition(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import REQUEST_TIME_ENUM + + document_data = {"a": {"b": {"c": SERVER_TIMESTAMP}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + assert transform.set_to_server_value == REQUEST_TIME_ENUM + assert not transform_pb._pb.HasField("current_document") + + +def _array_value_to_list(array_value): + from google.cloud.firestore_v1._helpers import decode_value + + return [decode_value(element, client=None) for element in array_value.values] + + +def test_documentextractor_get_transform_pb_w_array_remove(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import ArrayRemove + + values = [2, 4, 8] + document_data = {"a": {"b": {"c": ArrayRemove(values)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + removed = _array_value_to_list(transform.remove_all_from_array) + assert removed == values + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_array_union(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import ArrayUnion + + values = [1, 3, 5] + document_data = {"a": {"b": {"c": ArrayUnion(values)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = _array_value_to_list(transform.append_missing_elements) + assert added == values + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_increment_int(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Increment + + value = 1 + document_data = {"a": {"b": {"c": Increment(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.increment.integer_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_increment_float(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Increment + + value = 3.1415926 + document_data = {"a": {"b": {"c": Increment(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.increment.double_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_maximum_int(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Maximum + + value = 1 + document_data = {"a": {"b": {"c": Maximum(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.maximum.integer_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_maximum_float(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Maximum + + value = 3.1415926 + document_data = {"a": {"b": {"c": Maximum(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.maximum.double_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_minimum_int(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Minimum + + value = 1 + document_data = {"a": {"b": {"c": Minimum(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.minimum.integer_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def test_documentextractor_get_transform_pb_w_minimum_float(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.transforms import Minimum + + value = 3.1415926 + document_data = {"a": {"b": {"c": Minimum(value)}}} + inst = _make_document_extractor(document_data) + document_path = "projects/project-id/databases/(default)/documents/document-id" + + transform_pb = inst.get_transform_pb(document_path) + + assert isinstance(transform_pb, write.Write) + assert transform_pb.transform.document == document_path + transforms = transform_pb.transform.field_transforms + assert len(transforms) == 1 + transform = transforms[0] + assert transform.field_path == "a.b.c" + added = transform.minimum.double_value + assert added == value + assert not transform_pb._pb.HasField("current_document") + + +def _make_write_w_document_for_create(document_path, **data): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import encode_dict + from google.cloud.firestore_v1.types import common + + return write.Write( + update=document.Document(name=document_path, fields=encode_dict(data)), + current_document=common.Precondition(exists=False), + ) + + +def _add_field_transforms_for_create(update_pb, fields): + from google.cloud.firestore_v1 import DocumentTransform + + server_val = DocumentTransform.FieldTransform.ServerValue + for field in fields: + update_pb.update_transforms.append( + DocumentTransform.FieldTransform( + field_path=field, set_to_server_value=server_val.REQUEST_TIME + ) ) - transform_pb = inst.get_transform_pb(document_path) - self.assertIsInstance(transform_pb, write.Write) - self.assertEqual(transform_pb.transform.document, document_path) - transforms = transform_pb.transform.field_transforms - self.assertEqual(len(transforms), 1) - transform = transforms[0] - self.assertEqual(transform.field_path, "a.b.c") - added = transform.minimum.double_value - self.assertEqual(added, value) - self.assertFalse(transform_pb._pb.HasField("current_document")) +def __pbs_for_create_helper(do_transform=False, empty_val=False): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_create + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"cheese": 1.5, "crackers": True} -class Test_pbs_for_create(unittest.TestCase): - @staticmethod - def _call_fut(document_path, document_data): - from google.cloud.firestore_v1._helpers import pbs_for_create + if do_transform: + document_data["butter"] = SERVER_TIMESTAMP - return pbs_for_create(document_path, document_data) + if empty_val: + document_data["mustard"] = {} - @staticmethod - def _make_write_w_document(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1._helpers import encode_dict - from google.cloud.firestore_v1.types import common + write_pbs = pbs_for_create(document_path, document_data) - return write.Write( - update=document.Document(name=document_path, fields=encode_dict(data)), - current_document=common.Precondition(exists=False), + if empty_val: + update_pb = _make_write_w_document_for_create( + document_path, cheese=1.5, crackers=True, mustard={} ) + else: + update_pb = _make_write_w_document_for_create( + document_path, cheese=1.5, crackers=True + ) + expected_pbs = [update_pb] - @staticmethod - def _add_field_transforms(update_pb, fields): - from google.cloud.firestore_v1 import DocumentTransform + if do_transform: + _add_field_transforms_for_create(update_pb, fields=["butter"]) - server_val = DocumentTransform.FieldTransform.ServerValue - for field in fields: - update_pb.update_transforms.append( - DocumentTransform.FieldTransform( - field_path=field, set_to_server_value=server_val.REQUEST_TIME - ) - ) + assert write_pbs == expected_pbs - def _helper(self, do_transform=False, empty_val=False): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} +def test__pbs_for_create_wo_transform(): + __pbs_for_create_helper() - if do_transform: - document_data["butter"] = SERVER_TIMESTAMP - if empty_val: - document_data["mustard"] = {} +def test__pbs_for_create_w_transform(): + __pbs_for_create_helper(do_transform=True) - write_pbs = self._call_fut(document_path, document_data) - if empty_val: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True, mustard={} - ) - else: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True +def test__pbs_for_create_w_transform_and_empty_value(): + __pbs_for_create_helper(do_transform=True, empty_val=True) + + +def _make_write_w_document_for_set_no_merge(document_path, **data): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import encode_dict + + return write.Write( + update=document.Document(name=document_path, fields=encode_dict(data)) + ) + + +def _add_field_transforms_for_set_no_merge(update_pb, fields): + from google.cloud.firestore_v1 import DocumentTransform + + server_val = DocumentTransform.FieldTransform.ServerValue + for field in fields: + update_pb.update_transforms.append( + DocumentTransform.FieldTransform( + field_path=field, set_to_server_value=server_val.REQUEST_TIME ) - expected_pbs = [update_pb] + ) + + +def test__pbs_for_set_w_empty_document(): + from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge + + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {} + + write_pbs = pbs_for_set_no_merge(document_path, document_data) + + update_pb = _make_write_w_document_for_set_no_merge(document_path) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs + + +def test__pbs_for_set_w_only_server_timestamp(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge - if do_transform: - self._add_field_transforms(update_pb, fields=["butter"]) + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"butter": SERVER_TIMESTAMP} - self.assertEqual(write_pbs, expected_pbs) + write_pbs = pbs_for_set_no_merge(document_path, document_data) - def test_without_transform(self): - self._helper() + update_pb = _make_write_w_document_for_set_no_merge(document_path) + _add_field_transforms_for_set_no_merge(update_pb, fields=["butter"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - def test_w_transform(self): - self._helper(do_transform=True) - def test_w_transform_and_empty_value(self): - self._helper(do_transform=True, empty_val=True) +def _pbs_for_set_no_merge_helper(do_transform=False, empty_val=False): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_no_merge + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"cheese": 1.5, "crackers": True} -class Test_pbs_for_set_no_merge(unittest.TestCase): - @staticmethod - def _call_fut(document_path, document_data): - from google.cloud.firestore_v1 import _helpers + if do_transform: + document_data["butter"] = SERVER_TIMESTAMP - return _helpers.pbs_for_set_no_merge(document_path, document_data) + if empty_val: + document_data["mustard"] = {} - @staticmethod - def _make_write_w_document(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1._helpers import encode_dict + write_pbs = pbs_for_set_no_merge(document_path, document_data) - return write.Write( - update=document.Document(name=document_path, fields=encode_dict(data)) + if empty_val: + update_pb = _make_write_w_document_for_set_no_merge( + document_path, cheese=1.5, crackers=True, mustard={} + ) + else: + update_pb = _make_write_w_document_for_set_no_merge( + document_path, cheese=1.5, crackers=True ) + expected_pbs = [update_pb] - @staticmethod - def _add_field_transforms(update_pb, fields): - from google.cloud.firestore_v1 import DocumentTransform + if do_transform: + _add_field_transforms_for_set_no_merge(update_pb, fields=["butter"]) - server_val = DocumentTransform.FieldTransform.ServerValue - for field in fields: - update_pb.update_transforms.append( - DocumentTransform.FieldTransform( - field_path=field, set_to_server_value=server_val.REQUEST_TIME - ) - ) + assert write_pbs == expected_pbs - def test_w_empty_document(self): - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {} - write_pbs = self._call_fut(document_path, document_data) +def test__pbs_for_set_defaults(): + _pbs_for_set_no_merge_helper() - update_pb = self._make_write_w_document(document_path) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) - def test_w_only_server_timestamp(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP +def test__pbs_for_set_w_transform(): + _pbs_for_set_no_merge_helper(do_transform=True) - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"butter": SERVER_TIMESTAMP} - write_pbs = self._call_fut(document_path, document_data) +def test__pbs_for_set_w_transform_and_empty_value(): + # Exercise https://github.com/googleapis/google-cloud-python/issuses/5944 + _pbs_for_set_no_merge_helper(do_transform=True, empty_val=True) - update_pb = self._make_write_w_document(document_path) - self._add_field_transforms(update_pb, fields=["butter"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) - def _helper(self, do_transform=False, empty_val=False): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP +def _make_document_extractor_for_merge(document_data): + from google.cloud.firestore_v1 import _helpers - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} + return _helpers.DocumentExtractorForMerge(document_data) - if do_transform: - document_data["butter"] = SERVER_TIMESTAMP - if empty_val: - document_data["mustard"] = {} +def test_documentextractorformerge_ctor_w_empty_document(): + document_data = {} - write_pbs = self._call_fut(document_path, document_data) + inst = _make_document_extractor_for_merge(document_data) - if empty_val: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True, mustard={} - ) - else: - update_pb = self._make_write_w_document( - document_path, cheese=1.5, crackers=True - ) - expected_pbs = [update_pb] + assert inst.data_merge == [] + assert inst.transform_merge == [] + assert inst.merge == [] - if do_transform: - self._add_field_transforms(update_pb, fields=["butter"]) - self.assertEqual(write_pbs, expected_pbs) +def test_documentextractorformerge_apply_merge_all_w_empty_document(): + document_data = {} + inst = _make_document_extractor_for_merge(document_data) - def test_defaults(self): - self._helper() + inst.apply_merge(True) - def test_w_transform(self): - self._helper(do_transform=True) + assert inst.data_merge == [] + assert inst.transform_merge == [] + assert inst.merge == [] - def test_w_transform_and_empty_value(self): - # Exercise #5944 - self._helper(do_transform=True, empty_val=True) +def test_documentextractorformerge_apply_merge_all_w_delete(): + from google.cloud.firestore_v1.transforms import DELETE_FIELD -class TestDocumentExtractorForMerge(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1 import _helpers + document_data = {"write_me": "value", "delete_me": DELETE_FIELD} + inst = _make_document_extractor_for_merge(document_data) - return _helpers.DocumentExtractorForMerge + inst.apply_merge(True) - def _make_one(self, document_data): - return self._get_target_class()(document_data) + expected_data_merge = [ + _make_field_path("delete_me"), + _make_field_path("write_me"), + ] + assert inst.data_merge == expected_data_merge + assert inst.transform_merge == [] + assert inst.merge == expected_data_merge - def test_ctor_w_empty_document(self): - document_data = {} - inst = self._make_one(document_data) +def test_documentextractorformerge_apply_merge_all_w_server_timestamp(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - self.assertEqual(inst.data_merge, []) - self.assertEqual(inst.transform_merge, []) - self.assertEqual(inst.merge, []) + document_data = {"write_me": "value", "timestamp": SERVER_TIMESTAMP} + inst = _make_document_extractor_for_merge(document_data) - def test_apply_merge_all_w_empty_document(self): - document_data = {} - inst = self._make_one(document_data) + inst.apply_merge(True) - inst.apply_merge(True) + expected_data_merge = [_make_field_path("write_me")] + expected_transform_merge = [_make_field_path("timestamp")] + expected_merge = [_make_field_path("timestamp"), _make_field_path("write_me")] + assert inst.data_merge == expected_data_merge + assert inst.transform_merge == expected_transform_merge + assert inst.merge == expected_merge - self.assertEqual(inst.data_merge, []) - self.assertEqual(inst.transform_merge, []) - self.assertEqual(inst.merge, []) - def test_apply_merge_all_w_delete(self): - from google.cloud.firestore_v1.transforms import DELETE_FIELD +def test_documentextractorformerge_apply_merge_list_fields_w_empty_document(): + document_data = {} + inst = _make_document_extractor_for_merge(document_data) - document_data = {"write_me": "value", "delete_me": DELETE_FIELD} - inst = self._make_one(document_data) + with pytest.raises(ValueError): + inst.apply_merge(["nonesuch", "or.this"]) - inst.apply_merge(True) - expected_data_merge = [ - _make_field_path("delete_me"), - _make_field_path("write_me"), - ] - self.assertEqual(inst.data_merge, expected_data_merge) - self.assertEqual(inst.transform_merge, []) - self.assertEqual(inst.merge, expected_data_merge) +def test_documentextractorformerge_apply_merge_list_fields_w_unmerged_delete(): + from google.cloud.firestore_v1.transforms import DELETE_FIELD - def test_apply_merge_all_w_server_timestamp(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + "unmerged_delete": DELETE_FIELD, + } + inst = _make_document_extractor_for_merge(document_data) - document_data = {"write_me": "value", "timestamp": SERVER_TIMESTAMP} - inst = self._make_one(document_data) + with pytest.raises(ValueError): + inst.apply_merge(["write_me", "delete_me"]) - inst.apply_merge(True) - expected_data_merge = [_make_field_path("write_me")] - expected_transform_merge = [_make_field_path("timestamp")] - expected_merge = [_make_field_path("timestamp"), _make_field_path("write_me")] - self.assertEqual(inst.data_merge, expected_data_merge) - self.assertEqual(inst.transform_merge, expected_transform_merge) - self.assertEqual(inst.merge, expected_merge) +def test_documentextractorformerge_apply_merge_list_fields_w_delete(): + from google.cloud.firestore_v1.transforms import DELETE_FIELD - def test_apply_merge_list_fields_w_empty_document(self): - document_data = {} - inst = self._make_one(document_data) + document_data = { + "write_me": "value", + "delete_me": DELETE_FIELD, + "ignore_me": 123, + } + inst = _make_document_extractor_for_merge(document_data) - with self.assertRaises(ValueError): - inst.apply_merge(["nonesuch", "or.this"]) + inst.apply_merge(["write_me", "delete_me"]) - def test_apply_merge_list_fields_w_unmerged_delete(self): - from google.cloud.firestore_v1.transforms import DELETE_FIELD + expected_set_fields = {"write_me": "value"} + expected_deleted_fields = [_make_field_path("delete_me")] + assert inst.set_fields == expected_set_fields + assert inst.deleted_fields == expected_deleted_fields - document_data = { - "write_me": "value", - "delete_me": DELETE_FIELD, - "ignore_me": 123, - "unmerged_delete": DELETE_FIELD, - } - inst = self._make_one(document_data) - with self.assertRaises(ValueError): - inst.apply_merge(["write_me", "delete_me"]) +def test_documentextractorformerge_apply_merge_list_fields_w_prefixes(): - def test_apply_merge_list_fields_w_delete(self): - from google.cloud.firestore_v1.transforms import DELETE_FIELD + document_data = {"a": {"b": {"c": 123}}} + inst = _make_document_extractor_for_merge(document_data) - document_data = { - "write_me": "value", - "delete_me": DELETE_FIELD, - "ignore_me": 123, - } - inst = self._make_one(document_data) + with pytest.raises(ValueError): + inst.apply_merge(["a", "a.b"]) - inst.apply_merge(["write_me", "delete_me"]) - expected_set_fields = {"write_me": "value"} - expected_deleted_fields = [_make_field_path("delete_me")] - self.assertEqual(inst.set_fields, expected_set_fields) - self.assertEqual(inst.deleted_fields, expected_deleted_fields) +def test_documentextractorformerge_apply_merge_lists_w_missing_data_paths(): - def test_apply_merge_list_fields_w_prefixes(self): + document_data = {"write_me": "value", "ignore_me": 123} + inst = _make_document_extractor_for_merge(document_data) - document_data = {"a": {"b": {"c": 123}}} - inst = self._make_one(document_data) + with pytest.raises(ValueError): + inst.apply_merge(["write_me", "nonesuch"]) - with self.assertRaises(ValueError): - inst.apply_merge(["a", "a.b"]) - def test_apply_merge_list_fields_w_missing_data_string_paths(self): +def test_documentextractorformerge_apply_merge_list_fields_w_non_merge_field(): - document_data = {"write_me": "value", "ignore_me": 123} - inst = self._make_one(document_data) + document_data = {"write_me": "value", "ignore_me": 123} + inst = _make_document_extractor_for_merge(document_data) - with self.assertRaises(ValueError): - inst.apply_merge(["write_me", "nonesuch"]) + inst.apply_merge([_make_field_path("write_me")]) - def test_apply_merge_list_fields_w_non_merge_field(self): + expected_set_fields = {"write_me": "value"} + assert inst.set_fields == expected_set_fields - document_data = {"write_me": "value", "ignore_me": 123} - inst = self._make_one(document_data) - inst.apply_merge([_make_field_path("write_me")]) +def test_documentextractorformerge_apply_merge_list_fields_w_server_timestamp(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - expected_set_fields = {"write_me": "value"} - self.assertEqual(inst.set_fields, expected_set_fields) + document_data = { + "write_me": "value", + "timestamp": SERVER_TIMESTAMP, + "ignored_stamp": SERVER_TIMESTAMP, + } + inst = _make_document_extractor_for_merge(document_data) - def test_apply_merge_list_fields_w_server_timestamp(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + inst.apply_merge([_make_field_path("write_me"), _make_field_path("timestamp")]) - document_data = { - "write_me": "value", - "timestamp": SERVER_TIMESTAMP, - "ignored_stamp": SERVER_TIMESTAMP, - } - inst = self._make_one(document_data) - - inst.apply_merge([_make_field_path("write_me"), _make_field_path("timestamp")]) - - expected_data_merge = [_make_field_path("write_me")] - expected_transform_merge = [_make_field_path("timestamp")] - expected_merge = [_make_field_path("timestamp"), _make_field_path("write_me")] - self.assertEqual(inst.data_merge, expected_data_merge) - self.assertEqual(inst.transform_merge, expected_transform_merge) - self.assertEqual(inst.merge, expected_merge) - expected_server_timestamps = [_make_field_path("timestamp")] - self.assertEqual(inst.server_timestamps, expected_server_timestamps) - - def test_apply_merge_list_fields_w_array_remove(self): - from google.cloud.firestore_v1.transforms import ArrayRemove - - values = [2, 4, 8] - document_data = { - "write_me": "value", - "remove_me": ArrayRemove(values), - "ignored_remove_me": ArrayRemove((1, 3, 5)), - } - inst = self._make_one(document_data) - - inst.apply_merge([_make_field_path("write_me"), _make_field_path("remove_me")]) - - expected_data_merge = [_make_field_path("write_me")] - expected_transform_merge = [_make_field_path("remove_me")] - expected_merge = [_make_field_path("remove_me"), _make_field_path("write_me")] - self.assertEqual(inst.data_merge, expected_data_merge) - self.assertEqual(inst.transform_merge, expected_transform_merge) - self.assertEqual(inst.merge, expected_merge) - expected_array_removes = {_make_field_path("remove_me"): values} - self.assertEqual(inst.array_removes, expected_array_removes) - - def test_apply_merge_list_fields_w_array_union(self): - from google.cloud.firestore_v1.transforms import ArrayUnion - - values = [1, 3, 5] - document_data = { - "write_me": "value", - "union_me": ArrayUnion(values), - "ignored_union_me": ArrayUnion((2, 4, 8)), - } - inst = self._make_one(document_data) + expected_data_merge = [_make_field_path("write_me")] + expected_transform_merge = [_make_field_path("timestamp")] + expected_merge = [_make_field_path("timestamp"), _make_field_path("write_me")] + assert inst.data_merge == expected_data_merge + assert inst.transform_merge == expected_transform_merge + assert inst.merge == expected_merge + expected_server_timestamps = [_make_field_path("timestamp")] + assert inst.server_timestamps == expected_server_timestamps - inst.apply_merge([_make_field_path("write_me"), _make_field_path("union_me")]) - expected_data_merge = [_make_field_path("write_me")] - expected_transform_merge = [_make_field_path("union_me")] - expected_merge = [_make_field_path("union_me"), _make_field_path("write_me")] - self.assertEqual(inst.data_merge, expected_data_merge) - self.assertEqual(inst.transform_merge, expected_transform_merge) - self.assertEqual(inst.merge, expected_merge) - expected_array_unions = {_make_field_path("union_me"): values} - self.assertEqual(inst.array_unions, expected_array_unions) +def test_documentextractorformerge_apply_merge_list_fields_w_array_remove(): + from google.cloud.firestore_v1.transforms import ArrayRemove + values = [2, 4, 8] + document_data = { + "write_me": "value", + "remove_me": ArrayRemove(values), + "ignored_remove_me": ArrayRemove((1, 3, 5)), + } + inst = _make_document_extractor_for_merge(document_data) -class Test_pbs_for_set_with_merge(unittest.TestCase): - @staticmethod - def _call_fut(document_path, document_data, merge): - from google.cloud.firestore_v1 import _helpers + inst.apply_merge([_make_field_path("write_me"), _make_field_path("remove_me")]) - return _helpers.pbs_for_set_with_merge( - document_path, document_data, merge=merge - ) + expected_data_merge = [_make_field_path("write_me")] + expected_transform_merge = [_make_field_path("remove_me")] + expected_merge = [_make_field_path("remove_me"), _make_field_path("write_me")] + assert inst.data_merge == expected_data_merge + assert inst.transform_merge == expected_transform_merge + assert inst.merge == expected_merge + expected_array_removes = {_make_field_path("remove_me"): values} + assert inst.array_removes == expected_array_removes - @staticmethod - def _make_write_w_document(document_path, **data): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1._helpers import encode_dict - return write.Write( - update=document.Document(name=document_path, fields=encode_dict(data)) - ) +def test_documentextractorformerge_apply_merge_list_fields_w_array_union(): + from google.cloud.firestore_v1.transforms import ArrayUnion - @staticmethod - def _add_field_transforms(update_pb, fields): - from google.cloud.firestore_v1 import DocumentTransform + values = [1, 3, 5] + document_data = { + "write_me": "value", + "union_me": ArrayUnion(values), + "ignored_union_me": ArrayUnion((2, 4, 8)), + } + inst = _make_document_extractor_for_merge(document_data) - server_val = DocumentTransform.FieldTransform.ServerValue - for field in fields: - update_pb.update_transforms.append( - DocumentTransform.FieldTransform( - field_path=field, set_to_server_value=server_val.REQUEST_TIME - ) - ) + inst.apply_merge([_make_field_path("write_me"), _make_field_path("union_me")]) - @staticmethod - def _update_document_mask(update_pb, field_paths): - from google.cloud.firestore_v1.types import common + expected_data_merge = [_make_field_path("write_me")] + expected_transform_merge = [_make_field_path("union_me")] + expected_merge = [_make_field_path("union_me"), _make_field_path("write_me")] + assert inst.data_merge == expected_data_merge + assert inst.transform_merge == expected_transform_merge + assert inst.merge == expected_merge + expected_array_unions = {_make_field_path("union_me"): values} + assert inst.array_unions == expected_array_unions - update_pb._pb.update_mask.CopyFrom( - common.DocumentMask(field_paths=sorted(field_paths))._pb + +def _make_write_w_document_for_set_w_merge(document_path, **data): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import encode_dict + + return write.Write( + update=document.Document(name=document_path, fields=encode_dict(data)) + ) + + +def _add_field_transforms_for_set_w_merge(update_pb, fields): + from google.cloud.firestore_v1 import DocumentTransform + + server_val = DocumentTransform.FieldTransform.ServerValue + for field in fields: + update_pb.update_transforms.append( + DocumentTransform.FieldTransform( + field_path=field, set_to_server_value=server_val.REQUEST_TIME + ) ) - def test_with_merge_true_wo_transform(self): - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} - write_pbs = self._call_fut(document_path, document_data, merge=True) +def _update_document_mask(update_pb, field_paths): + from google.cloud.firestore_v1.types import common - update_pb = self._make_write_w_document(document_path, **document_data) - self._update_document_mask(update_pb, field_paths=sorted(document_data)) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) + update_pb._pb.update_mask.CopyFrom( + common.DocumentMask(field_paths=sorted(field_paths))._pb + ) - def test_with_merge_field_wo_transform(self): - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"cheese": 1.5, "crackers": True} - write_pbs = self._call_fut(document_path, document_data, merge=["cheese"]) +def test__pbs_for_set_with_merge_w_merge_true_wo_transform(): + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - update_pb = self._make_write_w_document( - document_path, cheese=document_data["cheese"] - ) - self._update_document_mask(update_pb, field_paths=["cheese"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"cheese": 1.5, "crackers": True} - def test_with_merge_true_w_only_transform(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + write_pbs = pbs_for_set_with_merge(document_path, document_data, merge=True) - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - document_data = {"butter": SERVER_TIMESTAMP} + update_pb = _make_write_w_document_for_set_w_merge(document_path, **document_data) + _update_document_mask(update_pb, field_paths=sorted(document_data)) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - write_pbs = self._call_fut(document_path, document_data, merge=True) - update_pb = self._make_write_w_document(document_path) - self._update_document_mask(update_pb, field_paths=()) - self._add_field_transforms(update_pb, fields=["butter"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) +def test__pbs_for_set_with_merge_w_merge_field_wo_transform(): + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - def test_with_merge_true_w_transform(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"cheese": 1.5, "crackers": True} - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - update_data = {"cheese": 1.5, "crackers": True} - document_data = update_data.copy() - document_data["butter"] = SERVER_TIMESTAMP + write_pbs = pbs_for_set_with_merge(document_path, document_data, merge=["cheese"]) - write_pbs = self._call_fut(document_path, document_data, merge=True) + update_pb = _make_write_w_document_for_set_w_merge( + document_path, cheese=document_data["cheese"] + ) + _update_document_mask(update_pb, field_paths=["cheese"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - update_pb = self._make_write_w_document(document_path, **update_data) - self._update_document_mask(update_pb, field_paths=sorted(update_data)) - self._add_field_transforms(update_pb, fields=["butter"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) - def test_with_merge_field_w_transform(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP +def test__pbs_for_set_with_merge_w_merge_true_w_only_transform(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - update_data = {"cheese": 1.5, "crackers": True} - document_data = update_data.copy() - document_data["butter"] = SERVER_TIMESTAMP + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + document_data = {"butter": SERVER_TIMESTAMP} - write_pbs = self._call_fut( - document_path, document_data, merge=["cheese", "butter"] - ) + write_pbs = pbs_for_set_with_merge(document_path, document_data, merge=True) - update_pb = self._make_write_w_document( - document_path, cheese=document_data["cheese"] - ) - self._update_document_mask(update_pb, ["cheese"]) - self._add_field_transforms(update_pb, fields=["butter"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) + update_pb = _make_write_w_document_for_set_w_merge(document_path) + _update_document_mask(update_pb, field_paths=()) + _add_field_transforms_for_set_w_merge(update_pb, fields=["butter"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - def test_with_merge_field_w_transform_masking_simple(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - update_data = {"cheese": 1.5, "crackers": True} - document_data = update_data.copy() - document_data["butter"] = {"pecan": SERVER_TIMESTAMP} +def test__pbs_for_set_with_merge_w_merge_true_w_transform(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - write_pbs = self._call_fut(document_path, document_data, merge=["butter.pecan"]) + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + update_data = {"cheese": 1.5, "crackers": True} + document_data = update_data.copy() + document_data["butter"] = SERVER_TIMESTAMP - update_pb = self._make_write_w_document(document_path) - self._update_document_mask(update_pb, field_paths=()) - self._add_field_transforms(update_pb, fields=["butter.pecan"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) + write_pbs = pbs_for_set_with_merge(document_path, document_data, merge=True) - def test_with_merge_field_w_transform_parent(self): - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + update_pb = _make_write_w_document_for_set_w_merge(document_path, **update_data) + _update_document_mask(update_pb, field_paths=sorted(update_data)) + _add_field_transforms_for_set_w_merge(update_pb, fields=["butter"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - document_path = _make_ref_string(u"little", u"town", u"of", u"ham") - update_data = {"cheese": 1.5, "crackers": True} - document_data = update_data.copy() - document_data["butter"] = {"popcorn": "yum", "pecan": SERVER_TIMESTAMP} - write_pbs = self._call_fut( - document_path, document_data, merge=["cheese", "butter"] - ) +def test__pbs_for_set_with_merge_w_merge_field_w_transform(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - update_pb = self._make_write_w_document( - document_path, cheese=update_data["cheese"], butter={"popcorn": "yum"} - ) - self._update_document_mask(update_pb, ["cheese", "butter"]) - self._add_field_transforms(update_pb, fields=["butter.pecan"]) - expected_pbs = [update_pb] - self.assertEqual(write_pbs, expected_pbs) + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + update_data = {"cheese": 1.5, "crackers": True} + document_data = update_data.copy() + document_data["butter"] = SERVER_TIMESTAMP + + write_pbs = pbs_for_set_with_merge( + document_path, document_data, merge=["cheese", "butter"] + ) + + update_pb = _make_write_w_document_for_set_w_merge( + document_path, cheese=document_data["cheese"] + ) + _update_document_mask(update_pb, ["cheese"]) + _add_field_transforms_for_set_w_merge(update_pb, fields=["butter"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs -class TestDocumentExtractorForUpdate(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1 import _helpers +def test__pbs_for_set_with_merge_w_merge_field_w_transform_masking_simple(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - return _helpers.DocumentExtractorForUpdate + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + update_data = {"cheese": 1.5, "crackers": True} + document_data = update_data.copy() + document_data["butter"] = {"pecan": SERVER_TIMESTAMP} - def _make_one(self, document_data): - return self._get_target_class()(document_data) + write_pbs = pbs_for_set_with_merge( + document_path, document_data, merge=["butter.pecan"] + ) - def test_ctor_w_empty_document(self): - document_data = {} + update_pb = _make_write_w_document_for_set_w_merge(document_path) + _update_document_mask(update_pb, field_paths=()) + _add_field_transforms_for_set_w_merge(update_pb, fields=["butter.pecan"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - inst = self._make_one(document_data) - self.assertEqual(inst.top_level_paths, []) - def test_ctor_w_simple_keys(self): - document_data = {"a": 1, "b": 2, "c": 3} +def test__pbs_for_set_with_merge_w_merge_field_w_transform_parent(): + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_set_with_merge - expected_paths = [ - _make_field_path("a"), - _make_field_path("b"), - _make_field_path("c"), - ] - inst = self._make_one(document_data) - self.assertEqual(inst.top_level_paths, expected_paths) + document_path = _make_ref_string(u"little", u"town", u"of", u"ham") + update_data = {"cheese": 1.5, "crackers": True} + document_data = update_data.copy() + document_data["butter"] = {"popcorn": "yum", "pecan": SERVER_TIMESTAMP} - def test_ctor_w_nested_keys(self): - document_data = {"a": {"d": {"e": 1}}, "b": {"f": 7}, "c": 3} + write_pbs = pbs_for_set_with_merge( + document_path, document_data, merge=["cheese", "butter"] + ) - expected_paths = [ - _make_field_path("a"), - _make_field_path("b"), - _make_field_path("c"), - ] - inst = self._make_one(document_data) - self.assertEqual(inst.top_level_paths, expected_paths) + update_pb = _make_write_w_document_for_set_w_merge( + document_path, cheese=update_data["cheese"], butter={"popcorn": "yum"} + ) + _update_document_mask(update_pb, ["cheese", "butter"]) + _add_field_transforms_for_set_w_merge(update_pb, fields=["butter.pecan"]) + expected_pbs = [update_pb] + assert write_pbs == expected_pbs - def test_ctor_w_dotted_keys(self): - document_data = {"a.d.e": 1, "b.f": 7, "c": 3} - expected_paths = [ - _make_field_path("a", "d", "e"), - _make_field_path("b", "f"), - _make_field_path("c"), - ] - inst = self._make_one(document_data) - self.assertEqual(inst.top_level_paths, expected_paths) +def _make_document_extractor_for_update(document_data): + from google.cloud.firestore_v1._helpers import DocumentExtractorForUpdate + + return DocumentExtractorForUpdate(document_data) + + +def test_documentextractorforupdate_ctor_w_empty_document(): + document_data = {} + + inst = _make_document_extractor_for_update(document_data) + assert inst.top_level_paths == [] + + +def test_documentextractorforupdate_ctor_w_simple_keys(): + document_data = {"a": 1, "b": 2, "c": 3} + + expected_paths = [ + _make_field_path("a"), + _make_field_path("b"), + _make_field_path("c"), + ] + inst = _make_document_extractor_for_update(document_data) + assert inst.top_level_paths == expected_paths + + +def test_documentextractorforupdate_ctor_w_nested_keys(): + document_data = {"a": {"d": {"e": 1}}, "b": {"f": 7}, "c": 3} + + expected_paths = [ + _make_field_path("a"), + _make_field_path("b"), + _make_field_path("c"), + ] + inst = _make_document_extractor_for_update(document_data) + assert inst.top_level_paths == expected_paths + + +def test_documentextractorforupdate_ctor_w_dotted_keys(): + document_data = {"a.d.e": 1, "b.f": 7, "c": 3} - def test_ctor_w_nested_dotted_keys(self): - document_data = {"a.d.e": 1, "b.f": {"h.i": 9}, "c": 3} + expected_paths = [ + _make_field_path("a", "d", "e"), + _make_field_path("b", "f"), + _make_field_path("c"), + ] + inst = _make_document_extractor_for_update(document_data) + assert inst.top_level_paths == expected_paths - expected_paths = [ - _make_field_path("a", "d", "e"), - _make_field_path("b", "f"), - _make_field_path("c"), + +def test_documentextractorforupdate_ctor_w_nested_dotted_keys(): + document_data = {"a.d.e": 1, "b.f": {"h.i": 9}, "c": 3} + + expected_paths = [ + _make_field_path("a", "d", "e"), + _make_field_path("b", "f"), + _make_field_path("c"), + ] + expected_set_fields = {"a": {"d": {"e": 1}}, "b": {"f": {"h.i": 9}}, "c": 3} + inst = _make_document_extractor_for_update(document_data) + assert inst.top_level_paths == expected_paths + assert inst.set_fields == expected_set_fields + + +def _pbs_for_update_helper(option=None, do_transform=False, **write_kwargs): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.field_path import FieldPath + from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP + from google.cloud.firestore_v1 import DocumentTransform + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import pbs_for_update + + document_path = _make_ref_string(u"toy", u"car", u"onion", u"garlic") + field_path1 = "bitez.yum" + value = b"\x00\x01" + field_path2 = "blog.internet" + + field_updates = {field_path1: value} + if do_transform: + field_updates[field_path2] = SERVER_TIMESTAMP + + write_pbs = pbs_for_update(document_path, field_updates, option) + + map_pb = document.MapValue(fields={"yum": _value_pb(bytes_value=value)}) + + field_paths = [field_path1] + + expected_update_pb = write.Write( + update=document.Document( + name=document_path, fields={"bitez": _value_pb(map_value=map_pb)} + ), + update_mask=common.DocumentMask(field_paths=field_paths), + **write_kwargs + ) + if isinstance(option, _helpers.ExistsOption): + precondition = common.Precondition(exists=False) + expected_update_pb._pb.current_document.CopyFrom(precondition._pb) + + if do_transform: + transform_paths = FieldPath.from_string(field_path2) + server_val = DocumentTransform.FieldTransform.ServerValue + field_transform_pbs = [ + write.DocumentTransform.FieldTransform( + field_path=transform_paths.to_api_repr(), + set_to_server_value=server_val.REQUEST_TIME, + ) ] - expected_set_fields = {"a": {"d": {"e": 1}}, "b": {"f": {"h.i": 9}}, "c": 3} - inst = self._make_one(document_data) - self.assertEqual(inst.top_level_paths, expected_paths) - self.assertEqual(inst.set_fields, expected_set_fields) + expected_update_pb.update_transforms.extend(field_transform_pbs) + assert write_pbs == [expected_update_pb] -class Test_pbs_for_update(unittest.TestCase): - @staticmethod - def _call_fut(document_path, field_updates, option): - from google.cloud.firestore_v1._helpers import pbs_for_update - return pbs_for_update(document_path, field_updates, option) +def test__pbs_for_update_wo_option(): + from google.cloud.firestore_v1.types import common - def _helper(self, option=None, do_transform=False, **write_kwargs): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.field_path import FieldPath - from google.cloud.firestore_v1.transforms import SERVER_TIMESTAMP - from google.cloud.firestore_v1 import DocumentTransform - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write + precondition = common.Precondition(exists=True) + _pbs_for_update_helper(current_document=precondition) - document_path = _make_ref_string(u"toy", u"car", u"onion", u"garlic") - field_path1 = "bitez.yum" - value = b"\x00\x01" - field_path2 = "blog.internet" - field_updates = {field_path1: value} - if do_transform: - field_updates[field_path2] = SERVER_TIMESTAMP +def test__pbs_for_update_w__exists_option(): + from google.cloud.firestore_v1 import _helpers - write_pbs = self._call_fut(document_path, field_updates, option) + option = _helpers.ExistsOption(False) + _pbs_for_update_helper(option=option) - map_pb = document.MapValue(fields={"yum": _value_pb(bytes_value=value)}) - field_paths = [field_path1] +def test__pbs_for_update_w_update_and_transform(): + from google.cloud.firestore_v1.types import common - expected_update_pb = write.Write( - update=document.Document( - name=document_path, fields={"bitez": _value_pb(map_value=map_pb)} - ), - update_mask=common.DocumentMask(field_paths=field_paths), - **write_kwargs - ) - if isinstance(option, _helpers.ExistsOption): - precondition = common.Precondition(exists=False) - expected_update_pb._pb.current_document.CopyFrom(precondition._pb) + precondition = common.Precondition(exists=True) + _pbs_for_update_helper(current_document=precondition, do_transform=True) + + +def _pb_for_delete_helper(option=None, **write_kwargs): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import pb_for_delete + + document_path = _make_ref_string(u"chicken", u"philly", u"one", u"two") + write_pb = pb_for_delete(document_path, option) + + expected_pb = write.Write(delete=document_path, **write_kwargs) + assert write_pb == expected_pb + + +def test__pb_for_delete_wo_option(): + _pb_for_delete_helper() + + +def test__pb_for_delete_w_option(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1 import _helpers + + update_time = timestamp_pb2.Timestamp(seconds=1309700594, nanos=822211297) + option = _helpers.LastUpdateOption(update_time) + precondition = common.Precondition(update_time=update_time) + _pb_for_delete_helper(option=option, current_document=precondition) + + +def test_get_transaction_id_w_no_transaction(): + from google.cloud.firestore_v1._helpers import get_transaction_id + + ret_val = get_transaction_id(None) + assert ret_val is None - if do_transform: - transform_paths = FieldPath.from_string(field_path2) - server_val = DocumentTransform.FieldTransform.ServerValue - field_transform_pbs = [ - write.DocumentTransform.FieldTransform( - field_path=transform_paths.to_api_repr(), - set_to_server_value=server_val.REQUEST_TIME, - ) - ] - expected_update_pb.update_transforms.extend(field_transform_pbs) - self.assertEqual(write_pbs, [expected_update_pb]) +def test_get_transaction_id_w_invalid_transaction(): + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1._helpers import get_transaction_id - def test_without_option(self): - from google.cloud.firestore_v1.types import common + transaction = Transaction(mock.sentinel.client) + assert not transaction.in_progress + with pytest.raises(ValueError): + get_transaction_id(transaction) - precondition = common.Precondition(exists=True) - self._helper(current_document=precondition) - def test_with_exists_option(self): - from google.cloud.firestore_v1 import _helpers +def test_get_transaction_id_w_after_writes_not_allowed(): + from google.cloud.firestore_v1._helpers import ReadAfterWriteError + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1._helpers import get_transaction_id - option = _helpers.ExistsOption(False) - self._helper(option=option) + transaction = Transaction(mock.sentinel.client) + transaction._id = b"under-hook" + transaction._write_pbs.append(mock.sentinel.write) - def test_update_and_transform(self): - from google.cloud.firestore_v1.types import common + with pytest.raises(ReadAfterWriteError): + get_transaction_id(transaction) - precondition = common.Precondition(exists=True) - self._helper(current_document=precondition, do_transform=True) +def test_get_transaction_id_w_after_writes_allowed(): + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1._helpers import get_transaction_id -class Test_pb_for_delete(unittest.TestCase): - @staticmethod - def _call_fut(document_path, option): - from google.cloud.firestore_v1._helpers import pb_for_delete + transaction = Transaction(mock.sentinel.client) + txn_id = b"we-are-0fine" + transaction._id = txn_id + transaction._write_pbs.append(mock.sentinel.write) - return pb_for_delete(document_path, option) + ret_val = get_transaction_id(transaction, read_operation=False) + assert ret_val == txn_id - def _helper(self, option=None, **write_kwargs): - from google.cloud.firestore_v1.types import write - document_path = _make_ref_string(u"chicken", u"philly", u"one", u"two") - write_pb = self._call_fut(document_path, option) +def test_get_transaction_id_w_good_transaction(): + from google.cloud.firestore_v1.transaction import Transaction + from google.cloud.firestore_v1._helpers import get_transaction_id - expected_pb = write.Write(delete=document_path, **write_kwargs) - self.assertEqual(write_pb, expected_pb) + transaction = Transaction(mock.sentinel.client) + txn_id = b"doubt-it" + transaction._id = txn_id + assert transaction.in_progress - def test_without_option(self): - self._helper() + assert get_transaction_id(transaction) == txn_id - def test_with_option(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1 import _helpers - update_time = timestamp_pb2.Timestamp(seconds=1309700594, nanos=822211297) - option = _helpers.LastUpdateOption(update_time) - precondition = common.Precondition(update_time=update_time) - self._helper(option=option, current_document=precondition) +def test_metadata_with_prefix(): + from google.cloud.firestore_v1._helpers import metadata_with_prefix + database_string = u"projects/prahj/databases/dee-bee" + metadata = metadata_with_prefix(database_string) -class Test_get_transaction_id(unittest.TestCase): - @staticmethod - def _call_fut(transaction, **kwargs): - from google.cloud.firestore_v1._helpers import get_transaction_id + assert metadata == [("google-cloud-resource-prefix", database_string)] - return get_transaction_id(transaction, **kwargs) - def test_no_transaction(self): - ret_val = self._call_fut(None) - self.assertIsNone(ret_val) +def test_writeoption_modify_write(): + from google.cloud.firestore_v1._helpers import WriteOption - def test_invalid_transaction(self): - from google.cloud.firestore_v1.transaction import Transaction + option = WriteOption() + with pytest.raises(NotImplementedError): + option.modify_write(None) - transaction = Transaction(mock.sentinel.client) - self.assertFalse(transaction.in_progress) - with self.assertRaises(ValueError): - self._call_fut(transaction) - def test_after_writes_not_allowed(self): - from google.cloud.firestore_v1._helpers import ReadAfterWriteError - from google.cloud.firestore_v1.transaction import Transaction +def test_lastupdateoption_constructor(): + from google.cloud.firestore_v1._helpers import LastUpdateOption - transaction = Transaction(mock.sentinel.client) - transaction._id = b"under-hook" - transaction._write_pbs.append(mock.sentinel.write) + option = LastUpdateOption(mock.sentinel.timestamp) + assert option._last_update_time is mock.sentinel.timestamp - with self.assertRaises(ReadAfterWriteError): - self._call_fut(transaction) - def test_after_writes_allowed(self): - from google.cloud.firestore_v1.transaction import Transaction +def test_lastupdateoption___eq___different_type(): + from google.cloud.firestore_v1._helpers import LastUpdateOption - transaction = Transaction(mock.sentinel.client) - txn_id = b"we-are-0fine" - transaction._id = txn_id - transaction._write_pbs.append(mock.sentinel.write) + option = LastUpdateOption(mock.sentinel.timestamp) + other = object() + assert not option == other - ret_val = self._call_fut(transaction, read_operation=False) - self.assertEqual(ret_val, txn_id) - def test_good_transaction(self): - from google.cloud.firestore_v1.transaction import Transaction +def test_lastupdateoption___eq___different_timestamp(): + from google.cloud.firestore_v1._helpers import LastUpdateOption - transaction = Transaction(mock.sentinel.client) - txn_id = b"doubt-it" - transaction._id = txn_id - self.assertTrue(transaction.in_progress) + option = LastUpdateOption(mock.sentinel.timestamp) + other = LastUpdateOption(mock.sentinel.other_timestamp) + assert not option == other - self.assertEqual(self._call_fut(transaction), txn_id) +def test_lastupdateoption___eq___same_timestamp(): + from google.cloud.firestore_v1._helpers import LastUpdateOption -class Test_metadata_with_prefix(unittest.TestCase): - @staticmethod - def _call_fut(database_string): - from google.cloud.firestore_v1._helpers import metadata_with_prefix + option = LastUpdateOption(mock.sentinel.timestamp) + other = LastUpdateOption(mock.sentinel.timestamp) + assert option == other - return metadata_with_prefix(database_string) - def test_it(self): - database_string = u"projects/prahj/databases/dee-bee" - metadata = self._call_fut(database_string) +def test_lastupdateoption_modify_write_update_time(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import LastUpdateOption - self.assertEqual(metadata, [("google-cloud-resource-prefix", database_string)]) + timestamp_pb = timestamp_pb2.Timestamp(seconds=683893592, nanos=229362000) + option = LastUpdateOption(timestamp_pb) + write_pb = write.Write() + ret_val = option.modify_write(write_pb) + assert ret_val is None + expected_doc = common.Precondition(update_time=timestamp_pb) + assert write_pb.current_document == expected_doc -class TestWriteOption(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1._helpers import WriteOption - return WriteOption +def test_existsoption_constructor(): + from google.cloud.firestore_v1._helpers import ExistsOption - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) + option = ExistsOption(mock.sentinel.totes_bool) + assert option._exists is mock.sentinel.totes_bool - def test_modify_write(self): - option = self._make_one() - with self.assertRaises(NotImplementedError): - option.modify_write(None) +def test_existsoption___eq___different_type(): + from google.cloud.firestore_v1._helpers import ExistsOption -class TestLastUpdateOption(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1._helpers import LastUpdateOption + option = ExistsOption(mock.sentinel.timestamp) + other = object() + assert not option == other - return LastUpdateOption - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) +def test_existsoption___eq___different_exists(): + from google.cloud.firestore_v1._helpers import ExistsOption - def test_constructor(self): - option = self._make_one(mock.sentinel.timestamp) - self.assertIs(option._last_update_time, mock.sentinel.timestamp) + option = ExistsOption(True) + other = ExistsOption(False) + assert not option == other - def test___eq___different_type(self): - option = self._make_one(mock.sentinel.timestamp) - other = object() - self.assertFalse(option == other) - def test___eq___different_timestamp(self): - option = self._make_one(mock.sentinel.timestamp) - other = self._make_one(mock.sentinel.other_timestamp) - self.assertFalse(option == other) +def test_existsoption___eq___same_exists(): + from google.cloud.firestore_v1._helpers import ExistsOption - def test___eq___same_timestamp(self): - option = self._make_one(mock.sentinel.timestamp) - other = self._make_one(mock.sentinel.timestamp) - self.assertTrue(option == other) + option = ExistsOption(True) + other = ExistsOption(True) + assert option == other - def test_modify_write_update_time(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import write - timestamp_pb = timestamp_pb2.Timestamp(seconds=683893592, nanos=229362000) - option = self._make_one(timestamp_pb) +def test_existsoption_modify_write(): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1._helpers import ExistsOption + + for exists in (True, False): + option = ExistsOption(exists) write_pb = write.Write() ret_val = option.modify_write(write_pb) - self.assertIsNone(ret_val) - expected_doc = common.Precondition(update_time=timestamp_pb) - self.assertEqual(write_pb.current_document, expected_doc) - + assert ret_val is None + expected_doc = common.Precondition(exists=exists) + assert write_pb.current_document == expected_doc -class TestExistsOption(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1._helpers import ExistsOption - return ExistsOption +def test_make_retry_timeout_kwargs_default(): + from google.api_core.gapic_v1.method import DEFAULT + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) + kwargs = make_retry_timeout_kwargs(DEFAULT, None) + expected = {} + assert kwargs == expected - def test_constructor(self): - option = self._make_one(mock.sentinel.totes_bool) - self.assertIs(option._exists, mock.sentinel.totes_bool) - def test___eq___different_type(self): - option = self._make_one(mock.sentinel.timestamp) - other = object() - self.assertFalse(option == other) +def test_make_retry_timeout_kwargs_retry_None(): + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs - def test___eq___different_exists(self): - option = self._make_one(True) - other = self._make_one(False) - self.assertFalse(option == other) + kwargs = make_retry_timeout_kwargs(None, None) + expected = {"retry": None} + assert kwargs == expected - def test___eq___same_exists(self): - option = self._make_one(True) - other = self._make_one(True) - self.assertTrue(option == other) - def test_modify_write(self): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import write +def test_make_retry_timeout_kwargs_retry_only(): + from google.api_core.retry import Retry + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs - for exists in (True, False): - option = self._make_one(exists) - write_pb = write.Write() - ret_val = option.modify_write(write_pb) + retry = Retry(predicate=object()) + kwargs = make_retry_timeout_kwargs(retry, None) + expected = {"retry": retry} + assert kwargs == expected - self.assertIsNone(ret_val) - expected_doc = common.Precondition(exists=exists) - self.assertEqual(write_pb.current_document, expected_doc) +def test_make_retry_timeout_kwargs_timeout_only(): + from google.api_core.gapic_v1.method import DEFAULT + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs -class Test_make_retry_timeout_kwargs(unittest.TestCase): - @staticmethod - def _call_fut(retry, timeout): - from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs + timeout = 123.0 + kwargs = make_retry_timeout_kwargs(DEFAULT, timeout) + expected = {"timeout": timeout} + assert kwargs == expected - return make_retry_timeout_kwargs(retry, timeout) - def test_default(self): - from google.api_core.gapic_v1.method import DEFAULT +def test_make_retry_timeout_kwargs_retry_and_timeout(): + from google.api_core.retry import Retry + from google.cloud.firestore_v1._helpers import make_retry_timeout_kwargs - kwargs = self._call_fut(DEFAULT, None) - expected = {} - self.assertEqual(kwargs, expected) + retry = Retry(predicate=object()) + timeout = 123.0 + kwargs = make_retry_timeout_kwargs(retry, timeout) + expected = {"retry": retry, "timeout": timeout} + assert kwargs == expected - def test_retry_None(self): - kwargs = self._call_fut(None, None) - expected = {"retry": None} - self.assertEqual(kwargs, expected) - def test_retry_only(self): - from google.api_core.retry import Retry +@pytest.mark.asyncio +async def test_asyncgenerator_async_iter(): + from typing import List - retry = Retry(predicate=object()) - kwargs = self._call_fut(retry, None) - expected = {"retry": retry} - self.assertEqual(kwargs, expected) + consumed: List[int] = [] + async for el in AsyncIter([1, 2, 3]): + consumed.append(el) + assert consumed == [1, 2, 3] - def test_timeout_only(self): - from google.api_core.gapic_v1.method import DEFAULT - timeout = 123.0 - kwargs = self._call_fut(DEFAULT, timeout) - expected = {"timeout": timeout} - self.assertEqual(kwargs, expected) +class AsyncMock(mock.MagicMock): + async def __call__(self, *args, **kwargs): + return super(AsyncMock, self).__call__(*args, **kwargs) - def test_retry_and_timeout(self): - from google.api_core.retry import Retry - retry = Retry(predicate=object()) - timeout = 123.0 - kwargs = self._call_fut(retry, timeout) - expected = {"retry": retry, "timeout": timeout} - self.assertEqual(kwargs, expected) +class AsyncIter: + """Utility to help recreate the effect of an async generator. Useful when + you need to mock a system that requires `async for`. + """ + def __init__(self, items): + self.items = items -class TestAsyncGenerator(aiounittest.AsyncTestCase): - @pytest.mark.asyncio - async def test_async_iter(self): - consumed: List[int] = [] - async for el in AsyncIter([1, 2, 3]): - consumed.append(el) - self.assertEqual(consumed, [1, 2, 3]) + async def __aiter__(self): + for i in self.items: + yield i def _value_pb(**kwargs): diff --git a/tests/unit/v1/test_async_batch.py b/tests/unit/v1/test_async_batch.py index 39f0d539141d4..6bed2351b3314 100644 --- a/tests/unit/v1/test_async_batch.py +++ b/tests/unit/v1/test_async_batch.py @@ -12,155 +12,150 @@ # See the License for the specific language governing permissions and # limitations under the License. +import mock import pytest -import aiounittest -import mock from tests.unit.v1.test__helpers import AsyncMock -class TestAsyncWriteBatch(aiounittest.AsyncTestCase): - """Tests the AsyncWriteBatch.commit method""" - - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_batch import AsyncWriteBatch - - return AsyncWriteBatch - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - batch = self._make_one(mock.sentinel.client) - self.assertIs(batch._client, mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - self.assertIsNone(batch.write_results) - self.assertIsNone(batch.commit_time) - - async def _commit_helper(self, retry=None, timeout=None): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock(spec=["commit"]) - timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore.CommitResponse( - write_results=[write.WriteResult(), write.WriteResult()], - commit_time=timestamp, - ) - firestore_api.commit.return_value = commit_response - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Attach the fake GAPIC to a real client. - client = _make_client("grand") - client._firestore_api_internal = firestore_api - - # Actually make a batch with some mutations and call commit(). - batch = self._make_one(client) - document1 = client.document("a", "b") - batch.create(document1, {"ten": 10, "buck": "ets"}) - document2 = client.document("c", "d", "e", "f") - batch.delete(document2) +def _make_async_write_batch(client): + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + + return AsyncWriteBatch(client) + + +def test_constructor(): + batch = _make_async_write_batch(mock.sentinel.client) + assert batch._client is mock.sentinel.client + assert batch._write_pbs == [] + assert batch.write_results is None + assert batch.commit_time is None + + +async def _commit_helper(retry=None, timeout=None): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = _make_async_write_batch(client) + document1 = client.document("a", "b") + batch.create(document1, {"ten": 10, "buck": "ets"}) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + write_results = await batch.commit(**kwargs) + + assert write_results == list(commit_response.write_results) + assert batch.write_results == write_results + assert batch.commit_time.timestamp_pb() == timestamp + # Make sure batch has no more "changes". + assert batch._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_commit(): + await _commit_helper() + + +@pytest.mark.asyncio +async def test_commit_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + await _commit_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_as_context_mgr_wo_error(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + firestore_api = AsyncMock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + client = _make_client() + client._firestore_api_internal = firestore_api + batch = _make_async_write_batch(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + async with batch as ctx_mgr: + assert ctx_mgr is batch + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) + ctx_mgr.delete(document2) write_pbs = batch._write_pbs[::] - write_results = await batch.commit(**kwargs) - - self.assertEqual(write_results, list(commit_response.write_results)) - self.assertEqual(batch.write_results, write_results) - self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) - # Make sure batch has no more "changes". - self.assertEqual(batch._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - - @pytest.mark.asyncio - async def test_commit(self): - await self._commit_helper() - - @pytest.mark.asyncio - async def test_commit_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - - await self._commit_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_as_context_mgr_wo_error(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - firestore_api = AsyncMock(spec=["commit"]) - timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore.CommitResponse( - write_results=[write.WriteResult(), write.WriteResult()], - commit_time=timestamp, - ) - firestore_api.commit.return_value = commit_response - client = _make_client() - client._firestore_api_internal = firestore_api - batch = self._make_one(client) - document1 = client.document("a", "b") - document2 = client.document("c", "d", "e", "f") - + assert batch.write_results == list(commit_response.write_results) + assert batch.commit_time.timestamp_pb() == timestamp + # Make sure batch has no more "changes". + assert batch._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_as_context_mgr_w_error(): + firestore_api = AsyncMock(spec=["commit"]) + client = _make_client() + client._firestore_api_internal = firestore_api + batch = _make_async_write_batch(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + with pytest.raises(RuntimeError): async with batch as ctx_mgr: - self.assertIs(ctx_mgr, batch) ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) - write_pbs = batch._write_pbs[::] - - self.assertEqual(batch.write_results, list(commit_response.write_results)) - self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) - # Make sure batch has no more "changes". - self.assertEqual(batch._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test_as_context_mgr_w_error(self): - firestore_api = AsyncMock(spec=["commit"]) - client = _make_client() - client._firestore_api_internal = firestore_api - batch = self._make_one(client) - document1 = client.document("a", "b") - document2 = client.document("c", "d", "e", "f") - - with self.assertRaises(RuntimeError): - async with batch as ctx_mgr: - ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) - ctx_mgr.delete(document2) - raise RuntimeError("testing") - - # batch still has its changes, as _aexit_ (and commit) is not invoked - # changes are preserved so commit can be retried - self.assertIsNone(batch.write_results) - self.assertIsNone(batch.commit_time) - self.assertEqual(len(batch._write_pbs), 2) - - firestore_api.commit.assert_not_called() + raise RuntimeError("testing") + + # batch still has its changes, as _aexit_ (and commit) is not invoked + # changes are preserved so commit can be retried + assert batch.write_results is None + assert batch.commit_time is None + assert len(batch._write_pbs) == 2 + + firestore_api.commit.assert_not_called() def _make_credentials(): diff --git a/tests/unit/v1/test_async_client.py b/tests/unit/v1/test_async_client.py index 6d8c57c389c8c..3af0ef6d38fb3 100644 --- a/tests/unit/v1/test_async_client.py +++ b/tests/unit/v1/test_async_client.py @@ -12,495 +12,523 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import datetime import types -import aiounittest import mock -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse -from tests.unit.v1.test__helpers import AsyncIter, AsyncMock +import pytest +from tests.unit.v1.test__helpers import AsyncIter +from tests.unit.v1.test__helpers import AsyncMock -class TestAsyncClient(aiounittest.AsyncTestCase): - PROJECT = "my-prahjekt" +PROJECT = "my-prahjekt" - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_client import AsyncClient - return AsyncClient +def _make_async_client(*args, **kwargs): + from google.cloud.firestore_v1.async_client import AsyncClient - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) + return AsyncClient(*args, **kwargs) - def _make_default_one(self): - credentials = _make_credentials() - return self._make_one(project=self.PROJECT, credentials=credentials) - def test_constructor(self): - from google.cloud.firestore_v1.async_client import _CLIENT_INFO - from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE +def _make_default_async_client(): + credentials = _make_credentials() + return _make_async_client(project=PROJECT, credentials=credentials) - credentials = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=credentials) - self.assertEqual(client.project, self.PROJECT) - self.assertEqual(client._credentials, credentials) - self.assertEqual(client._database, DEFAULT_DATABASE) - self.assertIs(client._client_info, _CLIENT_INFO) - def test_constructor_explicit(self): - from google.api_core.client_options import ClientOptions +def test_asyncclient_constructor(): + from google.cloud.firestore_v1.async_client import _CLIENT_INFO + from google.cloud.firestore_v1.async_client import DEFAULT_DATABASE - credentials = _make_credentials() - database = "now-db" - client_info = mock.Mock() - client_options = ClientOptions("endpoint") - client = self._make_one( - project=self.PROJECT, - credentials=credentials, - database=database, - client_info=client_info, - client_options=client_options, - ) - self.assertEqual(client.project, self.PROJECT) - self.assertEqual(client._credentials, credentials) - self.assertEqual(client._database, database) - self.assertIs(client._client_info, client_info) - self.assertIs(client._client_options, client_options) - - def test_constructor_w_client_options(self): - credentials = _make_credentials() - client = self._make_one( - project=self.PROJECT, - credentials=credentials, - client_options={"api_endpoint": "foo-firestore.googleapis.com"}, - ) - self.assertEqual(client._target, "foo-firestore.googleapis.com") + credentials = _make_credentials() + client = _make_async_client(project=PROJECT, credentials=credentials) + assert client.project == PROJECT + assert client._credentials == credentials + assert client._database == DEFAULT_DATABASE + assert client._client_info is _CLIENT_INFO - def test_collection_factory(self): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - collection_id = "users" - client = self._make_default_one() - collection = client.collection(collection_id) +def test_asyncclient_constructor_explicit(): + from google.api_core.client_options import ClientOptions - self.assertEqual(collection._path, (collection_id,)) - self.assertIs(collection._client, client) - self.assertIsInstance(collection, AsyncCollectionReference) + credentials = _make_credentials() + database = "now-db" + client_info = mock.Mock() + client_options = ClientOptions("endpoint") + client = _make_async_client( + project=PROJECT, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, + ) + assert client.project == PROJECT + assert client._credentials == credentials + assert client._database == database + assert client._client_info is client_info + assert client._client_options is client_options + + +def test_asyncclient_constructor_w_client_options(): + credentials = _make_credentials() + client = _make_async_client( + project=PROJECT, + credentials=credentials, + client_options={"api_endpoint": "foo-firestore.googleapis.com"}, + ) + assert client._target == "foo-firestore.googleapis.com" - def test_collection_factory_nested(self): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - client = self._make_default_one() - parts = ("users", "alovelace", "beep") - collection_path = "/".join(parts) - collection1 = client.collection(collection_path) +def test_asyncclient_collection_factory(): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - self.assertEqual(collection1._path, parts) - self.assertIs(collection1._client, client) - self.assertIsInstance(collection1, AsyncCollectionReference) + collection_id = "users" + client = _make_default_async_client() + collection = client.collection(collection_id) - # Make sure using segments gives the same result. - collection2 = client.collection(*parts) - self.assertEqual(collection2._path, parts) - self.assertIs(collection2._client, client) - self.assertIsInstance(collection2, AsyncCollectionReference) + assert collection._path == (collection_id,) + assert collection._client is client + assert isinstance(collection, AsyncCollectionReference) - def test__get_collection_reference(self): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - client = self._make_default_one() - collection = client._get_collection_reference("collectionId") +def test_asyncclient_collection_factory_nested(): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - self.assertIs(collection._client, client) - self.assertIsInstance(collection, AsyncCollectionReference) + client = _make_default_async_client() + parts = ("users", "alovelace", "beep") + collection_path = "/".join(parts) + collection1 = client.collection(collection_path) - def test_collection_group(self): - client = self._make_default_one() - query = client.collection_group("collectionId").where("foo", "==", "bar") + assert collection1._path == parts + assert collection1._client is client + assert isinstance(collection1, AsyncCollectionReference) - self.assertTrue(query._all_descendants) - self.assertEqual(query._field_filters[0].field.field_path, "foo") - self.assertEqual(query._field_filters[0].value.string_value, "bar") - self.assertEqual( - query._field_filters[0].op, query._field_filters[0].Operator.EQUAL - ) - self.assertEqual(query._parent.id, "collectionId") - - def test_collection_group_no_slashes(self): - client = self._make_default_one() - with self.assertRaises(ValueError): - client.collection_group("foo/bar") - - def test_document_factory(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - parts = ("rooms", "roomA") - client = self._make_default_one() - doc_path = "/".join(parts) - document1 = client.document(doc_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, AsyncDocumentReference) - - # Make sure using segments gives the same result. - document2 = client.document(*parts) - self.assertEqual(document2._path, parts) - self.assertIs(document2._client, client) - self.assertIsInstance(document2, AsyncDocumentReference) - - def test_document_factory_w_absolute_path(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - parts = ("rooms", "roomA") - client = self._make_default_one() - doc_path = "/".join(parts) - to_match = client.document(doc_path) - document1 = client.document(to_match._document_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, AsyncDocumentReference) - - def test_document_factory_w_nested_path(self): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - - client = self._make_default_one() - parts = ("rooms", "roomA", "shoes", "dressy") - doc_path = "/".join(parts) - document1 = client.document(doc_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, AsyncDocumentReference) - - # Make sure using segments gives the same result. - document2 = client.document(*parts) - self.assertEqual(document2._path, parts) - self.assertIs(document2._client, client) - self.assertIsInstance(document2, AsyncDocumentReference) - - async def _collections_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - from google.cloud.firestore_v1 import _helpers - - collection_ids = ["users", "projects"] - - class Pager(object): - async def __aiter__(self, **_): - for collection_id in collection_ids: - yield collection_id - - firestore_api = AsyncMock() - firestore_api.mock_add_spec(spec=["list_collection_ids"]) - firestore_api.list_collection_ids.return_value = Pager() - - client = self._make_default_one() - client._firestore_api_internal = firestore_api - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - collections = [c async for c in client.collections(**kwargs)] - - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, AsyncCollectionReference) - self.assertEqual(collection.parent, None) - self.assertEqual(collection.id, collection_id) - - base_path = client._database_string + "/documents" - firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, - ) + # Make sure using segments gives the same result. + collection2 = client.collection(*parts) + assert collection2._path == parts + assert collection2._client is client + assert isinstance(collection2, AsyncCollectionReference) - @pytest.mark.asyncio - async def test_collections(self): - await self._collections_helper() - - @pytest.mark.asyncio - async def test_collections_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._collections_helper(retry=retry, timeout=timeout) - - async def _invoke_get_all(self, client, references, document_pbs, **kwargs): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["batch_get_documents"]) - response_iterator = AsyncIter(document_pbs) - firestore_api.batch_get_documents.return_value = response_iterator - - # Attach the fake GAPIC to a real client. - client._firestore_api_internal = firestore_api - - # Actually call get_all(). - snapshots = client.get_all(references, **kwargs) - self.assertIsInstance(snapshots, types.AsyncGeneratorType) - - return [s async for s in snapshots] - - async def _get_all_helper( - self, num_snapshots=2, txn_id=None, retry=None, timeout=None - ): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.async_document import DocumentSnapshot - - client = self._make_default_one() - - data1 = {"a": "cheese"} - document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) - - data2 = {"b": True, "c": 18} - document2 = client.document("pineapple", "lamp2") - document, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document, read_time=read_time) - - document3 = client.document("pineapple", "lamp3") - response3 = _make_batch_response(missing=document3._document_path) - - expected_data = [data1, data2, None][:num_snapshots] - documents = [document1, document2, document3][:num_snapshots] - responses = [response1, response2, response3][:num_snapshots] - field_paths = [ - field_path for field_path in ["a", "b", None][:num_snapshots] if field_path - ] - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - if txn_id is not None: - transaction = client.transaction() - transaction._id = txn_id - kwargs["transaction"] = transaction - - snapshots = await self._invoke_get_all( - client, documents, responses, field_paths=field_paths, **kwargs, - ) - self.assertEqual(len(snapshots), num_snapshots) - - for data, document, snapshot in zip(expected_data, documents, snapshots): - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, document) - if data is None: - self.assertFalse(snapshot.exists) - else: - self.assertEqual(snapshot._data, data) - - # Verify the call to the mock. - doc_paths = [document._document_path for document in documents] - mask = common.DocumentMask(field_paths=field_paths) - - kwargs.pop("transaction", None) - - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": mask, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +def test_asyncclient__get_collection_reference(): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - @pytest.mark.asyncio - async def test_get_all(self): - await self._get_all_helper() - - @pytest.mark.asyncio - async def test_get_all_with_transaction(self): - txn_id = b"the-man-is-non-stop" - await self._get_all_helper(num_snapshots=1, txn_id=txn_id) - - @pytest.mark.asyncio - async def test_get_all_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_all_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_get_all_wrong_order(self): - await self._get_all_helper(num_snapshots=3) - - @pytest.mark.asyncio - async def test_get_all_unknown_result(self): - from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE - - client = self._make_default_one() - - expected_document = client.document("pineapple", "lamp1") - - data = {"z": 28.5} - wrong_document = client.document("pineapple", "lamp2") - document_pb, read_time = _doc_get_info(wrong_document._document_path, data) - response = _make_batch_response(found=document_pb, read_time=read_time) - - # Exercise the mocked ``batch_get_documents``. - with self.assertRaises(ValueError) as exc_info: - await self._invoke_get_all(client, [expected_document], [response]) - - err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - # Verify the call to the mock. - doc_paths = [expected_document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + client = _make_default_async_client() + collection = client._get_collection_reference("collectionId") - def test_bulk_writer(self): - """BulkWriter is opaquely async and thus does not have a dedicated - async variant.""" - from google.cloud.firestore_v1.bulk_writer import BulkWriter - - client = self._make_default_one() - bulk_writer = client.bulk_writer() - self.assertIsInstance(bulk_writer, BulkWriter) - self.assertIs(bulk_writer._client, client._sync_copy) - - def test_sync_copy(self): - client = self._make_default_one() - # Multiple calls to this method should return the same cached instance. - self.assertIs(client._to_sync_copy(), client._to_sync_copy()) - - @pytest.mark.asyncio - async def test_recursive_delete(self): - client = self._make_default_one() - client._firestore_api_internal = AsyncMock(spec=["run_query"]) - collection_ref = client.collection("my_collection") - - results = [] - for index in range(10): - results.append( - RunQueryResponse(document=Document(name=f"{collection_ref.id}/{index}")) - ) + assert collection._client is client + assert isinstance(collection, AsyncCollectionReference) - chunks = [ - results[:3], - results[3:6], - results[6:9], - results[9:], - ] - def _get_chunk(*args, **kwargs): - return AsyncIter(items=chunks.pop(0)) +def test_asyncclient_collection_group(): + client = _make_default_async_client() + query = client.collection_group("collectionId").where("foo", "==", "bar") - client._firestore_api_internal.run_query.side_effect = _get_chunk + assert query._all_descendants + assert query._field_filters[0].field.field_path == "foo" + assert query._field_filters[0].value.string_value == "bar" + assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL + assert query._parent.id == "collectionId" - bulk_writer = mock.MagicMock() - bulk_writer.mock_add_spec(spec=["delete", "close"]) - num_deleted = await client.recursive_delete( - collection_ref, bulk_writer=bulk_writer, chunk_size=3 - ) - self.assertEqual(num_deleted, len(results)) +def test_asyncclient_collection_group_no_slashes(): + client = _make_default_async_client() + with pytest.raises(ValueError): + client.collection_group("foo/bar") - @pytest.mark.asyncio - async def test_recursive_delete_from_document(self): - client = self._make_default_one() - client._firestore_api_internal = mock.Mock( - spec=["run_query", "list_collection_ids"] - ) - collection_ref = client.collection("my_collection") - collection_1_id: str = "collection_1_id" - collection_2_id: str = "collection_2_id" +def test_asyncclient_document_factory(): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference - parent_doc = collection_ref.document("parent") + parts = ("rooms", "roomA") + client = _make_default_async_client() + doc_path = "/".join(parts) + document1 = client.document(doc_path) - collection_1_results = [] - collection_2_results = [] + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, AsyncDocumentReference) - for index in range(10): - collection_1_results.append( - RunQueryResponse(document=Document(name=f"{collection_1_id}/{index}"),), - ) + # Make sure using segments gives the same result. + document2 = client.document(*parts) + assert document2._path == parts + assert document2._client is client + assert isinstance(document2, AsyncDocumentReference) + + +def test_asyncclient_document_factory_w_absolute_path(): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + parts = ("rooms", "roomA") + client = _make_default_async_client() + doc_path = "/".join(parts) + to_match = client.document(doc_path) + document1 = client.document(to_match._document_path) + + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, AsyncDocumentReference) + + +def test_asyncclient_document_factory_w_nested_path(): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + + client = _make_default_async_client() + parts = ("rooms", "roomA", "shoes", "dressy") + doc_path = "/".join(parts) + document1 = client.document(doc_path) + + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, AsyncDocumentReference) + + # Make sure using segments gives the same result. + document2 = client.document(*parts) + assert document2._path == parts + assert document2._client is client + assert isinstance(document2, AsyncDocumentReference) + + +async def _collections_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + from google.cloud.firestore_v1 import _helpers + + collection_ids = ["users", "projects"] + + class Pager(object): + async def __aiter__(self, **_): + for collection_id in collection_ids: + yield collection_id + + firestore_api = AsyncMock() + firestore_api.mock_add_spec(spec=["list_collection_ids"]) + firestore_api.list_collection_ids.return_value = Pager() + + client = _make_default_async_client() + client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + collections = [c async for c in client.collections(**kwargs)] + + assert len(collections) == len(collection_ids) + for collection, collection_id in zip(collections, collection_ids): + assert isinstance(collection, AsyncCollectionReference) + assert collection.parent is None + assert collection.id == collection_id + + base_path = client._database_string + "/documents" + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncclient_collections(): + await _collections_helper() + + +@pytest.mark.asyncio +async def test_asyncclient_collections_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _collections_helper(retry=retry, timeout=timeout) + + +async def _invoke_get_all(client, references, document_pbs, **kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["batch_get_documents"]) + response_iterator = AsyncIter(document_pbs) + firestore_api.batch_get_documents.return_value = response_iterator + + # Attach the fake GAPIC to a real client. + client._firestore_api_internal = firestore_api + + # Actually call get_all(). + snapshots = client.get_all(references, **kwargs) + assert isinstance(snapshots, types.AsyncGeneratorType) + + return [s async for s in snapshots] + + +async def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + client = _make_default_async_client() + + data1 = {"a": "cheese"} + document1 = client.document("pineapple", "lamp1") + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + + data2 = {"b": True, "c": 18} + document2 = client.document("pineapple", "lamp2") + document, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=read_time) + + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) + + expected_data = [data1, data2, None][:num_snapshots] + documents = [document1, document2, document3][:num_snapshots] + responses = [response1, response2, response3][:num_snapshots] + field_paths = [ + field_path for field_path in ["a", "b", None][:num_snapshots] if field_path + ] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if txn_id is not None: + transaction = client.transaction() + transaction._id = txn_id + kwargs["transaction"] = transaction + + snapshots = await _invoke_get_all( + client, documents, responses, field_paths=field_paths, **kwargs, + ) + + assert len(snapshots) == num_snapshots + + for data, document, snapshot in zip(expected_data, documents, snapshots): + assert isinstance(snapshot, DocumentSnapshot) + assert snapshot._reference is document + if data is None: + assert not snapshot.exists + else: + assert snapshot._data == data + + # Verify the call to the mock. + doc_paths = [document._document_path for document in documents] + mask = common.DocumentMask(field_paths=field_paths) + + kwargs.pop("transaction", None) + + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncclient_get_all(): + await _get_all_helper() - collection_2_results.append( - RunQueryResponse(document=Document(name=f"{collection_2_id}/{index}"),), - ) - col_1_chunks = [ - collection_1_results[:3], - collection_1_results[3:6], - collection_1_results[6:9], - collection_1_results[9:], - ] - - col_2_chunks = [ - collection_2_results[:3], - collection_2_results[3:6], - collection_2_results[6:9], - collection_2_results[9:], - ] - - async def _get_chunk(*args, **kwargs): - start_at = ( - kwargs["request"]["structured_query"].start_at.values[0].reference_value +@pytest.mark.asyncio +async def test_asyncclient_get_all_with_transaction(): + txn_id = b"the-man-is-non-stop" + await _get_all_helper(num_snapshots=1, txn_id=txn_id) + + +@pytest.mark.asyncio +async def test_asyncclient_get_all_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_all_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncclient_get_all_wrong_order(): + await _get_all_helper(num_snapshots=3) + + +@pytest.mark.asyncio +async def test_asyncclient_get_all_unknown_result(): + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + + client = _make_default_async_client() + + expected_document = client.document("pineapple", "lamp1") + + data = {"z": 28.5} + wrong_document = client.document("pineapple", "lamp2") + document_pb, read_time = _doc_get_info(wrong_document._document_path, data) + response = _make_batch_response(found=document_pb, read_time=read_time) + + # Exercise the mocked ``batch_get_documents``. + with pytest.raises(ValueError) as exc_info: + await _invoke_get_all(client, [expected_document], [response]) + + err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) + assert exc_info.value.args == (err_msg,) + + # Verify the call to the mock. + doc_paths = [expected_document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_asyncclient_bulk_writer(): + """BulkWriter is opaquely async and thus does not have a dedicated + async variant.""" + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = _make_default_async_client() + bulk_writer = client.bulk_writer() + assert isinstance(bulk_writer, BulkWriter) + assert bulk_writer._client is client._sync_copy + + +def test_asyncclient_sync_copy(): + client = _make_default_async_client() + # Multiple calls to this method should return the same cached instance. + assert client._to_sync_copy() is client._to_sync_copy() + + +@pytest.mark.asyncio +async def test_asyncclient_recursive_delete(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_default_async_client() + client._firestore_api_internal = AsyncMock(spec=["run_query"]) + collection_ref = client.collection("my_collection") + + results = [] + for index in range(10): + results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_ref.id}/{index}") ) + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] - if collection_1_id in start_at: - return AsyncIter(col_1_chunks.pop(0)) - return AsyncIter(col_2_chunks.pop(0)) + def _get_chunk(*args, **kwargs): + return AsyncIter(items=chunks.pop(0)) - async def _get_collections(*args, **kwargs): - return AsyncIter([collection_1_id, collection_2_id]) + client._firestore_api_internal.run_query.side_effect = _get_chunk - client._firestore_api_internal.run_query.side_effect = _get_chunk - client._firestore_api_internal.list_collection_ids.side_effect = ( - _get_collections + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = await client.recursive_delete( + collection_ref, bulk_writer=bulk_writer, chunk_size=3 + ) + assert num_deleted == len(results) + + +@pytest.mark.asyncio +async def test_asyncclient_recursive_delete_from_document(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_default_async_client() + client._firestore_api_internal = mock.Mock( + spec=["run_query", "list_collection_ids"] + ) + collection_ref = client.collection("my_collection") + + collection_1_id: str = "collection_1_id" + collection_2_id: str = "collection_2_id" + + parent_doc = collection_ref.document("parent") + + collection_1_results = [] + collection_2_results = [] + + for index in range(10): + collection_1_results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_1_id}/{index}"), + ), ) - bulk_writer = mock.MagicMock() - bulk_writer.mock_add_spec(spec=["delete", "close"]) + collection_2_results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_2_id}/{index}"), + ), + ) - num_deleted = await client.recursive_delete( - parent_doc, bulk_writer=bulk_writer, chunk_size=3 + col_1_chunks = [ + collection_1_results[:3], + collection_1_results[3:6], + collection_1_results[6:9], + collection_1_results[9:], + ] + + col_2_chunks = [ + collection_2_results[:3], + collection_2_results[3:6], + collection_2_results[6:9], + collection_2_results[9:], + ] + + async def _get_chunk(*args, **kwargs): + start_at = ( + kwargs["request"]["structured_query"].start_at.values[0].reference_value ) - expected_len = len(collection_1_results) + len(collection_2_results) + 1 - self.assertEqual(num_deleted, expected_len) - - @pytest.mark.asyncio - async def test_recursive_delete_raises(self): - client = self._make_default_one() - with self.assertRaises(TypeError): - await client.recursive_delete(object()) - - def test_batch(self): - from google.cloud.firestore_v1.async_batch import AsyncWriteBatch - - client = self._make_default_one() - batch = client.batch() - self.assertIsInstance(batch, AsyncWriteBatch) - self.assertIs(batch._client, client) - self.assertEqual(batch._write_pbs, []) - - def test_transaction(self): - from google.cloud.firestore_v1.async_transaction import AsyncTransaction - - client = self._make_default_one() - transaction = client.transaction(max_attempts=3, read_only=True) - self.assertIsInstance(transaction, AsyncTransaction) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, 3) - self.assertTrue(transaction._read_only) - self.assertIsNone(transaction._id) + if collection_1_id in start_at: + return AsyncIter(col_1_chunks.pop(0)) + return AsyncIter(col_2_chunks.pop(0)) + + async def _get_collections(*args, **kwargs): + return AsyncIter([collection_1_id, collection_2_id]) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + client._firestore_api_internal.list_collection_ids.side_effect = _get_collections + + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = await client.recursive_delete( + parent_doc, bulk_writer=bulk_writer, chunk_size=3 + ) + + expected_len = len(collection_1_results) + len(collection_2_results) + 1 + assert num_deleted == expected_len + + +@pytest.mark.asyncio +async def test_asyncclient_recursive_delete_raises(): + client = _make_default_async_client() + with pytest.raises(TypeError): + await client.recursive_delete(object()) + + +def test_asyncclient_batch(): + from google.cloud.firestore_v1.async_batch import AsyncWriteBatch + + client = _make_default_async_client() + batch = client.batch() + assert isinstance(batch, AsyncWriteBatch) + assert batch._client is client + assert batch._write_pbs == [] + + +def test_asyncclient_transaction(): + from google.cloud.firestore_v1.async_transaction import AsyncTransaction + + client = _make_default_async_client() + transaction = client.transaction(max_attempts=3, read_only=True) + assert isinstance(transaction, AsyncTransaction) + assert transaction._write_pbs == [] + assert transaction._max_attempts == 3 + assert transaction._read_only + assert transaction._id is None def _make_credentials(): diff --git a/tests/unit/v1/test_async_collection.py b/tests/unit/v1/test_async_collection.py index 1955ca52defa9..69a33d11224da 100644 --- a/tests/unit/v1/test_async_collection.py +++ b/tests/unit/v1/test_async_collection.py @@ -12,412 +12,425 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse -import pytest import types -import aiounittest import mock -from tests.unit.v1.test__helpers import AsyncIter, AsyncMock - - -class TestAsyncCollectionReference(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - - return AsyncCollectionReference - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - @staticmethod - def _get_public_methods(klass): - return set().union( - *( - ( - name - for name, value in class_.__dict__.items() - if ( - not name.startswith("_") - and isinstance(value, types.FunctionType) - ) - ) - for class_ in (klass,) + klass.__bases__ - ) - ) +import pytest - def test_query_method_matching(self): - from google.cloud.firestore_v1.async_query import AsyncQuery - - query_methods = self._get_public_methods(AsyncQuery) - klass = self._get_target_class() - collection_methods = self._get_public_methods(klass) - # Make sure every query method is present on - # ``AsyncCollectionReference``. - self.assertLessEqual(query_methods, collection_methods) - - def test_document_name_default(self): - client = _make_client() - document = client.collection("test").document() - # name is random, but assert it is not None - self.assertTrue(document.id is not None) - - def test_constructor(self): - collection_id1 = "rooms" - document_id = "roomA" - collection_id2 = "messages" - client = mock.sentinel.client - - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - self.assertIs(collection._client, client) - expected_path = (collection_id1, document_id, collection_id2) - self.assertEqual(collection._path, expected_path) - - @pytest.mark.asyncio - async def test_add_auto_assigned(self): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - from google.cloud.firestore_v1._helpers import pbs_for_create - - # Create a minimal fake GAPIC add attach it to a real client. - firestore_api = AsyncMock(spec=["create_document", "commit"]) - write_result = mock.Mock( - update_time=mock.sentinel.update_time, spec=["update_time"] - ) - commit_response = mock.Mock( - write_results=[write_result], - spec=["write_results", "commit_time"], - commit_time=mock.sentinel.commit_time, +from tests.unit.v1.test__helpers import AsyncIter +from tests.unit.v1.test__helpers import AsyncMock + + +def _make_async_collection_reference(*args, **kwargs): + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + return AsyncCollectionReference(*args, **kwargs) + + +def _get_public_methods(klass): + return set().union( + *( + ( + name + for name, value in class_.__dict__.items() + if (not name.startswith("_") and isinstance(value, types.FunctionType)) + ) + for class_ in (klass,) + klass.__bases__ ) - firestore_api.commit.return_value = commit_response - create_doc_response = document.Document() - firestore_api.create_document.return_value = create_doc_response - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a collection. - collection = self._make_one("grand-parent", "parent", "child", client=client) - - # Actually call add() on our collection; include a transform to make - # sure transforms during adds work. - document_data = {"been": "here", "now": SERVER_TIMESTAMP} - - patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") - random_doc_id = "DEADBEEF" - with patch as patched: - patched.return_value = random_doc_id - update_time, document_ref = await collection.add(document_data) - - # Verify the response and the mocks. - self.assertIs(update_time, mock.sentinel.update_time) - self.assertIsInstance(document_ref, AsyncDocumentReference) - self.assertIs(document_ref._client, client) - expected_path = collection._path + (random_doc_id,) - self.assertEqual(document_ref._path, expected_path) - - write_pbs = pbs_for_create(document_ref._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, + ) + + +def test_asynccollectionreference_constructor(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_async_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + assert collection._client is client + expected_path = (collection_id1, document_id, collection_id2) + assert collection._path == expected_path + + +def test_asynccollectionreference_query_method_matching(): + from google.cloud.firestore_v1.async_query import AsyncQuery + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference + + query_methods = _get_public_methods(AsyncQuery) + collection_methods = _get_public_methods(AsyncCollectionReference) + # Make sure every query method is present on + # ``AsyncCollectionReference``. + assert query_methods <= collection_methods + + +def test_asynccollectionreference_document_name_default(): + client = _make_client() + document = client.collection("test").document() + # name is random, but assert it is not None + assert document.id is not None + + +@pytest.mark.asyncio +async def test_asynccollectionreference_add_auto_assigned(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_create + + # Create a minimal fake GAPIC add attach it to a real client. + firestore_api = AsyncMock(spec=["create_document", "commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + create_doc_response = document.Document() + firestore_api.create_document.return_value = create_doc_response + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection. + collection = _make_async_collection_reference( + "grand-parent", "parent", "child", client=client + ) + + # Actually call add() on our collection; include a transform to make + # sure transforms during adds work. + document_data = {"been": "here", "now": SERVER_TIMESTAMP} + + patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") + random_doc_id = "DEADBEEF" + with patch as patched: + patched.return_value = random_doc_id + update_time, document_ref = await collection.add(document_data) + + # Verify the response and the mocks. + assert update_time is mock.sentinel.update_time + assert isinstance(document_ref, AsyncDocumentReference) + assert document_ref._client is client + expected_path = collection._path + (random_doc_id,) + assert document_ref._path == expected_path + + write_pbs = pbs_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + # Since we generate the ID locally, we don't call 'create_document'. + firestore_api.create_document.assert_not_called() + + +def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) + + +async def _add_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection and call add(). + collection = _make_async_collection_reference("parent", client=client) + document_data = {"zorp": 208.75, "i-did-not": b"know that"} + doc_id = "child" + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + update_time, document_ref = await collection.add( + document_data, document_id=doc_id, **kwargs, + ) + + # Verify the response and the mocks. + assert update_time is mock.sentinel.update_time + assert isinstance(document_ref, AsyncDocumentReference) + assert document_ref._client is client + assert document_ref._path == (collection.id, doc_id) + + write_pb = _write_pb_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asynccollectionreference_add_explicit_id(): + await _add_helper() + + +@pytest.mark.asyncio +async def test_asynccollectionreference_add_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _add_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asynccollectionreference_chunkify(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_client() + col = client.collection("my-collection") + + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + + results = [] + for index in range(10): + name = ( + f"projects/project-project/databases/(default)/" + f"documents/my-collection/{index}" ) - # Since we generate the ID locally, we don't call 'create_document'. - firestore_api.create_document.assert_not_called() - - @staticmethod - def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ), - current_document=common.Precondition(exists=False), + results.append( + firestore.RunQueryResponse(document=document.Document(name=name),), ) - async def _add_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1 import _helpers + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + async def _get_chunk(*args, **kwargs): + return AsyncIter(chunks.pop(0)) + + client._firestore_api_internal.run_query.side_effect = _get_chunk + + counter = 0 + expected_lengths = [3, 3, 3, 1] + async for chunk in col._chunkify(3): + msg = f"Expected chunk of length {expected_lengths[counter]} at index {counter}. Saw {len(chunk)}." + assert len(chunk) == expected_lengths[counter], msg + counter += 1 + + +@pytest.mark.asyncio +async def _list_documents_helper(page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.api_core.page_iterator_async import AsyncIterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.types.document import Document + + class _AsyncIterator(AsyncIterator): + def __init__(self, pages): + super(_AsyncIterator, self).__init__(client=None) + self._pages = pages + + async def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + client = _make_client() + template = client._database_string + "/documents/{}" + document_ids = ["doc-1", "doc-2"] + documents = [ + Document(name=template.format(document_id)) for document_id in document_ids + ] + iterator = _AsyncIterator(pages=[documents]) + firestore_api = AsyncMock() + firestore_api.mock_add_spec(spec=["list_documents"]) + firestore_api.list_documents.return_value = iterator + client._firestore_api_internal = firestore_api + collection = _make_async_collection_reference("collection", client=client) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if page_size is not None: + documents = [ + i async for i in collection.list_documents(page_size=page_size, **kwargs,) + ] + else: + documents = [i async for i in collection.list_documents(**kwargs)] - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["commit"]) - write_result = mock.Mock( - update_time=mock.sentinel.update_time, spec=["update_time"] - ) - commit_response = mock.Mock( - write_results=[write_result], - spec=["write_results", "commit_time"], - commit_time=mock.sentinel.commit_time, - ) - firestore_api.commit.return_value = commit_response + # Verify the response and the mocks. + assert len(documents) == len(document_ids) + for document, document_id in zip(documents, document_ids): + assert isinstance(document, AsyncDocumentReference) + assert document.parent == collection + assert document.id == document_id - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api + parent, _ = collection._parent_info() + firestore_api.list_documents.assert_called_once_with( + request={ + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + "mask": {"field_paths": None}, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - # Actually make a collection and call add(). - collection = self._make_one("parent", client=client) - document_data = {"zorp": 208.75, "i-did-not": b"know that"} - doc_id = "child" - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - update_time, document_ref = await collection.add( - document_data, document_id=doc_id, **kwargs, - ) +@pytest.mark.asyncio +async def test_asynccollectionreference_list_documents_wo_page_size(): + await _list_documents_helper() - # Verify the response and the mocks. - self.assertIs(update_time, mock.sentinel.update_time) - self.assertIsInstance(document_ref, AsyncDocumentReference) - self.assertIs(document_ref._client, client) - self.assertEqual(document_ref._path, (collection.id, doc_id)) - - write_pb = self._write_pb_for_create(document_ref._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - @pytest.mark.asyncio - async def test_add_explicit_id(self): - await self._add_helper() - - @pytest.mark.asyncio - async def test_add_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._add_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_chunkify(self): - client = _make_client() - col = client.collection("my-collection") - - client._firestore_api_internal = mock.Mock(spec=["run_query"]) - - results = [] - for index in range(10): - results.append( - RunQueryResponse( - document=Document( - name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", - ), - ), - ) +@pytest.mark.asyncio +async def test_asynccollectionreference_list_documents_w_retry_timeout(): + from google.api_core.retry import Retry - chunks = [ - results[:3], - results[3:6], - results[6:9], - results[9:], - ] + retry = Retry(predicate=object()) + timeout = 123.0 + await _list_documents_helper(retry=retry, timeout=timeout) - async def _get_chunk(*args, **kwargs): - return AsyncIter(chunks.pop(0)) - - client._firestore_api_internal.run_query.side_effect = _get_chunk - - counter = 0 - expected_lengths = [3, 3, 3, 1] - async for chunk in col._chunkify(3): - msg = f"Expected chunk of length {expected_lengths[counter]} at index {counter}. Saw {len(chunk)}." - self.assertEqual(len(chunk), expected_lengths[counter], msg) - counter += 1 - - @pytest.mark.asyncio - async def _list_documents_helper(self, page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.api_core.page_iterator_async import AsyncIterator - from google.api_core.page_iterator import Page - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1.types.document import Document - - class _AsyncIterator(AsyncIterator): - def __init__(self, pages): - super(_AsyncIterator, self).__init__(client=None) - self._pages = pages - - async def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) - - client = _make_client() - template = client._database_string + "/documents/{}" - document_ids = ["doc-1", "doc-2"] - documents = [ - Document(name=template.format(document_id)) for document_id in document_ids - ] - iterator = _AsyncIterator(pages=[documents]) - firestore_api = AsyncMock() - firestore_api.mock_add_spec(spec=["list_documents"]) - firestore_api.list_documents.return_value = iterator - client._firestore_api_internal = firestore_api - collection = self._make_one("collection", client=client) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - if page_size is not None: - documents = [ - i - async for i in collection.list_documents(page_size=page_size, **kwargs,) - ] - else: - documents = [i async for i in collection.list_documents(**kwargs)] - - # Verify the response and the mocks. - self.assertEqual(len(documents), len(document_ids)) - for document, document_id in zip(documents, document_ids): - self.assertIsInstance(document, AsyncDocumentReference) - self.assertEqual(document.parent, collection) - self.assertEqual(document.id, document_id) - - parent, _ = collection._parent_info() - firestore_api.list_documents.assert_called_once_with( - request={ - "parent": parent, - "collection_id": collection.id, - "page_size": page_size, - "show_missing": True, - "mask": {"field_paths": None}, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - @pytest.mark.asyncio - async def test_list_documents_wo_page_size(self): - await self._list_documents_helper() +@pytest.mark.asyncio +async def test_asynccollectionreference_list_documents_w_page_size(): + await _list_documents_helper(page_size=25) - @pytest.mark.asyncio - async def test_list_documents_w_retry_timeout(self): - from google.api_core.retry import Retry - retry = Retry(predicate=object()) - timeout = 123.0 - await self._list_documents_helper(retry=retry, timeout=timeout) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_get(query_class): + collection = _make_async_collection_reference("collection") + get_response = await collection.get() - @pytest.mark.asyncio - async def test_list_documents_w_page_size(self): - await self._list_documents_helper(page_size=25) + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_get(self, query_class): - collection = self._make_one("collection") - get_response = await collection.get() + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with(transaction=None) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with(transaction=None) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_get_w_retry_timeout(query_class): + from google.api_core.retry import Retry - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_get_w_retry_timeout(self, query_class): - from google.api_core.retry import Retry + retry = Retry(predicate=object()) + timeout = 123.0 + collection = _make_async_collection_reference("collection") + get_response = await collection.get(retry=retry, timeout=timeout) - retry = Retry(predicate=object()) - timeout = 123.0 - collection = self._make_one("collection") - get_response = await collection.get(retry=retry, timeout=timeout) + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with( - transaction=None, retry=retry, timeout=timeout, - ) - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_get_with_transaction(self, query_class): - collection = self._make_one("collection") - transaction = mock.sentinel.txn - get_response = await collection.get(transaction=transaction) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_get_with_transaction(query_class): + collection = _make_async_collection_reference("collection") + transaction = mock.sentinel.txn + get_response = await collection.get(transaction=transaction) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with(transaction=transaction) + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with(transaction=transaction) - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_stream(self, query_class): - query_class.return_value.stream.return_value = AsyncIter(range(3)) - collection = self._make_one("collection") - stream_response = collection.stream() +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_stream(query_class): + query_class.return_value.stream.return_value = AsyncIter(range(3)) - async for _ in stream_response: - pass + collection = _make_async_collection_reference("collection") + stream_response = collection.stream() - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - query_instance.stream.assert_called_once_with(transaction=None) + async for _ in stream_response: + pass - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_stream_w_retry_timeout(self, query_class): - from google.api_core.retry import Retry + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=None) - retry = Retry(predicate=object()) - timeout = 123.0 - query_class.return_value.stream.return_value = AsyncIter(range(3)) - collection = self._make_one("collection") - stream_response = collection.stream(retry=retry, timeout=timeout) +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_stream_w_retry_timeout(query_class): + from google.api_core.retry import Retry - async for _ in stream_response: - pass + retry = Retry(predicate=object()) + timeout = 123.0 + query_class.return_value.stream.return_value = AsyncIter(range(3)) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - query_instance.stream.assert_called_once_with( - transaction=None, retry=retry, timeout=timeout, - ) + collection = _make_async_collection_reference("collection") + stream_response = collection.stream(retry=retry, timeout=timeout) + + async for _ in stream_response: + pass + + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) + + +@mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) +@pytest.mark.asyncio +async def test_asynccollectionreference_stream_with_transaction(query_class): + query_class.return_value.stream.return_value = AsyncIter(range(3)) - @mock.patch("google.cloud.firestore_v1.async_query.AsyncQuery", autospec=True) - @pytest.mark.asyncio - async def test_stream_with_transaction(self, query_class): - query_class.return_value.stream.return_value = AsyncIter(range(3)) + collection = _make_async_collection_reference("collection") + transaction = mock.sentinel.txn + stream_response = collection.stream(transaction=transaction) - collection = self._make_one("collection") - transaction = mock.sentinel.txn - stream_response = collection.stream(transaction=transaction) + async for _ in stream_response: + pass - async for _ in stream_response: - pass + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + query_instance.stream.assert_called_once_with(transaction=transaction) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - query_instance.stream.assert_called_once_with(transaction=transaction) - def test_recursive(self): - from google.cloud.firestore_v1.async_query import AsyncQuery +def test_asynccollectionreference_recursive(): + from google.cloud.firestore_v1.async_query import AsyncQuery - col = self._make_one("collection") - self.assertIsInstance(col.recursive(), AsyncQuery) + col = _make_async_collection_reference("collection") + assert isinstance(col.recursive(), AsyncQuery) def _make_credentials(): diff --git a/tests/unit/v1/test_async_document.py b/tests/unit/v1/test_async_document.py index 701ef5a59dada..7d8558fe8ded9 100644 --- a/tests/unit/v1/test_async_document.py +++ b/tests/unit/v1/test_async_document.py @@ -12,561 +12,586 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import collections -import aiounittest import mock +import pytest + from tests.unit.v1.test__helpers import AsyncIter, AsyncMock -class TestAsyncDocumentReference(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference +def _make_async_document_reference(*args, **kwargs): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference - return AsyncDocumentReference + return AsyncDocumentReference(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor(self): - collection_id1 = "users" - document_id1 = "alovelace" - collection_id2 = "platform" - document_id2 = "*nix" - client = mock.MagicMock() - client.__hash__.return_value = 1234 +def test_asyncdocumentreference_constructor(): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 + + document = _make_async_document_reference( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + assert document._client is client + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + assert document.path == expected_path - document = self._make_one( - collection_id1, document_id1, collection_id2, document_id2, client=client - ) - self.assertIs(document._client, client) - expected_path = "/".join( - (collection_id1, document_id1, collection_id2, document_id2) - ) - self.assertEqual(document.path, expected_path) - - @staticmethod - def _make_commit_repsonse(write_results=None): - from google.cloud.firestore_v1.types import firestore - - response = mock.create_autospec(firestore.CommitResponse) - response.write_results = write_results or [mock.sentinel.write_result] - response.commit_time = mock.sentinel.commit_time - return response - - @staticmethod - def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ), - current_document=common.Precondition(exists=False), - ) - async def _create_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock() - firestore_api.commit.mock_add_spec(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("dignity") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("foo", "twelve", client=client) - document_data = {"hello": "goodbye", "count": 99} - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - write_result = await document.create(document_data, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - write_pb = self._write_pb_for_create(document._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, +def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.types import firestore + + response = mock.create_autospec(firestore.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response + + +def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) + + +async def _create_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock() + firestore_api.commit.mock_add_spec(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_async_document_reference("foo", "twelve", client=client) + document_data = {"hello": "goodbye", "count": 99} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = await document.create(document_data, **kwargs) + + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + write_pb = _write_pb_for_create(document._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_create(): + await _create_helper() + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_create_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _create_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_create_empty(): + # Create a minimal fake GAPIC with a dummy response. + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + firestore_api = AsyncMock(spec=["commit"]) + document_reference = mock.create_autospec(AsyncDocumentReference) + snapshot = mock.create_autospec(DocumentSnapshot) + snapshot.exists = True + document_reference.get.return_value = snapshot + firestore_api.commit.return_value = _make_commit_repsonse( + write_results=[document_reference] + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + client.get_all = mock.MagicMock() + client.get_all.exists.return_value = True + + # Actually make a document and call create(). + document = _make_async_document_reference("foo", "twelve", client=client) + document_data = {} + write_result = await document.create(document_data) + assert (await write_result.get()).exists + + +def _write_pb_for_set(document_path, document_data, merge): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + write_pbs = write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) ) + ) + if merge: + field_paths = [ + field_path + for field_path, value in _helpers.extract_fields( + document_data, _helpers.FieldPath() + ) + ] + field_paths = [field_path.to_api_repr() for field_path in sorted(field_paths)] + mask = common.DocumentMask(field_paths=sorted(field_paths)) + write_pbs._pb.update_mask.CopyFrom(mask._pb) + return write_pbs + + +@pytest.mark.asyncio +async def _set_helper(merge=False, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("db-dee-bee") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_async_document_reference("User", "Interface", client=client) + document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = await document.set(document_data, merge, **kwargs) + + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + write_pb = _write_pb_for_set(document._document_path, document_data, merge) + + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_set(): + await _set_helper() + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_set_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _set_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_set_merge(): + await _set_helper(merge=True) + + +def _write_pb_for_update(document_path, update_values, field_paths): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(update_values) + ), + update_mask=common.DocumentMask(field_paths=field_paths), + current_document=common.Precondition(exists=True), + ) + + +@pytest.mark.asyncio +async def _update_helper(retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.transforms import DELETE_FIELD + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_async_document_reference("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = collections.OrderedDict( + (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) + ) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if option_kwargs: + option = client.write_option(**option_kwargs) + write_result = await document.update(field_updates, option=option, **kwargs) + else: + option = None + write_result = await document.update(field_updates, **kwargs) + + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + update_values = { + "hello": field_updates["hello"], + "then": {"do": field_updates["then.do"]}, + } + field_paths = list(field_updates.keys()) + write_pb = _write_pb_for_update( + document._document_path, update_values, sorted(field_paths) + ) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_update_with_exists(): + with pytest.raises(ValueError): + await _update_helper(exists=True) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_update(): + await _update_helper() + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_update_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _update_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_update_with_precondition(): + from google.protobuf import timestamp_pb2 + + timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + await _update_helper(last_update_time=timestamp) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_empty_update(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_async_document_reference("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = {} + with pytest.raises(ValueError): + await document.update(field_updates) + + +@pytest.mark.asyncio +async def _delete_helper(retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Actually make a document and call delete(). + document = _make_async_document_reference("where", "we-are", client=client) + if option_kwargs: + option = client.write_option(**option_kwargs) + delete_time = await document.delete(option=option, **kwargs) + else: + option = None + delete_time = await document.delete(**kwargs) + + # Verify the response and the mocks. + assert delete_time is mock.sentinel.commit_time + write_pb = write.Write(delete=document._document_path) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_delete(): + await _delete_helper() + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_delete_with_option(): + from google.protobuf import timestamp_pb2 + + timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + await _delete_helper(last_update_time=timestamp_pb) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_delete_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _delete_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def _get_helper( + field_paths=None, + use_transaction=False, + not_found=False, + # This should be an impossible case, but we test against it for + # completeness + return_empty=False, + retry=None, + timeout=None, +): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.transaction import Transaction + + # Create a minimal fake GAPIC with a dummy response. + create_time = 123 + update_time = 234 + read_time = 345 + firestore_api = AsyncMock(spec=["batch_get_documents"]) + response = mock.create_autospec(firestore.BatchGetDocumentsResponse) + response.read_time = 345 + response.found = mock.create_autospec(document.Document) + response.found.fields = {} + response.found.create_time = create_time + response.found.update_time = update_time + + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + document_reference = _make_async_document_reference( + "where", "we-are", client=client + ) + response.found.name = None if not_found else document_reference._document_path + response.missing = document_reference._document_path if not_found else None + + def WhichOneof(val): + return "missing" if not_found else "found" + + response._pb = response + response._pb.WhichOneof = WhichOneof + firestore_api.batch_get_documents.return_value = AsyncIter( + [response] if not return_empty else [] + ) + + if use_transaction: + transaction = Transaction(client) + transaction_id = transaction._id = b"asking-me-2" + else: + transaction = None + + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + snapshot = await document_reference.get( + field_paths=field_paths, transaction=transaction, **kwargs, + ) + + assert snapshot.reference is document_reference + if not_found or return_empty: + assert snapshot._data is None + assert not snapshot.exists + assert snapshot.read_time is not None + assert snapshot.create_time is None + assert snapshot.update_time is None + else: + assert snapshot.to_dict() == {} + assert snapshot.exists + assert snapshot.read_time is read_time + assert snapshot.create_time is create_time + assert snapshot.update_time is update_time + + # Verify the request made to the API + if field_paths is not None: + mask = common.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + if use_transaction: + expected_transaction_id = transaction_id + else: + expected_transaction_id = None + + firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": [document_reference._document_path], + "mask": mask, + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_not_found(): + await _get_helper(not_found=True) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_default(): + await _get_helper() + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_return_empty(): + await _get_helper(return_empty=True) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_w_string_field_path(): + with pytest.raises(ValueError): + await _get_helper(field_paths="foo") + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_with_field_path(): + await _get_helper(field_paths=["foo"]) - @pytest.mark.asyncio - async def test_create(self): - await self._create_helper() - - @pytest.mark.asyncio - async def test_create_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._create_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_create_empty(self): - # Create a minimal fake GAPIC with a dummy response. - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1.async_document import DocumentSnapshot - - firestore_api = AsyncMock(spec=["commit"]) - document_reference = mock.create_autospec(AsyncDocumentReference) - snapshot = mock.create_autospec(DocumentSnapshot) - snapshot.exists = True - document_reference.get.return_value = snapshot - firestore_api.commit.return_value = self._make_commit_repsonse( - write_results=[document_reference] - ) - # Attach the fake GAPIC to a real client. - client = _make_client("dignity") - client._firestore_api_internal = firestore_api - client.get_all = mock.MagicMock() - client.get_all.exists.return_value = True - - # Actually make a document and call create(). - document = self._make_one("foo", "twelve", client=client) - document_data = {} - write_result = await document.create(document_data) - self.assertTrue((await write_result.get()).exists) - - @staticmethod - def _write_pb_for_set(document_path, document_data, merge): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - write_pbs = write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ) - ) - if merge: - field_paths = [ - field_path - for field_path, value in _helpers.extract_fields( - document_data, _helpers.FieldPath() - ) - ] - field_paths = [ - field_path.to_api_repr() for field_path in sorted(field_paths) - ] - mask = common.DocumentMask(field_paths=sorted(field_paths)) - write_pbs._pb.update_mask.CopyFrom(mask._pb) - return write_pbs - - @pytest.mark.asyncio - async def _set_helper(self, merge=False, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("db-dee-bee") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("User", "Interface", client=client) - document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - write_result = await document.set(document_data, merge, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - write_pb = self._write_pb_for_set(document._document_path, document_data, merge) - - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_with_multiple_field_paths(): + await _get_helper(field_paths=["foo", "bar.baz"]) - @pytest.mark.asyncio - async def test_set(self): - await self._set_helper() - - @pytest.mark.asyncio - async def test_set_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._set_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_set_merge(self): - await self._set_helper(merge=True) - - @staticmethod - def _write_pb_for_update(document_path, update_values, field_paths): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(update_values) - ), - update_mask=common.DocumentMask(field_paths=field_paths), - current_document=common.Precondition(exists=True), - ) - @pytest.mark.asyncio - async def _update_helper(self, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.transforms import DELETE_FIELD +@pytest.mark.asyncio +async def test_asyncdocumentreference_get_with_transaction(): + await _get_helper(use_transaction=True) - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") - client._firestore_api_internal = firestore_api +@pytest.mark.asyncio +async def _collections_helper(page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - # Actually make a document and call create(). - document = self._make_one("baked", "Alaska", client=client) - # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. - field_updates = collections.OrderedDict( - (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) - ) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - if option_kwargs: - option = client.write_option(**option_kwargs) - write_result = await document.update(field_updates, option=option, **kwargs) - else: - option = None - write_result = await document.update(field_updates, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - update_values = { - "hello": field_updates["hello"], - "then": {"do": field_updates["then.do"]}, - } - field_paths = list(field_updates.keys()) - write_pb = self._write_pb_for_update( - document._document_path, update_values, sorted(field_paths) - ) - if option is not None: - option.modify_write(write_pb) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + collection_ids = ["coll-1", "coll-2"] - @pytest.mark.asyncio - async def test_update_with_exists(self): - with self.assertRaises(ValueError): - await self._update_helper(exists=True) - - @pytest.mark.asyncio - async def test_update(self): - await self._update_helper() - - @pytest.mark.asyncio - async def test_update_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._update_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_update_with_precondition(self): - from google.protobuf import timestamp_pb2 - - timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - await self._update_helper(last_update_time=timestamp) - - @pytest.mark.asyncio - async def test_empty_update(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("baked", "Alaska", client=client) - # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. - field_updates = {} - with self.assertRaises(ValueError): - await document.update(field_updates) - - @pytest.mark.asyncio - async def _delete_helper(self, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("donut-base") - client._firestore_api_internal = firestore_api - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Actually make a document and call delete(). - document = self._make_one("where", "we-are", client=client) - if option_kwargs: - option = client.write_option(**option_kwargs) - delete_time = await document.delete(option=option, **kwargs) - else: - option = None - delete_time = await document.delete(**kwargs) - - # Verify the response and the mocks. - self.assertIs(delete_time, mock.sentinel.commit_time) - write_pb = write.Write(delete=document._document_path) - if option is not None: - option.modify_write(write_pb) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + class Pager(object): + async def __aiter__(self, **_): + for collection_id in collection_ids: + yield collection_id - @pytest.mark.asyncio - async def test_delete(self): - await self._delete_helper() - - @pytest.mark.asyncio - async def test_delete_with_option(self): - from google.protobuf import timestamp_pb2 - - timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - await self._delete_helper(last_update_time=timestamp_pb) - - @pytest.mark.asyncio - async def test_delete_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._delete_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def _get_helper( - self, - field_paths=None, - use_transaction=False, - not_found=False, - # This should be an impossible case, but we test against it for - # completeness - return_empty=False, - retry=None, - timeout=None, - ): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.transaction import Transaction - - # Create a minimal fake GAPIC with a dummy response. - create_time = 123 - update_time = 234 - read_time = 345 - firestore_api = AsyncMock(spec=["batch_get_documents"]) - response = mock.create_autospec(firestore.BatchGetDocumentsResponse) - response.read_time = 345 - response.found = mock.create_autospec(document.Document) - response.found.fields = {} - response.found.create_time = create_time - response.found.update_time = update_time - - client = _make_client("donut-base") - client._firestore_api_internal = firestore_api - document_reference = self._make_one("where", "we-are", client=client) - response.found.name = None if not_found else document_reference._document_path - response.missing = document_reference._document_path if not_found else None - - def WhichOneof(val): - return "missing" if not_found else "found" - - response._pb = response - response._pb.WhichOneof = WhichOneof - firestore_api.batch_get_documents.return_value = AsyncIter( - [response] if not return_empty else [] - ) + firestore_api = AsyncMock() + firestore_api.mock_add_spec(spec=["list_collection_ids"]) + firestore_api.list_collection_ids.return_value = Pager() - if use_transaction: - transaction = Transaction(client) - transaction_id = transaction._id = b"asking-me-2" - else: - transaction = None + client = _make_client() + client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + # Actually make a document and call delete(). + document = _make_async_document_reference("where", "we-are", client=client) + if page_size is not None: + collections = [ + c async for c in document.collections(page_size=page_size, **kwargs) + ] + else: + collections = [c async for c in document.collections(**kwargs)] - snapshot = await document_reference.get( - field_paths=field_paths, transaction=transaction, **kwargs, - ) + # Verify the response and the mocks. + assert len(collections) == len(collection_ids) + for collection, collection_id in zip(collections, collection_ids): + assert isinstance(collection, AsyncCollectionReference) + assert collection.parent == document + assert collection.id == collection_id - self.assertIs(snapshot.reference, document_reference) - if not_found or return_empty: - self.assertIsNone(snapshot._data) - self.assertFalse(snapshot.exists) - self.assertIsNotNone(snapshot.read_time) - self.assertIsNone(snapshot.create_time) - self.assertIsNone(snapshot.update_time) - else: - self.assertEqual(snapshot.to_dict(), {}) - self.assertTrue(snapshot.exists) - self.assertIs(snapshot.read_time, read_time) - self.assertIs(snapshot.create_time, create_time) - self.assertIs(snapshot.update_time, update_time) - - # Verify the request made to the API - if field_paths is not None: - mask = common.DocumentMask(field_paths=sorted(field_paths)) - else: - mask = None - - if use_transaction: - expected_transaction_id = transaction_id - else: - expected_transaction_id = None - - firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": [document_reference._document_path], - "mask": mask, - "transaction": expected_transaction_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncdocumentreference_collections(): + await _collections_helper() - @pytest.mark.asyncio - async def test_get_not_found(self): - await self._get_helper(not_found=True) - - @pytest.mark.asyncio - async def test_get_default(self): - await self._get_helper() - - @pytest.mark.asyncio - async def test_get_return_empty(self): - await self._get_helper(return_empty=True) - - @pytest.mark.asyncio - async def test_get_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_get_w_string_field_path(self): - with self.assertRaises(ValueError): - await self._get_helper(field_paths="foo") - - @pytest.mark.asyncio - async def test_get_with_field_path(self): - await self._get_helper(field_paths=["foo"]) - - @pytest.mark.asyncio - async def test_get_with_multiple_field_paths(self): - await self._get_helper(field_paths=["foo", "bar.baz"]) - - @pytest.mark.asyncio - async def test_get_with_transaction(self): - await self._get_helper(use_transaction=True) - - @pytest.mark.asyncio - async def _collections_helper(self, page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.async_collection import AsyncCollectionReference - - collection_ids = ["coll-1", "coll-2"] - - class Pager(object): - async def __aiter__(self, **_): - for collection_id in collection_ids: - yield collection_id - - firestore_api = AsyncMock() - firestore_api.mock_add_spec(spec=["list_collection_ids"]) - firestore_api.list_collection_ids.return_value = Pager() - - client = _make_client() - client._firestore_api_internal = firestore_api - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Actually make a document and call delete(). - document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = [ - c async for c in document.collections(page_size=page_size, **kwargs) - ] - else: - collections = [c async for c in document.collections(**kwargs)] - - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, AsyncCollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) - - firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - **kwargs, - ) - @pytest.mark.asyncio - async def test_collections(self): - await self._collections_helper() +@pytest.mark.asyncio +async def test_asyncdocumentreference_collections_w_retry_timeout(): + from google.api_core.retry import Retry - @pytest.mark.asyncio - async def test_collections_w_retry_timeout(self): - from google.api_core.retry import Retry + retry = Retry(predicate=object()) + timeout = 123.0 + await _collections_helper(retry=retry, timeout=timeout) - retry = Retry(predicate=object()) - timeout = 123.0 - await self._collections_helper(retry=retry, timeout=timeout) - @pytest.mark.asyncio - async def test_collections_w_page_size(self): - await self._collections_helper(page_size=10) +@pytest.mark.asyncio +async def test_asyncdocumentreference_collections_w_page_size(): + await _collections_helper(page_size=10) def _make_credentials(): diff --git a/tests/unit/v1/test_async_query.py b/tests/unit/v1/test_async_query.py index 392d7e7a7982f..c7f01608da61d 100644 --- a/tests/unit/v1/test_async_query.py +++ b/tests/unit/v1/test_async_query.py @@ -12,666 +12,693 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse -import pytest import types -import aiounittest import mock -from tests.unit.v1.test__helpers import AsyncIter, AsyncMock -from tests.unit.v1.test_base_query import ( - _make_credentials, - _make_query_response, - _make_cursor_pb, -) - - -class TestAsyncQuery(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_query import AsyncQuery - - return AsyncQuery - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIsNone(query._projection) - self.assertEqual(query._field_filters, ()) - self.assertEqual(query._orders, ()) - self.assertIsNone(query._limit) - self.assertIsNone(query._offset) - self.assertIsNone(query._start_at) - self.assertIsNone(query._end_at) - self.assertFalse(query._all_descendants) - - async def _get_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dee") - - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = AsyncIter([response_pb]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Execute the query and check the response. - query = self._make_one(parent) - returned = await query.get(**kwargs) - - self.assertIsInstance(returned, list) - self.assertEqual(len(returned), 1) - - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +import pytest - @pytest.mark.asyncio - async def test_get(self): - await self._get_helper() +from tests.unit.v1.test__helpers import AsyncIter +from tests.unit.v1.test__helpers import AsyncMock +from tests.unit.v1.test_base_query import _make_credentials +from tests.unit.v1.test_base_query import _make_query_response +from tests.unit.v1.test_base_query import _make_cursor_pb - @pytest.mark.asyncio - async def test_get_w_retry_timeout(self): - from google.api_core.retry import Retry - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_helper(retry=retry, timeout=timeout) +def _make_async_query(*args, **kwargs): + from google.cloud.firestore_v1.async_query import AsyncQuery - @pytest.mark.asyncio - async def test_get_limit_to_last(self): - from google.cloud import firestore - from google.cloud.firestore_v1.base_query import _enum_from_direction + return AsyncQuery(*args, **kwargs) - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api +def test_asyncquery_constructor(): + query = _make_async_query(mock.sentinel.parent) + assert query._parent is mock.sentinel.parent + assert query._projection is None + assert query._field_filters == () + assert query._orders == () + assert query._limit is None + assert query._offset is None + assert query._start_at is None + assert query._end_at is None + assert not query._all_descendants - # Make a **real** collection reference as parent. - parent = client.collection("dee") - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - data2 = {"snooze": 20} +async def _get_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers - response_pb = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response(name=name, data=data2) + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) - firestore_api.run_query.return_value = AsyncIter([response_pb2, response_pb]) + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api - # Execute the query and check the response. - query = self._make_one(parent) - query = query.order_by( - "snooze", direction=firestore.AsyncQuery.DESCENDING - ).limit_to_last(2) - returned = await query.get() + # Make a **real** collection reference as parent. + parent = client.collection("dee") - self.assertIsInstance(returned, list) - self.assertEqual( - query._orders[0].direction, - _enum_from_direction(firestore.AsyncQuery.ASCENDING), - ) - self.assertEqual(len(returned), 2) - - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - - snapshot2 = returned[1] - self.assertEqual(snapshot2.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot2.to_dict(), data2) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - @pytest.mark.asyncio - async def test_chunkify_w_empty(self): - client = _make_client() - firestore_api = AsyncMock(spec=["run_query"]) - firestore_api.run_query.return_value = AsyncIter([]) - client._firestore_api_internal = firestore_api - query = client.collection("asdf")._query() - - chunks = [] - async for chunk in query._chunkify(10): - chunks.append(chunk) - - assert chunks == [[]] - - @pytest.mark.asyncio - async def test_chunkify_w_chunksize_lt_limit(self): - client = _make_client() - firestore_api = AsyncMock(spec=["run_query"]) - doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" - for index in range(5) - ] - responses1 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[:2] - ] - responses2 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[2:4] - ] - responses3 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[4:] - ] - firestore_api.run_query.side_effect = [ - AsyncIter(responses1), - AsyncIter(responses2), - AsyncIter(responses3), - ] - client._firestore_api_internal = firestore_api - query = client.collection("asdf")._query() - - chunks = [] - async for chunk in query._chunkify(2): - chunks.append(chunk) - - self.assertEqual(len(chunks), 3) - expected_ids = [str(index) for index in range(5)] - self.assertEqual([snapshot.id for snapshot in chunks[0]], expected_ids[:2]) - self.assertEqual([snapshot.id for snapshot in chunks[1]], expected_ids[2:4]) - self.assertEqual([snapshot.id for snapshot in chunks[2]], expected_ids[4:]) - - @pytest.mark.asyncio - async def test_chunkify_w_chunksize_gt_limit(self): - client = _make_client() - - firestore_api = AsyncMock(spec=["run_query"]) - responses = [ - RunQueryResponse( - document=Document( - name=f"projects/project-project/databases/(default)/documents/asdf/{index}", + # Execute the query and check the response. + query = _make_async_query(parent) + returned = await query.get(**kwargs) + + assert isinstance(returned, list) + assert len(returned) == 1 + + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_get(): + await _get_helper() + + +@pytest.mark.asyncio +async def test_asyncquery_get_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncquery_get_limit_to_last(): + from google.cloud import firestore + from google.cloud.firestore_v1.base_query import _enum_from_direction + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + data2 = {"snooze": 20} + + response_pb = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response(name=name, data=data2) + + firestore_api.run_query.return_value = AsyncIter([response_pb2, response_pb]) + + # Execute the query and check the response. + query = _make_async_query(parent) + query = query.order_by( + "snooze", direction=firestore.AsyncQuery.DESCENDING + ).limit_to_last(2) + returned = await query.get() + + assert isinstance(returned, list) + assert query._orders[0].direction == _enum_from_direction( + firestore.AsyncQuery.ASCENDING + ) + assert len(returned) == 2 + + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + snapshot2 = returned[1] + assert snapshot2.reference._path == ("dee", "sleep") + assert snapshot2.to_dict() == data2 + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_chunkify_w_empty(): + client = _make_client() + firestore_api = AsyncMock(spec=["run_query"]) + firestore_api.run_query.return_value = AsyncIter([]) + client._firestore_api_internal = firestore_api + query = client.collection("asdf")._query() + + chunks = [] + async for chunk in query._chunkify(10): + chunks.append(chunk) + + assert chunks == [[]] + + +@pytest.mark.asyncio +async def test_asyncquery_chunkify_w_chunksize_lt_limit(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_client() + firestore_api = AsyncMock(spec=["run_query"]) + doc_ids = [ + f"projects/project-project/databases/(default)/documents/asdf/{index}" + for index in range(5) + ] + responses1 = [ + firestore.RunQueryResponse(document=document.Document(name=doc_id),) + for doc_id in doc_ids[:2] + ] + responses2 = [ + firestore.RunQueryResponse(document=document.Document(name=doc_id),) + for doc_id in doc_ids[2:4] + ] + responses3 = [ + firestore.RunQueryResponse(document=document.Document(name=doc_id),) + for doc_id in doc_ids[4:] + ] + firestore_api.run_query.side_effect = [ + AsyncIter(responses1), + AsyncIter(responses2), + AsyncIter(responses3), + ] + client._firestore_api_internal = firestore_api + query = client.collection("asdf")._query() + + chunks = [] + async for chunk in query._chunkify(2): + chunks.append(chunk) + + assert len(chunks) == 3 + expected_ids = [str(index) for index in range(5)] + assert [snapshot.id for snapshot in chunks[0]] == expected_ids[:2] + assert [snapshot.id for snapshot in chunks[1]] == expected_ids[2:4] + assert [snapshot.id for snapshot in chunks[2]] == expected_ids[4:] + + +@pytest.mark.asyncio +async def test_asyncquery_chunkify_w_chunksize_gt_limit(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_client() + + firestore_api = AsyncMock(spec=["run_query"]) + responses = [ + firestore.RunQueryResponse( + document=document.Document( + name=( + f"projects/project-project/databases/(default)/" + f"documents/asdf/{index}" ), - ) - for index in range(5) - ] - firestore_api.run_query.return_value = AsyncIter(responses) - client._firestore_api_internal = firestore_api - - query = client.collection("asdf")._query() - - chunks = [] - async for chunk in query.limit(5)._chunkify(10): - chunks.append(chunk) - - self.assertEqual(len(chunks), 1) - expected_ids = [str(index) for index in range(5)] - self.assertEqual([snapshot.id for snapshot in chunks[0]], expected_ids) - - async def _stream_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dee") - - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = AsyncIter([response_pb]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Execute the query and check the response. - query = self._make_one(parent) - - get_response = query.stream(**kwargs) - - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [x async for x in get_response] - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, + ), ) + for index in range(5) + ] + firestore_api.run_query.return_value = AsyncIter(responses) + client._firestore_api_internal = firestore_api - @pytest.mark.asyncio - async def test_stream_simple(self): - await self._stream_helper() - - @pytest.mark.asyncio - async def test_stream_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._stream_helper(retry=retry, timeout=timeout) - - @pytest.mark.asyncio - async def test_stream_with_limit_to_last(self): - # Attach the fake GAPIC to a real client. - client = _make_client() - # Make a **real** collection reference as parent. - parent = client.collection("dee") - # Execute the query and check the response. - query = self._make_one(parent) - query = query.limit_to_last(2) - - stream_response = query.stream() - - with self.assertRaises(ValueError): - [d async for d in stream_response] - - @pytest.mark.asyncio - async def test_stream_with_transaction(self): - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Create a real-ish transaction for this client. - transaction = client.transaction() - txn_id = b"\x00\x00\x01-work-\xf2" - transaction._id = txn_id - - # Make a **real** collection reference as parent. - parent = client.collection("declaration") - - # Add a dummy response to the minimal fake GAPIC. - parent_path, expected_prefix = parent._parent_info() - name = "{}/burger".format(expected_prefix) - data = {"lettuce": b"\xee\x87"} - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = AsyncIter([response_pb]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream(transaction=transaction) - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [x async for x in get_response] - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("declaration", "burger")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) + query = client.collection("asdf")._query() - @pytest.mark.asyncio - async def test_stream_no_results(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["run_query"]) - empty_response = _make_query_response() - run_query_response = AsyncIter([empty_response]) - firestore_api.run_query.return_value = run_query_response - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dah", "dah", "dum") - query = self._make_one(parent) - - get_response = query.stream() - self.assertIsInstance(get_response, types.AsyncGeneratorType) - self.assertEqual([x async for x in get_response], []) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + chunks = [] + async for chunk in query.limit(5)._chunkify(10): + chunks.append(chunk) - @pytest.mark.asyncio - async def test_stream_second_response_in_empty_stream(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = AsyncMock(spec=["run_query"]) - empty_response1 = _make_query_response() - empty_response2 = _make_query_response() - run_query_response = AsyncIter([empty_response1, empty_response2]) - firestore_api.run_query.return_value = run_query_response - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dah", "dah", "dum") - query = self._make_one(parent) - - get_response = query.stream() - self.assertIsInstance(get_response, types.AsyncGeneratorType) - self.assertEqual([x async for x in get_response], []) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + assert len(chunks) == 1 + expected_ids = [str(index) for index in range(5)] + assert [snapshot.id for snapshot in chunks[0]] == expected_ids - @pytest.mark.asyncio - async def test_stream_with_skipped_results(self): - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("talk", "and", "chew-gum") - - # Add two dummy responses to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - response_pb1 = _make_query_response(skipped_results=1) - name = "{}/clock".format(expected_prefix) - data = {"noon": 12, "nested": {"bird": 10.5}} - response_pb2 = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream() - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [x async for x in get_response] - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("talk", "and", "chew-gum", "clock")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - @pytest.mark.asyncio - async def test_stream_empty_after_first_response(self): - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - - # Add two dummy responses to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/bark".format(expected_prefix) - data = {"lee": "hoop"} - response_pb1 = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response() - firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream() - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [x async for x in get_response] - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("charles", "bark")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) +async def _stream_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers - @pytest.mark.asyncio - async def test_stream_w_collection_group(self): - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - other = client.collection("dora") - - # Add two dummy responses to the minimal fake GAPIC. - _, other_prefix = other._parent_info() - name = "{}/bark".format(other_prefix) - data = {"lee": "hoop"} - response_pb1 = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response() - firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - query._all_descendants = True - get_response = query.stream() - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [x async for x in get_response] - self.assertEqual(len(returned), 1) - snapshot = returned[0] - to_match = other.document("bark") - self.assertEqual(snapshot.reference._document_path, to_match._document_path) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = _make_async_query(parent) + + get_response = query.stream(**kwargs) + + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_simple(): + await _stream_helper() + + +@pytest.mark.asyncio +async def test_asyncquery_stream_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _stream_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_with_limit_to_last(): + # Attach the fake GAPIC to a real client. + client = _make_client() + # Make a **real** collection reference as parent. + parent = client.collection("dee") + # Execute the query and check the response. + query = _make_async_query(parent) + query = query.limit_to_last(2) + + stream_response = query.stream() + + with pytest.raises(ValueError): + [d async for d in stream_response] + + +@pytest.mark.asyncio +async def test_asyncquery_stream_with_transaction(): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb]) + + # Execute the query and check the response. + query = _make_async_query(parent) + get_response = query.stream(transaction=transaction) + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("declaration", "burger") + assert snapshot.to_dict() == data + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_no_results(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["run_query"]) + empty_response = _make_query_response() + run_query_response = AsyncIter([empty_response]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = _make_async_query(parent) + + get_response = query.stream() + assert isinstance(get_response, types.AsyncGeneratorType) + assert [x async for x in get_response] == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_second_response_in_empty_stream(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = AsyncMock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = AsyncIter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = _make_async_query(parent) + + get_response = query.stream() + assert isinstance(get_response, types.AsyncGeneratorType) + assert [x async for x in get_response] == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_with_skipped_results(): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("talk", "and", "chew-gum") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + response_pb1 = _make_query_response(skipped_results=1) + name = "{}/clock".format(expected_prefix) + data = {"noon": 12, "nested": {"bird": 10.5}} + response_pb2 = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_async_query(parent) + get_response = query.stream() + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("talk", "and", "chew-gum", "clock") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_empty_after_first_response(): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/bark".format(expected_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_async_query(parent) + get_response = query.stream() + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("charles", "bark") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asyncquery_stream_w_collection_group(): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + other = client.collection("dora") + + # Add two dummy responses to the minimal fake GAPIC. + _, other_prefix = other._parent_info() + name = "{}/bark".format(other_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_async_query(parent) + query._all_descendants = True + get_response = query.stream() + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [x async for x in get_response] + assert len(returned) == 1 + snapshot = returned[0] + to_match = other.document("bark") + assert snapshot.reference._document_path == to_match._document_path + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def _make_async_collection_group(*args, **kwargs): + from google.cloud.firestore_v1.async_query import AsyncCollectionGroup + + return AsyncCollectionGroup(*args, **kwargs) + + +def test_asynccollectiongroup_constructor(): + query = _make_async_collection_group(mock.sentinel.parent) + assert query._parent is mock.sentinel.parent + assert query._projection is None + assert query._field_filters == () + assert query._orders == () + assert query._limit is None + assert query._offset is None + assert query._start_at is None + assert query._end_at is None + assert query._all_descendants + + +def test_asynccollectiongroup_constructor_all_descendents_is_false(): + with pytest.raises(ValueError): + _make_async_collection_group(mock.sentinel.parent, all_descendants=False) + + +@pytest.mark.asyncio +async def _get_partitions_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["partition_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Make two **real** document references to use as cursors + document1 = parent.document("one") + document2 = parent.document("two") + + # Add cursor pb's to the minimal fake GAPIC. + cursor_pb1 = _make_cursor_pb(([document1], False)) + cursor_pb2 = _make_cursor_pb(([document2], False)) + firestore_api.partition_query.return_value = AsyncIter([cursor_pb1, cursor_pb2]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = _make_async_collection_group(parent) + get_response = query.get_partitions(2, **kwargs) + + assert isinstance(get_response, types.AsyncGeneratorType) + returned = [i async for i in get_response] + assert len(returned) == 3 + + # Verify the mock call. + parent_path, _ = parent._parent_info() + partition_query = _make_async_collection_group( + parent, orders=(query._make_order("__name__", query.ASCENDING),), + ) + firestore_api.partition_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions(): + await _get_partitions_helper() + + +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_partitions_helper(retry=retry, timeout=timeout) + + +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_filter(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = _make_async_collection_group(parent).where("foo", "==", "bar") + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] -class TestCollectionGroup(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_query import AsyncCollectionGroup - - return AsyncCollectionGroup - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIsNone(query._projection) - self.assertEqual(query._field_filters, ()) - self.assertEqual(query._orders, ()) - self.assertIsNone(query._limit) - self.assertIsNone(query._offset) - self.assertIsNone(query._start_at) - self.assertIsNone(query._end_at) - self.assertTrue(query._all_descendants) - - def test_constructor_all_descendents_is_false(self): - with pytest.raises(ValueError): - self._make_one(mock.sentinel.parent, all_descendants=False) - - @pytest.mark.asyncio - async def _get_partitions_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC. - firestore_api = AsyncMock(spec=["partition_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - - # Make two **real** document references to use as cursors - document1 = parent.document("one") - document2 = parent.document("two") - - # Add cursor pb's to the minimal fake GAPIC. - cursor_pb1 = _make_cursor_pb(([document1], False)) - cursor_pb2 = _make_cursor_pb(([document2], False)) - firestore_api.partition_query.return_value = AsyncIter([cursor_pb1, cursor_pb2]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.get_partitions(2, **kwargs) - - self.assertIsInstance(get_response, types.AsyncGeneratorType) - returned = [i async for i in get_response] - self.assertEqual(len(returned), 3) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - partition_query = self._make_one( - parent, orders=(query._make_order("__name__", query.ASCENDING),), - ) - firestore_api.partition_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": partition_query._to_protobuf(), - "partition_count": 2, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_projection(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = _make_async_collection_group(parent).select("foo") + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + + +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_limit(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = _make_async_collection_group(parent).limit(10) + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] + - @pytest.mark.asyncio - async def test_get_partitions(self): - await self._get_partitions_helper() - - @pytest.mark.asyncio - async def test_get_partitions_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_partitions_helper(retry=retry, timeout=timeout) - - async def test_get_partitions_w_filter(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).where("foo", "==", "bar") - with pytest.raises(ValueError): - [i async for i in query.get_partitions(2)] - - async def test_get_partitions_w_projection(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).select("foo") - with pytest.raises(ValueError): - [i async for i in query.get_partitions(2)] - - async def test_get_partitions_w_limit(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).limit(10) - with pytest.raises(ValueError): - [i async for i in query.get_partitions(2)] - - async def test_get_partitions_w_offset(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).offset(10) - with pytest.raises(ValueError): - [i async for i in query.get_partitions(2)] +@pytest.mark.asyncio +async def test_asynccollectiongroup_get_partitions_w_offset(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = _make_async_collection_group(parent).offset(10) + with pytest.raises(ValueError): + [i async for i in query.get_partitions(2)] def _make_client(project="project-project"): diff --git a/tests/unit/v1/test_async_transaction.py b/tests/unit/v1/test_async_transaction.py index 2e0f572b074d9..81c7bdc08a7b2 100644 --- a/tests/unit/v1/test_async_transaction.py +++ b/tests/unit/v1/test_async_transaction.py @@ -12,1014 +12,1005 @@ # See the License for the specific language governing permissions and # limitations under the License. +import mock import pytest -import aiounittest -import mock from tests.unit.v1.test__helpers import AsyncMock -class TestAsyncTransaction(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_transaction import AsyncTransaction +def _make_async_transaction(*args, **kwargs): + from google.cloud.firestore_v1.async_transaction import AsyncTransaction - return AsyncTransaction + return AsyncTransaction(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor_defaults(self): - from google.cloud.firestore_v1.async_transaction import MAX_ATTEMPTS +def test_asynctransaction_constructor_defaults(): + from google.cloud.firestore_v1.async_transaction import MAX_ATTEMPTS - transaction = self._make_one(mock.sentinel.client) - self.assertIs(transaction._client, mock.sentinel.client) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) - self.assertFalse(transaction._read_only) - self.assertIsNone(transaction._id) + transaction = _make_async_transaction(mock.sentinel.client) + assert transaction._client is mock.sentinel.client + assert transaction._write_pbs == [] + assert transaction._max_attempts == MAX_ATTEMPTS + assert not transaction._read_only + assert transaction._id is None - def test_constructor_explicit(self): - transaction = self._make_one( - mock.sentinel.client, max_attempts=10, read_only=True - ) - self.assertIs(transaction._client, mock.sentinel.client) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, 10) - self.assertTrue(transaction._read_only) - self.assertIsNone(transaction._id) - def test__add_write_pbs_failure(self): - from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY +def test_asynctransaction_constructor_explicit(): + transaction = _make_async_transaction( + mock.sentinel.client, max_attempts=10, read_only=True + ) + assert transaction._client is mock.sentinel.client + assert transaction._write_pbs == [] + assert transaction._max_attempts == 10 + assert transaction._read_only + assert transaction._id is None - batch = self._make_one(mock.sentinel.client, read_only=True) - self.assertEqual(batch._write_pbs, []) - with self.assertRaises(ValueError) as exc_info: - batch._add_write_pbs([mock.sentinel.write]) - self.assertEqual(exc_info.exception.args, (_WRITE_READ_ONLY,)) - self.assertEqual(batch._write_pbs, []) +def test_asynctransaction__add_write_pbs_failure(): + from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY - def test__add_write_pbs(self): - batch = self._make_one(mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) + batch = _make_async_transaction(mock.sentinel.client, read_only=True) + assert batch._write_pbs == [] + with pytest.raises(ValueError) as exc_info: batch._add_write_pbs([mock.sentinel.write]) - self.assertEqual(batch._write_pbs, [mock.sentinel.write]) - - def test__clean_up(self): - transaction = self._make_one(mock.sentinel.client) - transaction._write_pbs.extend( - [mock.sentinel.write_pb1, mock.sentinel.write_pb2] - ) - transaction._id = b"not-this-time-my-friend" - - ret_val = transaction._clean_up() - self.assertIsNone(ret_val) - - self.assertEqual(transaction._write_pbs, []) - self.assertIsNone(transaction._id) - - @pytest.mark.asyncio - async def test__begin(self): - from google.cloud.firestore_v1.types import firestore - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - txn_id = b"to-begin" - response = firestore.BeginTransactionResponse(transaction=txn_id) - firestore_api.begin_transaction.return_value = response - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and ``begin()`` it. - transaction = self._make_one(client) - self.assertIsNone(transaction._id) - - ret_val = await transaction._begin() - self.assertIsNone(ret_val) - self.assertEqual(transaction._id, txn_id) - - # Verify the called mock. - firestore_api.begin_transaction.assert_called_once_with( - request={"database": client._database_string, "options": None}, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__begin_failure(self): - from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN - - client = _make_client() - transaction = self._make_one(client) - transaction._id = b"not-none" - - with self.assertRaises(ValueError) as exc_info: - await transaction._begin() - - err_msg = _CANT_BEGIN.format(transaction._id) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - @pytest.mark.asyncio - async def test__rollback(self): - from google.protobuf import empty_pb2 - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - firestore_api.rollback.return_value = empty_pb2.Empty() - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and roll it back. - transaction = self._make_one(client) - txn_id = b"to-be-r\x00lled" - transaction._id = txn_id - ret_val = await transaction._rollback() - self.assertIsNone(ret_val) - self.assertIsNone(transaction._id) - - # Verify the called mock. - firestore_api.rollback.assert_called_once_with( - request={"database": client._database_string, "transaction": txn_id}, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__rollback_not_allowed(self): - from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK - - client = _make_client() - transaction = self._make_one(client) - self.assertIsNone(transaction._id) - - with self.assertRaises(ValueError) as exc_info: - await transaction._rollback() - - self.assertEqual(exc_info.exception.args, (_CANT_ROLLBACK,)) - - @pytest.mark.asyncio - async def test__rollback_failure(self): - from google.api_core import exceptions - - # Create a minimal fake GAPIC with a dummy failure. - firestore_api = AsyncMock() - exc = exceptions.InternalServerError("Fire during rollback.") - firestore_api.rollback.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and roll it back. - transaction = self._make_one(client) - txn_id = b"roll-bad-server" - transaction._id = txn_id - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - await transaction._rollback() - - self.assertIs(exc_info.exception, exc) - self.assertIsNone(transaction._id) - self.assertEqual(transaction._write_pbs, []) - - # Verify the called mock. - firestore_api.rollback.assert_called_once_with( - request={"database": client._database_string, "transaction": txn_id}, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__commit(self): - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) - firestore_api.commit.return_value = commit_response - - # Attach the fake GAPIC to a real client. - client = _make_client("phone-joe") - client._firestore_api_internal = firestore_api - - # Actually make a transaction with some mutations and call _commit(). - transaction = self._make_one(client) - txn_id = b"under-over-thru-woods" - transaction._id = txn_id - document = client.document("zap", "galaxy", "ship", "space") - transaction.set(document, {"apple": 4.5}) - write_pbs = transaction._write_pbs[::] - - write_results = await transaction._commit() - self.assertEqual(write_results, list(commit_response.write_results)) - # Make sure transaction has no more "changes". - self.assertIsNone(transaction._id) - self.assertEqual(transaction._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__commit_not_allowed(self): - from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT - - transaction = self._make_one(mock.sentinel.client) - self.assertIsNone(transaction._id) - with self.assertRaises(ValueError) as exc_info: - await transaction._commit() - - self.assertEqual(exc_info.exception.args, (_CANT_COMMIT,)) - - @pytest.mark.asyncio - async def test__commit_failure(self): - from google.api_core import exceptions - - # Create a minimal fake GAPIC with a dummy failure. - firestore_api = AsyncMock() - exc = exceptions.InternalServerError("Fire during commit.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction with some mutations and call _commit(). - transaction = self._make_one(client) - txn_id = b"beep-fail-commit" - transaction._id = txn_id - transaction.create(client.document("up", "down"), {"water": 1.0}) - transaction.delete(client.document("up", "left")) - write_pbs = transaction._write_pbs[::] - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - await transaction._commit() - - self.assertIs(exc_info.exception, exc) - self.assertEqual(transaction._id, txn_id) - self.assertEqual(transaction._write_pbs, write_pbs) - - # Verify the called mock. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - async def _get_all_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - client = AsyncMock(spec=["get_all"]) - transaction = self._make_one(client) - ref1, ref2 = mock.Mock(), mock.Mock() - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = await transaction.get_all([ref1, ref2], **kwargs) - - client.get_all.assert_called_once_with( - [ref1, ref2], transaction=transaction, **kwargs, - ) - self.assertIs(result, client.get_all.return_value) - - @pytest.mark.asyncio - async def test_get_all(self): - await self._get_all_helper() - - @pytest.mark.asyncio - async def test_get_all_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_all_helper(retry=retry, timeout=timeout) - - async def _get_w_document_ref_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.async_document import AsyncDocumentReference - from google.cloud.firestore_v1 import _helpers - - client = AsyncMock(spec=["get_all"]) - transaction = self._make_one(client) - ref = AsyncDocumentReference("documents", "doc-id") - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = await transaction.get(ref, **kwargs) - - client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) - self.assertIs(result, client.get_all.return_value) - - @pytest.mark.asyncio - async def test_get_w_document_ref(self): - await self._get_w_document_ref_helper() - - @pytest.mark.asyncio - async def test_get_w_document_ref_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - await self._get_w_document_ref_helper(retry=retry, timeout=timeout) - - async def _get_w_query_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.async_query import AsyncQuery - from google.cloud.firestore_v1 import _helpers - - client = AsyncMock(spec=[]) - transaction = self._make_one(client) - query = AsyncQuery(parent=AsyncMock(spec=[])) - query.stream = AsyncMock() - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = await transaction.get(query, **kwargs,) - - query.stream.assert_called_once_with( - transaction=transaction, **kwargs, - ) - self.assertIs(result, query.stream.return_value) - - @pytest.mark.asyncio - async def test_get_w_query(self): - await self._get_w_query_helper() - - @pytest.mark.asyncio - async def test_get_w_query_w_retry_timeout(self): - await self._get_w_query_helper() - - @pytest.mark.asyncio - async def test_get_failure(self): - client = _make_client() - transaction = self._make_one(client) - ref_or_query = object() - with self.assertRaises(ValueError): - await transaction.get(ref_or_query) - - -class Test_Transactional(aiounittest.AsyncTestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.async_transaction import _AsyncTransactional - - return _AsyncTransactional - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - wrapped = self._make_one(mock.sentinel.callable_) - self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) - - @pytest.mark.asyncio - async def test__pre_commit_success(self): - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"totes-began" - transaction = _make_transaction(txn_id) - result = await wrapped._pre_commit(transaction, "pos", key="word") - self.assertIs(result, mock.sentinel.result) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "pos", key="word") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_not_called() - - @pytest.mark.asyncio - async def test__pre_commit_retry_id_already_set_success(self): - from google.cloud.firestore_v1.types import common - - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - txn_id1 = b"already-set" - wrapped.retry_id = txn_id1 - - txn_id2 = b"ok-here-too" - transaction = _make_transaction(txn_id2) - result = await wrapped._pre_commit(transaction) - self.assertIs(result, mock.sentinel.result) - - self.assertEqual(transaction._id, txn_id2) - self.assertEqual(wrapped.current_id, txn_id2) - self.assertEqual(wrapped.retry_id, txn_id1) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction) - firestore_api = transaction._client._firestore_api - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) - ) - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": options_, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_not_called() - - @pytest.mark.asyncio - async def test__pre_commit_failure(self): - exc = RuntimeError("Nope not today.") - to_wrap = AsyncMock(side_effect=exc, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"gotta-fail" - transaction = _make_transaction(txn_id) - with self.assertRaises(RuntimeError) as exc_info: - await wrapped._pre_commit(transaction, 10, 20) - self.assertIs(exc_info.exception, exc) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, 10, 20) - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - @pytest.mark.asyncio - async def test__pre_commit_failure_with_rollback_failure(self): - from google.api_core import exceptions - - exc1 = ValueError("I will not be only failure.") - to_wrap = AsyncMock(side_effect=exc1, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"both-will-fail" - transaction = _make_transaction(txn_id) - # Actually force the ``rollback`` to fail as well. - exc2 = exceptions.InternalServerError("Rollback blues.") - firestore_api = transaction._client._firestore_api - firestore_api.rollback.side_effect = exc2 - - # Try to ``_pre_commit`` - with self.assertRaises(exceptions.InternalServerError) as exc_info: - await wrapped._pre_commit(transaction, a="b", c="zebra") - self.assertIs(exc_info.exception, exc2) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, a="b", c="zebra") - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - @pytest.mark.asyncio - async def test__maybe_commit_success(self): - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"nyet" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - succeeded = await wrapped._maybe_commit(transaction) - self.assertTrue(succeeded) - - # On success, _id is reset. - self.assertIsNone(transaction._id) - - # Verify mocks. - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__maybe_commit_failure_read_only(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed" - transaction = _make_transaction(txn_id, read_only=True) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail (use ABORTED, but cannot - # retry since read-only). - exc = exceptions.Aborted("Read-only did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - with self.assertRaises(exceptions.Aborted) as exc_info: - await wrapped._maybe_commit(transaction) - self.assertIs(exc_info.exception, exc) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__maybe_commit_failure_can_retry(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed-but-retry" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Read-write did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - succeeded = await wrapped._maybe_commit(transaction) - self.assertFalse(succeeded) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test__maybe_commit_failure_cannot_retry(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed-but-not-retryable" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.InternalServerError("Real bad thing") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - await wrapped._maybe_commit(transaction) - self.assertIs(exc_info.exception, exc) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test___call__success_first_attempt(self): - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) - result = await wrapped(transaction, "a", b="c") - self.assertIs(result, mock.sentinel.result) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "a", b="c") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - @pytest.mark.asyncio - async def test___call__success_second_attempt(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) - - # Actually force the ``commit`` to fail on first / succeed on second. - exc = exceptions.Aborted("Contention junction.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = [ - exc, - firestore.CommitResponse(write_results=[write.WriteResult()]), - ] - - # Call the __call__-able ``wrapped``. - result = await wrapped(transaction, "a", b="c") - self.assertIs(result, mock.sentinel.result) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - wrapped_call = mock.call(transaction, "a", b="c") - self.assertEqual(to_wrap.mock_calls, [wrapped_call, wrapped_call]) - firestore_api = transaction._client._firestore_api - db_str = transaction._client._database_string - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) - ) - self.assertEqual( - firestore_api.begin_transaction.mock_calls, - [ - mock.call( - request={"database": db_str, "options": None}, - metadata=transaction._client._rpc_metadata, - ), - mock.call( - request={"database": db_str, "options": options_}, - metadata=transaction._client._rpc_metadata, - ), - ], - ) - firestore_api.rollback.assert_not_called() - commit_call = mock.call( - request={"database": db_str, "writes": [], "transaction": txn_id}, - metadata=transaction._client._rpc_metadata, - ) - self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) - - @pytest.mark.asyncio - async def test___call__failure(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.async_transaction import ( - _EXCEED_ATTEMPTS_TEMPLATE, - ) - - to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"only-one-shot" - transaction = _make_transaction(txn_id, max_attempts=1) - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Contention just once.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - # Call the __call__-able ``wrapped``. - with self.assertRaises(ValueError) as exc_info: - await wrapped(transaction, "here", there=1.5) - - err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "here", there=1.5) - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, + + assert exc_info.value.args == (_WRITE_READ_ONLY,) + assert batch._write_pbs == [] + + +def test_asynctransaction__add_write_pbs(): + batch = _make_async_transaction(mock.sentinel.client) + assert batch._write_pbs == [] + batch._add_write_pbs([mock.sentinel.write]) + assert batch._write_pbs == [mock.sentinel.write] + + +def test_asynctransaction__clean_up(): + transaction = _make_async_transaction(mock.sentinel.client) + transaction._write_pbs.extend([mock.sentinel.write_pb1, mock.sentinel.write_pb2]) + transaction._id = b"not-this-time-my-friend" + + ret_val = transaction._clean_up() + assert ret_val is None + + assert transaction._write_pbs == [] + assert transaction._id is None + + +@pytest.mark.asyncio +async def test_asynctransaction__begin(): + from google.cloud.firestore_v1.types import firestore + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() + txn_id = b"to-begin" + response = firestore.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and ``begin()`` it. + transaction = _make_async_transaction(client) + assert transaction._id is None + + ret_val = await transaction._begin() + assert ret_val is None + assert transaction._id == txn_id + + # Verify the called mock. + firestore_api.begin_transaction.assert_called_once_with( + request={"database": client._database_string, "options": None}, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransaction__begin_failure(): + from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN + + client = _make_client() + transaction = _make_async_transaction(client) + transaction._id = b"not-none" + + with pytest.raises(ValueError) as exc_info: + await transaction._begin() + + err_msg = _CANT_BEGIN.format(transaction._id) + assert exc_info.value.args == (err_msg,) + + +@pytest.mark.asyncio +async def test_asynctransaction__rollback(): + from google.protobuf import empty_pb2 + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() + firestore_api.rollback.return_value = empty_pb2.Empty() + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = _make_async_transaction(client) + txn_id = b"to-be-r\x00lled" + transaction._id = txn_id + ret_val = await transaction._rollback() + assert ret_val is None + assert transaction._id is None + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransaction__rollback_not_allowed(): + from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK + + client = _make_client() + transaction = _make_async_transaction(client) + assert transaction._id is None + + with pytest.raises(ValueError) as exc_info: + await transaction._rollback() + + assert exc_info.value.args == (_CANT_ROLLBACK,) + + +@pytest.mark.asyncio +async def test_asynctransaction__rollback_failure(): + from google.api_core import exceptions + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = AsyncMock() + exc = exceptions.InternalServerError("Fire during rollback.") + firestore_api.rollback.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = _make_async_transaction(client) + txn_id = b"roll-bad-server" + transaction._id = txn_id + + with pytest.raises(exceptions.InternalServerError) as exc_info: + await transaction._rollback() + + assert exc_info.value is exc + assert transaction._id is None + assert transaction._write_pbs == [] + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransaction__commit(): + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("phone-joe") + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = _make_async_transaction(client) + txn_id = b"under-over-thru-woods" + transaction._id = txn_id + document = client.document("zap", "galaxy", "ship", "space") + transaction.set(document, {"apple": 4.5}) + write_pbs = transaction._write_pbs[::] + + write_results = await transaction._commit() + assert write_results == list(commit_response.write_results) + # Make sure transaction has no more "changes". + assert transaction._id is None + assert transaction._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransaction__commit_not_allowed(): + from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT + + transaction = _make_async_transaction(mock.sentinel.client) + assert transaction._id is None + with pytest.raises(ValueError) as exc_info: + await transaction._commit() + + assert exc_info.value.args == (_CANT_COMMIT,) + + +@pytest.mark.asyncio +async def test_asynctransaction__commit_failure(): + from google.api_core import exceptions + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = AsyncMock() + exc = exceptions.InternalServerError("Fire during commit.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = _make_async_transaction(client) + txn_id = b"beep-fail-commit" + transaction._id = txn_id + transaction.create(client.document("up", "down"), {"water": 1.0}) + transaction.delete(client.document("up", "left")) + write_pbs = transaction._write_pbs[::] + + with pytest.raises(exceptions.InternalServerError) as exc_info: + await transaction._commit() + + assert exc_info.value is exc + assert transaction._id == txn_id + assert transaction._write_pbs == write_pbs + + # Verify the called mock. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +async def _get_all_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + + client = AsyncMock(spec=["get_all"]) + transaction = _make_async_transaction(client) + ref1, ref2 = mock.Mock(), mock.Mock() + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get_all([ref1, ref2], **kwargs) + + client.get_all.assert_called_once_with( + [ref1, ref2], transaction=transaction, **kwargs, + ) + assert result is client.get_all.return_value + + +@pytest.mark.asyncio +async def test_asynctransaction_get_all(): + await _get_all_helper() + + +@pytest.mark.asyncio +async def test_asynctransaction_get_all_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_all_helper(retry=retry, timeout=timeout) + + +async def _get_w_document_ref_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.async_document import AsyncDocumentReference + from google.cloud.firestore_v1 import _helpers + + client = AsyncMock(spec=["get_all"]) + transaction = _make_async_transaction(client) + ref = AsyncDocumentReference("documents", "doc-id") + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get(ref, **kwargs) + + client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) + assert result is client.get_all.return_value + + +@pytest.mark.asyncio +async def test_asynctransaction_get_w_document_ref(): + await _get_w_document_ref_helper() + + +@pytest.mark.asyncio +async def test_asynctransaction_get_w_document_ref_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + await _get_w_document_ref_helper(retry=retry, timeout=timeout) + + +async def _get_w_query_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.async_query import AsyncQuery + from google.cloud.firestore_v1 import _helpers + + client = AsyncMock(spec=[]) + transaction = _make_async_transaction(client) + query = AsyncQuery(parent=AsyncMock(spec=[])) + query.stream = AsyncMock() + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = await transaction.get(query, **kwargs,) + + query.stream.assert_called_once_with( + transaction=transaction, **kwargs, + ) + assert result is query.stream.return_value + + +@pytest.mark.asyncio +async def test_asynctransaction_get_w_query(): + await _get_w_query_helper() + + +@pytest.mark.asyncio +async def test_asynctransaction_get_w_query_w_retry_timeout(): + await _get_w_query_helper() + + +@pytest.mark.asyncio +async def test_asynctransaction_get_failure(): + client = _make_client() + transaction = _make_async_transaction(client) + ref_or_query = object() + with pytest.raises(ValueError): + await transaction.get(ref_or_query) + + +def _make_async_transactional(*args, **kwargs): + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional + + return _AsyncTransactional(*args, **kwargs) + + +def test_asynctransactional_constructor(): + wrapped = _make_async_transactional(mock.sentinel.callable_) + assert wrapped.to_wrap is mock.sentinel.callable_ + assert wrapped.current_id is None + assert wrapped.retry_id is None + + +@pytest.mark.asyncio +async def test_asynctransactional__pre_commit_success(): + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"totes-began" + transaction = _make_transaction(txn_id) + result = await wrapped._pre_commit(transaction, "pos", key="word") + assert result is mock.sentinel.result + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "pos", key="word") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_asynctransactional__pre_commit_retry_id_already_set_success(): + from google.cloud.firestore_v1.types import common + + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) + txn_id1 = b"already-set" + wrapped.retry_id = txn_id1 + + txn_id2 = b"ok-here-too" + transaction = _make_transaction(txn_id2) + result = await wrapped._pre_commit(transaction) + assert result is mock.sentinel.result + + assert transaction._id == txn_id2 + assert wrapped.current_id == txn_id2 + assert wrapped.retry_id == txn_id1 + + # Verify mocks. + to_wrap.assert_called_once_with(transaction) + firestore_api = transaction._client._firestore_api + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) + ) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": options_, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_asynctransactional__pre_commit_failure(): + exc = RuntimeError("Nope not today.") + to_wrap = AsyncMock(side_effect=exc, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"gotta-fail" + transaction = _make_transaction(txn_id) + with pytest.raises(RuntimeError) as exc_info: + await wrapped._pre_commit(transaction, 10, 20) + assert exc_info.value is exc + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, 10, 20) + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_asynctransactional__pre_commit_failure_with_rollback_failure(): + from google.api_core import exceptions + + exc1 = ValueError("I will not be only failure.") + to_wrap = AsyncMock(side_effect=exc1, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"both-will-fail" + transaction = _make_transaction(txn_id) + # Actually force the ``rollback`` to fail as well. + exc2 = exceptions.InternalServerError("Rollback blues.") + firestore_api = transaction._client._firestore_api + firestore_api.rollback.side_effect = exc2 + + # Try to ``_pre_commit`` + with pytest.raises(exceptions.InternalServerError) as exc_info: + await wrapped._pre_commit(transaction, a="b", c="zebra") + assert exc_info.value is exc2 + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, a="b", c="zebra") + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_asynctransactional__maybe_commit_success(): + wrapped = _make_async_transactional(mock.sentinel.callable_) + + txn_id = b"nyet" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + succeeded = await wrapped._maybe_commit(transaction) + assert succeeded + + # On success, _id is reset. + assert transaction._id is None + + # Verify mocks. + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransactional__maybe_commit_failure_read_only(): + from google.api_core import exceptions + + wrapped = _make_async_transactional(mock.sentinel.callable_) + + txn_id = b"failed" + transaction = _make_transaction(txn_id, read_only=True) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail (use ABORTED, but cannot + # retry since read-only). + exc = exceptions.Aborted("Read-only did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with pytest.raises(exceptions.Aborted) as exc_info: + await wrapped._maybe_commit(transaction) + assert exc_info.value is exc + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransactional__maybe_commit_failure_can_retry(): + from google.api_core import exceptions + + wrapped = _make_async_transactional(mock.sentinel.callable_) + + txn_id = b"failed-but-retry" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Read-write did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + succeeded = await wrapped._maybe_commit(transaction) + assert not succeeded + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransactional__maybe_commit_failure_cannot_retry(): + from google.api_core import exceptions + + wrapped = _make_async_transactional(mock.sentinel.callable_) + + txn_id = b"failed-but-not-retryable" + transaction = _make_transaction(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.InternalServerError("Real bad thing") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with pytest.raises(exceptions.InternalServerError) as exc_info: + await wrapped._maybe_commit(transaction) + assert exc_info.value is exc + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransactional___call__success_first_attempt(): + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + result = await wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "a", b="c") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +@pytest.mark.asyncio +async def test_asynctransactional___call__success_second_attempt(): + from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction(txn_id) + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = [ + exc, + firestore.CommitResponse(write_results=[write.WriteResult()]), + ] + + # Call the __call__-able ``wrapped``. + result = await wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + assert to_wrap.mock_calls == [wrapped_call, wrapped_call] + firestore_api = transaction._client._firestore_api + db_str = transaction._client._database_string + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={"database": db_str, "options": None}, metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, + ), + mock.call( + request={"database": db_str, "options": options_}, metadata=transaction._client._rpc_metadata, - ) + ), + ] + assert firestore_api.begin_transaction.mock_calls == expected_calls + firestore_api.rollback.assert_not_called() + commit_call = mock.call( + request={"database": db_str, "writes": [], "transaction": txn_id}, + metadata=transaction._client._rpc_metadata, + ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] + + +@pytest.mark.asyncio +async def test_asynctransactional___call__failure(): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _EXCEED_ATTEMPTS_TEMPLATE + + to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make_async_transactional(to_wrap) + + txn_id = b"only-one-shot" + transaction = _make_transaction(txn_id, max_attempts=1) + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + # Call the __call__-able ``wrapped``. + with pytest.raises(ValueError) as exc_info: + await wrapped(transaction, "here", there=1.5) + + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + assert exc_info.value.args == (err_msg,) + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) -class Test_async_transactional(aiounittest.AsyncTestCase): - @staticmethod - def _call_fut(to_wrap): - from google.cloud.firestore_v1.async_transaction import async_transactional +def test_async_transactional_factory(): + from google.cloud.firestore_v1.async_transaction import _AsyncTransactional + from google.cloud.firestore_v1.async_transaction import async_transactional - return async_transactional(to_wrap) + wrapped = async_transactional(mock.sentinel.callable_) + assert isinstance(wrapped, _AsyncTransactional) + assert wrapped.to_wrap is mock.sentinel.callable_ - def test_it(self): - from google.cloud.firestore_v1.async_transaction import _AsyncTransactional - wrapped = self._call_fut(mock.sentinel.callable_) - self.assertIsInstance(wrapped, _AsyncTransactional) - self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) +@mock.patch("google.cloud.firestore_v1.async_transaction._sleep") +@pytest.mark.asyncio +async def test__commit_with_retry_success_first_attempt(_sleep): + from google.cloud.firestore_v1.async_transaction import _commit_with_retry + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() -class Test__commit_with_retry(aiounittest.AsyncTestCase): - @staticmethod - @pytest.mark.asyncio - async def _call_fut(client, write_pbs, transaction_id): - from google.cloud.firestore_v1.async_transaction import _commit_with_retry + # Attach the fake GAPIC to a real client. + client = _make_client("summer") + client._firestore_api_internal = firestore_api - return await _commit_with_retry(client, write_pbs, transaction_id) + # Call function and check result. + txn_id = b"cheeeeeez" + commit_response = await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + assert commit_response is firestore_api.commit.return_value + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) - @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") - @pytest.mark.asyncio - async def test_success_first_attempt(self, _sleep): - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() +@mock.patch( + "google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0] +) +@pytest.mark.asyncio +async def test__commit_with_retry_success_third_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry - # Attach the fake GAPIC to a real client. - client = _make_client("summer") - client._firestore_api_internal = firestore_api + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() - # Call function and check result. - txn_id = b"cheeeeeez" - commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) - self.assertIs(commit_response, firestore_api.commit.return_value) + # Make sure the first two requests fail and the third succeeds. + firestore_api.commit.side_effect = [ + exceptions.ServiceUnavailable("Server sleepy."), + exceptions.ServiceUnavailable("Server groggy."), + mock.sentinel.commit_response, + ] - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) + # Attach the fake GAPIC to a real client. + client = _make_client("outside") + client._firestore_api_internal = firestore_api - @mock.patch( - "google.cloud.firestore_v1.async_transaction._sleep", side_effect=[2.0, 4.0] + # Call function and check result. + txn_id = b"the-world\x00" + commit_response = await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + assert commit_response is mock.sentinel.commit_response + + # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds + assert _sleep.call_count == 2 + _sleep.assert_any_call(1.0) + _sleep.assert_any_call(2.0) + # commit() called same way 3 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, ) - @pytest.mark.asyncio - async def test_success_third_attempt(self, _sleep): - from google.api_core import exceptions - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first two requests fail and the third succeeds. - firestore_api.commit.side_effect = [ - exceptions.ServiceUnavailable("Server sleepy."), - exceptions.ServiceUnavailable("Server groggy."), - mock.sentinel.commit_response, - ] - - # Attach the fake GAPIC to a real client. - client = _make_client("outside") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-world\x00" - commit_response = await self._call_fut(client, mock.sentinel.write_pbs, txn_id) - self.assertIs(commit_response, mock.sentinel.commit_response) - - # Verify mocks used. - # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds - self.assertEqual(_sleep.call_count, 2) - _sleep.assert_any_call(1.0) - _sleep.assert_any_call(2.0) - # commit() called same way 3 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - self.assertEqual( - firestore_api.commit.mock_calls, [commit_call, commit_call, commit_call] - ) - - @mock.patch("google.cloud.firestore_v1.async_transaction._sleep") - @pytest.mark.asyncio - async def test_failure_first_attempt(self, _sleep): - from google.api_core import exceptions - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first request fails with an un-retryable error. - exc = exceptions.ResourceExhausted("We ran out of fries.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" - with self.assertRaises(exceptions.ResourceExhausted) as exc_info: - await self._call_fut(client, mock.sentinel.write_pbs, txn_id) - - self.assertIs(exc_info.exception, exc) - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - @mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) - @pytest.mark.asyncio - async def test_failure_second_attempt(self, _sleep): - from google.api_core import exceptions - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = AsyncMock() - - # Make sure the first request fails retry-able and second - # fails non-retryable. - exc1 = exceptions.ServiceUnavailable("Come back next time.") - exc2 = exceptions.InternalServerError("Server on fritz.") - firestore_api.commit.side_effect = [exc1, exc2] - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-journey-when-and-where-well-go" - with self.assertRaises(exceptions.InternalServerError) as exc_info: - await self._call_fut(client, mock.sentinel.write_pbs, txn_id) - - self.assertIs(exc_info.exception, exc2) - - # Verify mocks used. - _sleep.assert_called_once_with(1.0) - # commit() called same way 2 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) - - -class Test__sleep(aiounittest.AsyncTestCase): - @staticmethod - @pytest.mark.asyncio - async def _call_fut(current_sleep, **kwargs): - from google.cloud.firestore_v1.async_transaction import _sleep - - return await _sleep(current_sleep, **kwargs) - - @mock.patch("random.uniform", return_value=5.5) - @mock.patch("asyncio.sleep", return_value=None) - @pytest.mark.asyncio - async def test_defaults(self, sleep, uniform): - curr_sleep = 10.0 - self.assertLessEqual(uniform.return_value, curr_sleep) - - new_sleep = await self._call_fut(curr_sleep) - self.assertEqual(new_sleep, 2.0 * curr_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - @mock.patch("random.uniform", return_value=10.5) - @mock.patch("asyncio.sleep", return_value=None) - @pytest.mark.asyncio - async def test_explicit(self, sleep, uniform): - curr_sleep = 12.25 - self.assertLessEqual(uniform.return_value, curr_sleep) - - multiplier = 1.5 - new_sleep = await self._call_fut( - curr_sleep, max_sleep=100.0, multiplier=multiplier - ) - self.assertEqual(new_sleep, multiplier * curr_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - @mock.patch("random.uniform", return_value=6.75) - @mock.patch("asyncio.sleep", return_value=None) - @pytest.mark.asyncio - async def test_exceeds_max(self, sleep, uniform): - curr_sleep = 20.0 - self.assertLessEqual(uniform.return_value, curr_sleep) - - max_sleep = 38.5 - new_sleep = await self._call_fut( - curr_sleep, max_sleep=max_sleep, multiplier=2.0 - ) - self.assertEqual(new_sleep, max_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) + assert firestore_api.commit.mock_calls == [commit_call, commit_call, commit_call] + + +@mock.patch("google.cloud.firestore_v1.async_transaction._sleep") +@pytest.mark.asyncio +async def test__commit_with_retry_failure_first_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() + + # Make sure the first request fails with an un-retryable error. + exc = exceptions.ResourceExhausted("We ran out of fries.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" + with pytest.raises(exceptions.ResourceExhausted) as exc_info: + await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + + assert exc_info.value is exc + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +@mock.patch("google.cloud.firestore_v1.async_transaction._sleep", return_value=2.0) +@pytest.mark.asyncio +async def test__commit_with_retry_failure_second_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.async_transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = AsyncMock() + + # Make sure the first request fails retry-able and second + # fails non-retryable. + exc1 = exceptions.ServiceUnavailable("Come back next time.") + exc2 = exceptions.InternalServerError("Server on fritz.") + firestore_api.commit.side_effect = [exc1, exc2] + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-journey-when-and-where-well-go" + with pytest.raises(exceptions.InternalServerError) as exc_info: + await _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + + assert exc_info.value is exc2 + + # Verify mocks used. + _sleep.assert_called_once_with(1.0) + # commit() called same way 2 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] + + +@mock.patch("random.uniform", return_value=5.5) +@mock.patch("asyncio.sleep", return_value=None) +@pytest.mark.asyncio +async def test_sleep_defaults(sleep, uniform): + from google.cloud.firestore_v1.async_transaction import _sleep + + curr_sleep = 10.0 + assert uniform.return_value <= curr_sleep + + new_sleep = await _sleep(curr_sleep) + assert new_sleep == 2.0 * curr_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +@mock.patch("random.uniform", return_value=10.5) +@mock.patch("asyncio.sleep", return_value=None) +@pytest.mark.asyncio +async def test_sleep_explicit(sleep, uniform): + from google.cloud.firestore_v1.async_transaction import _sleep + + curr_sleep = 12.25 + assert uniform.return_value <= curr_sleep + + multiplier = 1.5 + new_sleep = await _sleep(curr_sleep, max_sleep=100.0, multiplier=multiplier) + assert new_sleep == multiplier * curr_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +@mock.patch("random.uniform", return_value=6.75) +@mock.patch("asyncio.sleep", return_value=None) +@pytest.mark.asyncio +async def test_sleep_exceeds_max(sleep, uniform): + from google.cloud.firestore_v1.async_transaction import _sleep + + curr_sleep = 20.0 + assert uniform.return_value <= curr_sleep + + max_sleep = 38.5 + new_sleep = await _sleep(curr_sleep, max_sleep=max_sleep, multiplier=2.0) + assert new_sleep == max_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) def _make_credentials(): diff --git a/tests/unit/v1/test_base_batch.py b/tests/unit/v1/test_base_batch.py index 2706e9e867333..d47912055bf57 100644 --- a/tests/unit/v1/test_base_batch.py +++ b/tests/unit/v1/test_base_batch.py @@ -12,155 +12,153 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -from google.cloud.firestore_v1.base_batch import BaseWriteBatch - import mock -class DerivedBaseWriteBatch(BaseWriteBatch): - def __init__(self, client): - super().__init__(client=client) - - """Create a fake subclass of `BaseWriteBatch` for the purposes of - evaluating the shared methods.""" - - def commit(self): - pass # pragma: NO COVER - - -class TestBaseWriteBatch(unittest.TestCase): - @staticmethod - def _get_target_class(): - return DerivedBaseWriteBatch - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - batch = self._make_one(mock.sentinel.client) - self.assertIs(batch._client, mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - self.assertIsNone(batch.write_results) - self.assertIsNone(batch.commit_time) - - def test__add_write_pbs(self): - batch = self._make_one(mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) - self.assertEqual(batch._write_pbs, [mock.sentinel.write1, mock.sentinel.write2]) - - def test_create(self): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("this", "one") - document_data = {"a": 10, "b": 2.5} - ret_val = batch.create(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write.Write( - update=document.Document( - name=reference._document_path, - fields={ - "a": _value_pb(integer_value=document_data["a"]), - "b": _value_pb(double_value=document_data["b"]), - }, - ), - current_document=common.Precondition(exists=False), - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set(self): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data) - self.assertIsNone(ret_val) - new_write_pb = write.Write( - update=document.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ) - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_set_merge(self): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("another", "one") - field = "zapzap" - value = u"meadows and flowers" - document_data = {field: value} - ret_val = batch.set(reference, document_data, merge=True) - self.assertIsNone(ret_val) - new_write_pb = write.Write( - update=document.Document( - name=reference._document_path, - fields={field: _value_pb(string_value=value)}, - ), - update_mask={"field_paths": [field]}, - ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_update(self): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("cats", "cradle") - field_path = "head.foot" - value = u"knees toes shoulders" - field_updates = {field_path: value} - - ret_val = batch.update(reference, field_updates) - self.assertIsNone(ret_val) - - map_pb = document.MapValue(fields={"foot": _value_pb(string_value=value)}) - new_write_pb = write.Write( - update=document.Document( - name=reference._document_path, - fields={"head": _value_pb(map_value=map_pb)}, - ), - update_mask=common.DocumentMask(field_paths=[field_path]), - current_document=common.Precondition(exists=True), +def _make_derived_write_batch(*args, **kwargs): + from google.cloud.firestore_v1.base_batch import BaseWriteBatch + + class DerivedBaseWriteBatch(BaseWriteBatch): + def __init__(self, client): + super().__init__(client=client) + + """Create a fake subclass of `BaseWriteBatch` for the purposes of + evaluating the shared methods.""" + + def commit(self): + pass # pragma: NO COVER + + return DerivedBaseWriteBatch(*args, **kwargs) + + +def test_basewritebatch_constructor(): + batch = _make_derived_write_batch(mock.sentinel.client) + assert batch._client is mock.sentinel.client + assert batch._write_pbs == [] + assert batch.write_results is None + assert batch.commit_time is None + + +def test_basewritebatch__add_write_pbs(): + batch = _make_derived_write_batch(mock.sentinel.client) + assert batch._write_pbs == [] + batch._add_write_pbs([mock.sentinel.write1, mock.sentinel.write2]) + assert batch._write_pbs == [mock.sentinel.write1, mock.sentinel.write2] + + +def test_basewritebatch_create(): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + + client = _make_client() + batch = _make_derived_write_batch(client) + assert batch._write_pbs == [] + + reference = client.document("this", "one") + document_data = {"a": 10, "b": 2.5} + ret_val = batch.create(reference, document_data) + assert ret_val is None + new_write_pb = write.Write( + update=document.Document( + name=reference._document_path, + fields={ + "a": _value_pb(integer_value=document_data["a"]), + "b": _value_pb(double_value=document_data["b"]), + }, + ), + current_document=common.Precondition(exists=False), + ) + assert batch._write_pbs == [new_write_pb] + + +def test_basewritebatch_set(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + + client = _make_client() + batch = _make_derived_write_batch(client) + assert batch._write_pbs == [] + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data) + assert ret_val is None + new_write_pb = write.Write( + update=document.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, ) - self.assertEqual(batch._write_pbs, [new_write_pb]) - - def test_delete(self): - from google.cloud.firestore_v1.types import write - - client = _make_client() - batch = self._make_one(client) - self.assertEqual(batch._write_pbs, []) - - reference = client.document("early", "mornin", "dawn", "now") - ret_val = batch.delete(reference) - self.assertIsNone(ret_val) - new_write_pb = write.Write(delete=reference._document_path) - self.assertEqual(batch._write_pbs, [new_write_pb]) + ) + assert batch._write_pbs == [new_write_pb] + + +def test_basewritebatch_set_merge(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + + client = _make_client() + batch = _make_derived_write_batch(client) + assert batch._write_pbs == [] + + reference = client.document("another", "one") + field = "zapzap" + value = u"meadows and flowers" + document_data = {field: value} + ret_val = batch.set(reference, document_data, merge=True) + assert ret_val is None + new_write_pb = write.Write( + update=document.Document( + name=reference._document_path, + fields={field: _value_pb(string_value=value)}, + ), + update_mask={"field_paths": [field]}, + ) + assert batch._write_pbs == [new_write_pb] + + +def test_basewritebatch_update(): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + + client = _make_client() + batch = _make_derived_write_batch(client) + assert batch._write_pbs == [] + + reference = client.document("cats", "cradle") + field_path = "head.foot" + value = u"knees toes shoulders" + field_updates = {field_path: value} + + ret_val = batch.update(reference, field_updates) + assert ret_val is None + + map_pb = document.MapValue(fields={"foot": _value_pb(string_value=value)}) + new_write_pb = write.Write( + update=document.Document( + name=reference._document_path, fields={"head": _value_pb(map_value=map_pb)}, + ), + update_mask=common.DocumentMask(field_paths=[field_path]), + current_document=common.Precondition(exists=True), + ) + assert batch._write_pbs == [new_write_pb] + + +def test_basewritebatch_delete(): + from google.cloud.firestore_v1.types import write + + client = _make_client() + batch = _make_derived_write_batch(client) + assert batch._write_pbs == [] + + reference = client.document("early", "mornin", "dawn", "now") + ret_val = batch.delete(reference) + assert ret_val is None + new_write_pb = write.Write(delete=reference._document_path) + assert batch._write_pbs == [new_write_pb] def _value_pb(**kwargs): diff --git a/tests/unit/v1/test_base_client.py b/tests/unit/v1/test_base_client.py index 2af30a1a35856..42f9b25ca4375 100644 --- a/tests/unit/v1/test_base_client.py +++ b/tests/unit/v1/test_base_client.py @@ -13,424 +13,466 @@ # limitations under the License. import datetime -import unittest import grpc import mock +import pytest +PROJECT = "my-prahjekt" -class TestBaseClient(unittest.TestCase): - PROJECT = "my-prahjekt" +def _make_base_client(*args, **kwargs): + from google.cloud.firestore_v1.base_client import BaseClient - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.client import Client + return BaseClient(*args, **kwargs) - return Client - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) +def _make_default_base_client(): + credentials = _make_credentials() + return _make_base_client(project=PROJECT, credentials=credentials) - def _make_default_one(self): - credentials = _make_credentials() - return self._make_one(project=self.PROJECT, credentials=credentials) - def test_constructor_with_emulator_host_defaults(self): - from google.auth.credentials import AnonymousCredentials - from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT - from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST +def test_baseclient_constructor_with_emulator_host_defaults(): + from google.auth.credentials import AnonymousCredentials + from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST - emulator_host = "localhost:8081" + emulator_host = "localhost:8081" - with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): - client = self._make_one() + with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): + client = _make_base_client() - self.assertEqual(client._emulator_host, emulator_host) - self.assertIsInstance(client._credentials, AnonymousCredentials) - self.assertEqual(client.project, _DEFAULT_EMULATOR_PROJECT) + assert client._emulator_host == emulator_host + assert isinstance(client._credentials, AnonymousCredentials) + assert client.project == _DEFAULT_EMULATOR_PROJECT - def test_constructor_with_emulator_host_w_project(self): - from google.auth.credentials import AnonymousCredentials - from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST - emulator_host = "localhost:8081" +def test_baseclient_constructor_with_emulator_host_w_project(): + from google.auth.credentials import AnonymousCredentials + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST - with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): - client = self._make_one(project=self.PROJECT) + emulator_host = "localhost:8081" - self.assertEqual(client._emulator_host, emulator_host) - self.assertIsInstance(client._credentials, AnonymousCredentials) + with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): + client = _make_base_client(project=PROJECT) - def test_constructor_with_emulator_host_w_creds(self): - from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT - from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST + assert client._emulator_host == emulator_host + assert isinstance(client._credentials, AnonymousCredentials) - credentials = _make_credentials() - emulator_host = "localhost:8081" - with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): - client = self._make_one(credentials=credentials) +def test_baseclient_constructor_with_emulator_host_w_creds(): + from google.cloud.firestore_v1.base_client import _DEFAULT_EMULATOR_PROJECT + from google.cloud.firestore_v1.base_client import _FIRESTORE_EMULATOR_HOST + + credentials = _make_credentials() + emulator_host = "localhost:8081" + + with mock.patch("os.environ", {_FIRESTORE_EMULATOR_HOST: emulator_host}): + client = _make_base_client(credentials=credentials) + + assert client._emulator_host == emulator_host + assert client._credentials is credentials + assert client.project == _DEFAULT_EMULATOR_PROJECT + + +def test_baseclient__firestore_api_helper_w_already(): + client = _make_default_base_client() + internal = client._firestore_api_internal = mock.Mock() + + transport_class = mock.Mock() + client_class = mock.Mock() + client_module = mock.Mock() + + api = client._firestore_api_helper(transport_class, client_class, client_module) + + assert api is internal + transport_class.assert_not_called() + client_class.assert_not_called() - self.assertEqual(client._emulator_host, emulator_host) - self.assertIs(client._credentials, credentials) - self.assertEqual(client.project, _DEFAULT_EMULATOR_PROJECT) - @mock.patch( - "google.cloud.firestore_v1.services.firestore.client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, +def test_baseclient__firestore_api_helper_wo_emulator(): + client = _make_default_base_client() + client_options = client._client_options = mock.Mock() + target = client._target = mock.Mock() + assert client._firestore_api_internal is None + + transport_class = mock.Mock() + client_class = mock.Mock() + client_module = mock.Mock() + + api = client._firestore_api_helper(transport_class, client_class, client_module) + + assert api is client_class.return_value + assert client._firestore_api_internal is api + channel_options = {"grpc.keepalive_time_ms": 30000} + transport_class.create_channel.assert_called_once_with( + target, credentials=client._credentials, options=channel_options.items() ) - @mock.patch( - "google.cloud.firestore_v1.services.firestore.transports.grpc.FirestoreGrpcTransport", - autospec=True, + transport_class.assert_called_once_with( + host=target, channel=transport_class.create_channel.return_value, ) - def test__firestore_api_property(self, mock_channel, mock_client): - mock_client.DEFAULT_ENDPOINT = "endpoint" - client = self._make_default_one() - client_options = client._client_options = mock.Mock() - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - mock_client.assert_called_once_with( - transport=client._transport, client_options=client_options - ) + client_class.assert_called_once_with( + transport=transport_class.return_value, client_options=client_options + ) + + +def test_baseclient__firestore_api_helper_w_emulator(): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + client = _make_default_base_client() + + client_options = client._client_options = mock.Mock() + target = client._target = mock.Mock() + emulator_channel = client._emulator_channel = mock.Mock() + assert client._firestore_api_internal is None + + transport_class = mock.Mock(__name__="TestTransport") + client_class = mock.Mock() + client_module = mock.Mock() - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) + api = client._firestore_api_helper(transport_class, client_class, client_module) - @mock.patch( - "google.cloud.firestore_v1.services.firestore.client.FirestoreClient", - autospec=True, - return_value=mock.sentinel.firestore_api, + assert api is client_class.return_value + assert api is client._firestore_api_internal + + emulator_channel.assert_called_once_with(transport_class) + transport_class.assert_called_once_with( + host=target, channel=emulator_channel.return_value, ) - @mock.patch( - "google.cloud.firestore_v1.base_client.BaseClient._emulator_channel", - autospec=True, + client_class.assert_called_once_with( + transport=transport_class.return_value, client_options=client_options ) - def test__firestore_api_property_with_emulator( - self, mock_emulator_channel, mock_client - ): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - client = self._make_default_one() - - self.assertIsNone(client._firestore_api_internal) - firestore_api = client._firestore_api - self.assertIs(firestore_api, mock_client.return_value) - self.assertIs(firestore_api, client._firestore_api_internal) - - mock_emulator_channel.assert_called_once() - - # Call again to show that it is cached, but call count is still 1. - self.assertIs(client._firestore_api, mock_client.return_value) - self.assertEqual(mock_client.call_count, 1) - - def test___database_string_property(self): - credentials = _make_credentials() - database = "cheeeeez" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - self.assertIsNone(client._database_string_internal) - database_string = client._database_string - expected = "projects/{}/databases/{}".format(client.project, client._database) - self.assertEqual(database_string, expected) - self.assertIs(database_string, client._database_string_internal) - # Swap it out with a unique value to verify it is cached. - client._database_string_internal = mock.sentinel.cached - self.assertIs(client._database_string, mock.sentinel.cached) - def test___rpc_metadata_property(self): +def test_baseclient___database_string_property(): + credentials = _make_credentials() + database = "cheeeeez" + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database + ) + assert client._database_string_internal is None + database_string = client._database_string + expected = "projects/{}/databases/{}".format(client.project, client._database) + assert database_string == expected + assert database_string is client._database_string_internal + + # Swap it out with a unique value to verify it is cached. + client._database_string_internal = mock.sentinel.cached + assert client._database_string is mock.sentinel.cached + + +def test_baseclient___rpc_metadata_property(): + credentials = _make_credentials() + database = "quanta" + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database + ) + + assert client._rpc_metadata == [ + ("google-cloud-resource-prefix", client._database_string), + ] + + +def test_baseclient__rpc_metadata_property_with_emulator(): + emulator_host = "localhost:8081" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + credentials = _make_credentials() database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database ) - self.assertEqual( - client._rpc_metadata, - [("google-cloud-resource-prefix", client._database_string)], - ) + assert client._rpc_metadata == [ + ("google-cloud-resource-prefix", client._database_string), + ("authorization", "Bearer owner"), + ] - def test__rpc_metadata_property_with_emulator(self): - emulator_host = "localhost:8081" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - - credentials = _make_credentials() - database = "quanta" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - self.assertEqual( - client._rpc_metadata, - [ - ("google-cloud-resource-prefix", client._database_string), - ("authorization", "Bearer owner"), - ], + +def test_baseclient__emulator_channel(): + from google.cloud.firestore_v1.services.firestore.transports.grpc import ( + FirestoreGrpcTransport, + ) + from google.cloud.firestore_v1.services.firestore.transports.grpc_asyncio import ( + FirestoreGrpcAsyncIOTransport, + ) + + emulator_host = "localhost:8081" + credentials = _make_credentials() + database = "quanta" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + credentials.id_token = None + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database ) - def test_emulator_channel(self): - from google.cloud.firestore_v1.services.firestore.transports.grpc import ( - FirestoreGrpcTransport, + # checks that a channel is created + channel = client._emulator_channel(FirestoreGrpcTransport) + assert isinstance(channel, grpc.Channel) + channel = client._emulator_channel(FirestoreGrpcAsyncIOTransport) + assert isinstance(channel, grpc.aio.Channel) + + # Verify that when credentials are provided with an id token it is used + # for channel construction + # NOTE: On windows, emulation requires an insecure channel. If this is + # altered to use a secure channel, start by verifying that it still + # works as expected on windows. + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + credentials.id_token = "test" + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database ) - from google.cloud.firestore_v1.services.firestore.transports.grpc_asyncio import ( - FirestoreGrpcAsyncIOTransport, + with mock.patch("grpc.insecure_channel") as insecure_channel: + channel = client._emulator_channel(FirestoreGrpcTransport) + insecure_channel.assert_called_once_with( + emulator_host, options=[("Authorization", "Bearer test")] ) - emulator_host = "localhost:8081" - credentials = _make_credentials() - database = "quanta" - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - credentials.id_token = None - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - - # checks that a channel is created - channel = client._emulator_channel(FirestoreGrpcTransport) - self.assertTrue(isinstance(channel, grpc.Channel)) - channel = client._emulator_channel(FirestoreGrpcAsyncIOTransport) - self.assertTrue(isinstance(channel, grpc.aio.Channel)) - # Verify that when credentials are provided with an id token it is used - # for channel construction - # NOTE: On windows, emulation requires an insecure channel. If this is - # altered to use a secure channel, start by verifying that it still - # works as expected on windows. - with mock.patch("os.getenv") as getenv: - getenv.return_value = emulator_host - credentials.id_token = "test" - client = self._make_one( - project=self.PROJECT, credentials=credentials, database=database - ) - with mock.patch("grpc.insecure_channel") as insecure_channel: - channel = client._emulator_channel(FirestoreGrpcTransport) - insecure_channel.assert_called_once_with( - emulator_host, options=[("Authorization", "Bearer test")] - ) +def test_baseclient__target_helper_w_emulator_host(): + emulator_host = "localhost:8081" + credentials = _make_credentials() + database = "quanta" + with mock.patch("os.getenv") as getenv: + getenv.return_value = emulator_host + credentials.id_token = None + client = _make_base_client( + project=PROJECT, credentials=credentials, database=database + ) + + assert client._target_helper(None) == emulator_host - def test_field_path(self): - klass = self._get_target_class() - self.assertEqual(klass.field_path("a", "b", "c"), "a.b.c") - def test_write_option_last_update(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1._helpers import LastUpdateOption +def test_baseclient__target_helper_w_client_options_w_endpoint(): + credentials = _make_credentials() + endpoint = "https://api.example.com/firestore" + client_options = {"api_endpoint": endpoint} + client = _make_base_client( + project=PROJECT, credentials=credentials, client_options=client_options, + ) - timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) + assert client._target_helper(None) == endpoint - klass = self._get_target_class() - option = klass.write_option(last_update_time=timestamp) - self.assertIsInstance(option, LastUpdateOption) - self.assertEqual(option._last_update_time, timestamp) - def test_write_option_exists(self): - from google.cloud.firestore_v1._helpers import ExistsOption +def test_baseclient__target_helper_w_client_options_wo_endpoint(): + credentials = _make_credentials() + endpoint = "https://api.example.com/firestore" + client_options = {} + client_class = mock.Mock(instance=False, DEFAULT_ENDPOINT=endpoint) + client = _make_base_client( + project=PROJECT, credentials=credentials, client_options=client_options, + ) - klass = self._get_target_class() + assert client._target_helper(client_class) == endpoint - option1 = klass.write_option(exists=False) - self.assertIsInstance(option1, ExistsOption) - self.assertFalse(option1._exists) - option2 = klass.write_option(exists=True) - self.assertIsInstance(option2, ExistsOption) - self.assertTrue(option2._exists) +def test_baseclient__target_helper_wo_client_options(): + credentials = _make_credentials() + endpoint = "https://api.example.com/firestore" + client_class = mock.Mock(instance=False, DEFAULT_ENDPOINT=endpoint) + client = _make_base_client(project=PROJECT, credentials=credentials,) - def test_write_open_neither_arg(self): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + assert client._target_helper(client_class) == endpoint - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option() - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) +def test_baseclient_field_path(): + from google.cloud.firestore_v1.base_client import BaseClient - def test_write_multiple_args(self): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + assert BaseClient.field_path("a", "b", "c") == "a.b.c" - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(exists=False, last_update_time=mock.sentinel.timestamp) - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR,)) +def test_baseclient_write_option_last_update(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1._helpers import LastUpdateOption + from google.cloud.firestore_v1.base_client import BaseClient - def test_write_bad_arg(self): - from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + timestamp = timestamp_pb2.Timestamp(seconds=1299767599, nanos=811111097) - klass = self._get_target_class() - with self.assertRaises(TypeError) as exc_info: - klass.write_option(spinach="popeye") + option = BaseClient.write_option(last_update_time=timestamp) + assert isinstance(option, LastUpdateOption) + assert option._last_update_time == timestamp - extra = "{!r} was provided".format("spinach") - self.assertEqual(exc_info.exception.args, (_BAD_OPTION_ERR, extra)) +def test_baseclient_write_option_exists(): + from google.cloud.firestore_v1._helpers import ExistsOption + from google.cloud.firestore_v1.base_client import BaseClient -class Test__reference_info(unittest.TestCase): - @staticmethod - def _call_fut(references): - from google.cloud.firestore_v1.base_client import _reference_info + option1 = BaseClient.write_option(exists=False) + assert isinstance(option1, ExistsOption) + assert not option1._exists - return _reference_info(references) + option2 = BaseClient.write_option(exists=True) + assert isinstance(option2, ExistsOption) + assert option2._exists - def test_it(self): - from google.cloud.firestore_v1.client import Client - credentials = _make_credentials() - client = Client(project="hi-projject", credentials=credentials) +def test_baseclient_write_open_neither_arg(): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.base_client import BaseClient - reference1 = client.document("a", "b") - reference2 = client.document("a", "b", "c", "d") - reference3 = client.document("a", "b") - reference4 = client.document("f", "g") + with pytest.raises(TypeError) as exc_info: + BaseClient.write_option() - doc_path1 = reference1._document_path - doc_path2 = reference2._document_path - doc_path3 = reference3._document_path - doc_path4 = reference4._document_path - self.assertEqual(doc_path1, doc_path3) + assert exc_info.value.args == (_BAD_OPTION_ERR,) - document_paths, reference_map = self._call_fut( - [reference1, reference2, reference3, reference4] - ) - self.assertEqual(document_paths, [doc_path1, doc_path2, doc_path3, doc_path4]) - # reference3 over-rides reference1. - expected_map = { - doc_path2: reference2, - doc_path3: reference3, - doc_path4: reference4, - } - self.assertEqual(reference_map, expected_map) +def test_baseclient_write_multiple_args(): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.base_client import BaseClient -class Test__get_reference(unittest.TestCase): - @staticmethod - def _call_fut(document_path, reference_map): - from google.cloud.firestore_v1.base_client import _get_reference + with pytest.raises(TypeError) as exc_info: + BaseClient.write_option(exists=False, last_update_time=mock.sentinel.timestamp) - return _get_reference(document_path, reference_map) + assert exc_info.value.args == (_BAD_OPTION_ERR,) - def test_success(self): - doc_path = "a/b/c" - reference_map = {doc_path: mock.sentinel.reference} - self.assertIs(self._call_fut(doc_path, reference_map), mock.sentinel.reference) - def test_failure(self): - from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE +def test_baseclient_write_bad_arg(): + from google.cloud.firestore_v1.base_client import _BAD_OPTION_ERR + from google.cloud.firestore_v1.base_client import BaseClient - doc_path = "1/888/call-now" - with self.assertRaises(ValueError) as exc_info: - self._call_fut(doc_path, {}) + with pytest.raises(TypeError) as exc_info: + BaseClient.write_option(spinach="popeye") - err_msg = _BAD_DOC_TEMPLATE.format(doc_path) - self.assertEqual(exc_info.exception.args, (err_msg,)) + extra = "{!r} was provided".format("spinach") + assert exc_info.value.args == (_BAD_OPTION_ERR, extra) -class Test__parse_batch_get(unittest.TestCase): - @staticmethod - def _call_fut(get_doc_response, reference_map, client=mock.sentinel.client): - from google.cloud.firestore_v1.base_client import _parse_batch_get +def test__reference_info(): + from google.cloud.firestore_v1.base_client import _reference_info - return _parse_batch_get(get_doc_response, reference_map, client) + expected_doc_paths = ["/a/b", "/a/b/c/d", "/a/b", "/f/g"] + documents = [mock.Mock(_document_path=path) for path in expected_doc_paths] - @staticmethod - def _dummy_ref_string(): - from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE + document_paths, reference_map = _reference_info(documents) - project = u"bazzzz" - collection_id = u"fizz" - document_id = u"buzz" - return u"projects/{}/databases/{}/documents/{}/{}".format( - project, DEFAULT_DATABASE, collection_id, document_id - ) + assert document_paths == expected_doc_paths + # reference3 over-rides reference1. + expected_map = { + path: document + for path, document in list(zip(expected_doc_paths, documents))[1:] + } + assert reference_map == expected_map - def test_found(self): - from google.cloud.firestore_v1.types import document - from google.cloud._helpers import _datetime_to_pb_timestamp - from google.cloud.firestore_v1.document import DocumentSnapshot - - now = datetime.datetime.utcnow() - read_time = _datetime_to_pb_timestamp(now) - delta = datetime.timedelta(seconds=100) - update_time = _datetime_to_pb_timestamp(now - delta) - create_time = _datetime_to_pb_timestamp(now - 2 * delta) - - ref_string = self._dummy_ref_string() - document_pb = document.Document( - name=ref_string, - fields={ - "foo": document.Value(double_value=1.5), - "bar": document.Value(string_value=u"skillz"), - }, - create_time=create_time, - update_time=update_time, - ) - response_pb = _make_batch_response(found=document_pb, read_time=read_time) - - reference_map = {ref_string: mock.sentinel.reference} - snapshot = self._call_fut(response_pb, reference_map) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, mock.sentinel.reference) - self.assertEqual(snapshot._data, {"foo": 1.5, "bar": u"skillz"}) - self.assertTrue(snapshot._exists) - self.assertEqual(snapshot.read_time.timestamp_pb(), read_time) - self.assertEqual(snapshot.create_time.timestamp_pb(), create_time) - self.assertEqual(snapshot.update_time.timestamp_pb(), update_time) - - def test_missing(self): - from google.cloud.firestore_v1.document import DocumentReference - - ref_string = self._dummy_ref_string() - response_pb = _make_batch_response(missing=ref_string) - document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) - reference_map = {ref_string: document} - snapshot = self._call_fut(response_pb, reference_map) - self.assertFalse(snapshot.exists) - self.assertEqual(snapshot.id, "bazz") - self.assertIsNone(snapshot._data) - - def test_unset_result_type(self): - response_pb = _make_batch_response() - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - def test_unknown_result_type(self): - response_pb = mock.Mock() - response_pb._pb.mock_add_spec(spec=["WhichOneof"]) - response_pb._pb.WhichOneof.return_value = "zoob_value" - - with self.assertRaises(ValueError): - self._call_fut(response_pb, {}) - - response_pb._pb.WhichOneof.assert_called_once_with("result") - - -class Test__get_doc_mask(unittest.TestCase): - @staticmethod - def _call_fut(field_paths): - from google.cloud.firestore_v1.base_client import _get_doc_mask - - return _get_doc_mask(field_paths) - - def test_none(self): - self.assertIsNone(self._call_fut(None)) - - def test_paths(self): - from google.cloud.firestore_v1.types import common - - field_paths = ["a.b", "c"] - result = self._call_fut(field_paths) - expected = common.DocumentMask(field_paths=field_paths) - self.assertEqual(result, expected) + +def test__get_reference_success(): + from google.cloud.firestore_v1.base_client import _get_reference + + doc_path = "a/b/c" + reference_map = {doc_path: mock.sentinel.reference} + assert _get_reference(doc_path, reference_map) is mock.sentinel.reference + + +def test__get_reference_failure(): + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + from google.cloud.firestore_v1.base_client import _get_reference + + doc_path = "1/888/call-now" + with pytest.raises(ValueError) as exc_info: + _get_reference(doc_path, {}) + + err_msg = _BAD_DOC_TEMPLATE.format(doc_path) + assert exc_info.value.args == (err_msg,) + + +def _dummy_ref_string(): + from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE + + project = u"bazzzz" + collection_id = u"fizz" + document_id = u"buzz" + return u"projects/{}/databases/{}/documents/{}/{}".format( + project, DEFAULT_DATABASE, collection_id, document_id + ) + + +def test__parse_batch_get_found(): + from google.cloud.firestore_v1.types import document + from google.cloud._helpers import _datetime_to_pb_timestamp + from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.base_client import _parse_batch_get + + now = datetime.datetime.utcnow() + read_time = _datetime_to_pb_timestamp(now) + delta = datetime.timedelta(seconds=100) + update_time = _datetime_to_pb_timestamp(now - delta) + create_time = _datetime_to_pb_timestamp(now - 2 * delta) + + ref_string = _dummy_ref_string() + document_pb = document.Document( + name=ref_string, + fields={ + "foo": document.Value(double_value=1.5), + "bar": document.Value(string_value=u"skillz"), + }, + create_time=create_time, + update_time=update_time, + ) + response_pb = _make_batch_response(found=document_pb, read_time=read_time) + + reference_map = {ref_string: mock.sentinel.reference} + snapshot = _parse_batch_get(response_pb, reference_map, mock.sentinel.client) + assert isinstance(snapshot, DocumentSnapshot) + assert snapshot._reference is mock.sentinel.reference + assert snapshot._data == {"foo": 1.5, "bar": u"skillz"} + assert snapshot._exists + assert snapshot.read_time.timestamp_pb() == read_time + assert snapshot.create_time.timestamp_pb() == create_time + assert snapshot.update_time.timestamp_pb() == update_time + + +def test__parse_batch_get_missing(): + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.base_client import _parse_batch_get + + ref_string = _dummy_ref_string() + response_pb = _make_batch_response(missing=ref_string) + document = DocumentReference("fizz", "bazz", client=mock.sentinel.client) + reference_map = {ref_string: document} + snapshot = _parse_batch_get(response_pb, reference_map, mock.sentinel.client) + assert not snapshot.exists + assert snapshot.id == "bazz" + assert snapshot._data is None + + +def test__parse_batch_get_unset_result_type(): + from google.cloud.firestore_v1.base_client import _parse_batch_get + + response_pb = _make_batch_response() + with pytest.raises(ValueError): + _parse_batch_get(response_pb, {}, mock.sentinel.client) + + +def test__parse_batch_get_unknown_result_type(): + from google.cloud.firestore_v1.base_client import _parse_batch_get + + response_pb = mock.Mock() + response_pb._pb.mock_add_spec(spec=["WhichOneof"]) + response_pb._pb.WhichOneof.return_value = "zoob_value" + + with pytest.raises(ValueError): + _parse_batch_get(response_pb, {}, mock.sentinel.client) + + response_pb._pb.WhichOneof.assert_called_once_with("result") + + +def test__get_doc_mask_w_none(): + from google.cloud.firestore_v1.base_client import _get_doc_mask + + assert _get_doc_mask(None) is None + + +def test__get_doc_mask_w_paths(): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.base_client import _get_doc_mask + + field_paths = ["a.b", "c"] + result = _get_doc_mask(field_paths) + expected = common.DocumentMask(field_paths=field_paths) + assert result == expected def _make_credentials(): diff --git a/tests/unit/v1/test_base_collection.py b/tests/unit/v1/test_base_collection.py index 01c68483a63b5..8d4b7833368d3 100644 --- a/tests/unit/v1/test_base_collection.py +++ b/tests/unit/v1/test_base_collection.py @@ -12,331 +12,345 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest + + +def _make_base_collection_reference(*args, **kwargs): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference + + return BaseCollectionReference(*args, **kwargs) + + +def test_basecollectionreference_ctor(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_base_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + assert collection._client is client + expected_path = (collection_id1, document_id, collection_id2) + assert collection._path == expected_path + + +def test_basecollectionreference_ctor_invalid_path_empty(): + with pytest.raises(ValueError): + _make_base_collection_reference() + + +def test_basecollectionreference_ctor_invalid_path_bad_collection_id(): + with pytest.raises(ValueError): + _make_base_collection_reference(99, "doc", "bad-collection-id") + + +def test_basecollectionreference_ctor_invalid_path_bad_document_id(): + with pytest.raises(ValueError): + _make_base_collection_reference("bad-document-ID", None, "sub-collection") + + +def test_basecollectionreference_ctor_invalid_path_bad_number_args(): + with pytest.raises(ValueError): + _make_base_collection_reference("Just", "A-Document") + + +def test_basecollectionreference_ctor_invalid_kwarg(): + with pytest.raises(TypeError): + _make_base_collection_reference("Coh-lek-shun", donut=True) + + +def test_basecollectionreference___eq___other_type(): + client = mock.sentinel.client + collection = _make_base_collection_reference("name", client=client) + other = object() + assert not collection == other + + +def test_basecollectionreference___eq___different_path_same_client(): + client = mock.sentinel.client + collection = _make_base_collection_reference("name", client=client) + other = _make_base_collection_reference("other", client=client) + assert not collection == other + + +def test_basecollectionreference___eq___same_path_different_client(): + client = mock.sentinel.client + other_client = mock.sentinel.other_client + collection = _make_base_collection_reference("name", client=client) + other = _make_base_collection_reference("name", client=other_client) + assert not collection == other + + +def test_basecollectionreference___eq___same_path_same_client(): + client = mock.sentinel.client + collection = _make_base_collection_reference("name", client=client) + other = _make_base_collection_reference("name", client=client) + assert collection == other + + +def test_basecollectionreference_id_property(): + collection_id = "hi-bob" + collection = _make_base_collection_reference(collection_id) + assert collection.id == collection_id + + +def test_basecollectionreference_parent_property(): + from google.cloud.firestore_v1.document import DocumentReference + + collection_id1 = "grocery-store" + document_id = "market" + collection_id2 = "darth" + client = _make_client() + collection = _make_base_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + + parent = collection.parent + assert isinstance(parent, DocumentReference) + assert parent._client is client + assert parent._path == (collection_id1, document_id) + + +def test_basecollectionreference_parent_property_top_level(): + collection = _make_base_collection_reference("tahp-leh-vull") + assert collection.parent is None + + +def test_basecollectionreference_document_factory_explicit_id(): + from google.cloud.firestore_v1.document import DocumentReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + collection = _make_base_collection_reference(collection_id, client=client) + + child = collection.document(document_id) + assert isinstance(child, DocumentReference) + assert child._client is client + assert child._path == (collection_id, document_id) + + +@mock.patch( + "google.cloud.firestore_v1.base_collection._auto_id", + return_value="zorpzorpthreezorp012", +) +def test_basecollectionreference_document_factory_auto_id(mock_auto_id): + from google.cloud.firestore_v1.document import DocumentReference + + collection_name = "space-town" + client = _make_client() + collection = _make_base_collection_reference(collection_name, client=client) + + child = collection.document() + assert isinstance(child, DocumentReference) + assert child._client is client + assert child._path == (collection_name, mock_auto_id.return_value) + + mock_auto_id.assert_called_once_with() + + +def test_basecollectionreference__parent_info_top_level(): + client = _make_client() + collection_id = "soap" + collection = _make_base_collection_reference(collection_id, client=client) + + parent_path, expected_prefix = collection._parent_info() + + expected_path = "projects/{}/databases/{}/documents".format( + client.project, client._database + ) + assert parent_path == expected_path + prefix = "{}/{}".format(expected_path, collection_id) + assert expected_prefix == prefix + + +def test_basecollectionreference__parent_info_nested(): + collection_id1 = "bar" + document_id = "baz" + collection_id2 = "chunk" + client = _make_client() + collection = _make_base_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + parent_path, expected_prefix = collection._parent_info() -class TestCollectionReference(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference - - return BaseCollectionReference - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - collection_id1 = "rooms" - document_id = "roomA" - collection_id2 = "messages" - client = mock.sentinel.client - - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - self.assertIs(collection._client, client) - expected_path = (collection_id1, document_id, collection_id2) - self.assertEqual(collection._path, expected_path) - - def test_constructor_invalid_path_empty(self): - with self.assertRaises(ValueError): - self._make_one() - - def test_constructor_invalid_path_bad_collection_id(self): - with self.assertRaises(ValueError): - self._make_one(99, "doc", "bad-collection-id") - - def test_constructor_invalid_path_bad_document_id(self): - with self.assertRaises(ValueError): - self._make_one("bad-document-ID", None, "sub-collection") - - def test_constructor_invalid_path_bad_number_args(self): - with self.assertRaises(ValueError): - self._make_one("Just", "A-Document") - - def test_constructor_invalid_kwarg(self): - with self.assertRaises(TypeError): - self._make_one("Coh-lek-shun", donut=True) - - def test___eq___other_type(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = object() - self.assertFalse(collection == other) - - def test___eq___different_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("other", client=client) - self.assertFalse(collection == other) - - def test___eq___same_path_different_client(self): - client = mock.sentinel.client - other_client = mock.sentinel.other_client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=other_client) - self.assertFalse(collection == other) - - def test___eq___same_path_same_client(self): - client = mock.sentinel.client - collection = self._make_one("name", client=client) - other = self._make_one("name", client=client) - self.assertTrue(collection == other) - - def test_id_property(self): - collection_id = "hi-bob" - collection = self._make_one(collection_id) - self.assertEqual(collection.id, collection_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.document import DocumentReference - - collection_id1 = "grocery-store" - document_id = "market" - collection_id2 = "darth" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - - parent = collection.parent - self.assertIsInstance(parent, DocumentReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id1, document_id)) - - def test_parent_property_top_level(self): - collection = self._make_one("tahp-leh-vull") - self.assertIsNone(collection.parent) - - def test_document_factory_explicit_id(self): - from google.cloud.firestore_v1.document import DocumentReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - collection = self._make_one(collection_id, client=client) - - child = collection.document(document_id) - self.assertIsInstance(child, DocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id)) - - @mock.patch( - "google.cloud.firestore_v1.base_collection._auto_id", - return_value="zorpzorpthreezorp012", + expected_path = "projects/{}/databases/{}/documents/{}/{}".format( + client.project, client._database, collection_id1, document_id ) - def test_document_factory_auto_id(self, mock_auto_id): - from google.cloud.firestore_v1.document import DocumentReference + assert parent_path == expected_path + prefix = "{}/{}".format(expected_path, collection_id2) + assert expected_prefix == prefix - collection_name = "space-town" - client = _make_client() - collection = self._make_one(collection_name, client=client) - child = collection.document() - self.assertIsInstance(child, DocumentReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_name, mock_auto_id.return_value)) - - mock_auto_id.assert_called_once_with() +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_select(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - def test__parent_info_top_level(self): - client = _make_client() - collection_id = "soap" - collection = self._make_one(collection_id, client=client) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - parent_path, expected_prefix = collection._parent_info() + collection = _make_base_collection_reference("collection") + field_paths = ["a", "b"] + query = collection.select(field_paths) - expected_path = "projects/{}/databases/{}/documents".format( - client.project, client._database - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id) - self.assertEqual(expected_prefix, prefix) + mock_query.select.assert_called_once_with(field_paths) + assert query == mock_query.select.return_value - def test__parent_info_nested(self): - collection_id1 = "bar" - document_id = "baz" - collection_id2 = "chunk" - client = _make_client() - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - parent_path, expected_prefix = collection._parent_info() +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_where(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - expected_path = "projects/{}/databases/{}/documents/{}/{}".format( - client.project, client._database, collection_id1, document_id - ) - self.assertEqual(parent_path, expected_path) - prefix = "{}/{}".format(expected_path, collection_id2) - self.assertEqual(expected_prefix, prefix) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_select(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + collection = _make_base_collection_reference("collection") + field_path = "foo" + op_string = "==" + value = 45 + query = collection.where(field_path, op_string, value) - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + mock_query.where.assert_called_once_with(field_path, op_string, value) + assert query == mock_query.where.return_value - collection = self._make_one("collection") - field_paths = ["a", "b"] - query = collection.select(field_paths) - mock_query.select.assert_called_once_with(field_paths) - self.assertEqual(query, mock_query.select.return_value) +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_order_by(mock_query): + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_where(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + collection = _make_base_collection_reference("collection") + field_path = "foo" + direction = BaseQuery.DESCENDING + query = collection.order_by(field_path, direction=direction) - collection = self._make_one("collection") - field_path = "foo" - op_string = "==" - value = 45 - query = collection.where(field_path, op_string, value) + mock_query.order_by.assert_called_once_with(field_path, direction=direction) + assert query == mock_query.order_by.return_value - mock_query.where.assert_called_once_with(field_path, op_string, value) - self.assertEqual(query, mock_query.where.return_value) - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_order_by(self, mock_query): - from google.cloud.firestore_v1.base_query import BaseQuery - from google.cloud.firestore_v1.base_collection import BaseCollectionReference +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_limit(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - collection = self._make_one("collection") - field_path = "foo" - direction = BaseQuery.DESCENDING - query = collection.order_by(field_path, direction=direction) + collection = _make_base_collection_reference("collection") + limit = 15 + query = collection.limit(limit) - mock_query.order_by.assert_called_once_with(field_path, direction=direction) - self.assertEqual(query, mock_query.order_by.return_value) + mock_query.limit.assert_called_once_with(limit) + assert query == mock_query.limit.return_value - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_limit(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_limit_to_last(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - collection = self._make_one("collection") - limit = 15 - query = collection.limit(limit) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - mock_query.limit.assert_called_once_with(limit) - self.assertEqual(query, mock_query.limit.return_value) + collection = _make_base_collection_reference("collection") + limit = 15 + query = collection.limit_to_last(limit) - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_limit_to_last(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + mock_query.limit_to_last.assert_called_once_with(limit) + assert query == mock_query.limit_to_last.return_value - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query - collection = self._make_one("collection") - limit = 15 - query = collection.limit_to_last(limit) +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_offset(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - mock_query.limit_to_last.assert_called_once_with(limit) - self.assertEqual(query, mock_query.limit_to_last.return_value) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_offset(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + collection = _make_base_collection_reference("collection") + offset = 113 + query = collection.offset(offset) - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + mock_query.offset.assert_called_once_with(offset) + assert query == mock_query.offset.return_value - collection = self._make_one("collection") - offset = 113 - query = collection.offset(offset) - mock_query.offset.assert_called_once_with(offset) - self.assertEqual(query, mock_query.offset.return_value) +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_start_at(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_start_at(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + collection = _make_base_collection_reference("collection") + doc_fields = {"a": "b"} + query = collection.start_at(doc_fields) - collection = self._make_one("collection") - doc_fields = {"a": "b"} - query = collection.start_at(doc_fields) + mock_query.start_at.assert_called_once_with(doc_fields) + assert query == mock_query.start_at.return_value - mock_query.start_at.assert_called_once_with(doc_fields) - self.assertEqual(query, mock_query.start_at.return_value) - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_start_after(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_start_after(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - collection = self._make_one("collection") - doc_fields = {"d": "foo", "e": 10} - query = collection.start_after(doc_fields) + collection = _make_base_collection_reference("collection") + doc_fields = {"d": "foo", "e": 10} + query = collection.start_after(doc_fields) - mock_query.start_after.assert_called_once_with(doc_fields) - self.assertEqual(query, mock_query.start_after.return_value) + mock_query.start_after.assert_called_once_with(doc_fields) + assert query == mock_query.start_after.return_value - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_end_before(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_end_before(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - collection = self._make_one("collection") - doc_fields = {"bar": 10.5} - query = collection.end_before(doc_fields) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query - mock_query.end_before.assert_called_once_with(doc_fields) - self.assertEqual(query, mock_query.end_before.return_value) + collection = _make_base_collection_reference("collection") + doc_fields = {"bar": 10.5} + query = collection.end_before(doc_fields) - @mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) - def test_end_at(self, mock_query): - from google.cloud.firestore_v1.base_collection import BaseCollectionReference + mock_query.end_before.assert_called_once_with(doc_fields) + assert query == mock_query.end_before.return_value - with mock.patch.object(BaseCollectionReference, "_query") as _query: - _query.return_value = mock_query - collection = self._make_one("collection") - doc_fields = {"opportunity": True, "reason": 9} - query = collection.end_at(doc_fields) +@mock.patch("google.cloud.firestore_v1.base_query.BaseQuery", autospec=True) +def test_basecollectionreference_end_at(mock_query): + from google.cloud.firestore_v1.base_collection import BaseCollectionReference - mock_query.end_at.assert_called_once_with(doc_fields) - self.assertEqual(query, mock_query.end_at.return_value) + with mock.patch.object(BaseCollectionReference, "_query") as _query: + _query.return_value = mock_query + collection = _make_base_collection_reference("collection") + doc_fields = {"opportunity": True, "reason": 9} + query = collection.end_at(doc_fields) -class Test__auto_id(unittest.TestCase): - @staticmethod - def _call_fut(): - from google.cloud.firestore_v1.base_collection import _auto_id + mock_query.end_at.assert_called_once_with(doc_fields) + assert query == mock_query.end_at.return_value - return _auto_id() - @mock.patch("random.choice") - def test_it(self, mock_rand_choice): - from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS +@mock.patch("random.choice") +def test__auto_id(mock_rand_choice): + from google.cloud.firestore_v1.base_collection import _AUTO_ID_CHARS + from google.cloud.firestore_v1.base_collection import _auto_id - mock_result = "0123456789abcdefghij" - mock_rand_choice.side_effect = list(mock_result) - result = self._call_fut() - self.assertEqual(result, mock_result) + mock_result = "0123456789abcdefghij" + mock_rand_choice.side_effect = list(mock_result) + result = _auto_id() + assert result == mock_result - mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 - self.assertEqual(mock_rand_choice.mock_calls, mock_calls) + mock_calls = [mock.call(_AUTO_ID_CHARS)] * 20 + assert mock_rand_choice.mock_calls == mock_calls def _make_credentials(): diff --git a/tests/unit/v1/test_base_document.py b/tests/unit/v1/test_base_document.py index 2342f4485c4c3..d3a59d5adf7f5 100644 --- a/tests/unit/v1/test_base_document.py +++ b/tests/unit/v1/test_base_document.py @@ -12,412 +12,420 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import unittest import mock -from proto.datetime_helpers import DatetimeWithNanoseconds - - -class TestBaseDocumentReference(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.document import DocumentReference - - return DocumentReference - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - collection_id1 = "users" - document_id1 = "alovelace" - collection_id2 = "platform" - document_id2 = "*nix" - client = mock.MagicMock() - client.__hash__.return_value = 1234 - - document = self._make_one( - collection_id1, document_id1, collection_id2, document_id2, client=client - ) - self.assertIs(document._client, client) - expected_path = "/".join( - (collection_id1, document_id1, collection_id2, document_id2) - ) - self.assertEqual(document.path, expected_path) - - def test_constructor_invalid_path_empty(self): - with self.assertRaises(ValueError): - self._make_one() - - def test_constructor_invalid_path_bad_collection_id(self): - with self.assertRaises(ValueError): - self._make_one(None, "before", "bad-collection-id", "fifteen") - - def test_constructor_invalid_path_bad_document_id(self): - with self.assertRaises(ValueError): - self._make_one("bad-document-ID", None) - - def test_constructor_invalid_path_bad_number_args(self): - with self.assertRaises(ValueError): - self._make_one("Just", "A-Collection", "Sub") - - def test_constructor_invalid_kwarg(self): - with self.assertRaises(TypeError): - self._make_one("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) - - def test___copy__(self): - client = _make_client("rain") - document = self._make_one("a", "b", client=client) - # Access the document path so it is copied. - doc_path = document._document_path - self.assertEqual(doc_path, document._document_path_internal) - - new_document = document.__copy__() - self.assertIsNot(new_document, document) - self.assertIs(new_document._client, document._client) - self.assertEqual(new_document._path, document._path) - self.assertEqual( - new_document._document_path_internal, document._document_path_internal - ) - - def test___deepcopy__calls_copy(self): - client = mock.sentinel.client - document = self._make_one("a", "b", client=client) - document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) - - unused_memo = {} - new_document = document.__deepcopy__(unused_memo) - self.assertIs(new_document, mock.sentinel.new_doc) - document.__copy__.assert_called_once_with() - - def test__eq__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - pairs = ((document1, document2), (document1, document3), (document2, document3)) - for candidate1, candidate2 in pairs: - # We use == explicitly since assertNotEqual would use !=. - equality_val = candidate1 == candidate2 - self.assertFalse(equality_val) - - # Check the only equal one. - self.assertEqual(document1, document4) - self.assertIsNot(document1, document4) - - def test__eq__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - equality_val = document == other - self.assertFalse(equality_val) - self.assertIs(document.__eq__(other), NotImplemented) - - def test___hash__(self): - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - document = self._make_one("X", "YY", client=client) - self.assertEqual(hash(document), hash(("X", "YY")) + hash(client)) - - def test__ne__same_type(self): - document1 = self._make_one("X", "YY", client=mock.sentinel.client) - document2 = self._make_one("X", "ZZ", client=mock.sentinel.client) - document3 = self._make_one("X", "YY", client=mock.sentinel.client2) - document4 = self._make_one("X", "YY", client=mock.sentinel.client) - - self.assertNotEqual(document1, document2) - self.assertNotEqual(document1, document3) - self.assertNotEqual(document2, document3) - - # We use != explicitly since assertEqual would use ==. - inequality_val = document1 != document4 - self.assertFalse(inequality_val) - self.assertIsNot(document1, document4) - - def test__ne__other_type(self): - document = self._make_one("X", "YY", client=mock.sentinel.client) - other = object() - self.assertNotEqual(document, other) - self.assertIs(document.__ne__(other), NotImplemented) - - def test__document_path_property(self): - project = "hi-its-me-ok-bye" - client = _make_client(project=project) - - collection_id = "then" - document_id = "090909iii" - document = self._make_one(collection_id, document_id, client=client) - doc_path = document._document_path - expected = "projects/{}/databases/{}/documents/{}/{}".format( - project, client._database, collection_id, document_id - ) - self.assertEqual(doc_path, expected) - self.assertIs(document._document_path_internal, doc_path) - - # Make sure value is cached. - document._document_path_internal = mock.sentinel.cached - self.assertIs(document._document_path, mock.sentinel.cached) - - def test__document_path_property_no_client(self): - document = self._make_one("hi", "bye") - self.assertIsNone(document._client) - with self.assertRaises(ValueError): - getattr(document, "_document_path") - - self.assertIsNone(document._document_path_internal) - - def test_id_property(self): - document_id = "867-5309" - document = self._make_one("Co-lek-shun", document_id) - self.assertEqual(document.id, document_id) - - def test_parent_property(self): - from google.cloud.firestore_v1.collection import CollectionReference - - collection_id = "grocery-store" - document_id = "market" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) - - parent = document.parent - self.assertIsInstance(parent, CollectionReference) - self.assertIs(parent._client, client) - self.assertEqual(parent._path, (collection_id,)) - - def test_collection_factory(self): - from google.cloud.firestore_v1.collection import CollectionReference - - collection_id = "grocery-store" - document_id = "market" - new_collection = "fruits" - client = _make_client() - document = self._make_one(collection_id, document_id, client=client) - - child = document.collection(new_collection) - self.assertIsInstance(child, CollectionReference) - self.assertIs(child._client, client) - self.assertEqual(child._path, (collection_id, document_id, new_collection)) - - -class TestDocumentSnapshot(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.document import DocumentSnapshot - - return DocumentSnapshot - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def _make_reference(self, *args, **kwargs): - from google.cloud.firestore_v1.document import DocumentReference - - return DocumentReference(*args, **kwargs) - - def _make_w_ref(self, ref_path=("a", "b"), data={}, exists=True): - client = mock.sentinel.client - reference = self._make_reference(*ref_path, client=client) - return self._make_one( - reference, - data, - exists, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - - def test_constructor(self): - client = mock.sentinel.client - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - snapshot = self._make_one( - reference, - data, - True, - mock.sentinel.read_time, - mock.sentinel.create_time, - mock.sentinel.update_time, - ) - self.assertIs(snapshot._reference, reference) - self.assertEqual(snapshot._data, data) - self.assertIsNot(snapshot._data, data) # Make sure copied. - self.assertTrue(snapshot._exists) - self.assertIs(snapshot.read_time, mock.sentinel.read_time) - self.assertIs(snapshot.create_time, mock.sentinel.create_time) - self.assertIs(snapshot.update_time, mock.sentinel.update_time) - - def test___eq___other_type(self): - snapshot = self._make_w_ref() - other = object() - self.assertFalse(snapshot == other) - - def test___eq___different_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("c", "d")) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_different_data(self): - snapshot = self._make_w_ref(("a", "b")) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertFalse(snapshot == other) - - def test___eq___same_reference_same_data(self): - snapshot = self._make_w_ref(("a", "b"), {"foo": "bar"}) - other = self._make_w_ref(("a", "b"), {"foo": "bar"}) - self.assertTrue(snapshot == other) - - def test___hash__(self): - client = mock.MagicMock() - client.__hash__.return_value = 234566789 - reference = self._make_reference("hi", "bye", client=client) - data = {"zoop": 83} - update_time = DatetimeWithNanoseconds( - 2021, 10, 4, 17, 43, 27, nanosecond=123456789, tzinfo=datetime.timezone.utc - ) - snapshot = self._make_one( - reference, data, True, None, mock.sentinel.create_time, update_time - ) - self.assertEqual(hash(snapshot), hash(reference) + hash(update_time)) - - def test__client_property(self): - reference = self._make_reference( - "ok", "fine", "now", "fore", client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, False, None, None, None) - self.assertIs(snapshot._client, mock.sentinel.client) - - def test_exists_property(self): - reference = mock.sentinel.reference - - snapshot1 = self._make_one(reference, {}, False, None, None, None) - self.assertFalse(snapshot1.exists) - snapshot2 = self._make_one(reference, {}, True, None, None, None) - self.assertTrue(snapshot2.exists) - - def test_id_property(self): - document_id = "around" - reference = self._make_reference( - "look", document_id, client=mock.sentinel.client - ) - snapshot = self._make_one(reference, {}, True, None, None, None) - self.assertEqual(snapshot.id, document_id) - self.assertEqual(reference.id, document_id) - - def test_reference_property(self): - snapshot = self._make_one(mock.sentinel.reference, {}, True, None, None, None) - self.assertIs(snapshot.reference, mock.sentinel.reference) - - def test_get(self): - data = {"one": {"bold": "move"}} - snapshot = self._make_one(None, data, True, None, None, None) - - first_read = snapshot.get("one") - second_read = snapshot.get("one") - self.assertEqual(first_read, data.get("one")) - self.assertIsNot(first_read, data.get("one")) - self.assertEqual(first_read, second_read) - self.assertIsNot(first_read, second_read) - - with self.assertRaises(KeyError): - snapshot.get("two") - - def test_nonexistent_snapshot(self): - snapshot = self._make_one(None, None, False, None, None, None) - self.assertIsNone(snapshot.get("one")) - - def test_to_dict(self): - data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} - snapshot = self._make_one(None, data, True, None, None, None) - as_dict = snapshot.to_dict() - self.assertEqual(as_dict, data) - self.assertIsNot(as_dict, data) - # Check that the data remains unchanged. - as_dict["b"].append("hi") - self.assertEqual(data, snapshot.to_dict()) - self.assertNotEqual(data, as_dict) - - def test_non_existent(self): - snapshot = self._make_one(None, None, False, None, None, None) - as_dict = snapshot.to_dict() - self.assertIsNone(as_dict) - - -class Test__get_document_path(unittest.TestCase): - @staticmethod - def _call_fut(client, path): - from google.cloud.firestore_v1.base_document import _get_document_path - - return _get_document_path(client, path) - - def test_it(self): - project = "prah-jekt" - client = _make_client(project=project) - path = ("Some", "Document", "Child", "Shockument") - document_path = self._call_fut(client, path) - - expected = "projects/{}/databases/{}/documents/{}".format( - project, client._database, "/".join(path) - ) - self.assertEqual(document_path, expected) - - -class Test__consume_single_get(unittest.TestCase): - @staticmethod - def _call_fut(response_iterator): - from google.cloud.firestore_v1.base_document import _consume_single_get - - return _consume_single_get(response_iterator) - - def test_success(self): - response_iterator = iter([mock.sentinel.result]) - result = self._call_fut(response_iterator) - self.assertIs(result, mock.sentinel.result) - - def test_failure_not_enough(self): - response_iterator = iter([]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - def test_failure_too_many(self): - response_iterator = iter([None, None]) - with self.assertRaises(ValueError): - self._call_fut(response_iterator) - - -class Test__first_write_result(unittest.TestCase): - @staticmethod - def _call_fut(write_results): - from google.cloud.firestore_v1.base_document import _first_write_result - - return _first_write_result(write_results) - - def test_success(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import write - - single_result = write.WriteResult( - update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) - ) - write_results = [single_result] - result = self._call_fut(write_results) - self.assertIs(result, single_result) - - def test_failure_not_enough(self): - write_results = [] - with self.assertRaises(ValueError): - self._call_fut(write_results) - - def test_more_than_one(self): - from google.cloud.firestore_v1.types import write - - result1 = write.WriteResult() - result2 = write.WriteResult() - write_results = [result1, result2] - result = self._call_fut(write_results) - self.assertIs(result, result1) +import pytest + + +def _make_base_document_reference(*args, **kwargs): + from google.cloud.firestore_v1.base_document import BaseDocumentReference + + return BaseDocumentReference(*args, **kwargs) + + +def test_basedocumentreference_constructor(): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 + + document = _make_base_document_reference( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + assert document._client is client + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + assert document.path == expected_path + + +def test_basedocumentreference_constructor_invalid_path_empty(): + with pytest.raises(ValueError): + _make_base_document_reference() + + +def test_basedocumentreference_constructor_invalid_path_bad_collection_id(): + with pytest.raises(ValueError): + _make_base_document_reference(None, "before", "bad-collection-id", "fifteen") + + +def test_basedocumentreference_constructor_invalid_path_bad_document_id(): + with pytest.raises(ValueError): + _make_base_document_reference("bad-document-ID", None) + + +def test_basedocumentreference_constructor_invalid_path_bad_number_args(): + with pytest.raises(ValueError): + _make_base_document_reference("Just", "A-Collection", "Sub") + + +def test_basedocumentreference_constructor_invalid_kwarg(): + with pytest.raises(TypeError): + _make_base_document_reference("Coh-lek-shun", "Dahk-yu-mehnt", burger=18.75) + + +def test_basedocumentreference___copy__(): + client = _make_client("rain") + document = _make_base_document_reference("a", "b", client=client) + # Access the document path so it is copied. + doc_path = document._document_path + assert doc_path == document._document_path_internal + + new_document = document.__copy__() + assert new_document is not document + assert new_document._client is document._client + assert new_document._path == document._path + assert new_document._document_path_internal == document._document_path_internal + + +def test_basedocumentreference___deepcopy__calls_copy(): + client = mock.sentinel.client + document = _make_base_document_reference("a", "b", client=client) + document.__copy__ = mock.Mock(return_value=mock.sentinel.new_doc, spec=[]) + + unused_memo = {} + new_document = document.__deepcopy__(unused_memo) + assert new_document is mock.sentinel.new_doc + document.__copy__.assert_called_once_with() + + +def test_basedocumentreference__eq__same_type(): + document1 = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + document2 = _make_base_document_reference("X", "ZZ", client=mock.sentinel.client) + document3 = _make_base_document_reference("X", "YY", client=mock.sentinel.client2) + document4 = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + + pairs = ((document1, document2), (document1, document3), (document2, document3)) + for candidate1, candidate2 in pairs: + # We use == explicitly since assertNotEqual would use !=. + assert not (candidate1 == candidate2) + + # Check the only equal one. + assert document1 == document4 + assert document1 is not document4 + + +def test_basedocumentreference__eq__other_type(): + document = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + other = object() + assert not (document == other) + assert document.__eq__(other) is NotImplemented + + +def test_basedocumentreference___hash__(): + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + document = _make_base_document_reference("X", "YY", client=client) + assert hash(document) == hash(("X", "YY")) + hash(client) + + +def test_basedocumentreference__ne__same_type(): + document1 = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + document2 = _make_base_document_reference("X", "ZZ", client=mock.sentinel.client) + document3 = _make_base_document_reference("X", "YY", client=mock.sentinel.client2) + document4 = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + + assert document1 != document2 + assert document1 != document3 + assert document2 != document3 + + assert not (document1 != document4) + assert document1 is not document4 + + +def test_basedocumentreference__ne__other_type(): + document = _make_base_document_reference("X", "YY", client=mock.sentinel.client) + other = object() + assert document != other + assert document.__ne__(other) is NotImplemented + + +def test_basedocumentreference__document_path_property(): + project = "hi-its-me-ok-bye" + client = _make_client(project=project) + + collection_id = "then" + document_id = "090909iii" + document = _make_base_document_reference(collection_id, document_id, client=client) + doc_path = document._document_path + expected = "projects/{}/databases/{}/documents/{}/{}".format( + project, client._database, collection_id, document_id + ) + assert doc_path == expected + assert document._document_path_internal is doc_path + + # Make sure value is cached. + document._document_path_internal = mock.sentinel.cached + assert document._document_path is mock.sentinel.cached + + +def test_basedocumentreference__document_path_property_no_client(): + document = _make_base_document_reference("hi", "bye") + assert document._client is None + with pytest.raises(ValueError): + getattr(document, "_document_path") + + assert document._document_path_internal is None + + +def test_basedocumentreference_id_property(): + document_id = "867-5309" + document = _make_base_document_reference("Co-lek-shun", document_id) + assert document.id == document_id + + +def test_basedocumentreference_parent_property(): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + client = _make_client() + document = _make_base_document_reference(collection_id, document_id, client=client) + + parent = document.parent + assert isinstance(parent, CollectionReference) + assert parent._client is client + assert parent._path == (collection_id,) + + +def test_basedocumentreference_collection_factory(): + from google.cloud.firestore_v1.collection import CollectionReference + + collection_id = "grocery-store" + document_id = "market" + new_collection = "fruits" + client = _make_client() + document = _make_base_document_reference(collection_id, document_id, client=client) + + child = document.collection(new_collection) + assert isinstance(child, CollectionReference) + assert child._client is client + assert child._path == (collection_id, document_id, new_collection) + + +def _make_document_snapshot(*args, **kwargs): + from google.cloud.firestore_v1.document import DocumentSnapshot + + return DocumentSnapshot(*args, **kwargs) + + +def _make_w_ref(ref_path=("a", "b"), data={}, exists=True): + client = mock.sentinel.client + reference = _make_base_document_reference(*ref_path, client=client) + return _make_document_snapshot( + reference, + data, + exists, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + + +def test_documentsnapshot_constructor(): + client = mock.sentinel.client + reference = _make_base_document_reference("hi", "bye", client=client) + data = {"zoop": 83} + snapshot = _make_document_snapshot( + reference, + data, + True, + mock.sentinel.read_time, + mock.sentinel.create_time, + mock.sentinel.update_time, + ) + assert snapshot._reference is reference + assert snapshot._data == data + assert snapshot._data is not data # Make sure copied + assert snapshot._exists + assert snapshot.read_time is mock.sentinel.read_time + assert snapshot.create_time is mock.sentinel.create_time + assert snapshot.update_time is mock.sentinel.update_time + + +def test_documentsnapshot___eq___other_type(): + snapshot = _make_w_ref() + other = object() + assert not (snapshot == other) + + +def test_documentsnapshot___eq___different_reference_same_data(): + snapshot = _make_w_ref(("a", "b")) + other = _make_w_ref(("c", "d")) + assert not (snapshot == other) + + +def test_documentsnapshot___eq___same_reference_different_data(): + snapshot = _make_w_ref(("a", "b")) + other = _make_w_ref(("a", "b"), {"foo": "bar"}) + assert not (snapshot == other) + + +def test_documentsnapshot___eq___same_reference_same_data(): + snapshot = _make_w_ref(("a", "b"), {"foo": "bar"}) + other = _make_w_ref(("a", "b"), {"foo": "bar"}) + assert snapshot == other + + +def test_documentsnapshot___hash__(): + import datetime + from proto.datetime_helpers import DatetimeWithNanoseconds + + client = mock.MagicMock() + client.__hash__.return_value = 234566789 + reference = _make_base_document_reference("hi", "bye", client=client) + data = {"zoop": 83} + update_time = DatetimeWithNanoseconds( + 2021, 10, 4, 17, 43, 27, nanosecond=123456789, tzinfo=datetime.timezone.utc + ) + snapshot = _make_document_snapshot( + reference, data, True, None, mock.sentinel.create_time, update_time + ) + assert hash(snapshot) == hash(reference) + hash(update_time) + + +def test_documentsnapshot__client_property(): + reference = _make_base_document_reference( + "ok", "fine", "now", "fore", client=mock.sentinel.client + ) + snapshot = _make_document_snapshot(reference, {}, False, None, None, None) + assert snapshot._client is mock.sentinel.client + + +def test_documentsnapshot_exists_property(): + reference = mock.sentinel.reference + + snapshot1 = _make_document_snapshot(reference, {}, False, None, None, None) + assert not snapshot1.exists + snapshot2 = _make_document_snapshot(reference, {}, True, None, None, None) + assert snapshot2.exists + + +def test_documentsnapshot_id_property(): + document_id = "around" + reference = _make_base_document_reference( + "look", document_id, client=mock.sentinel.client + ) + snapshot = _make_document_snapshot(reference, {}, True, None, None, None) + assert snapshot.id == document_id + assert reference.id == document_id + + +def test_documentsnapshot_reference_property(): + snapshot = _make_document_snapshot( + mock.sentinel.reference, {}, True, None, None, None + ) + assert snapshot.reference is mock.sentinel.reference + + +def test_documentsnapshot_get(): + data = {"one": {"bold": "move"}} + snapshot = _make_document_snapshot(None, data, True, None, None, None) + + first_read = snapshot.get("one") + second_read = snapshot.get("one") + assert first_read == data.get("one") + assert first_read is not data.get("one") + assert first_read == second_read + assert first_read is not second_read + + with pytest.raises(KeyError): + snapshot.get("two") + + +def test_documentsnapshot_nonexistent_snapshot(): + snapshot = _make_document_snapshot(None, None, False, None, None, None) + assert snapshot.get("one") is None + + +def test_documentsnapshot_to_dict(): + data = {"a": 10, "b": ["definitely", "mutable"], "c": {"45": 50}} + snapshot = _make_document_snapshot(None, data, True, None, None, None) + as_dict = snapshot.to_dict() + assert as_dict == data + assert as_dict is not data + # Check that the data remains unchanged. + as_dict["b"].append("hi") + assert data == snapshot.to_dict() + assert data != as_dict + + +def test_documentsnapshot_non_existent(): + snapshot = _make_document_snapshot(None, None, False, None, None, None) + as_dict = snapshot.to_dict() + assert as_dict is None + + +def test__get_document_path(): + from google.cloud.firestore_v1.base_document import _get_document_path + + project = "prah-jekt" + client = _make_client(project=project) + path = ("Some", "Document", "Child", "Shockument") + document_path = _get_document_path(client, path) + + expected = "projects/{}/databases/{}/documents/{}".format( + project, client._database, "/".join(path) + ) + assert document_path == expected + + +def test__consume_single_get_success(): + from google.cloud.firestore_v1.base_document import _consume_single_get + + response_iterator = iter([mock.sentinel.result]) + result = _consume_single_get(response_iterator) + assert result is mock.sentinel.result + + +def test__consume_single_get_failure_not_enough(): + from google.cloud.firestore_v1.base_document import _consume_single_get + + response_iterator = iter([]) + with pytest.raises(ValueError): + _consume_single_get(response_iterator) + + +def test__consume_single_get_failure_too_many(): + from google.cloud.firestore_v1.base_document import _consume_single_get + + response_iterator = iter([None, None]) + with pytest.raises(ValueError): + _consume_single_get(response_iterator) + + +def test__first_write_result_success(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.base_document import _first_write_result + + single_result = write.WriteResult( + update_time=timestamp_pb2.Timestamp(seconds=1368767504, nanos=458000123) + ) + write_results = [single_result] + result = _first_write_result(write_results) + assert result is single_result + + +def test__first_write_result_failure_not_enough(): + from google.cloud.firestore_v1.base_document import _first_write_result + + write_results = [] + with pytest.raises(ValueError): + _first_write_result(write_results) + + +def test__first_write_result_more_than_one(): + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1.base_document import _first_write_result + + result1 = write.WriteResult() + result2 = write.WriteResult() + write_results = [result1, result2] + result = _first_write_result(write_results) + assert result is result1 def _make_credentials(): diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index a8496ff808475..8312df5ba9e00 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -13,1453 +13,1525 @@ # limitations under the License. import datetime -import unittest import mock +import pytest -class TestBaseQuery(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.query import Query +def _make_base_query(*args, **kwargs): + from google.cloud.firestore_v1.base_query import BaseQuery - return Query + return BaseQuery(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor_defaults(self): - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIsNone(query._projection) - self.assertEqual(query._field_filters, ()) - self.assertEqual(query._orders, ()) - self.assertIsNone(query._limit) - self.assertIsNone(query._offset) - self.assertIsNone(query._start_at) - self.assertIsNone(query._end_at) - self.assertFalse(query._all_descendants) +def _make_base_query_all_fields( + limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True, +): + kwargs = { + "projection": mock.sentinel.projection, + "field_filters": mock.sentinel.filters, + "orders": mock.sentinel.orders, + "limit": limit, + "offset": offset, + "start_at": mock.sentinel.start_at, + "end_at": mock.sentinel.end_at, + "all_descendants": all_descendants, + } - def _make_one_all_fields( - self, limit=9876, offset=12, skip_fields=(), parent=None, all_descendants=True - ): - kwargs = { - "projection": mock.sentinel.projection, - "field_filters": mock.sentinel.filters, - "orders": mock.sentinel.orders, - "limit": limit, - "offset": offset, - "start_at": mock.sentinel.start_at, - "end_at": mock.sentinel.end_at, - "all_descendants": all_descendants, - } - for field in skip_fields: - kwargs.pop(field) - if parent is None: - parent = mock.sentinel.parent - return self._make_one(parent, **kwargs) - - def test_constructor_explicit(self): - limit = 234 - offset = 56 - query = self._make_one_all_fields(limit=limit, offset=offset) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIs(query._projection, mock.sentinel.projection) - self.assertIs(query._field_filters, mock.sentinel.filters) - self.assertEqual(query._orders, mock.sentinel.orders) - self.assertEqual(query._limit, limit) - self.assertEqual(query._offset, offset) - self.assertIs(query._start_at, mock.sentinel.start_at) - self.assertIs(query._end_at, mock.sentinel.end_at) - self.assertTrue(query._all_descendants) - - def test__client_property(self): - parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) - query = self._make_one(parent) - self.assertIs(query._client, mock.sentinel.client) - - def test___eq___other_type(self): - query = self._make_one_all_fields() - other = object() - self.assertFalse(query == other) - - def test___eq___different_parent(self): - parent = mock.sentinel.parent - other_parent = mock.sentinel.other_parent - query = self._make_one_all_fields(parent=parent) - other = self._make_one_all_fields(parent=other_parent) - self.assertFalse(query == other) + for field in skip_fields: + kwargs.pop(field) - def test___eq___different_projection(self): + if parent is None: parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - query._projection = mock.sentinel.projection - other = self._make_one_all_fields(parent=parent, skip_fields=("projection",)) - other._projection = mock.sentinel.other_projection - self.assertFalse(query == other) - def test___eq___different_field_filters(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - query._field_filters = mock.sentinel.field_filters - other = self._make_one_all_fields(parent=parent, skip_fields=("field_filters",)) - other._field_filters = mock.sentinel.other_field_filters - self.assertFalse(query == other) + return _make_base_query(parent, **kwargs) - def test___eq___different_orders(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - query._orders = mock.sentinel.orders - other = self._make_one_all_fields(parent=parent, skip_fields=("orders",)) - other._orders = mock.sentinel.other_orders - self.assertFalse(query == other) - def test___eq___different_limit(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, limit=10) - other = self._make_one_all_fields(parent=parent, limit=20) - self.assertFalse(query == other) +def test_basequery_constructor_defaults(): + query = _make_base_query(mock.sentinel.parent) + assert query._parent is mock.sentinel.parent + assert query._projection is None + assert query._field_filters == () + assert query._orders == () + assert query._limit is None + assert query._offset is None + assert query._start_at is None + assert query._end_at is None + assert not query._all_descendants + + +def test_basequery_constructor_explicit(): + limit = 234 + offset = 56 + query = _make_base_query_all_fields(limit=limit, offset=offset) + assert query._parent is mock.sentinel.parent + assert query._projection is mock.sentinel.projection + assert query._field_filters is mock.sentinel.filters + assert query._orders == mock.sentinel.orders + assert query._limit == limit + assert query._offset == offset + assert query._start_at is mock.sentinel.start_at + assert query._end_at is mock.sentinel.end_at + assert query._all_descendants + - def test___eq___different_offset(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, offset=10) - other = self._make_one_all_fields(parent=parent, offset=20) - self.assertFalse(query == other) +def test_basequery__client_property(): + parent = mock.Mock(_client=mock.sentinel.client, spec=["_client"]) + query = _make_base_query(parent) + assert query._client is mock.sentinel.client - def test___eq___different_start_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - query._start_at = mock.sentinel.start_at - other = self._make_one_all_fields(parent=parent, skip_fields=("start_at",)) - other._start_at = mock.sentinel.other_start_at - self.assertFalse(query == other) - def test___eq___different_end_at(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - query._end_at = mock.sentinel.end_at - other = self._make_one_all_fields(parent=parent, skip_fields=("end_at",)) - other._end_at = mock.sentinel.other_end_at - self.assertFalse(query == other) +def test_basequery___eq___other_type(): + query = _make_base_query_all_fields() + other = object() + assert not (query == other) - def test___eq___different_all_descendants(self): - parent = mock.sentinel.parent - query = self._make_one_all_fields(parent=parent, all_descendants=True) - other = self._make_one_all_fields(parent=parent, all_descendants=False) - self.assertFalse(query == other) - def test___eq___hit(self): - query = self._make_one_all_fields() - other = self._make_one_all_fields() - self.assertTrue(query == other) +def test_basequery___eq___different_parent(): + parent = mock.sentinel.parent + other_parent = mock.sentinel.other_parent + query = _make_base_query_all_fields(parent=parent) + other = _make_base_query_all_fields(parent=other_parent) + assert not (query == other) - def _compare_queries(self, query1, query2, *attr_names): - attrs1 = query1.__dict__.copy() - attrs2 = query2.__dict__.copy() - self.assertEqual(len(attrs1), len(attrs2)) +def test_basequery___eq___different_projection(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, skip_fields=("projection",)) + query._projection = mock.sentinel.projection + other = _make_base_query_all_fields(parent=parent, skip_fields=("projection",)) + other._projection = mock.sentinel.other_projection + assert not (query == other) - # The only different should be in ``attr_name``. - for attr_name in attr_names: - attrs1.pop(attr_name) - attrs2.pop(attr_name) - for key, value in attrs1.items(): - self.assertIs(value, attrs2[key]) +def test_basequery___eq___different_field_filters(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, skip_fields=("field_filters",)) + query._field_filters = mock.sentinel.field_filters + other = _make_base_query_all_fields(parent=parent, skip_fields=("field_filters",)) + other._field_filters = mock.sentinel.other_field_filters + assert not (query == other) - @staticmethod - def _make_projection_for_select(field_paths): - from google.cloud.firestore_v1.types import query - return query.StructuredQuery.Projection( - fields=[ - query.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ) +def test_basequery___eq___different_orders(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, skip_fields=("orders",)) + query._orders = mock.sentinel.orders + other = _make_base_query_all_fields(parent=parent, skip_fields=("orders",)) + other._orders = mock.sentinel.other_orders + assert not (query == other) - def test_select_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - with self.assertRaises(ValueError): - query.select(["*"]) +def test_basequery___eq___different_limit(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, limit=10) + other = _make_base_query_all_fields(parent=parent, limit=20) + assert not (query == other) - def test_select(self): - query1 = self._make_one_all_fields(all_descendants=True) - field_paths2 = ["foo", "bar"] - query2 = query1.select(field_paths2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual( - query2._projection, self._make_projection_for_select(field_paths2) - ) - self._compare_queries(query1, query2, "_projection") - - # Make sure it overrides. - field_paths3 = ["foo.baz"] - query3 = query2.select(field_paths3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual( - query3._projection, self._make_projection_for_select(field_paths3) - ) - self._compare_queries(query2, query3, "_projection") +def test_basequery___eq___different_offset(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, offset=10) + other = _make_base_query_all_fields(parent=parent, offset=20) + assert not (query == other) - def test_where_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - with self.assertRaises(ValueError): - query.where("*", "==", 1) +def test_basequery___eq___different_start_at(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, skip_fields=("start_at",)) + query._start_at = mock.sentinel.start_at + other = _make_base_query_all_fields(parent=parent, skip_fields=("start_at",)) + other._start_at = mock.sentinel.other_start_at + assert not (query == other) - def test_where(self): - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - query_inst = self._make_one_all_fields( - skip_fields=("field_filters",), all_descendants=True - ) - new_query = query_inst.where("power.level", ">", 9000) +def test_basequery___eq___different_end_at(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, skip_fields=("end_at",)) + query._end_at = mock.sentinel.end_at + other = _make_base_query_all_fields(parent=parent, skip_fields=("end_at",)) + other._end_at = mock.sentinel.other_end_at + assert not (query == other) + + +def test_basequery___eq___different_all_descendants(): + parent = mock.sentinel.parent + query = _make_base_query_all_fields(parent=parent, all_descendants=True) + other = _make_base_query_all_fields(parent=parent, all_descendants=False) + assert not (query == other) + + +def test_basequery___eq___hit(): + query = _make_base_query_all_fields() + other = _make_base_query_all_fields() + assert query == other + + +def _compare_queries(query1, query2, *attr_names): + attrs1 = query1.__dict__.copy() + attrs2 = query2.__dict__.copy() + + assert len(attrs1) == len(attrs2) + + # The only different should be in ``attr_name``. + for attr_name in attr_names: + attrs1.pop(attr_name) + attrs2.pop(attr_name) + + for key, value in attrs1.items(): + assert value is attrs2[key] + + +def test_basequery_select_invalid_path(): + query = _make_base_query(mock.sentinel.parent) + + with pytest.raises(ValueError): + query.select(["*"]) + + +def test_basequery_select(): + from google.cloud.firestore_v1.base_query import BaseQuery + + query1 = _make_base_query_all_fields(all_descendants=True) + + field_paths2 = ["foo", "bar"] + query2 = query1.select(field_paths2) + assert query2 is not query1 + assert isinstance(query2, BaseQuery) + assert query2._projection == _make_projection_for_select(field_paths2) + _compare_queries(query1, query2, "_projection") + + # Make sure it overrides. + field_paths3 = ["foo.baz"] + query3 = query2.select(field_paths3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._projection == _make_projection_for_select(field_paths3) + _compare_queries(query2, query3, "_projection") + + +def test_basequery_where_invalid_path(): + query = _make_base_query(mock.sentinel.parent) + + with pytest.raises(ValueError): + query.where("*", "==", 1) + + +def test_basequery_where(): + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.types import StructuredQuery + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query + + query_inst = _make_base_query_all_fields( + skip_fields=("field_filters",), all_descendants=True + ) + new_query = query_inst.where("power.level", ">", 9000) + + assert query_inst is not new_query + assert isinstance(new_query, BaseQuery) + assert len(new_query._field_filters) == 1 + + field_pb = new_query._field_filters[0] + expected_pb = query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="power.level"), + op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document.Value(integer_value=9000), + ) + assert field_pb == expected_pb + _compare_queries(query_inst, new_query, "_field_filters") + + +def _where_unary_helper(value, op_enum, op_string="=="): + from google.cloud.firestore_v1.base_query import BaseQuery + from google.cloud.firestore_v1.types import StructuredQuery + + query_inst = _make_base_query_all_fields(skip_fields=("field_filters",)) + field_path = "feeeld" + new_query = query_inst.where(field_path, op_string, value) + + assert query_inst is not new_query + assert isinstance(new_query, BaseQuery) + assert len(new_query._field_filters) == 1 + + field_pb = new_query._field_filters[0] + expected_pb = StructuredQuery.UnaryFilter( + field=StructuredQuery.FieldReference(field_path=field_path), op=op_enum + ) + assert field_pb == expected_pb + _compare_queries(query_inst, new_query, "_field_filters") + + +def test_basequery_where_eq_null(): + from google.cloud.firestore_v1.types import StructuredQuery + + op_enum = StructuredQuery.UnaryFilter.Operator.IS_NULL + _where_unary_helper(None, op_enum) + + +def test_basequery_where_gt_null(): + with pytest.raises(ValueError): + _where_unary_helper(None, 0, op_string=">") + + +def test_basequery_where_eq_nan(): + from google.cloud.firestore_v1.types import StructuredQuery + + op_enum = StructuredQuery.UnaryFilter.Operator.IS_NAN + _where_unary_helper(float("nan"), op_enum) + + +def test_basequery_where_le_nan(): + with pytest.raises(ValueError): + _where_unary_helper(float("nan"), 0, op_string="<=") + + +def test_basequery_where_w_delete(): + from google.cloud.firestore_v1 import DELETE_FIELD + + with pytest.raises(ValueError): + _where_unary_helper(DELETE_FIELD, 0) + + +def test_basequery_where_w_server_timestamp(): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + with pytest.raises(ValueError): + _where_unary_helper(SERVER_TIMESTAMP, 0) + + +def test_basequery_where_w_array_remove(): + from google.cloud.firestore_v1 import ArrayRemove + + with pytest.raises(ValueError): + _where_unary_helper(ArrayRemove([1, 3, 5]), 0) + + +def test_basequery_where_w_array_union(): + from google.cloud.firestore_v1 import ArrayUnion + + with pytest.raises(ValueError): + _where_unary_helper(ArrayUnion([2, 4, 8]), 0) + + +def test_basequery_order_by_invalid_path(): + query = _make_base_query(mock.sentinel.parent) + + with pytest.raises(ValueError): + query.order_by("*") + + +def test_basequery_order_by(): + from google.cloud.firestore_v1.types import StructuredQuery + from google.cloud.firestore_v1.base_query import BaseQuery + + query1 = _make_base_query_all_fields(skip_fields=("orders",), all_descendants=True) + + field_path2 = "a" + query2 = query1.order_by(field_path2) + assert query2 is not query1 + assert isinstance(query2, BaseQuery) + order = _make_order_pb(field_path2, StructuredQuery.Direction.ASCENDING) + assert query2._orders == (order,) + _compare_queries(query1, query2, "_orders") - self.assertIsNot(query_inst, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) + # Make sure it appends to the orders. + field_path3 = "b" + query3 = query2.order_by(field_path3, direction=BaseQuery.DESCENDING) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + order_pb3 = _make_order_pb(field_path3, StructuredQuery.Direction.DESCENDING) + assert query3._orders == (order, order_pb3) + _compare_queries(query2, query3, "_orders") - field_pb = new_query._field_filters[0] - expected_pb = query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference(field_path="power.level"), + +def test_basequery_limit(): + from google.cloud.firestore_v1.base_query import BaseQuery + + query1 = _make_base_query_all_fields(all_descendants=True) + + limit2 = 100 + query2 = query1.limit(limit2) + assert not query2._limit_to_last + assert query2 is not query1 + assert isinstance(query2, BaseQuery) + assert query2._limit == limit2 + _compare_queries(query1, query2, "_limit") + + # Make sure it overrides. + limit3 = 10 + query3 = query2.limit(limit3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._limit == limit3 + _compare_queries(query2, query3, "_limit") + + +def test_basequery_limit_to_last(): + from google.cloud.firestore_v1.base_query import BaseQuery + + query1 = _make_base_query_all_fields(all_descendants=True) + + limit2 = 100 + query2 = query1.limit_to_last(limit2) + assert query2._limit_to_last + assert query2 is not query1 + assert isinstance(query2, BaseQuery) + assert query2._limit == limit2 + _compare_queries(query1, query2, "_limit", "_limit_to_last") + + # Make sure it overrides. + limit3 = 10 + query3 = query2.limit(limit3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._limit == limit3 + _compare_queries(query2, query3, "_limit", "_limit_to_last") + + +def test_basequery__resolve_chunk_size(): + # With a global limit + query = _make_client().collection("asdf").limit(5) + assert query._resolve_chunk_size(3, 10) == 2 + assert query._resolve_chunk_size(3, 1) == 1 + assert query._resolve_chunk_size(3, 2) == 2 + + # With no limit + query = _make_client().collection("asdf")._query() + assert query._resolve_chunk_size(3, 10) == 10 + assert query._resolve_chunk_size(3, 1) == 1 + assert query._resolve_chunk_size(3, 2) == 2 + + +def test_basequery_offset(): + from google.cloud.firestore_v1.base_query import BaseQuery + + query1 = _make_base_query_all_fields(all_descendants=True) + + offset2 = 23 + query2 = query1.offset(offset2) + assert query2 is not query1 + assert isinstance(query2, BaseQuery) + assert query2._offset == offset2 + _compare_queries(query1, query2, "_offset") + + # Make sure it overrides. + offset3 = 35 + query3 = query2.offset(offset3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._offset == offset3 + _compare_queries(query2, query3, "_offset") + + +def test_basequery__cursor_helper_w_dict(): + values = {"a": 7, "b": "foo"} + query1 = _make_base_query(mock.sentinel.parent) + query1._all_descendants = True + query2 = query1._cursor_helper(values, True, True) + + assert query2._parent is mock.sentinel.parent + assert query2._projection is None + assert query2._field_filters == () + assert query2._orders == query1._orders + assert query2._limit is None + assert query2._offset is None + assert query2._end_at is None + assert query2._all_descendants + + cursor, before = query2._start_at + + assert cursor == values + assert before + + +def test_basequery__cursor_helper_w_tuple(): + values = (7, "foo") + query1 = _make_base_query(mock.sentinel.parent) + query2 = query1._cursor_helper(values, False, True) + + assert query2._parent is mock.sentinel.parent + assert query2._projection is None + assert query2._field_filters == () + assert query2._orders == query1._orders + assert query2._limit is None + assert query2._offset is None + assert query2._end_at is None + + cursor, before = query2._start_at + + assert cursor == list(values) + assert not before + + +def test_basequery__cursor_helper_w_list(): + values = [7, "foo"] + query1 = _make_base_query(mock.sentinel.parent) + query2 = query1._cursor_helper(values, True, False) + + assert query2._parent is mock.sentinel.parent + assert query2._projection is None + assert query2._field_filters == () + assert query2._orders == query1._orders + assert query2._limit is None + assert query2._offset is None + assert query2._start_at is None + + cursor, before = query2._end_at + + assert cursor == values + assert cursor == values + assert before + + +def test_basequery__cursor_helper_w_snapshot_wrong_collection(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("there", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = _make_base_query(collection) + + with pytest.raises(ValueError): + query._cursor_helper(snapshot, False, False) + + +def test_basequery__cursor_helper_w_snapshot_other_collection_all_descendants(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("there", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query1 = _make_base_query(collection, all_descendants=True) + + query2 = query1._cursor_helper(snapshot, False, False) + + assert query2._parent is collection + assert query2._projection is None + assert query2._field_filters == () + assert query2._orders == () + assert query2._limit is None + assert query2._offset is None + assert query2._start_at is None + + cursor, before = query2._end_at + + assert cursor is snapshot + assert not before + + +def test_basequery__cursor_helper_w_snapshot(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query1 = _make_base_query(collection) + + query2 = query1._cursor_helper(snapshot, False, False) + + assert query2._parent is collection + assert query2._projection is None + assert query2._field_filters == () + assert query2._orders == () + assert query2._limit is None + assert query2._offset is None + assert query2._start_at is None + + cursor, before = query2._end_at + + assert cursor is snapshot + assert not before + + +def test_basequery_start_at(): + from google.cloud.firestore_v1.base_query import BaseQuery + + collection = _make_collection("here") + query1 = _make_base_query_all_fields( + parent=collection, skip_fields=("orders",), all_descendants=True + ) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.start_at(document_fields3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._start_at == (document_fields3, True) + _compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = _make_docref("here", "doc_id") + document_fields5 = _make_snapshot(docref, values5) + query5 = query4.start_at(document_fields5) + assert query5 is not query4 + assert isinstance(query5, BaseQuery) + assert query5._start_at == (document_fields5, True) + _compare_queries(query4, query5, "_start_at") + + +def test_basequery_start_after(): + from google.cloud.firestore_v1.base_query import BaseQuery + + collection = _make_collection("here") + query1 = _make_base_query_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.start_after(document_fields3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._start_at == (document_fields3, False) + _compare_queries(query2, query3, "_start_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = _make_docref("here", "doc_id") + document_fields5 = _make_snapshot(docref, values5) + query5 = query4.start_after(document_fields5) + assert query5 is not query4 + assert isinstance(query5, BaseQuery) + assert query5._start_at == (document_fields5, False) + _compare_queries(query4, query5, "_start_at") + + +def test_basequery_end_before(): + from google.cloud.firestore_v1.base_query import BaseQuery + + collection = _make_collection("here") + query1 = _make_base_query_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("down") + + document_fields3 = {"down": 99.75} + query3 = query2.end_before(document_fields3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._end_at == (document_fields3, True) + _compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("out") + values5 = {"down": 100.25, "out": b"\x00\x01"} + docref = _make_docref("here", "doc_id") + document_fields5 = _make_snapshot(docref, values5) + query5 = query4.end_before(document_fields5) + assert query5 is not query4 + assert isinstance(query5, BaseQuery) + assert query5._end_at == (document_fields5, True) + _compare_queries(query4, query5, "_end_at") + _compare_queries(query4, query5, "_end_at") + + +def test_basequery_end_at(): + from google.cloud.firestore_v1.base_query import BaseQuery + + collection = _make_collection("here") + query1 = _make_base_query_all_fields(parent=collection, skip_fields=("orders",)) + query2 = query1.order_by("hi") + + document_fields3 = {"hi": "mom"} + query3 = query2.end_at(document_fields3) + assert query3 is not query2 + assert isinstance(query3, BaseQuery) + assert query3._end_at == (document_fields3, False) + _compare_queries(query2, query3, "_end_at") + + # Make sure it overrides. + query4 = query3.order_by("bye") + values5 = {"hi": "zap", "bye": 88} + docref = _make_docref("here", "doc_id") + document_fields5 = _make_snapshot(docref, values5) + query5 = query4.end_at(document_fields5) + assert query5 is not query4 + assert isinstance(query5, BaseQuery) + assert query5._end_at == (document_fields5, False) + _compare_queries(query4, query5, "_end_at") + + +def test_basequery__filters_pb_empty(): + query = _make_base_query(mock.sentinel.parent) + assert len(query._field_filters) == 0 + assert query._filters_pb() is None + + +def test_basequery__filters_pb_single(): + from google.cloud.firestore_v1.types import StructuredQuery + + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query + + query1 = _make_base_query(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + filter_pb = query2._filters_pb() + expected_pb = query.StructuredQuery.Filter( + field_filter=query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="x.y"), op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document.Value(integer_value=9000), + value=document.Value(double_value=50.5), ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query_inst, new_query, "_field_filters") + ) + assert filter_pb == expected_pb - def _where_unary_helper(self, value, op_enum, op_string="=="): - from google.cloud.firestore_v1.types import StructuredQuery - query_inst = self._make_one_all_fields(skip_fields=("field_filters",)) - field_path = "feeeld" - new_query = query_inst.where(field_path, op_string, value) +def test_basequery__filters_pb_multi(): + from google.cloud.firestore_v1.types import StructuredQuery - self.assertIsNot(query_inst, new_query) - self.assertIsInstance(new_query, self._get_target_class()) - self.assertEqual(len(new_query._field_filters), 1) + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query - field_pb = new_query._field_filters[0] - expected_pb = StructuredQuery.UnaryFilter( - field=StructuredQuery.FieldReference(field_path=field_path), op=op_enum + query1 = _make_base_query(mock.sentinel.parent) + query2 = query1.where("x.y", ">", 50.5) + query3 = query2.where("ABC", "==", 123) + + filter_pb = query3._filters_pb() + op_class = StructuredQuery.FieldFilter.Operator + expected_pb = query.StructuredQuery.Filter( + composite_filter=query.StructuredQuery.CompositeFilter( + op=StructuredQuery.CompositeFilter.Operator.AND, + filters=[ + query.StructuredQuery.Filter( + field_filter=query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="x.y"), + op=op_class.GREATER_THAN, + value=document.Value(double_value=50.5), + ) + ), + query.StructuredQuery.Filter( + field_filter=query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="ABC"), + op=op_class.EQUAL, + value=document.Value(integer_value=123), + ) + ), + ], ) - self.assertEqual(field_pb, expected_pb) - self._compare_queries(query_inst, new_query, "_field_filters") + ) + assert filter_pb == expected_pb + - def test_where_eq_null(self): - from google.cloud.firestore_v1.types import StructuredQuery +def test_basequery__normalize_projection_none(): + query = _make_base_query(mock.sentinel.parent) + assert query._normalize_projection(None) is None - op_enum = StructuredQuery.UnaryFilter.Operator.IS_NULL - self._where_unary_helper(None, op_enum) - def test_where_gt_null(self): - with self.assertRaises(ValueError): - self._where_unary_helper(None, 0, op_string=">") +def test_basequery__normalize_projection_empty(): + projection = _make_projection_for_select([]) + query = _make_base_query(mock.sentinel.parent) + normalized = query._normalize_projection(projection) + field_paths = [field_ref.field_path for field_ref in normalized.fields] + assert field_paths == ["__name__"] - def test_where_eq_nan(self): - from google.cloud.firestore_v1.types import StructuredQuery - op_enum = StructuredQuery.UnaryFilter.Operator.IS_NAN - self._where_unary_helper(float("nan"), op_enum) +def test_basequery__normalize_projection_non_empty(): + projection = _make_projection_for_select(["a", "b"]) + query = _make_base_query(mock.sentinel.parent) + assert query._normalize_projection(projection) is projection - def test_where_le_nan(self): - with self.assertRaises(ValueError): - self._where_unary_helper(float("nan"), 0, op_string="<=") - def test_where_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD +def test_basequery__normalize_orders_wo_orders_wo_cursors(): + query = _make_base_query(mock.sentinel.parent) + expected = [] + assert query._normalize_orders() == expected - with self.assertRaises(ValueError): - self._where_unary_helper(DELETE_FIELD, 0) - def test_where_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP +def test_basequery__normalize_orders_w_orders_wo_cursors(): + query = _make_base_query(mock.sentinel.parent).order_by("a") + expected = [query._make_order("a", "ASCENDING")] + assert query._normalize_orders() == expected - with self.assertRaises(ValueError): - self._where_unary_helper(SERVER_TIMESTAMP, 0) - def test_where_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove +def test_basequery__normalize_orders_wo_orders_w_snapshot_cursor(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = _make_base_query(collection).start_at(snapshot) + expected = [query._make_order("__name__", "ASCENDING")] + assert query._normalize_orders() == expected - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayRemove([1, 3, 5]), 0) - def test_where_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion +def test_basequery__normalize_orders_w_name_orders_w_snapshot_cursor(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = ( + _make_base_query(collection) + .order_by("__name__", "DESCENDING") + .start_at(snapshot) + ) + expected = [query._make_order("__name__", "DESCENDING")] + assert query._normalize_orders() == expected + + +def test_basequery__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = ( + _make_base_query(collection) + .where("c", "<=", 20) + .order_by("c", "DESCENDING") + .start_at(snapshot) + ) + expected = [ + query._make_order("c", "DESCENDING"), + query._make_order("__name__", "DESCENDING"), + ] + assert query._normalize_orders() == expected + + +def test_basequery__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = _make_base_query(collection).where("c", "<=", 20).end_at(snapshot) + expected = [ + query._make_order("c", "ASCENDING"), + query._make_order("__name__", "ASCENDING"), + ] + assert query._normalize_orders() == expected + + +def test_basequery__normalize_orders_wo_orders_w_snapshot_cursor_w_isnull_where(): + values = {"a": 7, "b": "foo"} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + collection = _make_collection("here") + query = _make_base_query(collection).where("c", "==", None).end_at(snapshot) + expected = [ + query._make_order("__name__", "ASCENDING"), + ] + assert query._normalize_orders() == expected + + +def test_basequery__normalize_orders_w_name_orders_w_none_cursor(): + collection = _make_collection("here") + query = ( + _make_base_query(collection).order_by("__name__", "DESCENDING").start_at(None) + ) + expected = [query._make_order("__name__", "DESCENDING")] + assert query._normalize_orders() == expected + - with self.assertRaises(ValueError): - self._where_unary_helper(ArrayUnion([2, 4, 8]), 0) +def test_basequery__normalize_cursor_none(): + query = _make_base_query(mock.sentinel.parent) + assert query._normalize_cursor(None, query._orders) is None - def test_order_by_invalid_path(self): - query = self._make_one(mock.sentinel.parent) - with self.assertRaises(ValueError): - query.order_by("*") +def test_basequery__normalize_cursor_no_order(): + cursor = ([1], True) + query = _make_base_query(mock.sentinel.parent) - def test_order_by(self): - from google.cloud.firestore_v1.types import StructuredQuery + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) - klass = self._get_target_class() - query1 = self._make_one_all_fields( - skip_fields=("orders",), all_descendants=True - ) - field_path2 = "a" - query2 = query1.order_by(field_path2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, klass) - order = _make_order_pb(field_path2, StructuredQuery.Direction.ASCENDING) - self.assertEqual(query2._orders, (order,)) - self._compare_queries(query1, query2, "_orders") - - # Make sure it appends to the orders. - field_path3 = "b" - query3 = query2.order_by(field_path3, direction=klass.DESCENDING) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, klass) - order_pb3 = _make_order_pb(field_path3, StructuredQuery.Direction.DESCENDING) - self.assertEqual(query3._orders, (order, order_pb3)) - self._compare_queries(query2, query3, "_orders") - - def test_limit(self): - query1 = self._make_one_all_fields(all_descendants=True) - - limit2 = 100 - query2 = query1.limit(limit2) - self.assertFalse(query2._limit_to_last) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._limit, limit2) - self._compare_queries(query1, query2, "_limit") - - # Make sure it overrides. - limit3 = 10 - query3 = query2.limit(limit3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._limit, limit3) - self._compare_queries(query2, query3, "_limit") - - def test_limit_to_last(self): - query1 = self._make_one_all_fields(all_descendants=True) - - limit2 = 100 - query2 = query1.limit_to_last(limit2) - self.assertTrue(query2._limit_to_last) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._limit, limit2) - self._compare_queries(query1, query2, "_limit", "_limit_to_last") - - # Make sure it overrides. - limit3 = 10 - query3 = query2.limit(limit3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._limit, limit3) - self._compare_queries(query2, query3, "_limit", "_limit_to_last") - - def test__resolve_chunk_size(self): - # With a global limit - query = _make_client().collection("asdf").limit(5) - self.assertEqual(query._resolve_chunk_size(3, 10), 2) - self.assertEqual(query._resolve_chunk_size(3, 1), 1) - self.assertEqual(query._resolve_chunk_size(3, 2), 2) - - # With no limit - query = _make_client().collection("asdf")._query() - self.assertEqual(query._resolve_chunk_size(3, 10), 10) - self.assertEqual(query._resolve_chunk_size(3, 1), 1) - self.assertEqual(query._resolve_chunk_size(3, 2), 2) - - def test_offset(self): - query1 = self._make_one_all_fields(all_descendants=True) - - offset2 = 23 - query2 = query1.offset(offset2) - self.assertIsNot(query2, query1) - self.assertIsInstance(query2, self._get_target_class()) - self.assertEqual(query2._offset, offset2) - self._compare_queries(query1, query2, "_offset") - - # Make sure it overrides. - offset3 = 35 - query3 = query2.offset(offset3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._offset, offset3) - self._compare_queries(query2, query3, "_offset") - - @staticmethod - def _make_collection(*path, **kw): - from google.cloud.firestore_v1 import collection - - return collection.CollectionReference(*path, **kw) - - @staticmethod - def _make_docref(*path, **kw): - from google.cloud.firestore_v1 import document - - return document.DocumentReference(*path, **kw) - - @staticmethod - def _make_snapshot(docref, values): - from google.cloud.firestore_v1 import document - - return document.DocumentSnapshot(docref, values, True, None, None, None) - - def test__cursor_helper_w_dict(self): - values = {"a": 7, "b": "foo"} - query1 = self._make_one(mock.sentinel.parent) - query1._all_descendants = True - query2 = query1._cursor_helper(values, True, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - self.assertTrue(query2._all_descendants) - - cursor, before = query2._start_at - - self.assertEqual(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_tuple(self): - values = (7, "foo") - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, False, True) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._end_at) - - cursor, before = query2._start_at - - self.assertEqual(cursor, list(values)) - self.assertFalse(before) - - def test__cursor_helper_w_list(self): - values = [7, "foo"] - query1 = self._make_one(mock.sentinel.parent) - query2 = query1._cursor_helper(values, True, False) - - self.assertIs(query2._parent, mock.sentinel.parent) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, query1._orders) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertEqual(cursor, values) - self.assertIsNot(cursor, values) - self.assertTrue(before) - - def test__cursor_helper_w_snapshot_wrong_collection(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection) - - with self.assertRaises(ValueError): - query._cursor_helper(snapshot, False, False) - - def test__cursor_helper_w_snapshot_other_collection_all_descendants(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("there", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection, all_descendants=True) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test__cursor_helper_w_snapshot(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query1 = self._make_one(collection) - - query2 = query1._cursor_helper(snapshot, False, False) - - self.assertIs(query2._parent, collection) - self.assertIsNone(query2._projection) - self.assertEqual(query2._field_filters, ()) - self.assertEqual(query2._orders, ()) - self.assertIsNone(query2._limit) - self.assertIsNone(query2._offset) - self.assertIsNone(query2._start_at) - - cursor, before = query2._end_at - - self.assertIs(cursor, snapshot) - self.assertFalse(before) - - def test_start_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields( - parent=collection, skip_fields=("orders",), all_descendants=True - ) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.start_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_start_at") - - def test_start_after(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.start_after(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._start_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_start_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.start_after(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._start_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_start_at") - - def test_end_before(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("down") - - document_fields3 = {"down": 99.75} - query3 = query2.end_before(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, True)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("out") - values5 = {"down": 100.25, "out": b"\x00\x01"} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_before(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, True)) - self._compare_queries(query4, query5, "_end_at") - self._compare_queries(query4, query5, "_end_at") - - def test_end_at(self): - collection = self._make_collection("here") - query1 = self._make_one_all_fields(parent=collection, skip_fields=("orders",)) - query2 = query1.order_by("hi") - - document_fields3 = {"hi": "mom"} - query3 = query2.end_at(document_fields3) - self.assertIsNot(query3, query2) - self.assertIsInstance(query3, self._get_target_class()) - self.assertEqual(query3._end_at, (document_fields3, False)) - self._compare_queries(query2, query3, "_end_at") - - # Make sure it overrides. - query4 = query3.order_by("bye") - values5 = {"hi": "zap", "bye": 88} - docref = self._make_docref("here", "doc_id") - document_fields5 = self._make_snapshot(docref, values5) - query5 = query4.end_at(document_fields5) - self.assertIsNot(query5, query4) - self.assertIsInstance(query5, self._get_target_class()) - self.assertEqual(query5._end_at, (document_fields5, False)) - self._compare_queries(query4, query5, "_end_at") - - def test__filters_pb_empty(self): - query = self._make_one(mock.sentinel.parent) - self.assertEqual(len(query._field_filters), 0) - self.assertIsNone(query._filters_pb()) - - def test__filters_pb_single(self): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - filter_pb = query2._filters_pb() - expected_pb = query.StructuredQuery.Filter( +def test_basequery__normalize_cursor_as_list_mismatched_order(): + cursor = ([1, 2], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_as_dict_mismatched_order(): + cursor = ({"a": 1}, True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_as_dict_extra_orders_ok(): + cursor = ({"name": "Springfield"}, True) + query = _make_base_query(mock.sentinel.parent).order_by("name").order_by("state") + + normalized = query._normalize_cursor(cursor, query._orders) + assert normalized == (["Springfield"], True) + + +def test_basequery__normalize_cursor_extra_orders_ok(): + cursor = (["Springfield"], True) + query = _make_base_query(mock.sentinel.parent).order_by("name").order_by("state") + + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_w_delete(): + from google.cloud.firestore_v1 import DELETE_FIELD + + cursor = ([DELETE_FIELD], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_w_server_timestamp(): + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + + cursor = ([SERVER_TIMESTAMP], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_w_array_remove(): + from google.cloud.firestore_v1 import ArrayRemove + + cursor = ([ArrayRemove([1, 3, 5])], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_w_array_union(): + from google.cloud.firestore_v1 import ArrayUnion + + cursor = ([ArrayUnion([2, 4, 8])], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + with pytest.raises(ValueError): + query._normalize_cursor(cursor, query._orders) + + +def test_basequery__normalize_cursor_as_list_hit(): + cursor = ([1], True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + +def test_basequery__normalize_cursor_as_dict_hit(): + cursor = ({"b": 1}, True) + query = _make_base_query(mock.sentinel.parent).order_by("b", "ASCENDING") + + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + +def test_basequery__normalize_cursor_as_dict_with_dot_key_hit(): + cursor = ({"b.a": 1}, True) + query = _make_base_query(mock.sentinel.parent).order_by("b.a", "ASCENDING") + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + +def test_basequery__normalize_cursor_as_dict_with_inner_data_hit(): + cursor = ({"b": {"a": 1}}, True) + query = _make_base_query(mock.sentinel.parent).order_by("b.a", "ASCENDING") + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + +def test_basequery__normalize_cursor_as_snapshot_hit(): + values = {"b": 1} + docref = _make_docref("here", "doc_id") + snapshot = _make_snapshot(docref, values) + cursor = (snapshot, True) + collection = _make_collection("here") + query = _make_base_query(collection).order_by("b", "ASCENDING") + + assert query._normalize_cursor(cursor, query._orders) == ([1], True) + + +def test_basequery__normalize_cursor_w___name___w_reference(): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client"]) + parent._client = client + parent._path = ["C"] + query = _make_base_query(parent).order_by("__name__", "ASCENDING") + docref = _make_docref("here", "doc_id") + values = {"a": 7} + snapshot = _make_snapshot(docref, values) + expected = docref + cursor = (snapshot, True) + + assert query._normalize_cursor(cursor, query._orders) == ([expected], True) + + +def test_basequery__normalize_cursor_w___name___wo_slash(): + db_string = "projects/my-project/database/(default)" + client = mock.Mock(spec=["_database_string"]) + client._database_string = db_string + parent = mock.Mock(spec=["_path", "_client", "document"]) + parent._client = client + parent._path = ["C"] + document = parent.document.return_value = mock.Mock(spec=[]) + query = _make_base_query(parent).order_by("__name__", "ASCENDING") + cursor = (["b"], True) + expected = document + + assert query._normalize_cursor(cursor, query._orders) == ([expected], True) + parent.document.assert_called_once_with("b") + + +def test_basequery__to_protobuf_all_fields(): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.types import StructuredQuery + + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = _make_base_query(parent) + query2 = query1.select(["X", "Y", "Z"]) + query3 = query2.where("Y", ">", 2.5) + query4 = query3.order_by("X") + query5 = query4.limit(17) + query6 = query5.offset(3) + query7 = query6.start_at({"X": 10}) + query8 = query7.end_at({"X": 25}) + + structured_query_pb = query8._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "select": query.StructuredQuery.Projection( + fields=[ + query.StructuredQuery.FieldReference(field_path=field_path) + for field_path in ["X", "Y", "Z"] + ] + ), + "where": query.StructuredQuery.Filter( field_filter=query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference(field_path="x.y"), + field=query.StructuredQuery.FieldReference(field_path="Y"), op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document.Value(double_value=50.5), + value=document.Value(double_value=2.5), ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__filters_pb_multi(self): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - - query1 = self._make_one(mock.sentinel.parent) - query2 = query1.where("x.y", ">", 50.5) - query3 = query2.where("ABC", "==", 123) - - filter_pb = query3._filters_pb() - op_class = StructuredQuery.FieldFilter.Operator - expected_pb = query.StructuredQuery.Filter( - composite_filter=query.StructuredQuery.CompositeFilter( - op=StructuredQuery.CompositeFilter.Operator.AND, - filters=[ - query.StructuredQuery.Filter( - field_filter=query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference( - field_path="x.y" - ), - op=op_class.GREATER_THAN, - value=document.Value(double_value=50.5), - ) - ), - query.StructuredQuery.Filter( - field_filter=query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference( - field_path="ABC" - ), - op=op_class.EQUAL, - value=document.Value(integer_value=123), - ) - ), - ], + ), + "order_by": [_make_order_pb("X", StructuredQuery.Direction.ASCENDING)], + "start_at": query.Cursor( + values=[document.Value(integer_value=10)], before=True + ), + "end_at": query.Cursor(values=[document.Value(integer_value=25)]), + "offset": 3, + "limit": wrappers_pb2.Int32Value(value=17), + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb + + +def test_basequery__to_protobuf_select_only(): + from google.cloud.firestore_v1.types import query + + parent = mock.Mock(id="cat", spec=["id"]) + query1 = _make_base_query(parent) + field_paths = ["a.b", "a.c", "d"] + query2 = query1.select(field_paths) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "select": query.StructuredQuery.Projection( + fields=[ + query.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ), + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb + + +def test_basequery__to_protobuf_where_only(): + from google.cloud.firestore_v1.types import StructuredQuery + + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query + + parent = mock.Mock(id="dog", spec=["id"]) + query1 = _make_base_query(parent) + query2 = query1.where("a", "==", u"b") + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "where": query.StructuredQuery.Filter( + field_filter=query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="a"), + op=StructuredQuery.FieldFilter.Operator.EQUAL, + value=document.Value(string_value=u"b"), ) - ) - self.assertEqual(filter_pb, expected_pb) - - def test__normalize_projection_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_projection(None)) - - def test__normalize_projection_empty(self): - projection = self._make_projection_for_select([]) - query = self._make_one(mock.sentinel.parent) - normalized = query._normalize_projection(projection) - field_paths = [field_ref.field_path for field_ref in normalized.fields] - self.assertEqual(field_paths, ["__name__"]) - - def test__normalize_projection_non_empty(self): - projection = self._make_projection_for_select(["a", "b"]) - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._normalize_projection(projection), projection) - - def test__normalize_orders_wo_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent) - expected = [] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_orders_wo_cursors(self): - query = self._make_one(mock.sentinel.parent).order_by("a") - expected = [query._make_order("a", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).start_at(snapshot) - expected = [query._make_order("__name__", "ASCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_w_name_orders_w_snapshot_cursor(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .order_by("__name__", "DESCENDING") - .start_at(snapshot) - ) - expected = [query._make_order("__name__", "DESCENDING")] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_exists(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = ( - self._make_one(collection) - .where("c", "<=", 20) - .order_by("c", "DESCENDING") - .start_at(snapshot) - ) - expected = [ - query._make_order("c", "DESCENDING"), - query._make_order("__name__", "DESCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_neq_where(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).where("c", "<=", 20).end_at(snapshot) - expected = [ - query._make_order("c", "ASCENDING"), - query._make_order("__name__", "ASCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) - - def test__normalize_orders_wo_orders_w_snapshot_cursor_w_isnull_where(self): - values = {"a": 7, "b": "foo"} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - collection = self._make_collection("here") - query = self._make_one(collection).where("c", "==", None).end_at(snapshot) - expected = [ - query._make_order("__name__", "ASCENDING"), - ] - self.assertEqual(query._normalize_orders(), expected) + ), + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb - def test__normalize_orders_w_name_orders_w_none_cursor(self): - collection = self._make_collection("here") - query = ( - self._make_one(collection).order_by("__name__", "DESCENDING").start_at(None) - ) - expected = [query._make_order("__name__", "DESCENDING")] - self.assertEqual(query._normalize_orders(), expected) - def test__normalize_cursor_none(self): - query = self._make_one(mock.sentinel.parent) - self.assertIsNone(query._normalize_cursor(None, query._orders)) +def test_basequery__to_protobuf_order_by_only(): + from google.cloud.firestore_v1.types import StructuredQuery - def test__normalize_cursor_no_order(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent) + from google.cloud.firestore_v1.types import query - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) + parent = mock.Mock(id="fish", spec=["id"]) + query1 = _make_base_query(parent) + query2 = query1.order_by("abc") - def test__normalize_cursor_as_list_mismatched_order(self): - cursor = ([1, 2], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "order_by": [_make_order_pb("abc", StructuredQuery.Direction.ASCENDING)], + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) - def test__normalize_cursor_as_dict_mismatched_order(self): - cursor = ({"a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") +def test_basequery__to_protobuf_start_at_only(): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.types import StructuredQuery - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query - def test__normalize_cursor_as_dict_extra_orders_ok(self): - cursor = ({"name": "Springfield"}, True) - query = self._make_one(mock.sentinel.parent).order_by("name").order_by("state") + parent = mock.Mock(id="phish", spec=["id"]) + query_inst = ( + _make_base_query(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) + ) - normalized = query._normalize_cursor(cursor, query._orders) - self.assertEqual(normalized, (["Springfield"], True)) + structured_query_pb = query_inst._to_protobuf() + query_kwargs = { + "from_": [StructuredQuery.CollectionSelector(collection_id=parent.id)], + "order_by": [_make_order_pb("X.Y", StructuredQuery.Direction.ASCENDING)], + "start_at": query.Cursor(values=[document.Value(string_value=u"Z")]), + } + expected_pb = StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb - def test__normalize_cursor_extra_orders_ok(self): - cursor = (["Springfield"], True) - query = self._make_one(mock.sentinel.parent).order_by("name").order_by("state") - query._normalize_cursor(cursor, query._orders) +def test_basequery__to_protobuf_end_at_only(): + # NOTE: "only" is wrong since we must have ``order_by`` as well. + from google.cloud.firestore_v1.types import StructuredQuery - def test__normalize_cursor_w_delete(self): - from google.cloud.firestore_v1 import DELETE_FIELD + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query - cursor = ([DELETE_FIELD], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + parent = mock.Mock(id="ghoti", spec=["id"]) + query_inst = _make_base_query(parent).order_by("a").end_at({"a": 88}) - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) + structured_query_pb = query_inst._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "order_by": [_make_order_pb("a", StructuredQuery.Direction.ASCENDING)], + "end_at": query.Cursor(values=[document.Value(integer_value=88)]), + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb - def test__normalize_cursor_w_server_timestamp(self): - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - cursor = ([SERVER_TIMESTAMP], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") +def test_basequery__to_protobuf_offset_only(): + from google.cloud.firestore_v1.types import query - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) + parent = mock.Mock(id="cartt", spec=["id"]) + query1 = _make_base_query(parent) + offset = 14 + query2 = query1.offset(offset) - def test__normalize_cursor_w_array_remove(self): - from google.cloud.firestore_v1 import ArrayRemove + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "offset": offset, + } + expected_pb = query.StructuredQuery(**query_kwargs) + assert structured_query_pb == expected_pb - cursor = ([ArrayRemove([1, 3, 5])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) +def test_basequery__to_protobuf_limit_only(): + from google.protobuf import wrappers_pb2 + from google.cloud.firestore_v1.types import query - def test__normalize_cursor_w_array_union(self): - from google.cloud.firestore_v1 import ArrayUnion + parent = mock.Mock(id="donut", spec=["id"]) + query1 = _make_base_query(parent) + limit = 31 + query2 = query1.limit(limit) + + structured_query_pb = query2._to_protobuf() + query_kwargs = { + "from_": [query.StructuredQuery.CollectionSelector(collection_id=parent.id)], + "limit": wrappers_pb2.Int32Value(value=limit), + } + expected_pb = query.StructuredQuery(**query_kwargs) + + assert structured_query_pb == expected_pb + + +def test_basequery_comparator_no_ordering(): + query = _make_base_query(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + + sort = query._comparator(doc1, doc2) + assert sort == -1 + + +def test_basequery_comparator_no_ordering_same_id(): + query = _make_base_query(mock.sentinel.parent) + query._orders = [] + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") - cursor = ([ArrayUnion([2, 4, 8])], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument1") - with self.assertRaises(ValueError): - query._normalize_cursor(cursor, query._orders) + sort = query._comparator(doc1, doc2) + assert sort == 0 - def test__normalize_cursor_as_list_hit(self): - cursor = ([1], True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) +def test_basequery_comparator_ordering(): + query = _make_base_query(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] - def test__normalize_cursor_as_dict_hit(self): - cursor = ({"b": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b", "ASCENDING") + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + sort = query._comparator(doc1, doc2) + assert sort == 1 - def test__normalize_cursor_as_dict_with_dot_key_hit(self): - cursor = ({"b.a": 1}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) - def test__normalize_cursor_as_dict_with_inner_data_hit(self): - cursor = ({"b": {"a": 1}}, True) - query = self._make_one(mock.sentinel.parent).order_by("b.a", "ASCENDING") - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) +def test_basequery_comparator_ordering_descending(): + query = _make_base_query(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = -1 # descending + query._orders = [orderByMock] - def test__normalize_cursor_as_snapshot_hit(self): - values = {"b": 1} - docref = self._make_docref("here", "doc_id") - snapshot = self._make_snapshot(docref, values) - cursor = (snapshot, True) - collection = self._make_collection("here") - query = self._make_one(collection).order_by("b", "ASCENDING") + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "secondlovelace"}, + } + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } - self.assertEqual(query._normalize_cursor(cursor, query._orders), ([1], True)) + sort = query._comparator(doc1, doc2) + assert sort == -1 - def test__normalize_cursor_w___name___w_reference(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client"]) - parent._client = client - parent._path = ["C"] - query = self._make_one(parent).order_by("__name__", "ASCENDING") - docref = self._make_docref("here", "doc_id") - values = {"a": 7} - snapshot = self._make_snapshot(docref, values) - expected = docref - cursor = (snapshot, True) - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) +def test_basequery_comparator_missing_order_by_field_in_data_raises(): + query = _make_base_query(mock.sentinel.parent) + orderByMock = mock.Mock() + orderByMock.field.field_path = "last" + orderByMock.direction = 1 # ascending + query._orders = [orderByMock] - def test__normalize_cursor_w___name___wo_slash(self): - db_string = "projects/my-project/database/(default)" - client = mock.Mock(spec=["_database_string"]) - client._database_string = db_string - parent = mock.Mock(spec=["_path", "_client", "document"]) - parent._client = client - parent._path = ["C"] - document = parent.document.return_value = mock.Mock(spec=[]) - query = self._make_one(parent).order_by("__name__", "ASCENDING") - cursor = (["b"], True) - expected = document - - self.assertEqual( - query._normalize_cursor(cursor, query._orders), ([expected], True) - ) - parent.document.assert_called_once_with("b") - - def test__to_protobuf_all_fields(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.select(["X", "Y", "Z"]) - query3 = query2.where("Y", ">", 2.5) - query4 = query3.order_by("X") - query5 = query4.limit(17) - query6 = query5.offset(3) - query7 = query6.start_at({"X": 10}) - query8 = query7.end_at({"X": 25}) - - structured_query_pb = query8._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query.StructuredQuery.Projection( - fields=[ - query.StructuredQuery.FieldReference(field_path=field_path) - for field_path in ["X", "Y", "Z"] - ] - ), - "where": query.StructuredQuery.Filter( - field_filter=query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference(field_path="Y"), - op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document.Value(double_value=2.5), - ) - ), - "order_by": [_make_order_pb("X", StructuredQuery.Direction.ASCENDING)], - "start_at": query.Cursor( - values=[document.Value(integer_value=10)], before=True - ), - "end_at": query.Cursor(values=[document.Value(integer_value=25)]), - "offset": 3, - "limit": wrappers_pb2.Int32Value(value=17), - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_select_only(self): - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="cat", spec=["id"]) - query1 = self._make_one(parent) - field_paths = ["a.b", "a.c", "d"] - query2 = query1.select(field_paths) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "select": query.StructuredQuery.Projection( - fields=[ - query.StructuredQuery.FieldReference(field_path=field_path) - for field_path in field_paths - ] - ), - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_where_only(self): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="dog", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.where("a", "==", u"b") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "where": query.StructuredQuery.Filter( - field_filter=query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference(field_path="a"), - op=StructuredQuery.FieldFilter.Operator.EQUAL, - value=document.Value(string_value=u"b"), - ) - ), - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_order_by_only(self): - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="fish", spec=["id"]) - query1 = self._make_one(parent) - query2 = query1.order_by("abc") - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [_make_order_pb("abc", StructuredQuery.Direction.ASCENDING)], - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) + doc1 = mock.Mock() + doc1.reference._path = ("col", "adocument1") + doc1._data = {} + doc2 = mock.Mock() + doc2.reference._path = ("col", "adocument2") + doc2._data = { + "first": {"stringValue": "Ada"}, + "last": {"stringValue": "lovelace"}, + } - def test__to_protobuf_start_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.types import StructuredQuery + with pytest.raises(ValueError) as exc_info: + query._comparator(doc1, doc2) - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query + (message,) = exc_info.value.args + assert message.startswith("Can only compare fields ") - parent = mock.Mock(id="phish", spec=["id"]) - query_inst = ( - self._make_one(parent).order_by("X.Y").start_after({"X": {"Y": u"Z"}}) - ) - structured_query_pb = query_inst._to_protobuf() - query_kwargs = { - "from_": [StructuredQuery.CollectionSelector(collection_id=parent.id)], - "order_by": [_make_order_pb("X.Y", StructuredQuery.Direction.ASCENDING)], - "start_at": query.Cursor(values=[document.Value(string_value=u"Z")]), - } - expected_pb = StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_end_at_only(self): - # NOTE: "only" is wrong since we must have ``order_by`` as well. - from google.cloud.firestore_v1.types import StructuredQuery - - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="ghoti", spec=["id"]) - query_inst = self._make_one(parent).order_by("a").end_at({"a": 88}) - - structured_query_pb = query_inst._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "order_by": [_make_order_pb("a", StructuredQuery.Direction.ASCENDING)], - "end_at": query.Cursor(values=[document.Value(integer_value=88)]), - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_offset_only(self): - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="cartt", spec=["id"]) - query1 = self._make_one(parent) - offset = 14 - query2 = query1.offset(offset) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "offset": offset, - } - expected_pb = query.StructuredQuery(**query_kwargs) - self.assertEqual(structured_query_pb, expected_pb) - - def test__to_protobuf_limit_only(self): - from google.protobuf import wrappers_pb2 - from google.cloud.firestore_v1.types import query - - parent = mock.Mock(id="donut", spec=["id"]) - query1 = self._make_one(parent) - limit = 31 - query2 = query1.limit(limit) - - structured_query_pb = query2._to_protobuf() - query_kwargs = { - "from_": [ - query.StructuredQuery.CollectionSelector(collection_id=parent.id) - ], - "limit": wrappers_pb2.Int32Value(value=limit), - } - expected_pb = query.StructuredQuery(**query_kwargs) - - self.assertEqual(structured_query_pb, expected_pb) - - def test_comparator_no_ordering(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_no_ordering_same_id(self): - query = self._make_one(mock.sentinel.parent) - query._orders = [] - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument1") - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 0) - - def test_comparator_ordering(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, 1) - - def test_comparator_ordering_descending(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = -1 # descending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "secondlovelace"}, - } - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - sort = query._comparator(doc1, doc2) - self.assertEqual(sort, -1) - - def test_comparator_missing_order_by_field_in_data_raises(self): - query = self._make_one(mock.sentinel.parent) - orderByMock = mock.Mock() - orderByMock.field.field_path = "last" - orderByMock.direction = 1 # ascending - query._orders = [orderByMock] - - doc1 = mock.Mock() - doc1.reference._path = ("col", "adocument1") - doc1._data = {} - doc2 = mock.Mock() - doc2.reference._path = ("col", "adocument2") - doc2._data = { - "first": {"stringValue": "Ada"}, - "last": {"stringValue": "lovelace"}, - } - - with self.assertRaisesRegex(ValueError, "Can only compare fields "): - query._comparator(doc1, doc2) - - def test_multiple_recursive_calls(self): - query = self._make_one(_make_client().collection("asdf")) - self.assertIsInstance( - query.recursive().recursive(), type(query), - ) +def test_basequery_recursive_multiple(): + from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1.base_query import BaseQuery + class DerivedQuery(BaseQuery): + @staticmethod + def _get_collection_reference_class(): + return CollectionReference -class Test__enum_from_op_string(unittest.TestCase): - @staticmethod - def _call_fut(op_string): - from google.cloud.firestore_v1.base_query import _enum_from_op_string + query = DerivedQuery(_make_client().collection("asdf")) + assert isinstance(query.recursive().recursive(), DerivedQuery) - return _enum_from_op_string(op_string) - @staticmethod - def _get_op_class(): - from google.cloud.firestore_v1.types import StructuredQuery +def _get_op_class(): + from google.cloud.firestore_v1.types import StructuredQuery - return StructuredQuery.FieldFilter.Operator + return StructuredQuery.FieldFilter.Operator - def test_lt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<"), op_class.LESS_THAN) - def test_le(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("<="), op_class.LESS_THAN_OR_EQUAL) +def test__enum_from_op_string_lt(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - def test_eq(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("=="), op_class.EQUAL) + op_class = _get_op_class() + assert _enum_from_op_string("<") == op_class.LESS_THAN - def test_ge(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">="), op_class.GREATER_THAN_OR_EQUAL) - def test_gt(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut(">"), op_class.GREATER_THAN) +def test__enum_from_op_string_le(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - def test_array_contains(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("array_contains"), op_class.ARRAY_CONTAINS) + op_class = _get_op_class() + assert _enum_from_op_string("<=") == op_class.LESS_THAN_OR_EQUAL - def test_in(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("in"), op_class.IN) - def test_array_contains_any(self): - op_class = self._get_op_class() - self.assertEqual( - self._call_fut("array_contains_any"), op_class.ARRAY_CONTAINS_ANY - ) +def test__enum_from_op_string_eq(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - def test_not_in(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("not-in"), op_class.NOT_IN) + op_class = _get_op_class() + assert _enum_from_op_string("==") == op_class.EQUAL - def test_not_eq(self): - op_class = self._get_op_class() - self.assertEqual(self._call_fut("!="), op_class.NOT_EQUAL) - def test_invalid(self): - with self.assertRaises(ValueError): - self._call_fut("?") +def test__enum_from_op_string_ge(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string + op_class = _get_op_class() + assert _enum_from_op_string(">=") == op_class.GREATER_THAN_OR_EQUAL -class Test__isnan(unittest.TestCase): - @staticmethod - def _call_fut(value): - from google.cloud.firestore_v1.base_query import _isnan - return _isnan(value) +def test__enum_from_op_string_gt(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - def test_valid(self): - self.assertTrue(self._call_fut(float("nan"))) + op_class = _get_op_class() + assert _enum_from_op_string(">") == op_class.GREATER_THAN - def test_invalid(self): - self.assertFalse(self._call_fut(51.5)) - self.assertFalse(self._call_fut(None)) - self.assertFalse(self._call_fut("str")) - self.assertFalse(self._call_fut(int)) - self.assertFalse(self._call_fut(1.0 + 1.0j)) +def test__enum_from_op_string_array_contains(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string -class Test__enum_from_direction(unittest.TestCase): - @staticmethod - def _call_fut(direction): - from google.cloud.firestore_v1.base_query import _enum_from_direction + op_class = _get_op_class() + assert _enum_from_op_string("array_contains") == op_class.ARRAY_CONTAINS - return _enum_from_direction(direction) - def test_success(self): - from google.cloud.firestore_v1.types import StructuredQuery +def test__enum_from_op_string_in(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - from google.cloud.firestore_v1.query import Query + op_class = _get_op_class() + assert _enum_from_op_string("in") == op_class.IN - dir_class = StructuredQuery.Direction - self.assertEqual(self._call_fut(Query.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(Query.DESCENDING), dir_class.DESCENDING) - # Ints pass through - self.assertEqual(self._call_fut(dir_class.ASCENDING), dir_class.ASCENDING) - self.assertEqual(self._call_fut(dir_class.DESCENDING), dir_class.DESCENDING) +def test__enum_from_op_string_array_contains_any(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - def test_failure(self): - with self.assertRaises(ValueError): - self._call_fut("neither-ASCENDING-nor-DESCENDING") + op_class = _get_op_class() + assert _enum_from_op_string("array_contains_any") == op_class.ARRAY_CONTAINS_ANY -class Test__filter_pb(unittest.TestCase): - @staticmethod - def _call_fut(field_or_unary): - from google.cloud.firestore_v1.base_query import _filter_pb +def test__enum_from_op_string_not_in(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - return _filter_pb(field_or_unary) + op_class = _get_op_class() + assert _enum_from_op_string("not-in") == op_class.NOT_IN - def test_unary(self): - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import query +def test__enum_from_op_string_not_eq(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - unary_pb = query.StructuredQuery.UnaryFilter( - field=query.StructuredQuery.FieldReference(field_path="a.b.c"), - op=StructuredQuery.UnaryFilter.Operator.IS_NULL, - ) - filter_pb = self._call_fut(unary_pb) - expected_pb = query.StructuredQuery.Filter(unary_filter=unary_pb) - self.assertEqual(filter_pb, expected_pb) + op_class = _get_op_class() + assert _enum_from_op_string("!=") == op_class.NOT_EQUAL - def test_field(self): - from google.cloud.firestore_v1.types import StructuredQuery - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import query +def test__enum_from_op_string_invalid(): + from google.cloud.firestore_v1.base_query import _enum_from_op_string - field_filter_pb = query.StructuredQuery.FieldFilter( - field=query.StructuredQuery.FieldReference(field_path="XYZ"), - op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, - value=document.Value(double_value=90.75), - ) - filter_pb = self._call_fut(field_filter_pb) - expected_pb = query.StructuredQuery.Filter(field_filter=field_filter_pb) - self.assertEqual(filter_pb, expected_pb) + with pytest.raises(ValueError): + _enum_from_op_string("?") - def test_bad_type(self): - with self.assertRaises(ValueError): - self._call_fut(None) +def test__isnan_valid(): + from google.cloud.firestore_v1.base_query import _isnan -class Test__cursor_pb(unittest.TestCase): - @staticmethod - def _call_fut(cursor_pair): - from google.cloud.firestore_v1.base_query import _cursor_pb + assert _isnan(float("nan")) - return _cursor_pb(cursor_pair) - def test_no_pair(self): - self.assertIsNone(self._call_fut(None)) +def test__isnan_invalid(): + from google.cloud.firestore_v1.base_query import _isnan - def test_success(self): - from google.cloud.firestore_v1.types import query - from google.cloud.firestore_v1 import _helpers + assert not _isnan(51.5) + assert not _isnan(None) + assert not _isnan("str") + assert not _isnan(int) + assert not _isnan(1.0 + 1.0j) - data = [1.5, 10, True] - cursor_pair = data, True - cursor_pb = self._call_fut(cursor_pair) +def test__enum_from_direction_success(): + from google.cloud.firestore_v1.types import StructuredQuery + from google.cloud.firestore_v1.base_query import _enum_from_direction + from google.cloud.firestore_v1.query import Query - expected_pb = query.Cursor( - values=[_helpers.encode_value(value) for value in data], before=True - ) - self.assertEqual(cursor_pb, expected_pb) - - -class Test__query_response_to_snapshot(unittest.TestCase): - @staticmethod - def _call_fut(response_pb, collection, expected_prefix): - from google.cloud.firestore_v1.base_query import _query_response_to_snapshot - - return _query_response_to_snapshot(response_pb, collection, expected_prefix) - - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None, None) - self.assertIsNone(snapshot) - - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot - - client = _make_client() - collection = client.collection("a", "b", "c") - _, expected_prefix = collection._parent_info() - - # Create name for the protobuf. - doc_id = "gigantic" - name = "{}/{}".format(expected_prefix, doc_id) - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=name, data=data) - - snapshot = self._call_fut(response_pb, collection, expected_prefix) - self.assertIsInstance(snapshot, DocumentSnapshot) - expected_path = collection._path + (doc_id,) - self.assertEqual(snapshot.reference._path, expected_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb.read_time) - self.assertEqual(snapshot.create_time, response_pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb.document.update_time) - - -class Test__collection_group_query_response_to_snapshot(unittest.TestCase): - @staticmethod - def _call_fut(response_pb, collection): - from google.cloud.firestore_v1.base_query import ( - _collection_group_query_response_to_snapshot, - ) + dir_class = StructuredQuery.Direction + assert _enum_from_direction(Query.ASCENDING) == dir_class.ASCENDING + assert _enum_from_direction(Query.DESCENDING) == dir_class.DESCENDING + + # Ints pass through + assert _enum_from_direction(dir_class.ASCENDING) == dir_class.ASCENDING + assert _enum_from_direction(dir_class.DESCENDING) == dir_class.DESCENDING + + +def test__enum_from_direction_failure(): + from google.cloud.firestore_v1.base_query import _enum_from_direction + + with pytest.raises(ValueError): + _enum_from_direction("neither-ASCENDING-nor-DESCENDING") + + +def test__filter_pb_unary(): + from google.cloud.firestore_v1.types import StructuredQuery + from google.cloud.firestore_v1.base_query import _filter_pb + from google.cloud.firestore_v1.types import query + + unary_pb = query.StructuredQuery.UnaryFilter( + field=query.StructuredQuery.FieldReference(field_path="a.b.c"), + op=StructuredQuery.UnaryFilter.Operator.IS_NULL, + ) + filter_pb = _filter_pb(unary_pb) + expected_pb = query.StructuredQuery.Filter(unary_filter=unary_pb) + assert filter_pb == expected_pb + + +def test__filter_pb_field(): + from google.cloud.firestore_v1.types import StructuredQuery + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1.base_query import _filter_pb + + field_filter_pb = query.StructuredQuery.FieldFilter( + field=query.StructuredQuery.FieldReference(field_path="XYZ"), + op=StructuredQuery.FieldFilter.Operator.GREATER_THAN, + value=document.Value(double_value=90.75), + ) + filter_pb = _filter_pb(field_filter_pb) + expected_pb = query.StructuredQuery.Filter(field_filter=field_filter_pb) + assert filter_pb == expected_pb + + +def test__filter_pb_bad_type(): + from google.cloud.firestore_v1.base_query import _filter_pb + + with pytest.raises(ValueError): + _filter_pb(None) + + +def test__cursor_pb_no_pair(): + from google.cloud.firestore_v1.base_query import _cursor_pb + + assert _cursor_pb(None) is None + + +def test__cursor_pb_success(): + from google.cloud.firestore_v1.types import query + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.base_query import _cursor_pb + + data = [1.5, 10, True] + cursor_pair = data, True + + cursor_pb = _cursor_pb(cursor_pair) + + expected_pb = query.Cursor( + values=[_helpers.encode_value(value) for value in data], before=True + ) + assert cursor_pb == expected_pb + + +def test__query_response_to_snapshot_empty(): + from google.cloud.firestore_v1.base_query import _query_response_to_snapshot - return _collection_group_query_response_to_snapshot(response_pb, collection) + response_pb = _make_query_response() + snapshot = _query_response_to_snapshot(response_pb, None, None) + assert snapshot is None - def test_empty(self): - response_pb = _make_query_response() - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) - def test_after_offset(self): - skipped_results = 410 - response_pb = _make_query_response(skipped_results=skipped_results) - snapshot = self._call_fut(response_pb, None) - self.assertIsNone(snapshot) +def test__query_response_to_snapshot_after_offset(): + from google.cloud.firestore_v1.base_query import _query_response_to_snapshot - def test_response(self): - from google.cloud.firestore_v1.document import DocumentSnapshot + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = _query_response_to_snapshot(response_pb, None, None) + assert snapshot is None - client = _make_client() - collection = client.collection("a", "b", "c") - other_collection = client.collection("a", "b", "d") - to_match = other_collection.document("gigantic") - data = {"a": 901, "b": True} - response_pb = _make_query_response(name=to_match._document_path, data=data) - snapshot = self._call_fut(response_pb, collection) - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertEqual(snapshot.reference._document_path, to_match._document_path) - self.assertEqual(snapshot.to_dict(), data) - self.assertTrue(snapshot.exists) - self.assertEqual(snapshot.read_time, response_pb._pb.read_time) - self.assertEqual(snapshot.create_time, response_pb._pb.document.create_time) - self.assertEqual(snapshot.update_time, response_pb._pb.document.update_time) +def test__query_response_to_snapshot_response(): + from google.cloud.firestore_v1.base_query import _query_response_to_snapshot + from google.cloud.firestore_v1.document import DocumentSnapshot + + client = _make_client() + collection = client.collection("a", "b", "c") + _, expected_prefix = collection._parent_info() + + # Create name for the protobuf. + doc_id = "gigantic" + name = "{}/{}".format(expected_prefix, doc_id) + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=name, data=data) + + snapshot = _query_response_to_snapshot(response_pb, collection, expected_prefix) + assert isinstance(snapshot, DocumentSnapshot) + expected_path = collection._path + (doc_id,) + assert snapshot.reference._path == expected_path + assert snapshot.to_dict() == data + assert snapshot.exists + assert snapshot.read_time == response_pb.read_time + assert snapshot.create_time == response_pb.document.create_time + assert snapshot.update_time == response_pb.document.update_time + + +def test__collection_group_query_response_to_snapshot_empty(): + from google.cloud.firestore_v1.base_query import ( + _collection_group_query_response_to_snapshot, + ) + + response_pb = _make_query_response() + snapshot = _collection_group_query_response_to_snapshot(response_pb, None) + assert snapshot is None + + +def test__collection_group_query_response_to_snapshot_after_offset(): + from google.cloud.firestore_v1.base_query import ( + _collection_group_query_response_to_snapshot, + ) + + skipped_results = 410 + response_pb = _make_query_response(skipped_results=skipped_results) + snapshot = _collection_group_query_response_to_snapshot(response_pb, None) + assert snapshot is None + + +def test__collection_group_query_response_to_snapshot_response(): + from google.cloud.firestore_v1.document import DocumentSnapshot + from google.cloud.firestore_v1.base_query import ( + _collection_group_query_response_to_snapshot, + ) + + client = _make_client() + collection = client.collection("a", "b", "c") + other_collection = client.collection("a", "b", "d") + to_match = other_collection.document("gigantic") + data = {"a": 901, "b": True} + response_pb = _make_query_response(name=to_match._document_path, data=data) + + snapshot = _collection_group_query_response_to_snapshot(response_pb, collection) + assert isinstance(snapshot, DocumentSnapshot) + assert snapshot.reference._document_path == to_match._document_path + assert snapshot.to_dict() == data + assert snapshot.exists + assert snapshot.read_time == response_pb._pb.read_time + assert snapshot.create_time == response_pb._pb.document.create_time + assert snapshot.update_time == response_pb._pb.document.update_time def _make_credentials(): @@ -1519,49 +1591,47 @@ def _make_cursor_pb(pair): return query.Cursor(values=value_pbs, before=before) -class TestQueryPartition(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.base_query import QueryPartition - - return QueryPartition - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - partition = self._make_one(mock.sentinel.query, "start", "end") - assert partition._query is mock.sentinel.query - assert partition.start_at == "start" - assert partition.end_at == "end" - - def test_query_begin(self): - partition = self._make_one(DummyQuery("PARENT"), None, "end") - query = partition.query() - assert query._parent == "PARENT" - assert query.all_descendants == "YUP" - assert query.orders == "ORDER" - assert query.start_at is None - assert query.end_at == (["end"], True) - - def test_query_middle(self): - partition = self._make_one(DummyQuery("PARENT"), "start", "end") - query = partition.query() - assert query._parent == "PARENT" - assert query.all_descendants == "YUP" - assert query.orders == "ORDER" - assert query.start_at == (["start"], True) - assert query.end_at == (["end"], True) - - def test_query_end(self): - partition = self._make_one(DummyQuery("PARENT"), "start", None) - query = partition.query() - assert query._parent == "PARENT" - assert query.all_descendants == "YUP" - assert query.orders == "ORDER" - assert query.start_at == (["start"], True) - assert query.end_at is None +def _make_query_partition(*args, **kwargs): + from google.cloud.firestore_v1.base_query import QueryPartition + + return QueryPartition(*args, **kwargs) + + +def test_constructor(): + partition = _make_query_partition(mock.sentinel.query, "start", "end") + assert partition._query is mock.sentinel.query + assert partition.start_at == "start" + assert partition.end_at == "end" + + +def test_query_begin(): + partition = _make_query_partition(DummyQuery("PARENT"), None, "end") + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at is None + assert query.end_at == (["end"], True) + + +def test_query_middle(): + partition = _make_query_partition(DummyQuery("PARENT"), "start", "end") + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at == (["start"], True) + assert query.end_at == (["end"], True) + + +def test_query_end(): + partition = _make_query_partition(DummyQuery("PARENT"), "start", None) + query = partition.query() + assert query._parent == "PARENT" + assert query.all_descendants == "YUP" + assert query.orders == "ORDER" + assert query.start_at == (["start"], True) + assert query.end_at is None class DummyQuery: @@ -1576,3 +1646,32 @@ def __init__( self.orders = orders self.start_at = start_at self.end_at = end_at + + +def _make_projection_for_select(field_paths): + from google.cloud.firestore_v1.types import query + + return query.StructuredQuery.Projection( + fields=[ + query.StructuredQuery.FieldReference(field_path=field_path) + for field_path in field_paths + ] + ) + + +def _make_collection(*path, **kw): + from google.cloud.firestore_v1 import collection + + return collection.CollectionReference(*path, **kw) + + +def _make_docref(*path, **kw): + from google.cloud.firestore_v1 import document + + return document.DocumentReference(*path, **kw) + + +def _make_snapshot(docref, values): + from google.cloud.firestore_v1 import document + + return document.DocumentSnapshot(docref, values, True, None, None, None) diff --git a/tests/unit/v1/test_base_transaction.py b/tests/unit/v1/test_base_transaction.py index b0dc527de2b17..db5dbd92a83fc 100644 --- a/tests/unit/v1/test_base_transaction.py +++ b/tests/unit/v1/test_base_transaction.py @@ -12,108 +12,106 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import mock +import pytest -class TestBaseTransaction(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.base_transaction import BaseTransaction +def _make_base_transaction(*args, **kwargs): + from google.cloud.firestore_v1.base_transaction import BaseTransaction - return BaseTransaction + return BaseTransaction(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor_defaults(self): - from google.cloud.firestore_v1.transaction import MAX_ATTEMPTS +def test_basetransaction_constructor_defaults(): + from google.cloud.firestore_v1.transaction import MAX_ATTEMPTS - transaction = self._make_one() - self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) - self.assertFalse(transaction._read_only) - self.assertIsNone(transaction._id) + transaction = _make_base_transaction() + assert transaction._max_attempts == MAX_ATTEMPTS + assert not transaction._read_only + assert transaction._id is None - def test_constructor_explicit(self): - transaction = self._make_one(max_attempts=10, read_only=True) - self.assertEqual(transaction._max_attempts, 10) - self.assertTrue(transaction._read_only) - self.assertIsNone(transaction._id) - def test__options_protobuf_read_only(self): - from google.cloud.firestore_v1.types import common +def test_basetransaction_constructor_explicit(): + transaction = _make_base_transaction(max_attempts=10, read_only=True) + assert transaction._max_attempts == 10 + assert transaction._read_only + assert transaction._id is None - transaction = self._make_one(read_only=True) - options_pb = transaction._options_protobuf(None) - expected_pb = common.TransactionOptions( - read_only=common.TransactionOptions.ReadOnly() - ) - self.assertEqual(options_pb, expected_pb) - def test__options_protobuf_read_only_retry(self): - from google.cloud.firestore_v1.base_transaction import _CANT_RETRY_READ_ONLY +def test_basetransaction__options_protobuf_read_only(): + from google.cloud.firestore_v1.types import common - transaction = self._make_one(read_only=True) - retry_id = b"illuminate" + transaction = _make_base_transaction(read_only=True) + options_pb = transaction._options_protobuf(None) + expected_pb = common.TransactionOptions( + read_only=common.TransactionOptions.ReadOnly() + ) + assert options_pb == expected_pb - with self.assertRaises(ValueError) as exc_info: - transaction._options_protobuf(retry_id) - self.assertEqual(exc_info.exception.args, (_CANT_RETRY_READ_ONLY,)) +def test_basetransaction__options_protobuf_read_only_retry(): + from google.cloud.firestore_v1.base_transaction import _CANT_RETRY_READ_ONLY - def test__options_protobuf_read_write(self): - transaction = self._make_one() - options_pb = transaction._options_protobuf(None) - self.assertIsNone(options_pb) + transaction = _make_base_transaction(read_only=True) + retry_id = b"illuminate" - def test__options_protobuf_on_retry(self): - from google.cloud.firestore_v1.types import common + with pytest.raises(ValueError) as exc_info: + transaction._options_protobuf(retry_id) - transaction = self._make_one() - retry_id = b"hocus-pocus" - options_pb = transaction._options_protobuf(retry_id) - expected_pb = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=retry_id) - ) - self.assertEqual(options_pb, expected_pb) + assert exc_info.value.args == (_CANT_RETRY_READ_ONLY,) - def test_in_progress_property(self): - transaction = self._make_one() - self.assertFalse(transaction.in_progress) - transaction._id = b"not-none-bites" - self.assertTrue(transaction.in_progress) - def test_id_property(self): - transaction = self._make_one() - transaction._id = mock.sentinel.eye_dee - self.assertIs(transaction.id, mock.sentinel.eye_dee) +def test_basetransaction__options_protobuf_read_write(): + transaction = _make_base_transaction() + options_pb = transaction._options_protobuf(None) + assert options_pb is None -class Test_Transactional(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.base_transaction import _BaseTransactional +def test_basetransaction__options_protobuf_on_retry(): + from google.cloud.firestore_v1.types import common - return _BaseTransactional + transaction = _make_base_transaction() + retry_id = b"hocus-pocus" + options_pb = transaction._options_protobuf(retry_id) + expected_pb = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=retry_id) + ) + assert options_pb == expected_pb - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor(self): - wrapped = self._make_one(mock.sentinel.callable_) - self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) +def test_basetransaction_in_progress_property(): + transaction = _make_base_transaction() + assert not transaction.in_progress + transaction._id = b"not-none-bites" + assert transaction.in_progress - def test__reset(self): - wrapped = self._make_one(mock.sentinel.callable_) - wrapped.current_id = b"not-none" - wrapped.retry_id = b"also-not" - ret_val = wrapped._reset() - self.assertIsNone(ret_val) +def test_basetransaction_id_property(): + transaction = _make_base_transaction() + transaction._id = mock.sentinel.eye_dee + assert transaction.id is mock.sentinel.eye_dee - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) + +def _make_base_transactional(*args, **kwargs): + from google.cloud.firestore_v1.base_transaction import _BaseTransactional + + return _BaseTransactional(*args, **kwargs) + + +def test_basetransactional_constructor(): + wrapped = _make_base_transactional(mock.sentinel.callable_) + assert wrapped.to_wrap is mock.sentinel.callable_ + assert wrapped.current_id is None + assert wrapped.retry_id is None + + +def test__basetransactional_reset(): + wrapped = _make_base_transactional(mock.sentinel.callable_) + wrapped.current_id = b"not-none" + wrapped.retry_id = b"also-not" + + ret_val = wrapped._reset() + assert ret_val is None + + assert wrapped.current_id is None + assert wrapped.retry_id is None diff --git a/tests/unit/v1/test_batch.py b/tests/unit/v1/test_batch.py index 3e3bef1ad8a36..e69fa558fc382 100644 --- a/tests/unit/v1/test_batch.py +++ b/tests/unit/v1/test_batch.py @@ -12,149 +12,144 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock - - -class TestWriteBatch(unittest.TestCase): - """Tests the WriteBatch.commit method""" - - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.batch import WriteBatch - - return WriteBatch - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - batch = self._make_one(mock.sentinel.client) - self.assertIs(batch._client, mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - self.assertIsNone(batch.write_results) - self.assertIsNone(batch.commit_time) - - def _commit_helper(self, retry=None, timeout=None): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.Mock(spec=["commit"]) - timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore.CommitResponse( - write_results=[write.WriteResult(), write.WriteResult()], - commit_time=timestamp, - ) - firestore_api.commit.return_value = commit_response - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Attach the fake GAPIC to a real client. - client = _make_client("grand") - client._firestore_api_internal = firestore_api - - # Actually make a batch with some mutations and call commit(). - batch = self._make_one(client) - document1 = client.document("a", "b") - batch.create(document1, {"ten": 10, "buck": "ets"}) - document2 = client.document("c", "d", "e", "f") - batch.delete(document2) - self.assertEqual(len(batch), 2) +import pytest + + +def _make_write_batch(*args, **kwargs): + from google.cloud.firestore_v1.batch import WriteBatch + + return WriteBatch(*args, **kwargs) + + +def test_writebatch_ctor(): + batch = _make_write_batch(mock.sentinel.client) + assert batch._client is mock.sentinel.client + assert batch._write_pbs == [] + assert batch.write_results is None + assert batch.commit_time is None + + +def _commit_helper(retry=None, timeout=None): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = _make_write_batch(client) + document1 = client.document("a", "b") + batch.create(document1, {"ten": 10, "buck": "ets"}) + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + assert len(batch) == 2 + write_pbs = batch._write_pbs[::] + + write_results = batch.commit(**kwargs) + assert write_results == list(commit_response.write_results) + assert batch.write_results == write_results + assert batch.commit_time.timestamp_pb() == timestamp + # Make sure batch has no more "changes". + assert batch._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_writebatch_commit(): + _commit_helper() + + +def test_writebatch_commit_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + _commit_helper(retry=retry, timeout=timeout) + + +def test_writebatch_as_context_mgr_wo_error(): + from google.protobuf import timestamp_pb2 + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + firestore_api = mock.Mock(spec=["commit"]) + timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) + commit_response = firestore.CommitResponse( + write_results=[write.WriteResult(), write.WriteResult()], commit_time=timestamp, + ) + firestore_api.commit.return_value = commit_response + client = _make_client() + client._firestore_api_internal = firestore_api + batch = _make_write_batch(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + with batch as ctx_mgr: + assert ctx_mgr is batch + ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) + ctx_mgr.delete(document2) write_pbs = batch._write_pbs[::] - write_results = batch.commit(**kwargs) - self.assertEqual(write_results, list(commit_response.write_results)) - self.assertEqual(batch.write_results, write_results) - self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) - # Make sure batch has no more "changes". - self.assertEqual(batch._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - - def test_commit(self): - self._commit_helper() - - def test_commit_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - - self._commit_helper(retry=retry, timeout=timeout) - - def test_as_context_mgr_wo_error(self): - from google.protobuf import timestamp_pb2 - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - firestore_api = mock.Mock(spec=["commit"]) - timestamp = timestamp_pb2.Timestamp(seconds=1234567, nanos=123456798) - commit_response = firestore.CommitResponse( - write_results=[write.WriteResult(), write.WriteResult()], - commit_time=timestamp, - ) - firestore_api.commit.return_value = commit_response - client = _make_client() - client._firestore_api_internal = firestore_api - batch = self._make_one(client) - document1 = client.document("a", "b") - document2 = client.document("c", "d", "e", "f") - + assert batch.write_results == list(commit_response.write_results) + assert batch.commit_time.timestamp_pb() == timestamp + # Make sure batch has no more "changes". + assert batch._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_writebatch_as_context_mgr_w_error(): + firestore_api = mock.Mock(spec=["commit"]) + client = _make_client() + client._firestore_api_internal = firestore_api + batch = _make_write_batch(client) + document1 = client.document("a", "b") + document2 = client.document("c", "d", "e", "f") + + with pytest.raises(RuntimeError): with batch as ctx_mgr: - self.assertIs(ctx_mgr, batch) ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) ctx_mgr.delete(document2) - write_pbs = batch._write_pbs[::] - - self.assertEqual(batch.write_results, list(commit_response.write_results)) - self.assertEqual(batch.commit_time.timestamp_pb(), timestamp) - # Make sure batch has no more "changes". - self.assertEqual(batch._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - - def test_as_context_mgr_w_error(self): - firestore_api = mock.Mock(spec=["commit"]) - client = _make_client() - client._firestore_api_internal = firestore_api - batch = self._make_one(client) - document1 = client.document("a", "b") - document2 = client.document("c", "d", "e", "f") - - with self.assertRaises(RuntimeError): - with batch as ctx_mgr: - ctx_mgr.create(document1, {"ten": 10, "buck": "ets"}) - ctx_mgr.delete(document2) - raise RuntimeError("testing") - - # batch still has its changes, as _exit_ (and commit) is not invoked - # changes are preserved so commit can be retried - self.assertIsNone(batch.write_results) - self.assertIsNone(batch.commit_time) - self.assertEqual(len(batch._write_pbs), 2) - - firestore_api.commit.assert_not_called() + raise RuntimeError("testing") + + # batch still has its changes, as _exit_ (and commit) is not invoked + # changes are preserved so commit can be retried + assert batch.write_results is None + assert batch.commit_time is None + assert len(batch._write_pbs) == 2 + + firestore_api.commit.assert_not_called() def _make_credentials(): diff --git a/tests/unit/v1/test_bulk_batch.py b/tests/unit/v1/test_bulk_batch.py index 20d43b9ccca80..97cd66a417f1f 100644 --- a/tests/unit/v1/test_bulk_batch.py +++ b/tests/unit/v1/test_bulk_batch.py @@ -12,84 +12,78 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock -class TestBulkWriteBatch(unittest.TestCase): - """Tests the BulkWriteBatch.commit method""" - - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch - - return BulkWriteBatch - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - batch = self._make_one(mock.sentinel.client) - self.assertIs(batch._client, mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) - self.assertIsNone(batch.write_results) - - def _write_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.Mock(spec=["batch_write"]) - write_response = firestore.BatchWriteResponse( - write_results=[write.WriteResult(), write.WriteResult()], - ) - firestore_api.batch_write.return_value = write_response - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Attach the fake GAPIC to a real client. - client = _make_client("grand") - client._firestore_api_internal = firestore_api - - # Actually make a batch with some mutations and call commit(). - batch = self._make_one(client) - document1 = client.document("a", "b") - self.assertFalse(document1 in batch) - batch.create(document1, {"ten": 10, "buck": "ets"}) - self.assertTrue(document1 in batch) - document2 = client.document("c", "d", "e", "f") - batch.delete(document2) - write_pbs = batch._write_pbs[::] - - resp = batch.commit(**kwargs) - self.assertEqual(resp.write_results, list(write_response.write_results)) - self.assertEqual(batch.write_results, resp.write_results) - # Make sure batch has no more "changes". - self.assertEqual(batch._write_pbs, []) - - # Verify the mocks. - firestore_api.batch_write.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "labels": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - - def test_write(self): - self._write_helper() - - def test_write_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - - self._write_helper(retry=retry, timeout=timeout) +def _make_bulk_write_batch(*args, **kwargs): + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + + return BulkWriteBatch(*args, **kwargs) + + +def test_bulkwritebatch_ctor(): + batch = _make_bulk_write_batch(mock.sentinel.client) + assert batch._client is mock.sentinel.client + assert batch._write_pbs == [] + assert batch.write_results is None + + +def _write_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.Mock(spec=["batch_write"]) + write_response = firestore.BatchWriteResponse( + write_results=[write.WriteResult(), write.WriteResult()], + ) + firestore_api.batch_write.return_value = write_response + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Attach the fake GAPIC to a real client. + client = _make_client("grand") + client._firestore_api_internal = firestore_api + + # Actually make a batch with some mutations and call commit(). + batch = _make_bulk_write_batch(client) + document1 = client.document("a", "b") + assert document1 not in batch + batch.create(document1, {"ten": 10, "buck": "ets"}) + assert document1 in batch + document2 = client.document("c", "d", "e", "f") + batch.delete(document2) + write_pbs = batch._write_pbs[::] + + resp = batch.commit(**kwargs) + assert resp.write_results == list(write_response.write_results) + assert batch.write_results == resp.write_results + # Make sure batch has no more "changes". + assert batch._write_pbs == [] + + # Verify the mocks. + firestore_api.batch_write.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "labels": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_bulkwritebatch_write(): + _write_helper() + + +def test_bulkwritebatch_write_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + + _write_helper(retry=retry, timeout=timeout) def _make_credentials(): diff --git a/tests/unit/v1/test_bulk_writer.py b/tests/unit/v1/test_bulk_writer.py index f39a288551758..dc185d387ec39 100644 --- a/tests/unit/v1/test_bulk_writer.py +++ b/tests/unit/v1/test_bulk_writer.py @@ -13,74 +13,69 @@ # limitations under the License. import datetime -import unittest from typing import List, NoReturn, Optional, Tuple, Type -from google.rpc import status_pb2 import aiounittest # type: ignore import mock - -from google.cloud.firestore_v1._helpers import build_timestamp, ExistsOption -from google.cloud.firestore_v1.async_client import AsyncClient -from google.cloud.firestore_v1.base_document import BaseDocumentReference -from google.cloud.firestore_v1.client import Client -from google.cloud.firestore_v1.base_client import BaseClient -from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch -from google.cloud.firestore_v1.bulk_writer import ( - BulkRetry, - BulkWriter, - BulkWriteFailure, - BulkWriterCreateOperation, - BulkWriterOptions, - BulkWriterOperation, - OperationRetry, - SendMode, -) -from google.cloud.firestore_v1.types.firestore import BatchWriteResponse -from google.cloud.firestore_v1.types.write import WriteResult -from tests.unit.v1._test_helpers import FakeThreadPoolExecutor - - -class NoSendBulkWriter(BulkWriter): - """Test-friendly BulkWriter subclass whose `_send` method returns faked - BatchWriteResponse instances and whose _process_response` method stores - those faked instances for later evaluation.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._responses: List[ - Tuple[BulkWriteBatch, BatchWriteResponse, BulkWriterOperation] - ] = [] - self._fail_indices: List[int] = [] - - def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: - """Generate a fake `BatchWriteResponse` for the supplied batch instead - of actually submitting it to the server. - """ - return BatchWriteResponse( - write_results=[ - WriteResult(update_time=build_timestamp()) - if index not in self._fail_indices - else WriteResult() - for index, el in enumerate(batch._document_references.values()) - ], - status=[ - status_pb2.Status(code=0 if index not in self._fail_indices else 1) - for index, el in enumerate(batch._document_references.values()) - ], - ) - - def _process_response( - self, - batch: BulkWriteBatch, - response: BatchWriteResponse, - operations: List[BulkWriterOperation], - ) -> NoReturn: - super()._process_response(batch, response, operations) - self._responses.append((batch, response, operations)) - - def _instantiate_executor(self): - return FakeThreadPoolExecutor() +import pytest + +from google.cloud.firestore_v1 import async_client +from google.cloud.firestore_v1 import client +from google.cloud.firestore_v1 import base_client + + +def _make_no_send_bulk_writer(*args, **kwargs): + from google.rpc import status_pb2 + from google.cloud.firestore_v1._helpers import build_timestamp + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + from google.cloud.firestore_v1.bulk_writer import BulkWriter + from google.cloud.firestore_v1.bulk_writer import BulkWriterOperation + from google.cloud.firestore_v1.types.firestore import BatchWriteResponse + from google.cloud.firestore_v1.types.write import WriteResult + from tests.unit.v1._test_helpers import FakeThreadPoolExecutor + + class NoSendBulkWriter(BulkWriter): + """Test-friendly BulkWriter subclass whose `_send` method returns faked + BatchWriteResponse instances and whose _process_response` method stores + those faked instances for later evaluation.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._responses: List[ + Tuple[BulkWriteBatch, BatchWriteResponse, BulkWriterOperation] + ] = [] + self._fail_indices: List[int] = [] + + def _send(self, batch: BulkWriteBatch) -> BatchWriteResponse: + """Generate a fake `BatchWriteResponse` for the supplied batch instead + of actually submitting it to the server. + """ + return BatchWriteResponse( + write_results=[ + WriteResult(update_time=build_timestamp()) + if index not in self._fail_indices + else WriteResult() + for index, el in enumerate(batch._document_references.values()) + ], + status=[ + status_pb2.Status(code=0 if index not in self._fail_indices else 1) + for index, el in enumerate(batch._document_references.values()) + ], + ) + + def _process_response( + self, + batch: BulkWriteBatch, + response: BatchWriteResponse, + operations: List[BulkWriterOperation], + ) -> NoReturn: + super()._process_response(batch, response, operations) + self._responses.append((batch, response, operations)) + + def _instantiate_executor(self): + return FakeThreadPoolExecutor() + + return NoSendBulkWriter(*args, **kwargs) def _make_credentials(): @@ -96,8 +91,8 @@ class _SyncClientMixin: _PRESERVES_CLIENT = True @staticmethod - def _make_client() -> Client: - return Client(credentials=_make_credentials(), project="project-id") + def _make_client() -> client.Client: + return client.Client(credentials=_make_credentials(), project="project-id") class _AsyncClientMixin: @@ -107,18 +102,22 @@ class _AsyncClientMixin: _PRESERVES_CLIENT = False @staticmethod - def _make_client() -> AsyncClient: - return AsyncClient(credentials=_make_credentials(), project="project-id") + def _make_client() -> async_client.AsyncClient: + return async_client.AsyncClient( + credentials=_make_credentials(), project="project-id" + ) class _BaseBulkWriterTests: - def _ctor_helper(self, **kw): + def _basebulkwriter_ctor_helper(self, **kw): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + client = self._make_client() if not self._PRESERVES_CLIENT: sync_copy = client._sync_copy = object() - bw = NoSendBulkWriter(client, **kw) + bw = _make_no_send_bulk_writer(client, **kw) if self._PRESERVES_CLIENT: assert bw._client is client @@ -130,27 +129,22 @@ def _ctor_helper(self, **kw): else: assert bw._options == BulkWriterOptions() - def test_ctor_defaults(self): - self._ctor_helper() + def test_basebulkwriter_ctor_defaults(self): + self._basebulkwriter_ctor_helper() - def test_ctor_explicit(self): - options = BulkWriterOptions(retry=BulkRetry.immediate) - self._ctor_helper(options=options) + def test_basebulkwriter_ctor_explicit(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - @staticmethod - def _get_document_reference( - client: BaseClient, - collection_name: Optional[str] = "col", - id: Optional[str] = None, - ) -> Type: - return client.collection(collection_name).document(id) + options = BulkWriterOptions(retry=BulkRetry.immediate) + self._basebulkwriter_ctor_helper(options=options) def _doc_iter(self, client, num: int, ids: Optional[List[str]] = None): for _ in range(num): id: Optional[str] = ids[_] if ids else None - yield self._get_document_reference(client, id=id), {"id": _} + yield _get_document_reference(client, id=id), {"id": _} - def _verify_bw_activity(self, bw: BulkWriter, counts: List[Tuple[int, int]]): + def _verify_bw_activity(self, bw, counts: List[Tuple[int, int]]): """ Args: bw: (BulkWriter) @@ -160,29 +154,26 @@ def _verify_bw_activity(self, bw: BulkWriter, counts: List[Tuple[int, int]]): representing the number of times batches of that size should have been sent. """ + from google.cloud.firestore_v1.types.firestore import BatchWriteResponse + total_batches = sum([el[1] for el in counts]) - batches_word = "batches" if total_batches != 1 else "batch" - self.assertEqual( - len(bw._responses), - total_batches, - f"Expected to have sent {total_batches} {batches_word}, but only sent {len(bw._responses)}", - ) + assert len(bw._responses) == total_batches docs_count = {} resp: BatchWriteResponse for _, resp, ops in bw._responses: docs_count.setdefault(len(resp.write_results), 0) docs_count[len(resp.write_results)] += 1 - self.assertEqual(len(docs_count), len(counts)) + assert len(docs_count) == len(counts) for size, num_sent in counts: - self.assertEqual(docs_count[size], num_sent) + assert docs_count[size] == num_sent # Assert flush leaves no operation behind - self.assertEqual(len(bw._operations), 0) + assert len(bw._operations) == 0 - def test_create_calls_send_correctly(self): + def test_basebulkwriter_create_calls_send_correctly(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, data in self._doc_iter(client, 101): bw.create(ref, data) bw.flush() @@ -190,9 +181,9 @@ def test_create_calls_send_correctly(self): # batch should have been sent once. self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) - def test_delete_calls_send_correctly(self): + def test_basebulkwriter_delete_calls_send_correctly(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, _ in self._doc_iter(client, 101): bw.delete(ref) bw.flush() @@ -200,19 +191,19 @@ def test_delete_calls_send_correctly(self): # batch should have been sent once. self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) - def test_delete_separates_batch(self): + def test_basebulkwriter_delete_separates_batch(self): client = self._make_client() - bw = NoSendBulkWriter(client) - ref = self._get_document_reference(client, id="asdf") + bw = _make_no_send_bulk_writer(client) + ref = _get_document_reference(client, id="asdf") bw.create(ref, {}) bw.delete(ref) bw.flush() # Consecutive batches each with 1 operation should have been sent self._verify_bw_activity(bw, [(1, 2,)]) - def test_set_calls_send_correctly(self): + def test_basebulkwriter_set_calls_send_correctly(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, data in self._doc_iter(client, 101): bw.set(ref, data) bw.flush() @@ -220,9 +211,9 @@ def test_set_calls_send_correctly(self): # batch should have been sent once. self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) - def test_update_calls_send_correctly(self): + def test_basebulkwriter_update_calls_send_correctly(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, data in self._doc_iter(client, 101): bw.update(ref, data) bw.flush() @@ -230,10 +221,10 @@ def test_update_calls_send_correctly(self): # batch should have been sent once. self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) - def test_update_separates_batch(self): + def test_basebulkwriter_update_separates_batch(self): client = self._make_client() - bw = NoSendBulkWriter(client) - ref = self._get_document_reference(client, id="asdf") + bw = _make_no_send_bulk_writer(client) + ref = _get_document_reference(client, id="asdf") bw.create(ref, {}) bw.update(ref, {"field": "value"}) bw.flush() @@ -241,9 +232,15 @@ def test_update_separates_batch(self): # batch should have been sent once. self._verify_bw_activity(bw, [(1, 2,)]) - def test_invokes_success_callbacks_successfully(self): + def test_basebulkwriter_invokes_success_callbacks_successfully(self): + from google.cloud.firestore_v1.base_document import BaseDocumentReference + from google.cloud.firestore_v1.bulk_batch import BulkWriteBatch + from google.cloud.firestore_v1.bulk_writer import BulkWriter + from google.cloud.firestore_v1.types.firestore import BatchWriteResponse + from google.cloud.firestore_v1.types.write import WriteResult + client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) bw._fail_indices = [] bw._sent_batches = 0 bw._sent_documents = 0 @@ -267,13 +264,15 @@ def _on_write(ref, result, bulk_writer): bw.create(ref, data) bw.flush() - self.assertEqual(bw._sent_batches, 6) - self.assertEqual(bw._sent_documents, 101) - self.assertEqual(len(bw._operations), 0) + assert bw._sent_batches == 6 + assert bw._sent_documents == 101 + assert len(bw._operations) == 0 + + def test_basebulkwriter_invokes_error_callbacks_successfully(self): + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure - def test_invokes_error_callbacks_successfully(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) # First document in each batch will "fail" bw._fail_indices = [0] bw._sent_batches = 0 @@ -303,14 +302,18 @@ def _on_error(error, bw) -> bool: bw.create(ref, data) bw.flush() - self.assertEqual(bw._sent_documents, 0) - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(bw._sent_batches, 2) - self.assertEqual(len(bw._operations), 0) + assert bw._sent_documents == 0 + assert bw._total_retries == times_to_retry + assert bw._sent_batches == 2 + assert len(bw._operations) == 0 + + def test_basebulkwriter_invokes_error_callbacks_successfully_multiple_retries(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_invokes_error_callbacks_successfully_multiple_retries(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -342,14 +345,17 @@ def _on_error(error, bw) -> bool: bw.create(ref, data) bw.flush() - self.assertEqual(bw._sent_documents, 1) - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(bw._sent_batches, times_to_retry + 1) - self.assertEqual(len(bw._operations), 0) + assert bw._sent_documents == 1 + assert bw._total_retries == times_to_retry + assert bw._sent_batches == times_to_retry + 1 + assert len(bw._operations) == 0 + + def test_basebulkwriter_default_error_handler(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_default_error_handler(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) bw._attempts = 0 @@ -365,11 +371,15 @@ def _on_error(error, bw): for ref, data in self._doc_iter(client, 1): bw.create(ref, data) bw.flush() - self.assertEqual(bw._attempts, 15) + assert bw._attempts == 15 + + def test_basebulkwriter_handles_errors_and_successes_correctly(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_handles_errors_and_successes_correctly(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -402,14 +412,18 @@ def _on_error(error, bw) -> bool: bw.flush() # 19 successful writes per batch - self.assertEqual(bw._sent_documents, 38) - self.assertEqual(bw._total_retries, times_to_retry * 2) - self.assertEqual(bw._sent_batches, 4) - self.assertEqual(len(bw._operations), 0) + assert bw._sent_documents == 38 + assert bw._total_retries == times_to_retry * 2 + assert bw._sent_batches == 4 + assert len(bw._operations) == 0 + + def test_basebulkwriter_create_retriable(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_create_retriable(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -430,12 +444,16 @@ def _on_error(error, bw) -> bool: bw.create(ref, data) bw.flush() - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(len(bw._operations), 0) + assert bw._total_retries == times_to_retry + assert len(bw._operations) == 0 + + def test_basebulkwriter_delete_retriable(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_delete_retriable(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -456,12 +474,16 @@ def _on_error(error, bw) -> bool: bw.delete(ref) bw.flush() - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(len(bw._operations), 0) + assert bw._total_retries == times_to_retry + assert len(bw._operations) == 0 + + def test_basebulkwriter_set_retriable(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_set_retriable(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -482,12 +504,16 @@ def _on_error(error, bw) -> bool: bw.set(ref, data) bw.flush() - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(len(bw._operations), 0) + assert bw._total_retries == times_to_retry + assert len(bw._operations) == 0 + + def test_basebulkwriter_update_retriable(self): + from google.cloud.firestore_v1.bulk_writer import BulkRetry + from google.cloud.firestore_v1.bulk_writer import BulkWriteFailure + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions - def test_update_retriable(self): client = self._make_client() - bw = NoSendBulkWriter( + bw = _make_no_send_bulk_writer( client, options=BulkWriterOptions(retry=BulkRetry.immediate), ) # First document in each batch will "fail" @@ -508,12 +534,17 @@ def _on_error(error, bw) -> bool: bw.update(ref, data) bw.flush() - self.assertEqual(bw._total_retries, times_to_retry) - self.assertEqual(len(bw._operations), 0) + assert bw._total_retries == times_to_retry + assert len(bw._operations) == 0 + + def test_basebulkwriter_serial_calls_send_correctly(self): + from google.cloud.firestore_v1.bulk_writer import BulkWriterOptions + from google.cloud.firestore_v1.bulk_writer import SendMode - def test_serial_calls_send_correctly(self): client = self._make_client() - bw = NoSendBulkWriter(client, options=BulkWriterOptions(mode=SendMode.serial)) + bw = _make_no_send_bulk_writer( + client, options=BulkWriterOptions(mode=SendMode.serial) + ) for ref, data in self._doc_iter(client, 101): bw.create(ref, data) bw.flush() @@ -521,9 +552,9 @@ def test_serial_calls_send_correctly(self): # batch should have been sent once. self._verify_bw_activity(bw, [(20, 5,), (1, 1,)]) - def test_separates_same_document(self): + def test_basebulkwriter_separates_same_document(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, data in self._doc_iter(client, 2, ["same-id", "same-id"]): bw.create(ref, data) bw.flush() @@ -531,9 +562,9 @@ def test_separates_same_document(self): # Expect to have sent 1-item batches twice. self._verify_bw_activity(bw, [(1, 2,)]) - def test_separates_same_document_different_operation(self): + def test_basebulkwriter_separates_same_document_different_operation(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for ref, data in self._doc_iter(client, 1, ["same-id"]): bw.create(ref, data) bw.set(ref, data) @@ -542,61 +573,63 @@ def test_separates_same_document_different_operation(self): # Expect to have sent 1-item batches twice. self._verify_bw_activity(bw, [(1, 2,)]) - def test_ensure_sending_repeatedly_callable(self): + def test_basebulkwriter_ensure_sending_repeatedly_callable(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) bw._is_sending = True bw._ensure_sending() - def test_flush_close_repeatedly_callable(self): + def test_basebulkwriter_flush_close_repeatedly_callable(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) bw.flush() bw.flush() bw.close() - def test_flush_sends_in_progress(self): + def test_basebulkwriter_flush_sends_in_progress(self): client = self._make_client() - bw = NoSendBulkWriter(client) - bw.create(self._get_document_reference(client), {"whatever": "you want"}) + bw = _make_no_send_bulk_writer(client) + bw.create(_get_document_reference(client), {"whatever": "you want"}) bw.flush() self._verify_bw_activity(bw, [(1, 1,)]) - def test_flush_sends_all_queued_batches(self): + def test_basebulkwriter_flush_sends_all_queued_batches(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) for _ in range(2): - bw.create(self._get_document_reference(client), {"whatever": "you want"}) + bw.create(_get_document_reference(client), {"whatever": "you want"}) bw._queued_batches.append(bw._operations) bw._reset_operations() bw.flush() self._verify_bw_activity(bw, [(1, 2,)]) - def test_cannot_add_after_close(self): + def test_basebulkwriter_cannot_add_after_close(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) bw.close() - self.assertRaises(Exception, bw._verify_not_closed) + with pytest.raises(Exception): + bw._verify_not_closed() - def test_multiple_flushes(self): + def test_basebulkwriter_multiple_flushes(self): client = self._make_client() - bw = NoSendBulkWriter(client) + bw = _make_no_send_bulk_writer(client) bw.flush() bw.flush() - def test_update_raises_with_bad_option(self): + def test_basebulkwriter_update_raises_with_bad_option(self): + from google.cloud.firestore_v1._helpers import ExistsOption + client = self._make_client() - bw = NoSendBulkWriter(client) - self.assertRaises( - ValueError, - bw.update, - self._get_document_reference(client, "id"), - {}, - option=ExistsOption(exists=True), - ) + bw = _make_no_send_bulk_writer(client) + with pytest.raises(ValueError): + bw.update( + _get_document_reference(client, "id"), + {}, + option=ExistsOption(exists=True), + ) -class TestSyncBulkWriter(_SyncClientMixin, _BaseBulkWriterTests, unittest.TestCase): +class TestSyncBulkWriter(_SyncClientMixin, _BaseBulkWriterTests): """All BulkWriters are opaquely async, but this one simulates a BulkWriter dealing with synchronous DocumentReferences.""" @@ -608,58 +641,67 @@ class TestAsyncBulkWriter( dealing with AsyncDocumentReferences.""" -class TestScheduling(unittest.TestCase): - @staticmethod - def _make_client() -> Client: - return Client(credentials=_make_credentials(), project="project-id") +def _make_sync_client() -> client.Client: + return client.Client(credentials=_make_credentials(), project="project-id") - def test_max_in_flight_honored(self): - bw = NoSendBulkWriter(self._make_client()) - # Calling this method sets up all the internal timekeeping machinery - bw._rate_limiter.take_tokens(20) - # Now we pretend that all tokens have been consumed. This will force us - # to wait actual, real world milliseconds before being cleared to send more - bw._rate_limiter._available_tokens = 0 +def test_scheduling_max_in_flight_honored(): + bw = _make_no_send_bulk_writer(_make_sync_client()) + # Calling this method sets up all the internal timekeeping machinery + bw._rate_limiter.take_tokens(20) - st = datetime.datetime.now() + # Now we pretend that all tokens have been consumed. This will force us + # to wait actual, real world milliseconds before being cleared to send more + bw._rate_limiter._available_tokens = 0 - # Make a real request, subject to the actual real world clock. - # As this request is 1/10th the per second limit, we should wait ~100ms - bw._request_send(50) + st = datetime.datetime.now() - self.assertGreater( - datetime.datetime.now() - st, datetime.timedelta(milliseconds=90), - ) + # Make a real request, subject to the actual real world clock. + # As this request is 1/10th the per second limit, we should wait ~100ms + bw._request_send(50) - def test_operation_retry_scheduling(self): - now = datetime.datetime.now() - one_second_from_now = now + datetime.timedelta(seconds=1) + assert datetime.datetime.now() - st > datetime.timedelta(milliseconds=90) + + +def test_scheduling_operation_retry_scheduling(): + from google.cloud.firestore_v1.bulk_writer import BulkWriterCreateOperation + from google.cloud.firestore_v1.bulk_writer import OperationRetry + + now = datetime.datetime.now() + one_second_from_now = now + datetime.timedelta(seconds=1) + + db = _make_sync_client() + operation = BulkWriterCreateOperation( + reference=db.collection("asdf").document("asdf"), + document_data={"does.not": "matter"}, + ) + operation2 = BulkWriterCreateOperation( + reference=db.collection("different").document("document"), + document_data={"different": "values"}, + ) + + op1 = OperationRetry(operation=operation, run_at=now) + op2 = OperationRetry(operation=operation2, run_at=now) + op3 = OperationRetry(operation=operation, run_at=one_second_from_now) + + assert op1 < op3 + assert op1 < op3.run_at + assert op2 < op3 + assert op2 < op3.run_at + + # Because these have the same values for `run_at`, neither should conclude + # they are less than the other. It is okay that if we checked them with + # greater-than evaluation, they would return True (because + # @functools.total_ordering flips the result from __lt__). In practice, + # this only arises for actual ties, and we don't care how actual ties are + # ordered as we maintain the sorted list of scheduled retries. + assert not (op1 < op2) + assert not (op2 < op1) - db = self._make_client() - operation = BulkWriterCreateOperation( - reference=db.collection("asdf").document("asdf"), - document_data={"does.not": "matter"}, - ) - operation2 = BulkWriterCreateOperation( - reference=db.collection("different").document("document"), - document_data={"different": "values"}, - ) - op1 = OperationRetry(operation=operation, run_at=now) - op2 = OperationRetry(operation=operation2, run_at=now) - op3 = OperationRetry(operation=operation, run_at=one_second_from_now) - - self.assertLess(op1, op3) - self.assertLess(op1, op3.run_at) - self.assertLess(op2, op3) - self.assertLess(op2, op3.run_at) - - # Because these have the same values for `run_at`, neither should conclude - # they are less than the other. It is okay that if we checked them with - # greater-than evaluation, they would return True (because - # @functools.total_ordering flips the result from __lt__). In practice, - # this only arises for actual ties, and we don't care how actual ties are - # ordered as we maintain the sorted list of scheduled retries. - self.assertFalse(op1 < op2) - self.assertFalse(op2 < op1) +def _get_document_reference( + client: base_client.BaseClient, + collection_name: Optional[str] = "col", + id: Optional[str] = None, +) -> Type: + return client.collection(collection_name).document(id) diff --git a/tests/unit/v1/test_bundle.py b/tests/unit/v1/test_bundle.py index e53e07fe14cff..99803683be3ed 100644 --- a/tests/unit/v1/test_bundle.py +++ b/tests/unit/v1/test_bundle.py @@ -14,23 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import typing -import unittest import mock -from google.cloud.firestore_bundle import BundleElement, FirestoreBundle -from google.cloud.firestore_v1 import _helpers -from google.cloud.firestore_v1.async_collection import AsyncCollectionReference -from google.cloud.firestore_v1.base_query import BaseQuery -from google.cloud.firestore_v1.collection import CollectionReference -from google.cloud.firestore_v1.query import Query -from google.cloud.firestore_v1.services.firestore.client import FirestoreClient -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse -from google.protobuf.timestamp_pb2 import Timestamp # type: ignore +import pytest + +from google.cloud.firestore_v1 import base_query +from google.cloud.firestore_v1 import collection +from google.cloud.firestore_v1 import query as query_mod from tests.unit.v1 import _test_helpers -from tests.unit.v1 import test__helpers class _CollectionQueryMixin: @@ -59,13 +51,18 @@ def _bundled_collection_helper( self, document_ids: typing.Optional[typing.List[str]] = None, data: typing.Optional[typing.List[typing.Dict]] = None, - ) -> CollectionReference: + ) -> collection.CollectionReference: """Builder of a mocked Query for the sake of testing Bundles. Bundling queries involves loading the actual documents for cold storage, and this method arranges all of the necessary mocks so that unit tests can think they are evaluating a live query. """ + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types.document import Document + from google.cloud.firestore_v1.types.firestore import RunQueryResponse + from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + client = self.get_client() template = client._database_string + "/documents/col/{}" document_ids = document_ids or ["doc-1", "doc-2"] @@ -100,13 +97,13 @@ def _bundled_query_helper( self, document_ids: typing.Optional[typing.List[str]] = None, data: typing.Optional[typing.List[typing.Dict]] = None, - ) -> BaseQuery: + ) -> base_query.BaseQuery: return self._bundled_collection_helper( document_ids=document_ids, data=data, )._query() -class TestBundle(_CollectionQueryMixin, unittest.TestCase): +class TestBundle(_CollectionQueryMixin): @staticmethod def build_results_iterable(items): return iter(items) @@ -117,19 +114,26 @@ def get_client(): @staticmethod def get_internal_client_mock(): - return mock.create_autospec(FirestoreClient) + from google.cloud.firestore_v1.services.firestore import client + + return mock.create_autospec(client.FirestoreClient) @classmethod def get_collection_class(cls): - return CollectionReference + return collection.CollectionReference def test_add_document(self): + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") doc = _test_helpers.build_document_snapshot(client=_test_helpers.make_client()) bundle.add_document(doc) - self.assertEqual(bundle.documents[self.doc_key].snapshot, doc) + assert bundle.documents[self.doc_key].snapshot == doc def test_add_newer_document(self): + from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") old_doc = _test_helpers.build_document_snapshot( data={"version": 1}, @@ -137,7 +141,7 @@ def test_add_newer_document(self): read_time=Timestamp(seconds=1, nanos=1), ) bundle.add_document(old_doc) - self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 1) + assert bundle.documents[self.doc_key].snapshot._data["version"] == 1 # Builds the same ID by default new_doc = _test_helpers.build_document_snapshot( @@ -146,9 +150,12 @@ def test_add_newer_document(self): read_time=Timestamp(seconds=1, nanos=2), ) bundle.add_document(new_doc) - self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + assert bundle.documents[self.doc_key].snapshot._data["version"] == 2 def test_add_older_document(self): + from google.protobuf.timestamp_pb2 import Timestamp # type: ignore + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") new_doc = _test_helpers.build_document_snapshot( data={"version": 2}, @@ -156,7 +163,7 @@ def test_add_older_document(self): read_time=Timestamp(seconds=1, nanos=2), ) bundle.add_document(new_doc) - self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + assert bundle.documents[self.doc_key].snapshot._data["version"] == 2 # Builds the same ID by default old_doc = _test_helpers.build_document_snapshot( @@ -165,9 +172,11 @@ def test_add_older_document(self): read_time=Timestamp(seconds=1, nanos=1), ) bundle.add_document(old_doc) - self.assertEqual(bundle.documents[self.doc_key].snapshot._data["version"], 2) + assert bundle.documents[self.doc_key].snapshot._data["version"] == 2 def test_add_document_with_different_read_times(self): + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") doc = _test_helpers.build_document_snapshot( client=_test_helpers.make_client(), @@ -183,147 +192,176 @@ def test_add_document_with_different_read_times(self): ) bundle.add_document(doc) - self.assertEqual( - bundle.documents[self.doc_key].snapshot._data, {"version": 1}, - ) + assert bundle.documents[self.doc_key].snapshot._data == {"version": 1} bundle.add_document(doc_refreshed) - self.assertEqual( - bundle.documents[self.doc_key].snapshot._data, {"version": 2}, - ) + assert bundle.documents[self.doc_key].snapshot._data == {"version": 2} def test_add_query(self): + from google.cloud.firestore_bundle import FirestoreBundle + query = self._bundled_query_helper() bundle = FirestoreBundle("test") bundle.add_named_query("asdf", query) - self.assertIsNotNone(bundle.named_queries.get("asdf")) - self.assertIsNotNone( + assert bundle.named_queries.get("asdf") is not None + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-1" ] + is not None ) - self.assertIsNotNone( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-2" ] + is not None ) def test_add_query_twice(self): + from google.cloud.firestore_bundle import FirestoreBundle + query = self._bundled_query_helper() bundle = FirestoreBundle("test") bundle.add_named_query("asdf", query) - self.assertRaises(ValueError, bundle.add_named_query, "asdf", query) + with pytest.raises(ValueError): + bundle.add_named_query("asdf", query) def test_adding_collection_raises_error(self): + from google.cloud.firestore_bundle import FirestoreBundle + col = self._bundled_collection_helper() bundle = FirestoreBundle("test") - self.assertRaises(ValueError, bundle.add_named_query, "asdf", col) + with pytest.raises(ValueError): + bundle.add_named_query("asdf", col) def test_bundle_build(self): + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") bundle.add_named_query("best name", self._bundled_query_helper()) - self.assertIsInstance(bundle.build(), str) + assert isinstance(bundle.build(), str) def test_get_documents(self): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_bundle import FirestoreBundle + bundle = FirestoreBundle("test") - query: Query = self._bundled_query_helper() # type: ignore + query: query_mod.Query = self._bundled_query_helper() # type: ignore bundle.add_named_query("sweet query", query) docs_iter = _helpers._get_documents_from_bundle( bundle, query_name="sweet query" ) doc = next(docs_iter) - self.assertEqual(doc.id, "doc-1") + assert doc.id == "doc-1" doc = next(docs_iter) - self.assertEqual(doc.id, "doc-2") + assert doc.id == "doc-2" # Now an empty one docs_iter = _helpers._get_documents_from_bundle( bundle, query_name="wrong query" ) doc = next(docs_iter, None) - self.assertIsNone(doc) + assert doc is None def test_get_documents_two_queries(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + bundle = FirestoreBundle("test") - query: Query = self._bundled_query_helper() # type: ignore + query: query_mod.Query = self._bundled_query_helper() # type: ignore bundle.add_named_query("sweet query", query) - query: Query = self._bundled_query_helper(document_ids=["doc-3", "doc-4"]) # type: ignore + query: query_mod.Query = self._bundled_query_helper(document_ids=["doc-3", "doc-4"]) # type: ignore bundle.add_named_query("second query", query) docs_iter = _helpers._get_documents_from_bundle( bundle, query_name="sweet query" ) doc = next(docs_iter) - self.assertEqual(doc.id, "doc-1") + assert doc.id == "doc-1" doc = next(docs_iter) - self.assertEqual(doc.id, "doc-2") + assert doc.id == "doc-2" docs_iter = _helpers._get_documents_from_bundle( bundle, query_name="second query" ) doc = next(docs_iter) - self.assertEqual(doc.id, "doc-3") + assert doc.id == "doc-3" doc = next(docs_iter) - self.assertEqual(doc.id, "doc-4") + assert doc.id == "doc-4" def test_get_document(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + bundle = FirestoreBundle("test") - query: Query = self._bundled_query_helper() # type: ignore + query: query_mod.Query = self._bundled_query_helper() # type: ignore bundle.add_named_query("sweet query", query) - self.assertIsNotNone( + assert ( _helpers._get_document_from_bundle( bundle, document_id="projects/project-project/databases/(default)/documents/col/doc-1", - ), + ) + is not None ) - self.assertIsNone( + assert ( _helpers._get_document_from_bundle( bundle, document_id="projects/project-project/databases/(default)/documents/col/doc-0", - ), + ) + is None ) -class TestAsyncBundle(_CollectionQueryMixin, unittest.TestCase): +class TestAsyncBundle(_CollectionQueryMixin): @staticmethod def get_client(): return _test_helpers.make_async_client() @staticmethod def build_results_iterable(items): + from tests.unit.v1 import test__helpers + return test__helpers.AsyncIter(items) @staticmethod def get_internal_client_mock(): + from tests.unit.v1 import test__helpers + return test__helpers.AsyncMock(spec=["run_query"]) @classmethod def get_collection_class(cls): - return AsyncCollectionReference + from google.cloud.firestore_v1 import async_collection + + return async_collection.AsyncCollectionReference def test_async_query(self): # Create an async query, but this test does not need to be # marked as async by pytest because `bundle.add_named_query()` # seemlessly handles accepting async iterables. + from google.cloud.firestore_bundle import FirestoreBundle + async_query = self._bundled_query_helper() bundle = FirestoreBundle("test") bundle.add_named_query("asdf", async_query) - self.assertIsNotNone(bundle.named_queries.get("asdf")) - self.assertIsNotNone( + assert bundle.named_queries.get("asdf") is not None + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-1" ] + is not None ) - self.assertIsNotNone( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-2" ] + is not None ) -class TestBundleBuilder(_CollectionQueryMixin, unittest.TestCase): +class TestBundleBuilder(_CollectionQueryMixin): @staticmethod def build_results_iterable(items): return iter(items) @@ -334,22 +372,30 @@ def get_client(): @staticmethod def get_internal_client_mock(): - return mock.create_autospec(FirestoreClient) + from google.cloud.firestore_v1.services.firestore import client + + return mock.create_autospec(client.FirestoreClient) @classmethod def get_collection_class(cls): - return CollectionReference + return collection.CollectionReference def test_build_round_trip(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + query = self._bundled_query_helper() bundle = FirestoreBundle("test") bundle.add_named_query("asdf", query) serialized = bundle.build() - self.assertEqual( - serialized, _helpers.deserialize_bundle(serialized, query._client).build(), + assert ( + serialized == _helpers.deserialize_bundle(serialized, query._client).build() ) def test_build_round_trip_emojis(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + smile = "😂" mermaid = "🧜🏿‍♀️" query = self._bundled_query_helper( @@ -360,23 +406,24 @@ def test_build_round_trip_emojis(self): serialized = bundle.build() reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) - self.assertEqual( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-1" - ].snapshot._data["smile"], - smile, + ].snapshot._data["smile"] + == smile ) - self.assertEqual( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-2" - ].snapshot._data["compound"], - mermaid, - ) - self.assertEqual( - serialized, reserialized_bundle.build(), + ].snapshot._data["compound"] + == mermaid ) + assert serialized == reserialized_bundle.build() def test_build_round_trip_more_unicode(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + bano = "baño" chinese_characters = "殷周金文集成引得" query = self._bundled_query_helper( @@ -387,23 +434,25 @@ def test_build_round_trip_more_unicode(self): serialized = bundle.build() reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) - self.assertEqual( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-1" - ].snapshot._data["bano"], - bano, + ].snapshot._data["bano"] + == bano ) - self.assertEqual( + assert ( bundle.documents[ "projects/project-project/databases/(default)/documents/col/doc-2" - ].snapshot._data["international"], - chinese_characters, - ) - self.assertEqual( - serialized, reserialized_bundle.build(), + ].snapshot._data["international"] + == chinese_characters ) + assert serialized == reserialized_bundle.build() def test_roundtrip_binary_data(self): + import sys + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + query = self._bundled_query_helper(data=[{"binary_data": b"\x0f"}],) bundle = FirestoreBundle("test") bundle.add_named_query("asdf", query) @@ -411,8 +460,8 @@ def test_roundtrip_binary_data(self): reserialized_bundle = _helpers.deserialize_bundle(serialized, query._client) gen = _helpers._get_documents_from_bundle(reserialized_bundle) snapshot = next(gen) - self.assertEqual( - int.from_bytes(snapshot._data["binary_data"], byteorder=sys.byteorder), 15, + assert ( + int.from_bytes(snapshot._data["binary_data"], byteorder=sys.byteorder) == 15 ) def test_deserialize_from_seconds_nanos(self): @@ -420,6 +469,7 @@ def test_deserialize_from_seconds_nanos(self): '{"seconds": 123, "nanos": 456}', instead of an ISO-formatted string. This tests deserialization from that format.""" from google.protobuf.json_format import ParseError + from google.cloud.firestore_v1 import _helpers client = _test_helpers.make_client(project_name="fir-bundles-test") @@ -441,13 +491,13 @@ def test_deserialize_from_seconds_nanos(self): + '"updateTime":{"seconds":"1615492486","nanos":34157000}}}' ) - self.assertRaises( - (ValueError, ParseError), # protobuf 3.18.0 raises ParseError - _helpers.deserialize_bundle, - _serialized, - client=client, - ) + with pytest.raises( + (ValueError, ParseError) + ): # protobuf 3.18.0 raises ParseError + _helpers.deserialize_bundle(_serialized, client=client) + # See https://github.com/googleapis/python-firestore/issues/505 + # # The following assertions would test deserialization of NodeJS bundles # were explicit handling of that edge case to be added. @@ -458,50 +508,56 @@ def test_deserialize_from_seconds_nanos(self): # instead of seconds/nanos. # re_serialized = bundle.build() # # Finally, confirm the round trip. - # self.assertEqual( - # re_serialized, - # _helpers.deserialize_bundle(re_serialized, client=client).build(), - # ) + # assert re_serialized == _helpers.deserialize_bundle(re_serialized, client=client).build() + # def test_deserialized_bundle_cached_metadata(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_v1 import _helpers + query = self._bundled_query_helper() bundle = FirestoreBundle("test") bundle.add_named_query("asdf", query) bundle_copy = _helpers.deserialize_bundle(bundle.build(), query._client) - self.assertIsInstance(bundle_copy, FirestoreBundle) - self.assertIsNotNone(bundle_copy._deserialized_metadata) + assert isinstance(bundle_copy, FirestoreBundle) + assert bundle_copy._deserialized_metadata is not None bundle_copy.add_named_query("second query", query) - self.assertIsNone(bundle_copy._deserialized_metadata) + assert bundle_copy._deserialized_metadata is None @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_invalid_json(self, fnc): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() fnc.return_value = iter([{}]) - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_not_metadata_first(self, fnc): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() fnc.return_value = iter([{"document": {}}]) - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_unexpected_termination(self, fnc, _): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() # invalid bc `document_metadata` must be followed by a `document` fnc.return_value = [{"metadata": {"id": "asdf"}}, {"documentMetadata": {}}] - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_valid_passes(self, fnc, _): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() fnc.return_value = [ {"metadata": {"id": "asdf"}}, @@ -513,46 +569,48 @@ def test_valid_passes(self, fnc, _): @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_invalid_bundle(self, fnc, _): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() # invalid bc `document` must follow `document_metadata` fnc.return_value = [{"metadata": {"id": "asdf"}}, {"document": {}}] - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_invalid_bundle_element_type(self, fnc, _): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() # invalid bc `wtfisthis?` is obviously invalid fnc.return_value = [{"metadata": {"id": "asdf"}}, {"wtfisthis?": {}}] - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) @mock.patch("google.cloud.firestore_bundle.FirestoreBundle._add_bundle_element") @mock.patch("google.cloud.firestore_v1._helpers._parse_bundle_elements_data") def test_invalid_bundle_start(self, fnc, _): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() # invalid bc first element must be of key `metadata` fnc.return_value = [{"document": {}}] - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "does not matter", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("does not matter", client) def test_not_actually_a_bundle_at_all(self): + from google.cloud.firestore_v1 import _helpers + client = _test_helpers.make_client() - self.assertRaises( - ValueError, _helpers.deserialize_bundle, "{}", client, - ) + with pytest.raises(ValueError): + _helpers.deserialize_bundle("{}", client) def test_add_invalid_bundle_element_type(self): + from google.cloud.firestore_bundle import FirestoreBundle + from google.cloud.firestore_bundle import BundleElement + client = _test_helpers.make_client() bundle = FirestoreBundle("asdf") - self.assertRaises( - ValueError, - bundle._add_bundle_element, - BundleElement(), - client=client, - type="asdf", - ) + with pytest.raises(ValueError): + bundle._add_bundle_element(BundleElement(), client=client, type="asdf") diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index 0c5473fc9756d..67425d4413b4a 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -14,473 +14,525 @@ import datetime import types -import unittest import mock -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse +import pytest -class TestClient(unittest.TestCase): +PROJECT = "my-prahjekt" - PROJECT = "my-prahjekt" - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.client import Client +def _make_client(*args, **kwargs): + from google.cloud.firestore_v1.client import Client - return Client + return Client(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def _make_default_one(self): - credentials = _make_credentials() - return self._make_one(project=self.PROJECT, credentials=credentials) +def _make_credentials(): + import google.auth.credentials - def test_constructor(self): - from google.cloud.firestore_v1.client import _CLIENT_INFO - from google.cloud.firestore_v1.client import DEFAULT_DATABASE + return mock.Mock(spec=google.auth.credentials.Credentials) - credentials = _make_credentials() - client = self._make_one(project=self.PROJECT, credentials=credentials) - self.assertEqual(client.project, self.PROJECT) - self.assertEqual(client._credentials, credentials) - self.assertEqual(client._database, DEFAULT_DATABASE) - self.assertIs(client._client_info, _CLIENT_INFO) - def test_constructor_explicit(self): - from google.api_core.client_options import ClientOptions +def _make_default_client(*args, **kwargs): + credentials = _make_credentials() + return _make_client(project=PROJECT, credentials=credentials) - credentials = _make_credentials() - database = "now-db" - client_info = mock.Mock() - client_options = ClientOptions("endpoint") - client = self._make_one( - project=self.PROJECT, - credentials=credentials, - database=database, - client_info=client_info, - client_options=client_options, - ) - self.assertEqual(client.project, self.PROJECT) - self.assertEqual(client._credentials, credentials) - self.assertEqual(client._database, database) - self.assertIs(client._client_info, client_info) - self.assertIs(client._client_options, client_options) - - def test_constructor_w_client_options(self): - credentials = _make_credentials() - client = self._make_one( - project=self.PROJECT, - credentials=credentials, - client_options={"api_endpoint": "foo-firestore.googleapis.com"}, - ) - self.assertEqual(client._target, "foo-firestore.googleapis.com") - def test_collection_factory(self): - from google.cloud.firestore_v1.collection import CollectionReference +def test_client_constructor_defaults(): + from google.cloud.firestore_v1.client import _CLIENT_INFO + from google.cloud.firestore_v1.client import DEFAULT_DATABASE - collection_id = "users" - client = self._make_default_one() - collection = client.collection(collection_id) + credentials = _make_credentials() + client = _make_client(project=PROJECT, credentials=credentials) + assert client.project == PROJECT + assert client._credentials == credentials + assert client._database == DEFAULT_DATABASE + assert client._client_info is _CLIENT_INFO - self.assertEqual(collection._path, (collection_id,)) - self.assertIs(collection._client, client) - self.assertIsInstance(collection, CollectionReference) - def test_collection_factory_nested(self): - from google.cloud.firestore_v1.collection import CollectionReference +def test_client_constructor_explicit(): + from google.api_core.client_options import ClientOptions - client = self._make_default_one() - parts = ("users", "alovelace", "beep") - collection_path = "/".join(parts) - collection1 = client.collection(collection_path) + credentials = _make_credentials() + database = "now-db" + client_info = mock.Mock() + client_options = ClientOptions("endpoint") + client = _make_client( + project=PROJECT, + credentials=credentials, + database=database, + client_info=client_info, + client_options=client_options, + ) + assert client.project == PROJECT + assert client._credentials == credentials + assert client._database == database + assert client._client_info is client_info + assert client._client_options is client_options - self.assertEqual(collection1._path, parts) - self.assertIs(collection1._client, client) - self.assertIsInstance(collection1, CollectionReference) - # Make sure using segments gives the same result. - collection2 = client.collection(*parts) - self.assertEqual(collection2._path, parts) - self.assertIs(collection2._client, client) - self.assertIsInstance(collection2, CollectionReference) +def test_client__firestore_api_property(): + credentials = _make_credentials() + client = _make_client(project=PROJECT, credentials=credentials) + helper = client._firestore_api_helper = mock.Mock() - def test__get_collection_reference(self): - from google.cloud.firestore_v1.collection import CollectionReference + g_patch = mock.patch("google.cloud.firestore_v1.client.firestore_grpc_transport") + f_patch = mock.patch("google.cloud.firestore_v1.client.firestore_client") - client = self._make_default_one() - collection = client._get_collection_reference("collectionId") + with g_patch as grpc_transport: + with f_patch as firestore_client: + api = client._firestore_api - self.assertIs(collection._client, client) - self.assertIsInstance(collection, CollectionReference) + assert api is helper.return_value - def test_collection_group(self): - client = self._make_default_one() - query = client.collection_group("collectionId").where("foo", "==", "bar") + helper.assert_called_once_with( + grpc_transport.FirestoreGrpcTransport, + firestore_client.FirestoreClient, + firestore_client, + ) - self.assertTrue(query._all_descendants) - self.assertEqual(query._field_filters[0].field.field_path, "foo") - self.assertEqual(query._field_filters[0].value.string_value, "bar") - self.assertEqual( - query._field_filters[0].op, query._field_filters[0].Operator.EQUAL - ) - self.assertEqual(query._parent.id, "collectionId") - - def test_collection_group_no_slashes(self): - client = self._make_default_one() - with self.assertRaises(ValueError): - client.collection_group("foo/bar") - - def test_document_factory(self): - from google.cloud.firestore_v1.document import DocumentReference - - parts = ("rooms", "roomA") - client = self._make_default_one() - doc_path = "/".join(parts) - document1 = client.document(doc_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) - - # Make sure using segments gives the same result. - document2 = client.document(*parts) - self.assertEqual(document2._path, parts) - self.assertIs(document2._client, client) - self.assertIsInstance(document2, DocumentReference) - - def test_document_factory_w_absolute_path(self): - from google.cloud.firestore_v1.document import DocumentReference - - parts = ("rooms", "roomA") - client = self._make_default_one() - doc_path = "/".join(parts) - to_match = client.document(doc_path) - document1 = client.document(to_match._document_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) - - def test_document_factory_w_nested_path(self): - from google.cloud.firestore_v1.document import DocumentReference - - client = self._make_default_one() - parts = ("rooms", "roomA", "shoes", "dressy") - doc_path = "/".join(parts) - document1 = client.document(doc_path) - - self.assertEqual(document1._path, parts) - self.assertIs(document1._client, client) - self.assertIsInstance(document1, DocumentReference) - - # Make sure using segments gives the same result. - document2 = client.document(*parts) - self.assertEqual(document2._path, parts) - self.assertIs(document2._client, client) - self.assertIsInstance(document2, DocumentReference) - - def _collections_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.collection import CollectionReference - - collection_ids = ["users", "projects"] - - class Pager(object): - def __iter__(self): - yield from collection_ids - - firestore_api = mock.Mock(spec=["list_collection_ids"]) - firestore_api.list_collection_ids.return_value = Pager() - - client = self._make_default_one() - client._firestore_api_internal = firestore_api - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - collections = list(client.collections(**kwargs)) - - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) - self.assertEqual(collection.parent, None) - self.assertEqual(collection.id, collection_id) - - base_path = client._database_string + "/documents" - firestore_api.list_collection_ids.assert_called_once_with( - request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, - ) - def test_collections(self): - self._collections_helper() +def test_client_constructor_w_client_options(): + credentials = _make_credentials() + client = _make_client( + project=PROJECT, + credentials=credentials, + client_options={"api_endpoint": "foo-firestore.googleapis.com"}, + ) + assert client._target == "foo-firestore.googleapis.com" - def test_collections_w_retry_timeout(self): - from google.api_core.retry import Retry - retry = Retry(predicate=object()) - timeout = 123.0 - self._collections_helper(retry=retry, timeout=timeout) +def test_client_collection_factory(): + from google.cloud.firestore_v1.collection import CollectionReference - def _invoke_get_all(self, client, references, document_pbs, **kwargs): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["batch_get_documents"]) - response_iterator = iter(document_pbs) - firestore_api.batch_get_documents.return_value = response_iterator + collection_id = "users" + client = _make_default_client() + collection = client.collection(collection_id) - # Attach the fake GAPIC to a real client. - client._firestore_api_internal = firestore_api + assert collection._path == (collection_id,) + assert collection._client is client + assert isinstance(collection, CollectionReference) - # Actually call get_all(). - snapshots = client.get_all(references, **kwargs) - self.assertIsInstance(snapshots, types.GeneratorType) - return list(snapshots) +def test_client_collection_factory_nested(): + from google.cloud.firestore_v1.collection import CollectionReference - def _get_all_helper(self, num_snapshots=2, txn_id=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.async_document import DocumentSnapshot + client = _make_default_client() + parts = ("users", "alovelace", "beep") + collection_path = "/".join(parts) + collection1 = client.collection(collection_path) - client = self._make_default_one() + assert collection1._path == parts + assert collection1._client is client + assert isinstance(collection1, CollectionReference) - data1 = {"a": "cheese"} - document1 = client.document("pineapple", "lamp1") - document_pb1, read_time = _doc_get_info(document1._document_path, data1) - response1 = _make_batch_response(found=document_pb1, read_time=read_time) + # Make sure using segments gives the same result. + collection2 = client.collection(*parts) + assert collection2._path == parts + assert collection2._client is client + assert isinstance(collection2, CollectionReference) - data2 = {"b": True, "c": 18} - document2 = client.document("pineapple", "lamp2") - document, read_time = _doc_get_info(document2._document_path, data2) - response2 = _make_batch_response(found=document, read_time=read_time) - document3 = client.document("pineapple", "lamp3") - response3 = _make_batch_response(missing=document3._document_path) +def test_client__get_collection_reference(): + from google.cloud.firestore_v1.collection import CollectionReference - expected_data = [data1, data2, None][:num_snapshots] - documents = [document1, document2, document3][:num_snapshots] - responses = [response1, response2, response3][:num_snapshots] - field_paths = [ - field_path for field_path in ["a", "b", None][:num_snapshots] if field_path - ] - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + client = _make_default_client() + collection = client._get_collection_reference("collectionId") - if txn_id is not None: - transaction = client.transaction() - transaction._id = txn_id - kwargs["transaction"] = transaction + assert collection._client is client + assert isinstance(collection, CollectionReference) - snapshots = self._invoke_get_all( - client, documents, responses, field_paths=field_paths, **kwargs, - ) - self.assertEqual(len(snapshots), num_snapshots) - - for data, document, snapshot in zip(expected_data, documents, snapshots): - self.assertIsInstance(snapshot, DocumentSnapshot) - self.assertIs(snapshot._reference, document) - if data is None: - self.assertFalse(snapshot.exists) - else: - self.assertEqual(snapshot._data, data) - - # Verify the call to the mock. - doc_paths = [document._document_path for document in documents] - mask = common.DocumentMask(field_paths=field_paths) - - kwargs.pop("transaction", None) - - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": mask, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +def test_client_collection_group(): + client = _make_default_client() + query = client.collection_group("collectionId").where("foo", "==", "bar") - def test_get_all(self): - self._get_all_helper() + assert query._all_descendants + assert query._field_filters[0].field.field_path == "foo" + assert query._field_filters[0].value.string_value == "bar" + assert query._field_filters[0].op == query._field_filters[0].Operator.EQUAL + assert query._parent.id == "collectionId" - def test_get_all_with_transaction(self): - txn_id = b"the-man-is-non-stop" - self._get_all_helper(num_snapshots=1, txn_id=txn_id) - def test_get_all_w_retry_timeout(self): - from google.api_core.retry import Retry +def test_client_collection_group_no_slashes(): + client = _make_default_client() + with pytest.raises(ValueError): + client.collection_group("foo/bar") - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_all_helper(retry=retry, timeout=timeout) - def test_get_all_wrong_order(self): - self._get_all_helper(num_snapshots=3) +def test_client_document_factory(): + from google.cloud.firestore_v1.document import DocumentReference - def test_get_all_unknown_result(self): - from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + parts = ("rooms", "roomA") + client = _make_default_client() + doc_path = "/".join(parts) + document1 = client.document(doc_path) - client = self._make_default_one() + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, DocumentReference) - expected_document = client.document("pineapple", "lamp1") + # Make sure using segments gives the same result. + document2 = client.document(*parts) + assert document2._path == parts + assert document2._client is client + assert isinstance(document2, DocumentReference) - data = {"z": 28.5} - wrong_document = client.document("pineapple", "lamp2") - document_pb, read_time = _doc_get_info(wrong_document._document_path, data) - response = _make_batch_response(found=document_pb, read_time=read_time) - # Exercise the mocked ``batch_get_documents``. - with self.assertRaises(ValueError) as exc_info: - self._invoke_get_all(client, [expected_document], [response]) +def test_client_document_factory_w_absolute_path(): + from google.cloud.firestore_v1.document import DocumentReference - err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) - self.assertEqual(exc_info.exception.args, (err_msg,)) + parts = ("rooms", "roomA") + client = _make_default_client() + doc_path = "/".join(parts) + to_match = client.document(doc_path) + document1 = client.document(to_match._document_path) - # Verify the call to the mock. - doc_paths = [expected_document._document_path] - client._firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": doc_paths, - "mask": None, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, DocumentReference) - def test_recursive_delete(self): - client = self._make_default_one() - client._firestore_api_internal = mock.Mock(spec=["run_query"]) - collection_ref = client.collection("my_collection") - results = [] - for index in range(10): - results.append( - RunQueryResponse(document=Document(name=f"{collection_ref.id}/{index}")) - ) +def test_client_document_factory_w_nested_path(): + from google.cloud.firestore_v1.document import DocumentReference - chunks = [ - results[:3], - results[3:6], - results[6:9], - results[9:], - ] + client = _make_default_client() + parts = ("rooms", "roomA", "shoes", "dressy") + doc_path = "/".join(parts) + document1 = client.document(doc_path) - def _get_chunk(*args, **kwargs): - return iter(chunks.pop(0)) + assert document1._path == parts + assert document1._client is client + assert isinstance(document1, DocumentReference) - client._firestore_api_internal.run_query.side_effect = _get_chunk + # Make sure using segments gives the same result. + document2 = client.document(*parts) + assert document2._path == parts + assert document2._client is client + assert isinstance(document2, DocumentReference) - bulk_writer = mock.MagicMock() - bulk_writer.mock_add_spec(spec=["delete", "close"]) - num_deleted = client.recursive_delete( - collection_ref, bulk_writer=bulk_writer, chunk_size=3 - ) - self.assertEqual(num_deleted, len(results)) +def _collections_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.collection import CollectionReference - def test_recursive_delete_from_document(self): - client = self._make_default_one() - client._firestore_api_internal = mock.Mock( - spec=["run_query", "list_collection_ids"] - ) - collection_ref = client.collection("my_collection") + collection_ids = ["users", "projects"] - collection_1_id: str = "collection_1_id" - collection_2_id: str = "collection_2_id" + class Pager(object): + def __iter__(self): + yield from collection_ids - parent_doc = collection_ref.document("parent") + firestore_api = mock.Mock(spec=["list_collection_ids"]) + firestore_api.list_collection_ids.return_value = Pager() - collection_1_results = [] - collection_2_results = [] + client = _make_default_client() + client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - for index in range(10): - collection_1_results.append( - RunQueryResponse(document=Document(name=f"{collection_1_id}/{index}"),), - ) + collections = list(client.collections(**kwargs)) - collection_2_results.append( - RunQueryResponse(document=Document(name=f"{collection_2_id}/{index}"),), - ) + assert len(collections) == len(collection_ids) + for collection, collection_id in zip(collections, collection_ids): + assert isinstance(collection, CollectionReference) + assert collection.parent is None + assert collection.id == collection_id + + base_path = client._database_string + "/documents" + firestore_api.list_collection_ids.assert_called_once_with( + request={"parent": base_path}, metadata=client._rpc_metadata, **kwargs, + ) + + +def test_client_collections(): + _collections_helper() + + +def test_client_collections_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _collections_helper(retry=retry, timeout=timeout) + + +def _invoke_get_all(client, references, document_pbs, **kwargs): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["batch_get_documents"]) + response_iterator = iter(document_pbs) + firestore_api.batch_get_documents.return_value = response_iterator + + # Attach the fake GAPIC to a real client. + client._firestore_api_internal = firestore_api + + # Actually call get_all(). + snapshots = client.get_all(references, **kwargs) + assert isinstance(snapshots, types.GeneratorType) + + return list(snapshots) + + +def _get_all_helper(num_snapshots=2, txn_id=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.async_document import DocumentSnapshot + + client = _make_default_client() + + data1 = {"a": "cheese"} + document1 = client.document("pineapple", "lamp1") + document_pb1, read_time = _doc_get_info(document1._document_path, data1) + response1 = _make_batch_response(found=document_pb1, read_time=read_time) + + data2 = {"b": True, "c": 18} + document2 = client.document("pineapple", "lamp2") + document, read_time = _doc_get_info(document2._document_path, data2) + response2 = _make_batch_response(found=document, read_time=read_time) + + document3 = client.document("pineapple", "lamp3") + response3 = _make_batch_response(missing=document3._document_path) + + expected_data = [data1, data2, None][:num_snapshots] + documents = [document1, document2, document3][:num_snapshots] + responses = [response1, response2, response3][:num_snapshots] + field_paths = [ + field_path for field_path in ["a", "b", None][:num_snapshots] if field_path + ] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if txn_id is not None: + transaction = client.transaction() + transaction._id = txn_id + kwargs["transaction"] = transaction + + snapshots = _invoke_get_all( + client, documents, responses, field_paths=field_paths, **kwargs, + ) + + assert len(snapshots) == num_snapshots + + for data, document, snapshot in zip(expected_data, documents, snapshots): + assert isinstance(snapshot, DocumentSnapshot) + assert snapshot._reference is document + if data is None: + assert not snapshot.exists + else: + assert snapshot._data == data + + # Verify the call to the mock. + doc_paths = [document._document_path for document in documents] + mask = common.DocumentMask(field_paths=field_paths) + + kwargs.pop("transaction", None) + + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_client_get_all(): + _get_all_helper() + + +def test_client_get_all_with_transaction(): + txn_id = b"the-man-is-non-stop" + _get_all_helper(num_snapshots=1, txn_id=txn_id) + + +def test_client_get_all_w_retry_timeout(): + from google.api_core.retry import Retry - col_1_chunks = [ - collection_1_results[:3], - collection_1_results[3:6], - collection_1_results[6:9], - collection_1_results[9:], - ] - - col_2_chunks = [ - collection_2_results[:3], - collection_2_results[3:6], - collection_2_results[6:9], - collection_2_results[9:], - ] - - def _get_chunk(*args, **kwargs): - start_at = ( - kwargs["request"]["structured_query"].start_at.values[0].reference_value + retry = Retry(predicate=object()) + timeout = 123.0 + _get_all_helper(retry=retry, timeout=timeout) + + +def test_client_get_all_wrong_order(): + _get_all_helper(num_snapshots=3) + + +def test_client_get_all_unknown_result(): + from google.cloud.firestore_v1.base_client import _BAD_DOC_TEMPLATE + + client = _make_default_client() + + expected_document = client.document("pineapple", "lamp1") + + data = {"z": 28.5} + wrong_document = client.document("pineapple", "lamp2") + document_pb, read_time = _doc_get_info(wrong_document._document_path, data) + response = _make_batch_response(found=document_pb, read_time=read_time) + + # Exercise the mocked ``batch_get_documents``. + with pytest.raises(ValueError) as exc_info: + _invoke_get_all(client, [expected_document], [response]) + + err_msg = _BAD_DOC_TEMPLATE.format(response.found.name) + assert exc_info.value.args == (err_msg,) + + # Verify the call to the mock. + doc_paths = [expected_document._document_path] + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": None, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_client_recursive_delete(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_default_client() + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + collection_ref = client.collection("my_collection") + + results = [] + for index in range(10): + results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_ref.id}/{index}") ) + ) + + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] + + def _get_chunk(*args, **kwargs): + return iter(chunks.pop(0)) - if collection_1_id in start_at: - return iter(col_1_chunks.pop(0)) - return iter(col_2_chunks.pop(0)) + client._firestore_api_internal.run_query.side_effect = _get_chunk - client._firestore_api_internal.run_query.side_effect = _get_chunk - client._firestore_api_internal.list_collection_ids.return_value = [ - collection_1_id, - collection_2_id, - ] + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) + + num_deleted = client.recursive_delete( + collection_ref, bulk_writer=bulk_writer, chunk_size=3 + ) + assert num_deleted == len(results) - bulk_writer = mock.MagicMock() - bulk_writer.mock_add_spec(spec=["delete", "close"]) - num_deleted = client.recursive_delete( - parent_doc, bulk_writer=bulk_writer, chunk_size=3 +def test_client_recursive_delete_from_document(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + + client = _make_default_client() + client._firestore_api_internal = mock.Mock( + spec=["run_query", "list_collection_ids"] + ) + collection_ref = client.collection("my_collection") + + collection_1_id: str = "collection_1_id" + collection_2_id: str = "collection_2_id" + + parent_doc = collection_ref.document("parent") + + collection_1_results = [] + collection_2_results = [] + + for index in range(10): + collection_1_results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_1_id}/{index}"), + ), ) - expected_len = len(collection_1_results) + len(collection_2_results) + 1 - self.assertEqual(num_deleted, expected_len) + collection_2_results.append( + firestore.RunQueryResponse( + document=document.Document(name=f"{collection_2_id}/{index}"), + ), + ) - def test_recursive_delete_raises(self): - client = self._make_default_one() - self.assertRaises(TypeError, client.recursive_delete, object()) + col_1_chunks = [ + collection_1_results[:3], + collection_1_results[3:6], + collection_1_results[6:9], + collection_1_results[9:], + ] + + col_2_chunks = [ + collection_2_results[:3], + collection_2_results[3:6], + collection_2_results[6:9], + collection_2_results[9:], + ] + + def _get_chunk(*args, **kwargs): + start_at = ( + kwargs["request"]["structured_query"].start_at.values[0].reference_value + ) - def test_batch(self): - from google.cloud.firestore_v1.batch import WriteBatch + if collection_1_id in start_at: + return iter(col_1_chunks.pop(0)) + return iter(col_2_chunks.pop(0)) - client = self._make_default_one() - batch = client.batch() - self.assertIsInstance(batch, WriteBatch) - self.assertIs(batch._client, client) - self.assertEqual(batch._write_pbs, []) + client._firestore_api_internal.run_query.side_effect = _get_chunk + client._firestore_api_internal.list_collection_ids.return_value = [ + collection_1_id, + collection_2_id, + ] - def test_bulk_writer(self): - from google.cloud.firestore_v1.bulk_writer import BulkWriter + bulk_writer = mock.MagicMock() + bulk_writer.mock_add_spec(spec=["delete", "close"]) - client = self._make_default_one() - bulk_writer = client.bulk_writer() - self.assertIsInstance(bulk_writer, BulkWriter) - self.assertIs(bulk_writer._client, client) + num_deleted = client.recursive_delete( + parent_doc, bulk_writer=bulk_writer, chunk_size=3 + ) - def test_transaction(self): - from google.cloud.firestore_v1.transaction import Transaction + expected_len = len(collection_1_results) + len(collection_2_results) + 1 + assert num_deleted == expected_len - client = self._make_default_one() - transaction = client.transaction(max_attempts=3, read_only=True) - self.assertIsInstance(transaction, Transaction) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, 3) - self.assertTrue(transaction._read_only) - self.assertIsNone(transaction._id) +def test_client_recursive_delete_raises(): + client = _make_default_client() + with pytest.raises(TypeError): + client.recursive_delete(object()) -def _make_credentials(): - import google.auth.credentials - return mock.Mock(spec=google.auth.credentials.Credentials) +def test_client_batch(): + from google.cloud.firestore_v1.batch import WriteBatch + + client = _make_default_client() + batch = client.batch() + assert isinstance(batch, WriteBatch) + assert batch._client is client + assert batch._write_pbs == [] + + +def test_client_bulk_writer(): + from google.cloud.firestore_v1.bulk_writer import BulkWriter + + client = _make_default_client() + bulk_writer = client.bulk_writer() + assert isinstance(bulk_writer, BulkWriter) + assert bulk_writer._client is client + + +def test_client_transaction(): + from google.cloud.firestore_v1.transaction import Transaction + + client = _make_default_client() + transaction = client.transaction(max_attempts=3, read_only=True) + assert isinstance(transaction, Transaction) + assert transaction._write_pbs == [] + assert transaction._max_attempts == 3 + assert transaction._read_only + assert transaction._id is None def _make_batch_response(**kwargs): diff --git a/tests/unit/v1/test_collection.py b/tests/unit/v1/test_collection.py index cfefeb9e61ab0..9bba2fd5b0a4b 100644 --- a/tests/unit/v1/test_collection.py +++ b/tests/unit/v1/test_collection.py @@ -12,383 +12,396 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse import types -import unittest import mock -from tests.unit.v1 import _test_helpers - - -class TestCollectionReference(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.collection import CollectionReference - - return CollectionReference - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - @staticmethod - def _get_public_methods(klass): - return set().union( - *( - ( - name - for name, value in class_.__dict__.items() - if ( - not name.startswith("_") - and isinstance(value, types.FunctionType) - ) - ) - for class_ in (klass,) + klass.__bases__ + +def _make_collection_reference(*args, **kwargs): + from google.cloud.firestore_v1.collection import CollectionReference + + return CollectionReference(*args, **kwargs) + + +def _get_public_methods(klass): + return set().union( + *( + ( + name + for name, value in class_.__dict__.items() + if (not name.startswith("_") and isinstance(value, types.FunctionType)) ) + for class_ in (klass,) + klass.__bases__ ) + ) + + +def test_query_method_matching(): + from google.cloud.firestore_v1.query import Query + from google.cloud.firestore_v1.collection import CollectionReference + + query_methods = _get_public_methods(Query) + collection_methods = _get_public_methods(CollectionReference) + # Make sure every query method is present on + # ``CollectionReference``. + assert query_methods <= collection_methods + + +def test_constructor(): + collection_id1 = "rooms" + document_id = "roomA" + collection_id2 = "messages" + client = mock.sentinel.client + + collection = _make_collection_reference( + collection_id1, document_id, collection_id2, client=client + ) + assert collection._client is client + expected_path = (collection_id1, document_id, collection_id2) + assert collection._path == expected_path + + +def test_add_auto_assigned(): + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1 import SERVER_TIMESTAMP + from google.cloud.firestore_v1._helpers import pbs_for_create + from tests.unit.v1 import _test_helpers + + # Create a minimal fake GAPIC add attach it to a real client. + firestore_api = mock.Mock(spec=["create_document", "commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + + firestore_api.commit.return_value = commit_response + create_doc_response = document.Document() + firestore_api.create_document.return_value = create_doc_response + client = _test_helpers.make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection. + collection = _make_collection_reference( + "grand-parent", "parent", "child", client=client + ) + + # Actually call add() on our collection; include a transform to make + # sure transforms during adds work. + document_data = {"been": "here", "now": SERVER_TIMESTAMP} + + patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") + random_doc_id = "DEADBEEF" + with patch as patched: + patched.return_value = random_doc_id + update_time, document_ref = collection.add(document_data) + + # Verify the response and the mocks. + assert update_time is mock.sentinel.update_time + assert isinstance(document_ref, DocumentReference) + assert document_ref._client is client + expected_path = collection._path + (random_doc_id,) + assert document_ref._path == expected_path + + write_pbs = pbs_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + # Since we generate the ID locally, we don't call 'create_document'. + firestore_api.create_document.assert_not_called() + + +def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) + + +def _add_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers + from tests.unit.v1 import _test_helpers + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + write_result = mock.Mock( + update_time=mock.sentinel.update_time, spec=["update_time"] + ) + commit_response = mock.Mock( + write_results=[write_result], + spec=["write_results", "commit_time"], + commit_time=mock.sentinel.commit_time, + ) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _test_helpers.make_client() + client._firestore_api_internal = firestore_api + + # Actually make a collection and call add(). + collection = _make_collection_reference("parent", client=client) + document_data = {"zorp": 208.75, "i-did-not": b"know that"} + doc_id = "child" + + kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) + update_time, document_ref = collection.add( + document_data, document_id=doc_id, **kwargs + ) + + # Verify the response and the mocks. + assert update_time is mock.sentinel.update_time + assert isinstance(document_ref, DocumentReference) + assert document_ref._client is client + assert document_ref._path == (collection.id, doc_id) + + write_pb = _write_pb_for_create(document_ref._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_add_explicit_id(): + _add_helper() + + +def test_add_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _add_helper(retry=retry, timeout=timeout) + + +def _list_documents_helper(page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + from google.cloud.firestore_v1.types.document import Document + from tests.unit.v1 import _test_helpers + + class _Iterator(Iterator): + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) + + client = _test_helpers.make_client() + template = client._database_string + "/documents/{}" + document_ids = ["doc-1", "doc-2"] + documents = [ + Document(name=template.format(document_id)) for document_id in document_ids + ] + iterator = _Iterator(pages=[documents]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_documents.return_value = iterator + client._firestore_api_internal = api_client + collection = _make_collection_reference("collection", client=client) + kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) + + if page_size is not None: + documents = list(collection.list_documents(page_size=page_size, **kwargs)) + else: + documents = list(collection.list_documents(**kwargs)) + + # Verify the response and the mocks. + assert len(documents) == len(document_ids) + for document, document_id in zip(documents, document_ids): + assert isinstance(document, DocumentReference) + assert document.parent == collection + assert document.id == document_id + + parent, _ = collection._parent_info() + api_client.list_documents.assert_called_once_with( + request={ + "parent": parent, + "collection_id": collection.id, + "page_size": page_size, + "show_missing": True, + "mask": {"field_paths": None}, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + - def test_query_method_matching(self): - from google.cloud.firestore_v1.query import Query +def test_list_documents_wo_page_size(): + _list_documents_helper() + + +def test_list_documents_w_retry_timeout(): + from google.api_core.retry import Retry - query_methods = self._get_public_methods(Query) - klass = self._get_target_class() - collection_methods = self._get_public_methods(klass) - # Make sure every query method is present on - # ``CollectionReference``. - self.assertLessEqual(query_methods, collection_methods) + retry = Retry(predicate=object()) + timeout = 123.0 + _list_documents_helper(retry=retry, timeout=timeout) - def test_constructor(self): - collection_id1 = "rooms" - document_id = "roomA" - collection_id2 = "messages" - client = mock.sentinel.client - collection = self._make_one( - collection_id1, document_id, collection_id2, client=client - ) - self.assertIs(collection._client, client) - expected_path = (collection_id1, document_id, collection_id2) - self.assertEqual(collection._path, expected_path) - - def test_add_auto_assigned(self): - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1 import SERVER_TIMESTAMP - from google.cloud.firestore_v1._helpers import pbs_for_create - - # Create a minimal fake GAPIC add attach it to a real client. - firestore_api = mock.Mock(spec=["create_document", "commit"]) - write_result = mock.Mock( - update_time=mock.sentinel.update_time, spec=["update_time"] - ) +def test_list_documents_w_page_size(): + _list_documents_helper(page_size=25) - commit_response = mock.Mock( - write_results=[write_result], - spec=["write_results", "commit_time"], - commit_time=mock.sentinel.commit_time, - ) - firestore_api.commit.return_value = commit_response - create_doc_response = document.Document() - firestore_api.create_document.return_value = create_doc_response - client = _test_helpers.make_client() - client._firestore_api_internal = firestore_api - - # Actually make a collection. - collection = self._make_one("grand-parent", "parent", "child", client=client) - - # Actually call add() on our collection; include a transform to make - # sure transforms during adds work. - document_data = {"been": "here", "now": SERVER_TIMESTAMP} - - patch = mock.patch("google.cloud.firestore_v1.base_collection._auto_id") - random_doc_id = "DEADBEEF" - with patch as patched: - patched.return_value = random_doc_id - update_time, document_ref = collection.add(document_data) - - # Verify the response and the mocks. - self.assertIs(update_time, mock.sentinel.update_time) - self.assertIsInstance(document_ref, DocumentReference) - self.assertIs(document_ref._client, client) - expected_path = collection._path + (random_doc_id,) - self.assertEqual(document_ref._path, expected_path) - - write_pbs = pbs_for_create(document_ref._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - # Since we generate the ID locally, we don't call 'create_document'. - firestore_api.create_document.assert_not_called() - - @staticmethod - def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ), - current_document=common.Precondition(exists=False), - ) +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_get(query_class): + collection = _make_collection_reference("collection") + get_response = collection.get() - def _add_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) - write_result = mock.Mock( - update_time=mock.sentinel.update_time, spec=["update_time"] - ) - commit_response = mock.Mock( - write_results=[write_result], - spec=["write_results", "commit_time"], - commit_time=mock.sentinel.commit_time, - ) - firestore_api.commit.return_value = commit_response + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with(transaction=None) - # Attach the fake GAPIC to a real client. - client = _test_helpers.make_client() - client._firestore_api_internal = firestore_api - # Actually make a collection and call add(). - collection = self._make_one("parent", client=client) - document_data = {"zorp": 208.75, "i-did-not": b"know that"} - doc_id = "child" +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_get_w_retry_timeout(query_class): + from google.api_core.retry import Retry - kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) - update_time, document_ref = collection.add( - document_data, document_id=doc_id, **kwargs - ) + retry = Retry(predicate=object()) + timeout = 123.0 + collection = _make_collection_reference("collection") + get_response = collection.get(retry=retry, timeout=timeout) - # Verify the response and the mocks. - self.assertIs(update_time, mock.sentinel.update_time) - self.assertIsInstance(document_ref, DocumentReference) - self.assertIs(document_ref._client, client) - self.assertEqual(document_ref._path, (collection.id, doc_id)) - - write_pb = self._write_pb_for_create(document_ref._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - def test_add_explicit_id(self): - self._add_helper() - - def test_add_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._add_helper(retry=retry, timeout=timeout) - - def _list_documents_helper(self, page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers as _fs_v1_helpers - from google.api_core.page_iterator import Iterator - from google.api_core.page_iterator import Page - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1.services.firestore.client import FirestoreClient - from google.cloud.firestore_v1.types.document import Document - - class _Iterator(Iterator): - def __init__(self, pages): - super(_Iterator, self).__init__(client=None) - self._pages = pages - - def _next_page(self): - if self._pages: - page, self._pages = self._pages[0], self._pages[1:] - return Page(self, page, self.item_to_value) - - client = _test_helpers.make_client() - template = client._database_string + "/documents/{}" - document_ids = ["doc-1", "doc-2"] - documents = [ - Document(name=template.format(document_id)) for document_id in document_ids - ] - iterator = _Iterator(pages=[documents]) - api_client = mock.create_autospec(FirestoreClient) - api_client.list_documents.return_value = iterator - client._firestore_api_internal = api_client - collection = self._make_one("collection", client=client) - kwargs = _fs_v1_helpers.make_retry_timeout_kwargs(retry, timeout) - - if page_size is not None: - documents = list(collection.list_documents(page_size=page_size, **kwargs)) - else: - documents = list(collection.list_documents(**kwargs)) - - # Verify the response and the mocks. - self.assertEqual(len(documents), len(document_ids)) - for document, document_id in zip(documents, document_ids): - self.assertIsInstance(document, DocumentReference) - self.assertEqual(document.parent, collection) - self.assertEqual(document.id, document_id) - - parent, _ = collection._parent_info() - api_client.list_documents.assert_called_once_with( - request={ - "parent": parent, - "collection_id": collection.id, - "page_size": page_size, - "show_missing": True, - "mask": {"field_paths": None}, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) - def test_list_documents_wo_page_size(self): - self._list_documents_helper() - def test_list_documents_w_retry_timeout(self): - from google.api_core.retry import Retry +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_get_with_transaction(query_class): - retry = Retry(predicate=object()) - timeout = 123.0 - self._list_documents_helper(retry=retry, timeout=timeout) + collection = _make_collection_reference("collection") + transaction = mock.sentinel.txn + get_response = collection.get(transaction=transaction) - def test_list_documents_w_page_size(self): - self._list_documents_helper(page_size=25) + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_get(self, query_class): - collection = self._make_one("collection") - get_response = collection.get() + assert get_response is query_instance.get.return_value + query_instance.get.assert_called_once_with(transaction=transaction) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with(transaction=None) +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_stream(query_class): + collection = _make_collection_reference("collection") + stream_response = collection.stream() - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_get_w_retry_timeout(self, query_class): - from google.api_core.retry import Retry + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + assert stream_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with(transaction=None) - retry = Retry(predicate=object()) - timeout = 123.0 - collection = self._make_one("collection") - get_response = collection.get(retry=retry, timeout=timeout) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_stream_w_retry_timeout(query_class): + from google.api_core.retry import Retry - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with( - transaction=None, retry=retry, timeout=timeout, - ) + retry = Retry(predicate=object()) + timeout = 123.0 + collection = _make_collection_reference("collection") + stream_response = collection.stream(retry=retry, timeout=timeout) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_get_with_transaction(self, query_class): + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + assert stream_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with( + transaction=None, retry=retry, timeout=timeout, + ) - collection = self._make_one("collection") - transaction = mock.sentinel.txn - get_response = collection.get(transaction=transaction) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value +@mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) +def test_stream_with_transaction(query_class): + collection = _make_collection_reference("collection") + transaction = mock.sentinel.txn + stream_response = collection.stream(transaction=transaction) - self.assertIs(get_response, query_instance.get.return_value) - query_instance.get.assert_called_once_with(transaction=transaction) + query_class.assert_called_once_with(collection) + query_instance = query_class.return_value + assert stream_response is query_instance.stream.return_value + query_instance.stream.assert_called_once_with(transaction=transaction) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_stream(self, query_class): - collection = self._make_one("collection") - stream_response = collection.stream() - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - self.assertIs(stream_response, query_instance.stream.return_value) - query_instance.stream.assert_called_once_with(transaction=None) +@mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) +def test_on_snapshot(watch): + collection = _make_collection_reference("collection") + collection.on_snapshot(None) + watch.for_query.assert_called_once() - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_stream_w_retry_timeout(self, query_class): - from google.api_core.retry import Retry - retry = Retry(predicate=object()) - timeout = 123.0 - collection = self._make_one("collection") - stream_response = collection.stream(retry=retry, timeout=timeout) +def test_recursive(): + from google.cloud.firestore_v1.query import Query + + col = _make_collection_reference("collection") + assert isinstance(col.recursive(), Query) - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - self.assertIs(stream_response, query_instance.stream.return_value) - query_instance.stream.assert_called_once_with( - transaction=None, retry=retry, timeout=timeout, - ) - @mock.patch("google.cloud.firestore_v1.query.Query", autospec=True) - def test_stream_with_transaction(self, query_class): - collection = self._make_one("collection") - transaction = mock.sentinel.txn - stream_response = collection.stream(transaction=transaction) - - query_class.assert_called_once_with(collection) - query_instance = query_class.return_value - self.assertIs(stream_response, query_instance.stream.return_value) - query_instance.stream.assert_called_once_with(transaction=transaction) - - @mock.patch("google.cloud.firestore_v1.collection.Watch", autospec=True) - def test_on_snapshot(self, watch): - collection = self._make_one("collection") - collection.on_snapshot(None) - watch.for_query.assert_called_once() - - def test_recursive(self): - from google.cloud.firestore_v1.query import Query - - col = self._make_one("collection") - self.assertIsInstance(col.recursive(), Query) - - def test_chunkify(self): - client = _test_helpers.make_client() - col = client.collection("my-collection") - - client._firestore_api_internal = mock.Mock(spec=["run_query"]) - - results = [] - for index in range(10): - results.append( - RunQueryResponse( - document=Document( - name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", - ), +def test_chunkify(): + from google.cloud.firestore_v1.types.document import Document + from google.cloud.firestore_v1.types.firestore import RunQueryResponse + from tests.unit.v1 import _test_helpers + + client = _test_helpers.make_client() + col = client.collection("my-collection") + + client._firestore_api_internal = mock.Mock(spec=["run_query"]) + + results = [] + for index in range(10): + results.append( + RunQueryResponse( + document=Document( + name=f"projects/project-project/databases/(default)/documents/my-collection/{index}", ), - ) + ), + ) - chunks = [ - results[:3], - results[3:6], - results[6:9], - results[9:], - ] + chunks = [ + results[:3], + results[3:6], + results[6:9], + results[9:], + ] - def _get_chunk(*args, **kwargs): - return iter(chunks.pop(0)) + def _get_chunk(*args, **kwargs): + return iter(chunks.pop(0)) - client._firestore_api_internal.run_query.side_effect = _get_chunk + client._firestore_api_internal.run_query.side_effect = _get_chunk - counter = 0 - expected_lengths = [3, 3, 3, 1] - for chunk in col._chunkify(3): - msg = f"Expected chunk of length {expected_lengths[counter]} at index {counter}. Saw {len(chunk)}." - self.assertEqual(len(chunk), expected_lengths[counter], msg) - counter += 1 + counter = 0 + expected_lengths = [3, 3, 3, 1] + for chunk in col._chunkify(3): + assert len(chunk) == expected_lengths[counter] + counter += 1 diff --git a/tests/unit/v1/test_document.py b/tests/unit/v1/test_document.py index 30c8a1c16c595..df52a7c3e6f70 100644 --- a/tests/unit/v1/test_document.py +++ b/tests/unit/v1/test_document.py @@ -12,534 +12,558 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections -import unittest import mock +import pytest -class TestDocumentReference(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.document import DocumentReference +def _make_document_reference(*args, **kwargs): + from google.cloud.firestore_v1.document import DocumentReference - return DocumentReference + return DocumentReference(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor(self): - collection_id1 = "users" - document_id1 = "alovelace" - collection_id2 = "platform" - document_id2 = "*nix" - client = mock.MagicMock() - client.__hash__.return_value = 1234 +def test_constructor(): + collection_id1 = "users" + document_id1 = "alovelace" + collection_id2 = "platform" + document_id2 = "*nix" + client = mock.MagicMock() + client.__hash__.return_value = 1234 - document = self._make_one( - collection_id1, document_id1, collection_id2, document_id2, client=client - ) - self.assertIs(document._client, client) - expected_path = "/".join( - (collection_id1, document_id1, collection_id2, document_id2) - ) - self.assertEqual(document.path, expected_path) - - @staticmethod - def _make_commit_repsonse(write_results=None): - from google.cloud.firestore_v1.types import firestore - - response = mock.create_autospec(firestore.CommitResponse) - response.write_results = write_results or [mock.sentinel.write_result] - response.commit_time = mock.sentinel.commit_time - return response - - @staticmethod - def _write_pb_for_create(document_path, document_data): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ), - current_document=common.Precondition(exists=False), - ) + document = _make_document_reference( + collection_id1, document_id1, collection_id2, document_id2, client=client + ) + assert document._client is client + expected_path = "/".join( + (collection_id1, document_id1, collection_id2, document_id2) + ) + assert document.path == expected_path - def _create_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock() - firestore_api.commit.mock_add_spec(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("dignity") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("foo", "twelve", client=client) - document_data = {"hello": "goodbye", "count": 99} - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - write_result = document.create(document_data, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - write_pb = self._write_pb_for_create(document._document_path, document_data) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - def test_create(self): - self._create_helper() +def _make_commit_repsonse(write_results=None): + from google.cloud.firestore_v1.types import firestore - def test_create_w_retry_timeout(self): - from google.api_core.retry import Retry + response = mock.create_autospec(firestore.CommitResponse) + response.write_results = write_results or [mock.sentinel.write_result] + response.commit_time = mock.sentinel.commit_time + return response - retry = Retry(predicate=object()) - timeout = 123.0 - self._create_helper(retry=retry, timeout=timeout) - def test_create_empty(self): - # Create a minimal fake GAPIC with a dummy response. - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1.document import DocumentSnapshot +def _write_pb_for_create(document_path, document_data): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers - firestore_api = mock.Mock(spec=["commit"]) - document_reference = mock.create_autospec(DocumentReference) - snapshot = mock.create_autospec(DocumentSnapshot) - snapshot.exists = True - document_reference.get.return_value = snapshot - firestore_api.commit.return_value = self._make_commit_repsonse( - write_results=[document_reference] - ) + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) + ), + current_document=common.Precondition(exists=False), + ) - # Attach the fake GAPIC to a real client. - client = _make_client("dignity") - client._firestore_api_internal = firestore_api - client.get_all = mock.MagicMock() - client.get_all.exists.return_value = True - - # Actually make a document and call create(). - document = self._make_one("foo", "twelve", client=client) - document_data = {} - write_result = document.create(document_data) - self.assertTrue(write_result.get().exists) - - @staticmethod - def _write_pb_for_set(document_path, document_data, merge): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers - - write_pbs = write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(document_data) - ) - ) - if merge: - field_paths = [ - field_path - for field_path, value in _helpers.extract_fields( - document_data, _helpers.FieldPath() - ) - ] - field_paths = [ - field_path.to_api_repr() for field_path in sorted(field_paths) - ] - mask = common.DocumentMask(field_paths=sorted(field_paths)) - write_pbs._pb.update_mask.CopyFrom(mask._pb) - return write_pbs - - def _set_helper(self, merge=False, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("db-dee-bee") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("User", "Interface", client=client) - document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - write_result = document.set(document_data, merge, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - write_pb = self._write_pb_for_set(document._document_path, document_data, merge) - - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - def test_set(self): - self._set_helper() +def _create_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers - def test_set_w_retry_timeout(self): - from google.api_core.retry import Retry + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock() + firestore_api.commit.mock_add_spec(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() - retry = Retry(predicate=object()) - timeout = 123.0 - self._set_helper(retry=retry, timeout=timeout) + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api - def test_set_merge(self): - self._set_helper(merge=True) + # Actually make a document and call create(). + document = _make_document_reference("foo", "twelve", client=client) + document_data = {"hello": "goodbye", "count": 99} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - @staticmethod - def _write_pb_for_update(document_path, update_values, field_paths): - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import write - from google.cloud.firestore_v1 import _helpers + write_result = document.create(document_data, **kwargs) - return write.Write( - update=document.Document( - name=document_path, fields=_helpers.encode_dict(update_values) - ), - update_mask=common.DocumentMask(field_paths=field_paths), - current_document=common.Precondition(exists=True), - ) + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + write_pb = _write_pb_for_create(document._document_path, document_data) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - def _update_helper(self, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.transforms import DELETE_FIELD - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() +def test_documentreference_create(): + _create_helper() - # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") - client._firestore_api_internal = firestore_api - # Actually make a document and call create(). - document = self._make_one("baked", "Alaska", client=client) - # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. - field_updates = collections.OrderedDict( - (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) - ) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - if option_kwargs: - option = client.write_option(**option_kwargs) - write_result = document.update(field_updates, option=option, **kwargs) - else: - option = None - write_result = document.update(field_updates, **kwargs) - - # Verify the response and the mocks. - self.assertIs(write_result, mock.sentinel.write_result) - update_values = { - "hello": field_updates["hello"], - "then": {"do": field_updates["then.do"]}, - } - field_paths = list(field_updates.keys()) - write_pb = self._write_pb_for_update( - document._document_path, update_values, sorted(field_paths) - ) - if option is not None: - option.modify_write(write_pb) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) +def test_documentreference_create_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _create_helper(retry=retry, timeout=timeout) - def test_update_with_exists(self): - with self.assertRaises(ValueError): - self._update_helper(exists=True) - - def test_update(self): - self._update_helper() - - def test_update_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._update_helper(retry=retry, timeout=timeout) - - def test_update_with_precondition(self): - from google.protobuf import timestamp_pb2 - - timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - self._update_helper(last_update_time=timestamp) - - def test_empty_update(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("potato-chip") - client._firestore_api_internal = firestore_api - - # Actually make a document and call create(). - document = self._make_one("baked", "Alaska", client=client) - # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. - field_updates = {} - with self.assertRaises(ValueError): - document.update(field_updates) - - def _delete_helper(self, retry=None, timeout=None, **option_kwargs): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["commit"]) - firestore_api.commit.return_value = self._make_commit_repsonse() - - # Attach the fake GAPIC to a real client. - client = _make_client("donut-base") - client._firestore_api_internal = firestore_api - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Actually make a document and call delete(). - document = self._make_one("where", "we-are", client=client) - if option_kwargs: - option = client.write_option(**option_kwargs) - delete_time = document.delete(option=option, **kwargs) - else: - option = None - delete_time = document.delete(**kwargs) - - # Verify the response and the mocks. - self.assertIs(delete_time, mock.sentinel.commit_time) - write_pb = write.Write(delete=document._document_path) - if option is not None: - option.modify_write(write_pb) - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": [write_pb], - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - def test_delete(self): - self._delete_helper() - - def test_delete_with_option(self): - from google.protobuf import timestamp_pb2 - - timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) - self._delete_helper(last_update_time=timestamp_pb) - - def test_delete_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._delete_helper(retry=retry, timeout=timeout) - - def _get_helper( - self, - field_paths=None, - use_transaction=False, - not_found=False, - # This should be an impossible case, but we test against it for - # completeness - return_empty=False, - retry=None, - timeout=None, - ): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import document - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.transaction import Transaction - - # Create a minimal fake GAPIC with a dummy response. - create_time = 123 - update_time = 234 - read_time = 345 - firestore_api = mock.Mock(spec=["batch_get_documents"]) - response = mock.create_autospec(firestore.BatchGetDocumentsResponse) - response.read_time = read_time - response.found = mock.create_autospec(document.Document) - response.found.fields = {} - response.found.create_time = create_time - response.found.update_time = update_time - - client = _make_client("donut-base") - client._firestore_api_internal = firestore_api - document_reference = self._make_one("where", "we-are", client=client) - - response.found.name = None if not_found else document_reference._document_path - response.missing = document_reference._document_path if not_found else None - - def WhichOneof(val): - return "missing" if not_found else "found" - - response._pb = response - response._pb.WhichOneof = WhichOneof - firestore_api.batch_get_documents.return_value = iter( - [response] if not return_empty else [] +def test_documentreference_create_empty(): + # Create a minimal fake GAPIC with a dummy response. + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1.document import DocumentSnapshot + + firestore_api = mock.Mock(spec=["commit"]) + document_reference = mock.create_autospec(DocumentReference) + snapshot = mock.create_autospec(DocumentSnapshot) + snapshot.exists = True + document_reference.get.return_value = snapshot + firestore_api.commit.return_value = _make_commit_repsonse( + write_results=[document_reference] + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("dignity") + client._firestore_api_internal = firestore_api + client.get_all = mock.MagicMock() + client.get_all.exists.return_value = True + + # Actually make a document and call create(). + document = _make_document_reference("foo", "twelve", client=client) + document_data = {} + write_result = document.create(document_data) + assert write_result.get().exists + + +def _write_pb_for_set(document_path, document_data, merge): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers + + write_pbs = write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(document_data) ) + ) + if merge: + field_paths = [ + field_path + for field_path, value in _helpers.extract_fields( + document_data, _helpers.FieldPath() + ) + ] + field_paths = [field_path.to_api_repr() for field_path in sorted(field_paths)] + mask = common.DocumentMask(field_paths=sorted(field_paths)) + write_pbs._pb.update_mask.CopyFrom(mask._pb) + return write_pbs + + +def _set_helper(merge=False, retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("db-dee-bee") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_document_reference("User", "Interface", client=client) + document_data = {"And": 500, "Now": b"\xba\xaa\xaa \xba\xaa\xaa"} + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + write_result = document.set(document_data, merge, **kwargs) + + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + write_pb = _write_pb_for_set(document._document_path, document_data, merge) + + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - if use_transaction: - transaction = Transaction(client) - transaction_id = transaction._id = b"asking-me-2" - else: - transaction = None - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) +def test_documentreference_set(): + _set_helper() + + +def test_documentreference_set_w_retry_timeout(): + from google.api_core.retry import Retry - snapshot = document_reference.get( - field_paths=field_paths, transaction=transaction, **kwargs - ) + retry = Retry(predicate=object()) + timeout = 123.0 + _set_helper(retry=retry, timeout=timeout) - self.assertIs(snapshot.reference, document_reference) - if not_found or return_empty: - self.assertIsNone(snapshot._data) - self.assertFalse(snapshot.exists) - self.assertIsNotNone(snapshot.read_time) - self.assertIsNone(snapshot.create_time) - self.assertIsNone(snapshot.update_time) - else: - self.assertEqual(snapshot.to_dict(), {}) - self.assertTrue(snapshot.exists) - self.assertIs(snapshot.read_time, read_time) - self.assertIs(snapshot.create_time, create_time) - self.assertIs(snapshot.update_time, update_time) - - # Verify the request made to the API - if field_paths is not None: - mask = common.DocumentMask(field_paths=sorted(field_paths)) - else: - mask = None - - if use_transaction: - expected_transaction_id = transaction_id - else: - expected_transaction_id = None - - firestore_api.batch_get_documents.assert_called_once_with( - request={ - "database": client._database_string, - "documents": [document_reference._document_path], - "mask": mask, - "transaction": expected_transaction_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ) - def test_get_not_found(self): - self._get_helper(not_found=True) +def test_documentreference_set_merge(): + _set_helper(merge=True) - def test_get_default(self): - self._get_helper() - def test_get_return_empty(self): - self._get_helper(return_empty=True) +def _write_pb_for_update(document_path, update_values, field_paths): + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import write + from google.cloud.firestore_v1 import _helpers - def test_get_w_retry_timeout(self): - from google.api_core.retry import Retry + return write.Write( + update=document.Document( + name=document_path, fields=_helpers.encode_dict(update_values) + ), + update_mask=common.DocumentMask(field_paths=field_paths), + current_document=common.Precondition(exists=True), + ) - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_helper(retry=retry, timeout=timeout) - def test_get_w_string_field_path(self): - with self.assertRaises(ValueError): - self._get_helper(field_paths="foo") +def _update_helper(retry=None, timeout=None, **option_kwargs): + from collections import OrderedDict + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.transforms import DELETE_FIELD - def test_get_with_field_path(self): - self._get_helper(field_paths=["foo"]) + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_document_reference("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = OrderedDict( + (("hello", 1), ("then.do", False), ("goodbye", DELETE_FIELD)) + ) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + if option_kwargs: + option = client.write_option(**option_kwargs) + write_result = document.update(field_updates, option=option, **kwargs) + else: + option = None + write_result = document.update(field_updates, **kwargs) + + # Verify the response and the mocks. + assert write_result is mock.sentinel.write_result + update_values = { + "hello": field_updates["hello"], + "then": {"do": field_updates["then.do"]}, + } + field_paths = list(field_updates.keys()) + write_pb = _write_pb_for_update( + document._document_path, update_values, sorted(field_paths) + ) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - def test_get_with_multiple_field_paths(self): - self._get_helper(field_paths=["foo", "bar.baz"]) - def test_get_with_transaction(self): - self._get_helper(use_transaction=True) +def test_documentreference_update_with_exists(): + with pytest.raises(ValueError): + _update_helper(exists=True) + + +def test_documentreference_update(): + _update_helper() + + +def test_documentreference_update_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _update_helper(retry=retry, timeout=timeout) + + +def test_documentreference_update_with_precondition(): + from google.protobuf import timestamp_pb2 + + timestamp = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + _update_helper(last_update_time=timestamp) + + +def test_documentreference_empty_update(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("potato-chip") + client._firestore_api_internal = firestore_api + + # Actually make a document and call create(). + document = _make_document_reference("baked", "Alaska", client=client) + # "Cheat" and use OrderedDict-s so that iteritems() is deterministic. + field_updates = {} + with pytest.raises(ValueError): + document.update(field_updates) + + +def _delete_helper(retry=None, timeout=None, **option_kwargs): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["commit"]) + firestore_api.commit.return_value = _make_commit_repsonse() + + # Attach the fake GAPIC to a real client. + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Actually make a document and call delete(). + document = _make_document_reference("where", "we-are", client=client) + if option_kwargs: + option = client.write_option(**option_kwargs) + delete_time = document.delete(option=option, **kwargs) + else: + option = None + delete_time = document.delete(**kwargs) + + # Verify the response and the mocks. + assert delete_time is mock.sentinel.commit_time + write_pb = write.Write(delete=document._document_path) + if option is not None: + option.modify_write(write_pb) + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": [write_pb], + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_documentreference_delete(): + _delete_helper() + + +def test_documentreference_delete_with_option(): + from google.protobuf import timestamp_pb2 + + timestamp_pb = timestamp_pb2.Timestamp(seconds=1058655101, nanos=100022244) + _delete_helper(last_update_time=timestamp_pb) + + +def test_documentreference_delete_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _delete_helper(retry=retry, timeout=timeout) + + +def _get_helper( + field_paths=None, + use_transaction=False, + not_found=False, + # This should be an impossible case, but we test against it for + # completeness + return_empty=False, + retry=None, + timeout=None, +): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import document + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.transaction import Transaction + + # Create a minimal fake GAPIC with a dummy response. + create_time = 123 + update_time = 234 + read_time = 345 + firestore_api = mock.Mock(spec=["batch_get_documents"]) + response = mock.create_autospec(firestore.BatchGetDocumentsResponse) + response.read_time = read_time + response.found = mock.create_autospec(document.Document) + response.found.fields = {} + response.found.create_time = create_time + response.found.update_time = update_time + + client = _make_client("donut-base") + client._firestore_api_internal = firestore_api + document_reference = _make_document_reference("where", "we-are", client=client) + + response.found.name = None if not_found else document_reference._document_path + response.missing = document_reference._document_path if not_found else None + + def WhichOneof(val): + return "missing" if not_found else "found" + + response._pb = response + response._pb.WhichOneof = WhichOneof + firestore_api.batch_get_documents.return_value = iter( + [response] if not return_empty else [] + ) + + if use_transaction: + transaction = Transaction(client) + transaction_id = transaction._id = b"asking-me-2" + else: + transaction = None + + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + snapshot = document_reference.get( + field_paths=field_paths, transaction=transaction, **kwargs + ) + + assert snapshot.reference is document_reference + if not_found or return_empty: + assert snapshot._data is None + assert not snapshot.exists + assert snapshot.read_time is not None + assert snapshot.create_time is None + assert snapshot.update_time is None + else: + assert snapshot.to_dict() == {} + assert snapshot.exists + assert snapshot.read_time is read_time + assert snapshot.create_time is create_time + assert snapshot.update_time is update_time - def _collections_helper(self, page_size=None, retry=None, timeout=None): - from google.cloud.firestore_v1.collection import CollectionReference - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + # Verify the request made to the API + if field_paths is not None: + mask = common.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None - collection_ids = ["coll-1", "coll-2"] + if use_transaction: + expected_transaction_id = transaction_id + else: + expected_transaction_id = None - class Pager(object): - def __iter__(self): - yield from collection_ids + firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": [document_reference._document_path], + "mask": mask, + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - api_client = mock.create_autospec(FirestoreClient) - api_client.list_collection_ids.return_value = Pager() - client = _make_client() - client._firestore_api_internal = api_client - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) +def test_documentreference_get_not_found(): + _get_helper(not_found=True) - # Actually make a document and call delete(). - document = self._make_one("where", "we-are", client=client) - if page_size is not None: - collections = list(document.collections(page_size=page_size, **kwargs)) - else: - collections = list(document.collections(**kwargs)) - # Verify the response and the mocks. - self.assertEqual(len(collections), len(collection_ids)) - for collection, collection_id in zip(collections, collection_ids): - self.assertIsInstance(collection, CollectionReference) - self.assertEqual(collection.parent, document) - self.assertEqual(collection.id, collection_id) +def test_documentreference_get_default(): + _get_helper() - api_client.list_collection_ids.assert_called_once_with( - request={"parent": document._document_path, "page_size": page_size}, - metadata=client._rpc_metadata, - **kwargs, - ) - def test_collections_wo_page_size(self): - self._collections_helper() +def test_documentreference_get_return_empty(): + _get_helper(return_empty=True) + + +def test_documentreference_get_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _get_helper(retry=retry, timeout=timeout) + + +def test_documentreference_get_w_string_field_path(): + with pytest.raises(ValueError): + _get_helper(field_paths="foo") + + +def test_documentreference_get_with_field_path(): + _get_helper(field_paths=["foo"]) + + +def test_documentreference_get_with_multiple_field_paths(): + _get_helper(field_paths=["foo", "bar.baz"]) + + +def test_documentreference_get_with_transaction(): + _get_helper(use_transaction=True) + + +def _collections_helper(page_size=None, retry=None, timeout=None): + from google.cloud.firestore_v1.collection import CollectionReference + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.services.firestore.client import FirestoreClient + + collection_ids = ["coll-1", "coll-2"] + + class Pager(object): + def __iter__(self): + yield from collection_ids + + api_client = mock.create_autospec(FirestoreClient) + api_client.list_collection_ids.return_value = Pager() + + client = _make_client() + client._firestore_api_internal = api_client + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Actually make a document and call delete(). + document = _make_document_reference("where", "we-are", client=client) + if page_size is not None: + collections = list(document.collections(page_size=page_size, **kwargs)) + else: + collections = list(document.collections(**kwargs)) + + # Verify the response and the mocks. + assert len(collections) == len(collection_ids) + for collection, collection_id in zip(collections, collection_ids): + assert isinstance(collection, CollectionReference) + assert collection.parent == document + assert collection.id == collection_id + + api_client.list_collection_ids.assert_called_once_with( + request={"parent": document._document_path, "page_size": page_size}, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_documentreference_collections_wo_page_size(): + _collections_helper() + + +def test_documentreference_collections_w_page_size(): + _collections_helper(page_size=10) + - def test_collections_w_page_size(self): - self._collections_helper(page_size=10) +def test_documentreference_collections_w_retry_timeout(): + from google.api_core.retry import Retry - def test_collections_w_retry_timeout(self): - from google.api_core.retry import Retry + retry = Retry(predicate=object()) + timeout = 123.0 + _collections_helper(retry=retry, timeout=timeout) - retry = Retry(predicate=object()) - timeout = 123.0 - self._collections_helper(retry=retry, timeout=timeout) - @mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) - def test_on_snapshot(self, watch): - client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) - document = self._make_one("yellow", "mellow", client=client) - document.on_snapshot(None) - watch.for_document.assert_called_once() +@mock.patch("google.cloud.firestore_v1.document.Watch", autospec=True) +def test_documentreference_on_snapshot(watch): + client = mock.Mock(_database_string="sprinklez", spec=["_database_string"]) + document = _make_document_reference("yellow", "mellow", client=client) + document.on_snapshot(None) + watch.for_document.assert_called_once() def _make_credentials(): diff --git a/tests/unit/v1/test_field_path.py b/tests/unit/v1/test_field_path.py index 55aefab4c152a..5efbadbd3a6e1 100644 --- a/tests/unit/v1/test_field_path.py +++ b/tests/unit/v1/test_field_path.py @@ -13,488 +13,617 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import mock +import pytest + + +def _expect_tokenize_field_path(path, split_path): + from google.cloud.firestore_v1 import field_path + + assert list(field_path._tokenize_field_path(path)) == split_path + + +def test__tokenize_field_path_w_empty(): + _expect_tokenize_field_path("", []) + + +def test__tokenize_field_path_w_single_dot(): + _expect_tokenize_field_path(".", ["."]) + + +def test__tokenize_field_path_w_single_simple(): + _expect_tokenize_field_path("abc", ["abc"]) + + +def test__tokenize_field_path_w_single_quoted(): + _expect_tokenize_field_path("`c*de`", ["`c*de`"]) + + +def test__tokenize_field_path_w_quoted_embedded_dot(): + _expect_tokenize_field_path("`c*.de`", ["`c*.de`"]) + + +def test__tokenize_field_path_w_quoted_escaped_backtick(): + _expect_tokenize_field_path(r"`c*\`de`", [r"`c*\`de`"]) + + +def test__tokenize_field_path_w_dotted_quoted(): + _expect_tokenize_field_path("`*`.`~`", ["`*`", ".", "`~`"]) + + +def test__tokenize_field_path_w_dotted(): + _expect_tokenize_field_path("a.b.`c*de`", ["a", ".", "b", ".", "`c*de`"]) + + +def test__tokenize_field_path_w_dotted_escaped(): + _expect_tokenize_field_path("_0.`1`.`+2`", ["_0", ".", "`1`", ".", "`+2`"]) + + +def test__tokenize_field_path_w_unconsumed_characters(): + from google.cloud.firestore_v1 import field_path + + path = "a~b" + with pytest.raises(ValueError): + list(field_path._tokenize_field_path(path)) + + +def test_split_field_path_w_single_dot(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.split_field_path(".") + + +def test_split_field_path_w_leading_dot(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.split_field_path(".a.b.c") + + +def test_split_field_path_w_trailing_dot(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.split_field_path("a.b.") + + +def test_split_field_path_w_missing_dot(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.split_field_path("a`c*de`f") + + +def test_split_field_path_w_half_quoted_field(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.split_field_path("`c*de") + + +def test_split_field_path_w_empty(): + from google.cloud.firestore_v1 import field_path + + assert field_path.split_field_path("") == [] + + +def test_split_field_path_w_simple_field(): + from google.cloud.firestore_v1 import field_path + + assert field_path.split_field_path("a") == ["a"] + + +def test_split_field_path_w_dotted_field(): + from google.cloud.firestore_v1 import field_path + + assert field_path.split_field_path("a.b.cde") == ["a", "b", "cde"] + + +def test_split_field_path_w_quoted_field(): + from google.cloud.firestore_v1 import field_path + + assert field_path.split_field_path("a.b.`c*de`") == ["a", "b", "`c*de`"] + + +def test_split_field_path_w_quoted_field_escaped_backtick(): + from google.cloud.firestore_v1 import field_path + + assert field_path.split_field_path(r"`c*\`de`") == [r"`c*\`de`"] + + +def test_parse_field_path_wo_escaped_names(): + from google.cloud.firestore_v1 import field_path + + assert field_path.parse_field_path("a.b.c") == ["a", "b", "c"] + + +def test_parse_field_path_w_escaped_backtick(): + from google.cloud.firestore_v1 import field_path + + assert field_path.parse_field_path("`a\\`b`.c.d") == ["a`b", "c", "d"] + + +def test_parse_field_path_w_escaped_backslash(): + from google.cloud.firestore_v1 import field_path + + assert field_path.parse_field_path("`a\\\\b`.c.d") == ["a\\b", "c", "d"] + + +def test_parse_field_path_w_first_name_escaped_wo_closing_backtick(): + from google.cloud.firestore_v1 import field_path + + with pytest.raises(ValueError): + field_path.parse_field_path("`a\\`b.c.d") + + +def test_render_field_path_w_empty(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path([]) == "" + + +def test_render_field_path_w_one_simple(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["a"]) == "a" + + +def test_render_field_path_w_one_starts_w_digit(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["0abc"]) == "`0abc`" + + +def test_render_field_path_w_one_w_non_alphanum(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["a b c"]) == "`a b c`" + + +def test_render_field_path_w_one_w_backtick(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["a`b"]) == "`a\\`b`" + + +def test_render_field_path_w_one_w_backslash(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["a\\b"]) == "`a\\\\b`" + + +def test_render_field_path_multiple(): + from google.cloud.firestore_v1 import field_path + + assert field_path.render_field_path(["a", "b", "c"]) == "a.b.c" + + +DATA = { + "top1": {"middle2": {"bottom3": 20, "bottom4": 22}, "middle5": True}, + "top6": b"\x00\x01 foo", +} + + +def test_get_nested_value_simple(): + from google.cloud.firestore_v1 import field_path + + assert field_path.get_nested_value("top1", DATA) is DATA["top1"] + + +def test_get_nested_value_nested(): + from google.cloud.firestore_v1 import field_path + + assert field_path.get_nested_value("top1.middle2", DATA) is DATA["top1"]["middle2"] + assert ( + field_path.get_nested_value("top1.middle2.bottom3", DATA) + is DATA["top1"]["middle2"]["bottom3"] + ) + + +def test_get_nested_value_missing_top_level(): + from google.cloud.firestore_v1 import field_path + from google.cloud.firestore_v1.field_path import _FIELD_PATH_MISSING_TOP + + path = "top8" + with pytest.raises(KeyError) as exc_info: + field_path.get_nested_value(path, DATA) + + err_msg = _FIELD_PATH_MISSING_TOP.format(path) + assert exc_info.value.args == (err_msg,) + + +def test_get_nested_value_missing_key(): + from google.cloud.firestore_v1 import field_path + from google.cloud.firestore_v1.field_path import _FIELD_PATH_MISSING_KEY + + with pytest.raises(KeyError) as exc_info: + field_path.get_nested_value("top1.middle2.nope", DATA) + + err_msg = _FIELD_PATH_MISSING_KEY.format("nope", "top1.middle2") + assert exc_info.value.args == (err_msg,) + + +def test_get_nested_value_bad_type(): + from google.cloud.firestore_v1 import field_path + from google.cloud.firestore_v1.field_path import _FIELD_PATH_WRONG_TYPE + + with pytest.raises(KeyError) as exc_info: + field_path.get_nested_value("top6.middle7", DATA) + + err_msg = _FIELD_PATH_WRONG_TYPE.format("top6", "middle7") + assert exc_info.value.args == (err_msg,) + + +def _make_field_path(*args, **kwargs): + from google.cloud.firestore_v1 import field_path + + return field_path.FieldPath(*args, **kwargs) + + +def test_fieldpath_ctor_w_none_in_part(): + with pytest.raises(ValueError): + _make_field_path("a", None, "b") + + +def test_fieldpath_ctor_w_empty_string_in_part(): + with pytest.raises(ValueError): + _make_field_path("a", "", "b") + + +def test_fieldpath_ctor_w_integer_part(): + with pytest.raises(ValueError): + _make_field_path("a", 3, "b") + + +def test_fieldpath_ctor_w_list(): + parts = ["a", "b", "c"] + with pytest.raises(ValueError): + _make_field_path(parts) + + +def test_fieldpath_ctor_w_tuple(): + parts = ("a", "b", "c") + with pytest.raises(ValueError): + _make_field_path(parts) + + +def test_fieldpath_ctor_w_iterable_part(): + with pytest.raises(ValueError): + _make_field_path("a", ["a"], "b") + + +def test_fieldpath_constructor_w_single_part(): + field_path = _make_field_path("a") + assert field_path.parts == ("a",) + + +def test_fieldpath_constructor_w_multiple_parts(): + field_path = _make_field_path("a", "b", "c") + assert field_path.parts == ("a", "b", "c") + + +def test_fieldpath_ctor_w_invalid_chars_in_part(): + invalid_parts = ("~", "*", "/", "[", "]", ".") + for invalid_part in invalid_parts: + field_path = _make_field_path(invalid_part) + assert field_path.parts == (invalid_part,) + + +def test_fieldpath_ctor_w_double_dots(): + field_path = _make_field_path("a..b") + assert field_path.parts == ("a..b",) + + +def test_fieldpath_ctor_w_unicode(): + field_path = _make_field_path("一", "二", "三") + assert field_path.parts == ("一", "二", "三") + + +def test_fieldpath_from_api_repr_w_empty_string(): + from google.cloud.firestore_v1 import field_path + + api_repr = "" + with pytest.raises(ValueError): + field_path.FieldPath.from_api_repr(api_repr) + + +def test_fieldpath_from_api_repr_w_empty_field_name(): + from google.cloud.firestore_v1 import field_path + + api_repr = "a..b" + with pytest.raises(ValueError): + field_path.FieldPath.from_api_repr(api_repr) + + +def test_fieldpath_from_api_repr_w_invalid_chars(): + from google.cloud.firestore_v1 import field_path + + invalid_parts = ("~", "*", "/", "[", "]", ".") + for invalid_part in invalid_parts: + with pytest.raises(ValueError): + field_path.FieldPath.from_api_repr(invalid_part) + + +def test_fieldpath_from_api_repr_w_ascii_single(): + from google.cloud.firestore_v1 import field_path + + api_repr = "a" + field_path = field_path.FieldPath.from_api_repr(api_repr) + assert field_path.parts == ("a",) + + +def test_fieldpath_from_api_repr_w_ascii_dotted(): + from google.cloud.firestore_v1 import field_path + + api_repr = "a.b.c" + field_path = field_path.FieldPath.from_api_repr(api_repr) + assert field_path.parts == ("a", "b", "c") + + +def test_fieldpath_from_api_repr_w_non_ascii_dotted_non_quoted(): + from google.cloud.firestore_v1 import field_path + + api_repr = "a.一" + with pytest.raises(ValueError): + field_path.FieldPath.from_api_repr(api_repr) + + +def test_fieldpath_from_api_repr_w_non_ascii_dotted_quoted(): + from google.cloud.firestore_v1 import field_path + + api_repr = "a.`一`" + field_path = field_path.FieldPath.from_api_repr(api_repr) + assert field_path.parts == ("a", "一") + + +def test_fieldpath_from_string_w_empty_string(): + from google.cloud.firestore_v1 import field_path + + path_string = "" + with pytest.raises(ValueError): + field_path.FieldPath.from_string(path_string) + + +def test_fieldpath_from_string_w_empty_field_name(): + from google.cloud.firestore_v1 import field_path + + path_string = "a..b" + with pytest.raises(ValueError): + field_path.FieldPath.from_string(path_string) + + +def test_fieldpath_from_string_w_leading_dot(): + from google.cloud.firestore_v1 import field_path + + path_string = ".b.c" + with pytest.raises(ValueError): + field_path.FieldPath.from_string(path_string) + + +def test_fieldpath_from_string_w_trailing_dot(): + from google.cloud.firestore_v1 import field_path + + path_string = "a.b." + with pytest.raises(ValueError): + field_path.FieldPath.from_string(path_string) + +def test_fieldpath_from_string_w_leading_invalid_chars(): + from google.cloud.firestore_v1 import field_path -class Test__tokenize_field_path(unittest.TestCase): - @staticmethod - def _call_fut(path): - from google.cloud.firestore_v1 import field_path + invalid_paths = ("~", "*", "/", "[", "]") + for invalid_path in invalid_paths: + path = field_path.FieldPath.from_string(invalid_path) + assert path.parts == (invalid_path,) - return field_path._tokenize_field_path(path) - def _expect(self, path, split_path): - self.assertEqual(list(self._call_fut(path)), split_path) +def test_fieldpath_from_string_w_embedded_invalid_chars(): + from google.cloud.firestore_v1 import field_path - def test_w_empty(self): - self._expect("", []) + invalid_paths = ("a~b", "x*y", "f/g", "h[j", "k]l") + for invalid_path in invalid_paths: + with pytest.raises(ValueError): + field_path.FieldPath.from_string(invalid_path) - def test_w_single_dot(self): - self._expect(".", ["."]) - def test_w_single_simple(self): - self._expect("abc", ["abc"]) +def test_fieldpath_from_string_w_ascii_single(): + from google.cloud.firestore_v1 import field_path - def test_w_single_quoted(self): - self._expect("`c*de`", ["`c*de`"]) + path_string = "a" + field_path = field_path.FieldPath.from_string(path_string) + assert field_path.parts == ("a",) - def test_w_quoted_embedded_dot(self): - self._expect("`c*.de`", ["`c*.de`"]) - def test_w_quoted_escaped_backtick(self): - self._expect(r"`c*\`de`", [r"`c*\`de`"]) +def test_fieldpath_from_string_w_ascii_dotted(): + from google.cloud.firestore_v1 import field_path - def test_w_dotted_quoted(self): - self._expect("`*`.`~`", ["`*`", ".", "`~`"]) + path_string = "a.b.c" + field_path = field_path.FieldPath.from_string(path_string) + assert field_path.parts == ("a", "b", "c") - def test_w_dotted(self): - self._expect("a.b.`c*de`", ["a", ".", "b", ".", "`c*de`"]) - def test_w_dotted_escaped(self): - self._expect("_0.`1`.`+2`", ["_0", ".", "`1`", ".", "`+2`"]) +def test_fieldpath_from_string_w_non_ascii_dotted(): + from google.cloud.firestore_v1 import field_path - def test_w_unconsumed_characters(self): - path = "a~b" - with self.assertRaises(ValueError): - list(self._call_fut(path)) + path_string = "a.一" + field_path = field_path.FieldPath.from_string(path_string) + assert field_path.parts == ("a", "一") -class Test_split_field_path(unittest.TestCase): - @staticmethod - def _call_fut(path): - from google.cloud.firestore_v1 import field_path +def test_fieldpath___hash___w_single_part(): + field_path = _make_field_path("a") + assert hash(field_path) == hash("a") - return field_path.split_field_path(path) - def test_w_single_dot(self): - with self.assertRaises(ValueError): - self._call_fut(".") +def test_fieldpath___hash___w_multiple_parts(): + field_path = _make_field_path("a", "b") + assert hash(field_path) == hash("a.b") - def test_w_leading_dot(self): - with self.assertRaises(ValueError): - self._call_fut(".a.b.c") - def test_w_trailing_dot(self): - with self.assertRaises(ValueError): - self._call_fut("a.b.") +def test_fieldpath___hash___w_escaped_parts(): + field_path = _make_field_path("a", "3") + assert hash(field_path) == hash("a.`3`") - def test_w_missing_dot(self): - with self.assertRaises(ValueError): - self._call_fut("a`c*de`f") - def test_w_half_quoted_field(self): - with self.assertRaises(ValueError): - self._call_fut("`c*de") +def test_fieldpath___eq___w_matching_type(): + from google.cloud.firestore_v1 import field_path - def test_w_empty(self): - self.assertEqual(self._call_fut(""), []) + path = _make_field_path("a", "b") + string_path = field_path.FieldPath.from_string("a.b") + assert path == string_path - def test_w_simple_field(self): - self.assertEqual(self._call_fut("a"), ["a"]) - def test_w_dotted_field(self): - self.assertEqual(self._call_fut("a.b.cde"), ["a", "b", "cde"]) +def test_fieldpath___eq___w_non_matching_type(): + field_path = _make_field_path("a", "c") + other = mock.Mock() + other.parts = "a", "b" + assert field_path != other - def test_w_quoted_field(self): - self.assertEqual(self._call_fut("a.b.`c*de`"), ["a", "b", "`c*de`"]) - def test_w_quoted_field_escaped_backtick(self): - self.assertEqual(self._call_fut(r"`c*\`de`"), [r"`c*\`de`"]) +def test_fieldpath___lt___w_matching_type(): + from google.cloud.firestore_v1 import field_path + path = _make_field_path("a", "b") + string_path = field_path.FieldPath.from_string("a.c") + assert path < string_path -class Test_parse_field_path(unittest.TestCase): - @staticmethod - def _call_fut(path): - from google.cloud.firestore_v1 import field_path - return field_path.parse_field_path(path) +def test_fieldpath___lt___w_non_matching_type(): + field_path = _make_field_path("a", "b") + other = object() + # Python 2 doesn't raise TypeError here, but Python3 does. + assert field_path.__lt__(other) is NotImplemented - def test_wo_escaped_names(self): - self.assertEqual(self._call_fut("a.b.c"), ["a", "b", "c"]) - def test_w_escaped_backtick(self): - self.assertEqual(self._call_fut("`a\\`b`.c.d"), ["a`b", "c", "d"]) +def test_fieldpath___add__(): + path1 = "a123", "b456" + path2 = "c789", "d012" + path3 = "c789.d012" + field_path1 = _make_field_path(*path1) + field_path1_string = _make_field_path(*path1) + field_path2 = _make_field_path(*path2) + field_path1 += field_path2 + field_path1_string += path3 + field_path2 = field_path2 + _make_field_path(*path1) + assert field_path1 == _make_field_path(*(path1 + path2)) + assert field_path2 == _make_field_path(*(path2 + path1)) + assert field_path1_string == field_path1 + assert field_path1 != field_path2 + with pytest.raises(TypeError): + field_path1 + 305 - def test_w_escaped_backslash(self): - self.assertEqual(self._call_fut("`a\\\\b`.c.d"), ["a\\b", "c", "d"]) - def test_w_first_name_escaped_wo_closing_backtick(self): - with self.assertRaises(ValueError): - self._call_fut("`a\\`b.c.d") +def test_fieldpath_to_api_repr_a(): + parts = "a" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "a" -class Test_render_field_path(unittest.TestCase): - @staticmethod - def _call_fut(field_names): - from google.cloud.firestore_v1 import field_path +def test_fieldpath_to_api_repr_backtick(): + parts = "`" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == r"`\``" - return field_path.render_field_path(field_names) - def test_w_empty(self): - self.assertEqual(self._call_fut([]), "") +def test_fieldpath_to_api_repr_dot(): + parts = "." + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "`.`" - def test_w_one_simple(self): - self.assertEqual(self._call_fut(["a"]), "a") - def test_w_one_starts_w_digit(self): - self.assertEqual(self._call_fut(["0abc"]), "`0abc`") +def test_fieldpath_to_api_repr_slash(): + parts = "\\" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == r"`\\`" - def test_w_one_w_non_alphanum(self): - self.assertEqual(self._call_fut(["a b c"]), "`a b c`") - def test_w_one_w_backtick(self): - self.assertEqual(self._call_fut(["a`b"]), "`a\\`b`") +def test_fieldpath_to_api_repr_double_slash(): + parts = r"\\" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == r"`\\\\`" - def test_w_one_w_backslash(self): - self.assertEqual(self._call_fut(["a\\b"]), "`a\\\\b`") - def test_multiple(self): - self.assertEqual(self._call_fut(["a", "b", "c"]), "a.b.c") +def test_fieldpath_to_api_repr_underscore(): + parts = "_33132" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "_33132" -class Test_get_nested_value(unittest.TestCase): +def test_fieldpath_to_api_repr_unicode_non_simple(): + parts = "一" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "`一`" - DATA = { - "top1": {"middle2": {"bottom3": 20, "bottom4": 22}, "middle5": True}, - "top6": b"\x00\x01 foo", - } - @staticmethod - def _call_fut(path, data): - from google.cloud.firestore_v1 import field_path +def test_fieldpath_to_api_repr_number_non_simple(): + parts = "03" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "`03`" - return field_path.get_nested_value(path, data) - def test_simple(self): - self.assertIs(self._call_fut("top1", self.DATA), self.DATA["top1"]) +def test_fieldpath_to_api_repr_simple_with_dot(): + field_path = _make_field_path("a.b") + assert field_path.to_api_repr() == "`a.b`" - def test_nested(self): - self.assertIs( - self._call_fut("top1.middle2", self.DATA), self.DATA["top1"]["middle2"] - ) - self.assertIs( - self._call_fut("top1.middle2.bottom3", self.DATA), - self.DATA["top1"]["middle2"]["bottom3"], - ) - def test_missing_top_level(self): - from google.cloud.firestore_v1.field_path import _FIELD_PATH_MISSING_TOP +def test_fieldpath_to_api_repr_non_simple_with_dot(): + parts = "a.一" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "`a.一`" - field_path = "top8" - with self.assertRaises(KeyError) as exc_info: - self._call_fut(field_path, self.DATA) - err_msg = _FIELD_PATH_MISSING_TOP.format(field_path) - self.assertEqual(exc_info.exception.args, (err_msg,)) +def test_fieldpath_to_api_repr_simple(): + parts = "a0332432" + field_path = _make_field_path(parts) + assert field_path.to_api_repr() == "a0332432" - def test_missing_key(self): - from google.cloud.firestore_v1.field_path import _FIELD_PATH_MISSING_KEY - with self.assertRaises(KeyError) as exc_info: - self._call_fut("top1.middle2.nope", self.DATA) +def test_fieldpath_to_api_repr_chain(): + parts = "a", "`", "\\", "_3", "03", "a03", "\\\\", "a0332432", "一" + field_path = _make_field_path(*parts) + assert field_path.to_api_repr() == r"a.`\``.`\\`._3.`03`.a03.`\\\\`.a0332432.`一`" - err_msg = _FIELD_PATH_MISSING_KEY.format("nope", "top1.middle2") - self.assertEqual(exc_info.exception.args, (err_msg,)) - def test_bad_type(self): - from google.cloud.firestore_v1.field_path import _FIELD_PATH_WRONG_TYPE +def test_fieldpath_eq_or_parent_same(): + field_path = _make_field_path("a", "b") + other = _make_field_path("a", "b") + assert field_path.eq_or_parent(other) - with self.assertRaises(KeyError) as exc_info: - self._call_fut("top6.middle7", self.DATA) - err_msg = _FIELD_PATH_WRONG_TYPE.format("top6", "middle7") - self.assertEqual(exc_info.exception.args, (err_msg,)) +def test_fieldpath_eq_or_parent_prefix(): + field_path = _make_field_path("a", "b") + other = _make_field_path("a", "b", "c") + assert field_path.eq_or_parent(other) + assert other.eq_or_parent(field_path) -class TestFieldPath(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1 import field_path +def test_fieldpath_eq_or_parent_no_prefix(): + field_path = _make_field_path("a", "b") + other = _make_field_path("d", "e", "f") + assert not field_path.eq_or_parent(other) + assert not other.eq_or_parent(field_path) - return field_path.FieldPath - def _make_one(self, *args): - klass = self._get_target_class() - return klass(*args) +def test_fieldpath_lineage_empty(): + field_path = _make_field_path() + expected = set() + assert field_path.lineage() == expected - def test_ctor_w_none_in_part(self): - with self.assertRaises(ValueError): - self._make_one("a", None, "b") - def test_ctor_w_empty_string_in_part(self): - with self.assertRaises(ValueError): - self._make_one("a", "", "b") +def test_fieldpath_lineage_single(): + field_path = _make_field_path("a") + expected = set() + assert field_path.lineage() == expected - def test_ctor_w_integer_part(self): - with self.assertRaises(ValueError): - self._make_one("a", 3, "b") - def test_ctor_w_list(self): - parts = ["a", "b", "c"] - with self.assertRaises(ValueError): - self._make_one(parts) +def test_fieldpath_lineage_nested(): + field_path = _make_field_path("a", "b", "c") + expected = set([_make_field_path("a"), _make_field_path("a", "b")]) + assert field_path.lineage() == expected - def test_ctor_w_tuple(self): - parts = ("a", "b", "c") - with self.assertRaises(ValueError): - self._make_one(parts) - def test_ctor_w_iterable_part(self): - with self.assertRaises(ValueError): - self._make_one("a", ["a"], "b") - - def test_constructor_w_single_part(self): - field_path = self._make_one("a") - self.assertEqual(field_path.parts, ("a",)) - - def test_constructor_w_multiple_parts(self): - field_path = self._make_one("a", "b", "c") - self.assertEqual(field_path.parts, ("a", "b", "c")) - - def test_ctor_w_invalid_chars_in_part(self): - invalid_parts = ("~", "*", "/", "[", "]", ".") - for invalid_part in invalid_parts: - field_path = self._make_one(invalid_part) - self.assertEqual(field_path.parts, (invalid_part,)) - - def test_ctor_w_double_dots(self): - field_path = self._make_one("a..b") - self.assertEqual(field_path.parts, ("a..b",)) - - def test_ctor_w_unicode(self): - field_path = self._make_one("一", "二", "三") - self.assertEqual(field_path.parts, ("一", "二", "三")) - - def test_from_api_repr_w_empty_string(self): - api_repr = "" - with self.assertRaises(ValueError): - self._get_target_class().from_api_repr(api_repr) - - def test_from_api_repr_w_empty_field_name(self): - api_repr = "a..b" - with self.assertRaises(ValueError): - self._get_target_class().from_api_repr(api_repr) - - def test_from_api_repr_w_invalid_chars(self): - invalid_parts = ("~", "*", "/", "[", "]", ".") - for invalid_part in invalid_parts: - with self.assertRaises(ValueError): - self._get_target_class().from_api_repr(invalid_part) - - def test_from_api_repr_w_ascii_single(self): - api_repr = "a" - field_path = self._get_target_class().from_api_repr(api_repr) - self.assertEqual(field_path.parts, ("a",)) - - def test_from_api_repr_w_ascii_dotted(self): - api_repr = "a.b.c" - field_path = self._get_target_class().from_api_repr(api_repr) - self.assertEqual(field_path.parts, ("a", "b", "c")) - - def test_from_api_repr_w_non_ascii_dotted_non_quoted(self): - api_repr = "a.一" - with self.assertRaises(ValueError): - self._get_target_class().from_api_repr(api_repr) - - def test_from_api_repr_w_non_ascii_dotted_quoted(self): - api_repr = "a.`一`" - field_path = self._get_target_class().from_api_repr(api_repr) - self.assertEqual(field_path.parts, ("a", "一")) - - def test_from_string_w_empty_string(self): - path_string = "" - with self.assertRaises(ValueError): - self._get_target_class().from_string(path_string) - - def test_from_string_w_empty_field_name(self): - path_string = "a..b" - with self.assertRaises(ValueError): - self._get_target_class().from_string(path_string) - - def test_from_string_w_leading_dot(self): - path_string = ".b.c" - with self.assertRaises(ValueError): - self._get_target_class().from_string(path_string) - - def test_from_string_w_trailing_dot(self): - path_string = "a.b." - with self.assertRaises(ValueError): - self._get_target_class().from_string(path_string) - - def test_from_string_w_leading_invalid_chars(self): - invalid_paths = ("~", "*", "/", "[", "]") - for invalid_path in invalid_paths: - field_path = self._get_target_class().from_string(invalid_path) - self.assertEqual(field_path.parts, (invalid_path,)) - - def test_from_string_w_embedded_invalid_chars(self): - invalid_paths = ("a~b", "x*y", "f/g", "h[j", "k]l") - for invalid_path in invalid_paths: - with self.assertRaises(ValueError): - self._get_target_class().from_string(invalid_path) - - def test_from_string_w_ascii_single(self): - path_string = "a" - field_path = self._get_target_class().from_string(path_string) - self.assertEqual(field_path.parts, ("a",)) - - def test_from_string_w_ascii_dotted(self): - path_string = "a.b.c" - field_path = self._get_target_class().from_string(path_string) - self.assertEqual(field_path.parts, ("a", "b", "c")) - - def test_from_string_w_non_ascii_dotted(self): - path_string = "a.一" - field_path = self._get_target_class().from_string(path_string) - self.assertEqual(field_path.parts, ("a", "一")) - - def test___hash___w_single_part(self): - field_path = self._make_one("a") - self.assertEqual(hash(field_path), hash("a")) - - def test___hash___w_multiple_parts(self): - field_path = self._make_one("a", "b") - self.assertEqual(hash(field_path), hash("a.b")) - - def test___hash___w_escaped_parts(self): - field_path = self._make_one("a", "3") - self.assertEqual(hash(field_path), hash("a.`3`")) - - def test___eq___w_matching_type(self): - field_path = self._make_one("a", "b") - string_path = self._get_target_class().from_string("a.b") - self.assertEqual(field_path, string_path) - - def test___eq___w_non_matching_type(self): - field_path = self._make_one("a", "c") - other = mock.Mock() - other.parts = "a", "b" - self.assertNotEqual(field_path, other) - - def test___lt___w_matching_type(self): - field_path = self._make_one("a", "b") - string_path = self._get_target_class().from_string("a.c") - self.assertTrue(field_path < string_path) - - def test___lt___w_non_matching_type(self): - field_path = self._make_one("a", "b") - other = object() - # Python 2 doesn't raise TypeError here, but Python3 does. - self.assertIs(field_path.__lt__(other), NotImplemented) - - def test___add__(self): - path1 = "a123", "b456" - path2 = "c789", "d012" - path3 = "c789.d012" - field_path1 = self._make_one(*path1) - field_path1_string = self._make_one(*path1) - field_path2 = self._make_one(*path2) - field_path1 += field_path2 - field_path1_string += path3 - field_path2 = field_path2 + self._make_one(*path1) - self.assertEqual(field_path1, self._make_one(*(path1 + path2))) - self.assertEqual(field_path2, self._make_one(*(path2 + path1))) - self.assertEqual(field_path1_string, field_path1) - self.assertNotEqual(field_path1, field_path2) - with self.assertRaises(TypeError): - field_path1 + 305 - - def test_to_api_repr_a(self): - parts = "a" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "a") - - def test_to_api_repr_backtick(self): - parts = "`" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), r"`\``") - - def test_to_api_repr_dot(self): - parts = "." - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "`.`") - - def test_to_api_repr_slash(self): - parts = "\\" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), r"`\\`") - - def test_to_api_repr_double_slash(self): - parts = r"\\" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), r"`\\\\`") - - def test_to_api_repr_underscore(self): - parts = "_33132" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "_33132") - - def test_to_api_repr_unicode_non_simple(self): - parts = "一" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "`一`") - - def test_to_api_repr_number_non_simple(self): - parts = "03" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "`03`") - - def test_to_api_repr_simple_with_dot(self): - field_path = self._make_one("a.b") - self.assertEqual(field_path.to_api_repr(), "`a.b`") - - def test_to_api_repr_non_simple_with_dot(self): - parts = "a.一" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "`a.一`") - - def test_to_api_repr_simple(self): - parts = "a0332432" - field_path = self._make_one(parts) - self.assertEqual(field_path.to_api_repr(), "a0332432") - - def test_to_api_repr_chain(self): - parts = "a", "`", "\\", "_3", "03", "a03", "\\\\", "a0332432", "一" - field_path = self._make_one(*parts) - self.assertEqual( - field_path.to_api_repr(), r"a.`\``.`\\`._3.`03`.a03.`\\\\`.a0332432.`一`" - ) - - def test_eq_or_parent_same(self): - field_path = self._make_one("a", "b") - other = self._make_one("a", "b") - self.assertTrue(field_path.eq_or_parent(other)) - - def test_eq_or_parent_prefix(self): - field_path = self._make_one("a", "b") - other = self._make_one("a", "b", "c") - self.assertTrue(field_path.eq_or_parent(other)) - self.assertTrue(other.eq_or_parent(field_path)) - - def test_eq_or_parent_no_prefix(self): - field_path = self._make_one("a", "b") - other = self._make_one("d", "e", "f") - self.assertFalse(field_path.eq_or_parent(other)) - self.assertFalse(other.eq_or_parent(field_path)) - - def test_lineage_empty(self): - field_path = self._make_one() - expected = set() - self.assertEqual(field_path.lineage(), expected) - - def test_lineage_single(self): - field_path = self._make_one("a") - expected = set() - self.assertEqual(field_path.lineage(), expected) - - def test_lineage_nested(self): - field_path = self._make_one("a", "b", "c") - expected = set([self._make_one("a"), self._make_one("a", "b")]) - self.assertEqual(field_path.lineage(), expected) - - def test_document_id(self): - parts = "__name__" - field_path = self._make_one(parts) - self.assertEqual(field_path.document_id(), parts) +def test_fieldpath_document_id(): + parts = "__name__" + field_path = _make_field_path(parts) + assert field_path.document_id() == parts diff --git a/tests/unit/v1/test_order.py b/tests/unit/v1/test_order.py index 90d99e563e6e9..3a2086c53d97e 100644 --- a/tests/unit/v1/test_order.py +++ b/tests/unit/v1/test_order.py @@ -14,227 +14,239 @@ # limitations under the License. import mock -import unittest - -from google.cloud.firestore_v1._helpers import encode_value, GeoPoint -from google.cloud.firestore_v1.order import Order -from google.cloud.firestore_v1.order import TypeOrder - -from google.cloud.firestore_v1.types import document - -from google.protobuf import timestamp_pb2 - - -class TestOrder(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.order import Order - - return Order - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_order(self): - # Constants used to represent min/max values of storage types. - int_max_value = 2 ** 31 - 1 - int_min_value = -(2 ** 31) - float_min_value = 1.175494351 ** -38 - float_nan = float("nan") - inf = float("inf") - - groups = [None] * 65 - - groups[0] = [nullValue()] - - groups[1] = [_boolean_value(False)] - groups[2] = [_boolean_value(True)] - - # numbers - groups[3] = [_double_value(float_nan), _double_value(float_nan)] - groups[4] = [_double_value(-inf)] - groups[5] = [_int_value(int_min_value - 1)] - groups[6] = [_int_value(int_min_value)] - groups[7] = [_double_value(-1.1)] - # Integers and Doubles order the same. - groups[8] = [_int_value(-1), _double_value(-1.0)] - groups[9] = [_double_value(-float_min_value)] - # zeros all compare the same. - groups[10] = [ - _int_value(0), - _double_value(-0.0), - _double_value(0.0), - _double_value(+0.0), - ] - groups[11] = [_double_value(float_min_value)] - groups[12] = [_int_value(1), _double_value(1.0)] - groups[13] = [_double_value(1.1)] - groups[14] = [_int_value(int_max_value)] - groups[15] = [_int_value(int_max_value + 1)] - groups[16] = [_double_value(inf)] - - groups[17] = [_timestamp_value(123, 0)] - groups[18] = [_timestamp_value(123, 123)] - groups[19] = [_timestamp_value(345, 0)] - - # strings - groups[20] = [_string_value("")] - groups[21] = [_string_value("\u0000\ud7ff\ue000\uffff")] - groups[22] = [_string_value("(╯°□°)╯︵ ┻━┻")] - groups[23] = [_string_value("a")] - groups[24] = [_string_value("abc def")] - # latin small letter e + combining acute accent + latin small letter b - groups[25] = [_string_value("e\u0301b")] - groups[26] = [_string_value("æ")] - # latin small letter e with acute accent + latin small letter a - groups[27] = [_string_value("\u00e9a")] - - # blobs - groups[28] = [_blob_value(b"")] - groups[29] = [_blob_value(b"\x00")] - groups[30] = [_blob_value(b"\x00\x01\x02\x03\x04")] - groups[31] = [_blob_value(b"\x00\x01\x02\x04\x03")] - groups[32] = [_blob_value(b"\x7f")] - - # resource names - groups[33] = [_reference_value("projects/p1/databases/d1/documents/c1/doc1")] - groups[34] = [_reference_value("projects/p1/databases/d1/documents/c1/doc2")] - groups[35] = [ - _reference_value("projects/p1/databases/d1/documents/c1/doc2/c2/doc1") - ] - groups[36] = [ - _reference_value("projects/p1/databases/d1/documents/c1/doc2/c2/doc2") - ] - groups[37] = [_reference_value("projects/p1/databases/d1/documents/c10/doc1")] - groups[38] = [_reference_value("projects/p1/databases/d1/documents/c2/doc1")] - groups[39] = [_reference_value("projects/p2/databases/d2/documents/c1/doc1")] - groups[40] = [_reference_value("projects/p2/databases/d2/documents/c1-/doc1")] - groups[41] = [_reference_value("projects/p2/databases/d3/documents/c1-/doc1")] - - # geo points - groups[42] = [_geoPoint_value(-90, -180)] - groups[43] = [_geoPoint_value(-90, 0)] - groups[44] = [_geoPoint_value(-90, 180)] - groups[45] = [_geoPoint_value(0, -180)] - groups[46] = [_geoPoint_value(0, 0)] - groups[47] = [_geoPoint_value(0, 180)] - groups[48] = [_geoPoint_value(1, -180)] - groups[49] = [_geoPoint_value(1, 0)] - groups[50] = [_geoPoint_value(1, 180)] - groups[51] = [_geoPoint_value(90, -180)] - groups[52] = [_geoPoint_value(90, 0)] - groups[53] = [_geoPoint_value(90, 180)] - - # arrays - groups[54] = [_array_value()] - groups[55] = [_array_value(["bar"])] - groups[56] = [_array_value(["foo"])] - groups[57] = [_array_value(["foo", 0])] - groups[58] = [_array_value(["foo", 1])] - groups[59] = [_array_value(["foo", "0"])] - - # objects - groups[60] = [_object_value({"bar": 0})] - groups[61] = [_object_value({"bar": 0, "foo": 1})] - groups[62] = [_object_value({"bar": 1})] - groups[63] = [_object_value({"bar": 2})] - groups[64] = [_object_value({"bar": "0"})] - - target = self._make_one() - - for i in range(len(groups)): - for left in groups[i]: - for j in range(len(groups)): - for right in groups[j]: - expected = Order._compare_to(i, j) - - self.assertEqual( - target.compare(left, right), - expected, - "comparing L->R {} ({}) to {} ({})".format( - i, left, j, right - ), - ) - - expected = Order._compare_to(j, i) - self.assertEqual( - target.compare(right, left), - expected, - "comparing R->L {} ({}) to {} ({})".format( - j, right, i, left - ), - ) - - def test_typeorder_type_failure(self): - target = self._make_one() - left = mock.Mock() - left.WhichOneof.return_value = "imaginary-type" - - with self.assertRaisesRegex(ValueError, "Could not detect value"): - target.compare(left, mock.Mock()) - - def test_failure_to_find_type(self): - target = self._make_one() - left = mock.Mock() - left.WhichOneof.return_value = "imaginary-type" - right = mock.Mock() - # Patch from value to get to the deep compare. Since left is a bad type - # expect this to fail with value error. - with mock.patch.object(TypeOrder, "from_value") as to: - to.value = None - with self.assertRaisesRegex(ValueError, "Unknown ``value_type``"): - target.compare(left, right) - - def test_compare_objects_different_keys(self): - left = _object_value({"foo": 0}) - right = _object_value({"bar": 0}) - - target = self._make_one() - target.compare(left, right) +import pytest + + +def _make_order(*args, **kwargs): + from google.cloud.firestore_v1.order import Order + + return Order(*args, **kwargs) + + +def test_order_compare_across_heterogenous_values(): + from google.cloud.firestore_v1.order import Order + + # Constants used to represent min/max values of storage types. + int_max_value = 2 ** 31 - 1 + int_min_value = -(2 ** 31) + float_min_value = 1.175494351 ** -38 + float_nan = float("nan") + inf = float("inf") + + groups = [None] * 65 + + groups[0] = [nullValue()] + + groups[1] = [_boolean_value(False)] + groups[2] = [_boolean_value(True)] + + # numbers + groups[3] = [_double_value(float_nan), _double_value(float_nan)] + groups[4] = [_double_value(-inf)] + groups[5] = [_int_value(int_min_value - 1)] + groups[6] = [_int_value(int_min_value)] + groups[7] = [_double_value(-1.1)] + # Integers and Doubles order the same. + groups[8] = [_int_value(-1), _double_value(-1.0)] + groups[9] = [_double_value(-float_min_value)] + # zeros all compare the same. + groups[10] = [ + _int_value(0), + _double_value(-0.0), + _double_value(0.0), + _double_value(+0.0), + ] + groups[11] = [_double_value(float_min_value)] + groups[12] = [_int_value(1), _double_value(1.0)] + groups[13] = [_double_value(1.1)] + groups[14] = [_int_value(int_max_value)] + groups[15] = [_int_value(int_max_value + 1)] + groups[16] = [_double_value(inf)] + + groups[17] = [_timestamp_value(123, 0)] + groups[18] = [_timestamp_value(123, 123)] + groups[19] = [_timestamp_value(345, 0)] + + # strings + groups[20] = [_string_value("")] + groups[21] = [_string_value("\u0000\ud7ff\ue000\uffff")] + groups[22] = [_string_value("(╯°□°)╯︵ ┻━┻")] + groups[23] = [_string_value("a")] + groups[24] = [_string_value("abc def")] + # latin small letter e + combining acute accent + latin small letter b + groups[25] = [_string_value("e\u0301b")] + groups[26] = [_string_value("æ")] + # latin small letter e with acute accent + latin small letter a + groups[27] = [_string_value("\u00e9a")] + + # blobs + groups[28] = [_blob_value(b"")] + groups[29] = [_blob_value(b"\x00")] + groups[30] = [_blob_value(b"\x00\x01\x02\x03\x04")] + groups[31] = [_blob_value(b"\x00\x01\x02\x04\x03")] + groups[32] = [_blob_value(b"\x7f")] + + # resource names + groups[33] = [_reference_value("projects/p1/databases/d1/documents/c1/doc1")] + groups[34] = [_reference_value("projects/p1/databases/d1/documents/c1/doc2")] + groups[35] = [ + _reference_value("projects/p1/databases/d1/documents/c1/doc2/c2/doc1") + ] + groups[36] = [ + _reference_value("projects/p1/databases/d1/documents/c1/doc2/c2/doc2") + ] + groups[37] = [_reference_value("projects/p1/databases/d1/documents/c10/doc1")] + groups[38] = [_reference_value("projects/p1/databases/d1/documents/c2/doc1")] + groups[39] = [_reference_value("projects/p2/databases/d2/documents/c1/doc1")] + groups[40] = [_reference_value("projects/p2/databases/d2/documents/c1-/doc1")] + groups[41] = [_reference_value("projects/p2/databases/d3/documents/c1-/doc1")] + + # geo points + groups[42] = [_geoPoint_value(-90, -180)] + groups[43] = [_geoPoint_value(-90, 0)] + groups[44] = [_geoPoint_value(-90, 180)] + groups[45] = [_geoPoint_value(0, -180)] + groups[46] = [_geoPoint_value(0, 0)] + groups[47] = [_geoPoint_value(0, 180)] + groups[48] = [_geoPoint_value(1, -180)] + groups[49] = [_geoPoint_value(1, 0)] + groups[50] = [_geoPoint_value(1, 180)] + groups[51] = [_geoPoint_value(90, -180)] + groups[52] = [_geoPoint_value(90, 0)] + groups[53] = [_geoPoint_value(90, 180)] + + # arrays + groups[54] = [_array_value()] + groups[55] = [_array_value(["bar"])] + groups[56] = [_array_value(["foo"])] + groups[57] = [_array_value(["foo", 0])] + groups[58] = [_array_value(["foo", 1])] + groups[59] = [_array_value(["foo", "0"])] + + # objects + groups[60] = [_object_value({"bar": 0})] + groups[61] = [_object_value({"bar": 0, "foo": 1})] + groups[62] = [_object_value({"bar": 1})] + groups[63] = [_object_value({"bar": 2})] + groups[64] = [_object_value({"bar": "0"})] + + target = _make_order() + + for i in range(len(groups)): + for left in groups[i]: + for j in range(len(groups)): + for right in groups[j]: + + expected = Order._compare_to(i, j) + assert target.compare(left, right) == expected + + expected = Order._compare_to(j, i) + assert target.compare(right, left) == expected + + +def test_order_compare_w_typeorder_type_failure(): + target = _make_order() + left = mock.Mock() + left.WhichOneof.return_value = "imaginary-type" + + with pytest.raises(ValueError) as exc_info: + target.compare(left, mock.Mock()) + + (message,) = exc_info.value.args + assert message.startswith("Could not detect value") + + +def test_order_compare_w_failure_to_find_type(): + from google.cloud.firestore_v1.order import TypeOrder + + target = _make_order() + left = mock.Mock() + left.WhichOneof.return_value = "imaginary-type" + right = mock.Mock() + # Patch from value to get to the deep compare. Since left is a bad type + # expect this to fail with value error. + with mock.patch.object(TypeOrder, "from_value") as to: + to.value = None + with pytest.raises(ValueError) as exc_info: + target.compare(left, right) + + (message,) = exc_info.value.args + assert message.startswith("Unknown ``value_type``") + + +def test_order_compare_w_objects_different_keys(): + left = _object_value({"foo": 0}) + right = _object_value({"bar": 0}) + + target = _make_order() + target.compare(left, right) def _boolean_value(b): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(b) def _double_value(d): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(d) def _int_value(value): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(value) def _string_value(s): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(s) def _reference_value(r): + from google.cloud.firestore_v1.types import document + return document.Value(reference_value=r) def _blob_value(b): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(b) def nullValue(): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(None) def _timestamp_value(seconds, nanos): + from google.cloud.firestore_v1.types import document + from google.protobuf import timestamp_pb2 + return document.Value( timestamp_value=timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) ) def _geoPoint_value(latitude, longitude): + from google.cloud.firestore_v1._helpers import encode_value + from google.cloud.firestore_v1._helpers import GeoPoint + return encode_value(GeoPoint(latitude, longitude)) def _array_value(values=[]): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(values) def _object_value(keysAndValues): + from google.cloud.firestore_v1._helpers import encode_value + return encode_value(keysAndValues) diff --git a/tests/unit/v1/test_query.py b/tests/unit/v1/test_query.py index ef99338eca1a3..17b82d3edea41 100644 --- a/tests/unit/v1/test_query.py +++ b/tests/unit/v1/test_query.py @@ -13,762 +13,778 @@ # limitations under the License. import types -import unittest import mock import pytest -from google.api_core import gapic_v1 -from google.cloud.firestore_v1.types.document import Document -from google.cloud.firestore_v1.types.firestore import RunQueryResponse from tests.unit.v1.test_base_query import _make_credentials from tests.unit.v1.test_base_query import _make_cursor_pb from tests.unit.v1.test_base_query import _make_query_response -class TestQuery(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.query import Query +def _make_query(*args, **kwargs): + from google.cloud.firestore_v1.query import Query - return Query + return Query(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor(self): - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIsNone(query._projection) - self.assertEqual(query._field_filters, ()) - self.assertEqual(query._orders, ()) - self.assertIsNone(query._limit) - self.assertIsNone(query._offset) - self.assertIsNone(query._start_at) - self.assertIsNone(query._end_at) - self.assertFalse(query._all_descendants) +def test_query_constructor(): + query = _make_query(mock.sentinel.parent) + assert query._parent is mock.sentinel.parent + assert query._projection is None + assert query._field_filters == () + assert query._orders == () + assert query._limit is None + assert query._offset is None + assert query._start_at is None + assert query._end_at is None + assert not query._all_descendants - def _get_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) +def _query_get_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) - # Make a **real** collection reference as parent. - parent = client.collection("dee") + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} + # Make a **real** collection reference as parent. + parent = client.collection("dee") - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} - # Execute the query and check the response. - query = self._make_one(parent) - returned = query.get(**kwargs) + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - self.assertIsInstance(returned, list) - self.assertEqual(len(returned), 1) + # Execute the query and check the response. + query = _make_query(parent) + returned = query.get(**kwargs) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) + assert isinstance(returned, list) + assert len(returned) == 1 - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( + snapshot = returned[0] + assert snapshot.reference._path, "dee" == "sleep" + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_query_get(): + _query_get_helper() + + +def test_query_get_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _query_get_helper(retry=retry, timeout=timeout) + + +def test_query_get_limit_to_last(): + from google.cloud import firestore + from google.cloud.firestore_v1.base_query import _enum_from_direction + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + data2 = {"snooze": 20} + + response_pb = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response(name=name, data=data2) + + firestore_api.run_query.return_value = iter([response_pb2, response_pb]) + + # Execute the query and check the response. + query = _make_query(parent) + query = query.order_by( + "snooze", direction=firestore.Query.DESCENDING + ).limit_to_last(2) + returned = query.get() + + assert isinstance(returned, list) + assert query._orders[0].direction == _enum_from_direction(firestore.Query.ASCENDING) + assert len(returned) == 2 + + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + snapshot2 = returned[1] + assert snapshot2.reference._path == ("dee", "sleep") + assert snapshot2.to_dict() == data2 + parent_path, _ = parent._parent_info() + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_chunkify_w_empty(): + client = _make_client() + firestore_api = mock.Mock(spec=["run_query"]) + firestore_api.run_query.return_value = iter([]) + client._firestore_api_internal = firestore_api + query = client.collection("asdf")._query() + + chunks = list(query._chunkify(10)) + + assert chunks == [[]] + + +def test_query_chunkify_w_chunksize_lt_limit(): + from google.cloud.firestore_v1.types.document import Document + from google.cloud.firestore_v1.types.firestore import RunQueryResponse + + client = _make_client() + firestore_api = mock.Mock(spec=["run_query"]) + doc_ids = [ + f"projects/project-project/databases/(default)/documents/asdf/{index}" + for index in range(5) + ] + responses1 = [ + RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[:2] + ] + responses2 = [ + RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[2:4] + ] + responses3 = [ + RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[4:] + ] + firestore_api.run_query.side_effect = [ + iter(responses1), + iter(responses2), + iter(responses3), + ] + client._firestore_api_internal = firestore_api + query = client.collection("asdf")._query() + + chunks = list(query._chunkify(2)) + + assert len(chunks) == 3 + expected_ids = [str(index) for index in range(5)] + assert [snapshot.id for snapshot in chunks[0]] == expected_ids[:2] + assert [snapshot.id for snapshot in chunks[1]] == expected_ids[2:4] + assert [snapshot.id for snapshot in chunks[2]] == expected_ids[4:] + + +def test_query_chunkify_w_chunksize_gt_limit(): + from google.cloud.firestore_v1.types.document import Document + from google.cloud.firestore_v1.types.firestore import RunQueryResponse + + client = _make_client() + firestore_api = mock.Mock(spec=["run_query"]) + doc_ids = [ + f"projects/project-project/databases/(default)/documents/asdf/{index}" + for index in range(5) + ] + responses = [ + RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids + ] + firestore_api.run_query.return_value = iter(responses) + client._firestore_api_internal = firestore_api + query = client.collection("asdf")._query() + + chunks = list(query.limit(5)._chunkify(10)) + + assert len(chunks) == 1 + chunk_ids = [snapshot.id for snapshot in chunks[0]] + expected_ids = [str(index) for index in range(5)] + assert chunk_ids == expected_ids + + +def _query_stream_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = _make_query(parent) + + get_response = query.stream(**kwargs) + + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + +def test_query_stream_simple(): + _query_stream_helper() + + +def test_query_stream_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _query_stream_helper(retry=retry, timeout=timeout) + + +def test_query_stream_with_limit_to_last(): + # Attach the fake GAPIC to a real client. + client = _make_client() + # Make a **real** collection reference as parent. + parent = client.collection("dee") + # Execute the query and check the response. + query = _make_query(parent) + query = query.limit_to_last(2) + + stream_response = query.stream() + + with pytest.raises(ValueError): + list(stream_response) + + +def test_query_stream_with_transaction(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Create a real-ish transaction for this client. + transaction = client.transaction() + txn_id = b"\x00\x00\x01-work-\xf2" + transaction._id = txn_id + + # Make a **real** collection reference as parent. + parent = client.collection("declaration") + + # Add a dummy response to the minimal fake GAPIC. + parent_path, expected_prefix = parent._parent_info() + name = "{}/burger".format(expected_prefix) + data = {"lettuce": b"\xee\x87"} + response_pb = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb]) + + # Execute the query and check the response. + query = _make_query(parent) + get_response = query.stream(transaction=transaction) + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("declaration", "burger") + assert snapshot.to_dict() == data + + # Verify the mock call. + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_stream_no_results(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response = _make_query_response() + run_query_response = iter([empty_response]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = _make_query(parent) + + get_response = query.stream() + assert isinstance(get_response, types.GeneratorType) + assert list(get_response) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_stream_second_response_in_empty_stream(): + # Create a minimal fake GAPIC with a dummy response. + firestore_api = mock.Mock(spec=["run_query"]) + empty_response1 = _make_query_response() + empty_response2 = _make_query_response() + run_query_response = iter([empty_response1, empty_response2]) + firestore_api.run_query.return_value = run_query_response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dah", "dah", "dum") + query = _make_query(parent) + + get_response = query.stream() + assert isinstance(get_response, types.GeneratorType) + assert list(get_response) == [] + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_stream_with_skipped_results(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("talk", "and", "chew-gum") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + response_pb1 = _make_query_response(skipped_results=1) + name = "{}/clock".format(expected_prefix) + data = {"noon": 12, "nested": {"bird": 10.5}} + response_pb2 = _make_query_response(name=name, data=data) + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_query(parent) + get_response = query.stream() + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("talk", "and", "chew-gum", "clock") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_stream_empty_after_first_response(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + + # Add two dummy responses to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/bark".format(expected_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_query(parent) + get_response = query.stream() + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("charles", "bark") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +def test_query_stream_w_collection_group(): + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query"]) + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("charles") + other = client.collection("dora") + + # Add two dummy responses to the minimal fake GAPIC. + _, other_prefix = other._parent_info() + name = "{}/bark".format(other_prefix) + data = {"lee": "hoop"} + response_pb1 = _make_query_response(name=name, data=data) + response_pb2 = _make_query_response() + firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) + + # Execute the query and check the response. + query = _make_query(parent) + query._all_descendants = True + get_response = query.stream() + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 1 + snapshot = returned[0] + to_match = other.document("bark") + assert snapshot.reference._document_path == to_match._document_path + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": None, + }, + metadata=client._rpc_metadata, + ) + + +# Marker: avoids needing to import 'gapic_v1' at module scope. +_not_passed = object() + + +def _query_stream_w_retriable_exc_helper( + retry=_not_passed, timeout=None, transaction=None, expect_retry=True, +): + from google.api_core import exceptions + from google.api_core import gapic_v1 + from google.cloud.firestore_v1 import _helpers + + if retry is _not_passed: + retry = gapic_v1.method.DEFAULT + + if transaction is not None: + expect_retry = False + + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["run_query", "_transport"]) + transport = firestore_api._transport = mock.Mock(spec=["run_query"]) + stub = transport.run_query = mock.create_autospec(gapic_v1.method._GapicCallable) + stub._retry = mock.Mock(spec=["_predicate"]) + stub._predicate = lambda exc: True # pragma: NO COVER + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + + # Add a dummy response to the minimal fake GAPIC. + _, expected_prefix = parent._parent_info() + name = "{}/sleep".format(expected_prefix) + data = {"snooze": 10} + response_pb = _make_query_response(name=name, data=data) + retriable_exc = exceptions.ServiceUnavailable("testing") + + def _stream_w_exception(*_args, **_kw): + yield response_pb + raise retriable_exc + + firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])] + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + # Execute the query and check the response. + query = _make_query(parent) + + get_response = query.stream(transaction=transaction, **kwargs) + + assert isinstance(get_response, types.GeneratorType) + if expect_retry: + returned = list(get_response) + else: + returned = [next(get_response)] + with pytest.raises(exceptions.ServiceUnavailable): + next(get_response) + + assert len(returned) == 1 + snapshot = returned[0] + assert snapshot.reference._path == ("dee", "sleep") + assert snapshot.to_dict() == data + + # Verify the mock call. + parent_path, _ = parent._parent_info() + calls = firestore_api.run_query.call_args_list + + if expect_retry: + assert len(calls) == 2 + else: + assert len(calls) == 1 + + if transaction is not None: + expected_transaction_id = transaction.id + else: + expected_transaction_id = None + + assert calls[0] == mock.call( + request={ + "parent": parent_path, + "structured_query": query._to_protobuf(), + "transaction": expected_transaction_id, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + if expect_retry: + new_query = query.start_after(snapshot) + assert calls[1] == mock.call( request={ "parent": parent_path, - "structured_query": query._to_protobuf(), + "structured_query": new_query._to_protobuf(), "transaction": None, }, metadata=client._rpc_metadata, **kwargs, ) - def test_get(self): - self._get_helper() - def test_get_w_retry_timeout(self): - from google.api_core.retry import Retry +def test_query_stream_w_retriable_exc_w_defaults(): + _query_stream_w_retriable_exc_helper() - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_helper(retry=retry, timeout=timeout) - def test_get_limit_to_last(self): - from google.cloud import firestore - from google.cloud.firestore_v1.base_query import _enum_from_direction +def test_query_stream_w_retriable_exc_w_retry(): + retry = mock.Mock(spec=["_predicate"]) + retry._predicate = lambda exc: False + _query_stream_w_retriable_exc_helper(retry=retry, expect_retry=False) - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api +def test_query_stream_w_retriable_exc_w_transaction(): + from google.cloud.firestore_v1 import transaction - # Make a **real** collection reference as parent. - parent = client.collection("dee") + txn = transaction.Transaction(client=mock.Mock(spec=[])) + txn._id = b"DEADBEEF" + _query_stream_w_retriable_exc_helper(transaction=txn) - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - data2 = {"snooze": 20} - response_pb = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response(name=name, data=data2) +@mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True) +def test_query_on_snapshot(watch): + query = _make_query(mock.sentinel.parent) + query.on_snapshot(None) + watch.for_query.assert_called_once() - firestore_api.run_query.return_value = iter([response_pb2, response_pb]) - # Execute the query and check the response. - query = self._make_one(parent) - query = query.order_by( - "snooze", direction=firestore.Query.DESCENDING - ).limit_to_last(2) - returned = query.get() +def _make_collection_group(*args, **kwargs): + from google.cloud.firestore_v1.query import CollectionGroup - self.assertIsInstance(returned, list) - self.assertEqual( - query._orders[0].direction, _enum_from_direction(firestore.Query.ASCENDING) - ) - self.assertEqual(len(returned), 2) + return CollectionGroup(*args, **kwargs) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - snapshot2 = returned[1] - self.assertEqual(snapshot2.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot2.to_dict(), data2) +def test_collection_group_constructor(): + query = _make_collection_group(mock.sentinel.parent) + assert query._parent is mock.sentinel.parent + assert query._projection is None + assert query._field_filters == () + assert query._orders == () + assert query._limit is None + assert query._offset is None + assert query._start_at is None + assert query._end_at is None + assert query._all_descendants - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - def test_chunkify_w_empty(self): - client = _make_client() - firestore_api = mock.Mock(spec=["run_query"]) - firestore_api.run_query.return_value = iter([]) - client._firestore_api_internal = firestore_api - query = client.collection("asdf")._query() - - chunks = list(query._chunkify(10)) - - assert chunks == [[]] - - def test_chunkify_w_chunksize_lt_limit(self): - client = _make_client() - firestore_api = mock.Mock(spec=["run_query"]) - doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" - for index in range(5) - ] - responses1 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[:2] - ] - responses2 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[2:4] - ] - responses3 = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids[4:] - ] - firestore_api.run_query.side_effect = [ - iter(responses1), - iter(responses2), - iter(responses3), - ] - client._firestore_api_internal = firestore_api - query = client.collection("asdf")._query() - - chunks = list(query._chunkify(2)) - - self.assertEqual(len(chunks), 3) - expected_ids = [str(index) for index in range(5)] - self.assertEqual([snapshot.id for snapshot in chunks[0]], expected_ids[:2]) - self.assertEqual([snapshot.id for snapshot in chunks[1]], expected_ids[2:4]) - self.assertEqual([snapshot.id for snapshot in chunks[2]], expected_ids[4:]) - - def test_chunkify_w_chunksize_gt_limit(self): - client = _make_client() - firestore_api = mock.Mock(spec=["run_query"]) - doc_ids = [ - f"projects/project-project/databases/(default)/documents/asdf/{index}" - for index in range(5) - ] - responses = [ - RunQueryResponse(document=Document(name=doc_id),) for doc_id in doc_ids - ] - firestore_api.run_query.return_value = iter(responses) - client._firestore_api_internal = firestore_api - query = client.collection("asdf")._query() - - chunks = list(query.limit(5)._chunkify(10)) - - self.assertEqual(len(chunks), 1) - self.assertEqual( - [snapshot.id for snapshot in chunks[0]], [str(index) for index in range(5)] - ) +def test_collection_group_constructor_all_descendents_is_false(): + with pytest.raises(ValueError): + _make_collection_group(mock.sentinel.parent, all_descendants=False) - def _stream_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) +def _collection_group_get_partitions_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api + # Create a minimal fake GAPIC. + firestore_api = mock.Mock(spec=["partition_query"]) - # Make a **real** collection reference as parent. - parent = client.collection("dee") + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + # Make a **real** collection reference as parent. + parent = client.collection("charles") - # Execute the query and check the response. - query = self._make_one(parent) + # Make two **real** document references to use as cursors + document1 = parent.document("one") + document2 = parent.document("two") - get_response = query.stream(**kwargs) + # Add cursor pb's to the minimal fake GAPIC. + cursor_pb1 = _make_cursor_pb(([document1], False)) + cursor_pb2 = _make_cursor_pb(([document2], False)) + firestore_api.partition_query.return_value = iter([cursor_pb1, cursor_pb2]) + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + # Execute the query and check the response. + query = _make_collection_group(parent) - def test_stream_simple(self): - self._stream_helper() - - def test_stream_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._stream_helper(retry=retry, timeout=timeout) - - def test_stream_with_limit_to_last(self): - # Attach the fake GAPIC to a real client. - client = _make_client() - # Make a **real** collection reference as parent. - parent = client.collection("dee") - # Execute the query and check the response. - query = self._make_one(parent) - query = query.limit_to_last(2) - - stream_response = query.stream() - - with self.assertRaises(ValueError): - list(stream_response) - - def test_stream_with_transaction(self): - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Create a real-ish transaction for this client. - transaction = client.transaction() - txn_id = b"\x00\x00\x01-work-\xf2" - transaction._id = txn_id - - # Make a **real** collection reference as parent. - parent = client.collection("declaration") - - # Add a dummy response to the minimal fake GAPIC. - parent_path, expected_prefix = parent._parent_info() - name = "{}/burger".format(expected_prefix) - data = {"lettuce": b"\xee\x87"} - response_pb = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream(transaction=transaction) - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("declaration", "burger")) - self.assertEqual(snapshot.to_dict(), data) + get_response = query.get_partitions(2, **kwargs) - # Verify the mock call. - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) + assert isinstance(get_response, types.GeneratorType) + returned = list(get_response) + assert len(returned) == 3 - def test_stream_no_results(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["run_query"]) - empty_response = _make_query_response() - run_query_response = iter([empty_response]) - firestore_api.run_query.return_value = run_query_response + # Verify the mock call. + parent_path, _ = parent._parent_info() + partition_query = _make_collection_group( + parent, orders=(query._make_order("__name__", query.ASCENDING),), + ) + firestore_api.partition_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": partition_query._to_protobuf(), + "partition_count": 2, + }, + metadata=client._rpc_metadata, + **kwargs, + ) - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - # Make a **real** collection reference as parent. - parent = client.collection("dah", "dah", "dum") - query = self._make_one(parent) +def test_collection_group_get_partitions(): + _collection_group_get_partitions_helper() - get_response = query.stream() - self.assertIsInstance(get_response, types.GeneratorType) - self.assertEqual(list(get_response), []) - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) +def test_collection_group_get_partitions_w_retry_timeout(): + from google.api_core.retry import Retry - def test_stream_second_response_in_empty_stream(self): - # Create a minimal fake GAPIC with a dummy response. - firestore_api = mock.Mock(spec=["run_query"]) - empty_response1 = _make_query_response() - empty_response2 = _make_query_response() - run_query_response = iter([empty_response1, empty_response2]) - firestore_api.run_query.return_value = run_query_response - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dah", "dah", "dum") - query = self._make_one(parent) - - get_response = query.stream() - self.assertIsInstance(get_response, types.GeneratorType) - self.assertEqual(list(get_response), []) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + retry = Retry(predicate=object()) + timeout = 123.0 + _collection_group_get_partitions_helper(retry=retry, timeout=timeout) - def test_stream_with_skipped_results(self): - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("talk", "and", "chew-gum") - - # Add two dummy responses to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - response_pb1 = _make_query_response(skipped_results=1) - name = "{}/clock".format(expected_prefix) - data = {"noon": 12, "nested": {"bird": 10.5}} - response_pb2 = _make_query_response(name=name, data=data) - firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream() - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("talk", "and", "chew-gum", "clock")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) - def test_stream_empty_after_first_response(self): - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - - # Add two dummy responses to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/bark".format(expected_prefix) - data = {"lee": "hoop"} - response_pb1 = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response() - firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - get_response = query.stream() - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("charles", "bark")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) +def test_collection_group_get_partitions_w_filter(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") - def test_stream_w_collection_group(self): - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - other = client.collection("dora") - - # Add two dummy responses to the minimal fake GAPIC. - _, other_prefix = other._parent_info() - name = "{}/bark".format(other_prefix) - data = {"lee": "hoop"} - response_pb1 = _make_query_response(name=name, data=data) - response_pb2 = _make_query_response() - firestore_api.run_query.return_value = iter([response_pb1, response_pb2]) - - # Execute the query and check the response. - query = self._make_one(parent) - query._all_descendants = True - get_response = query.stream() - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 1) - snapshot = returned[0] - to_match = other.document("bark") - self.assertEqual(snapshot.reference._document_path, to_match._document_path) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - firestore_api.run_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - ) + # Make a query that fails to partition + query = _make_collection_group(parent).where("foo", "==", "bar") + with pytest.raises(ValueError): + list(query.get_partitions(2)) - def _stream_w_retriable_exc_helper( - self, - retry=gapic_v1.method.DEFAULT, - timeout=None, - transaction=None, - expect_retry=True, - ): - from google.api_core import exceptions - from google.cloud.firestore_v1 import _helpers - - if transaction is not None: - expect_retry = False - - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["run_query", "_transport"]) - transport = firestore_api._transport = mock.Mock(spec=["run_query"]) - stub = transport.run_query = mock.create_autospec( - gapic_v1.method._GapicCallable - ) - stub._retry = mock.Mock(spec=["_predicate"]) - stub._predicate = lambda exc: True # pragma: NO COVER - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("dee") - - # Add a dummy response to the minimal fake GAPIC. - _, expected_prefix = parent._parent_info() - name = "{}/sleep".format(expected_prefix) - data = {"snooze": 10} - response_pb = _make_query_response(name=name, data=data) - retriable_exc = exceptions.ServiceUnavailable("testing") - - def _stream_w_exception(*_args, **_kw): - yield response_pb - raise retriable_exc - - firestore_api.run_query.side_effect = [_stream_w_exception(), iter([])] - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Execute the query and check the response. - query = self._make_one(parent) - - get_response = query.stream(transaction=transaction, **kwargs) - - self.assertIsInstance(get_response, types.GeneratorType) - if expect_retry: - returned = list(get_response) - else: - returned = [next(get_response)] - with self.assertRaises(exceptions.ServiceUnavailable): - next(get_response) - - self.assertEqual(len(returned), 1) - snapshot = returned[0] - self.assertEqual(snapshot.reference._path, ("dee", "sleep")) - self.assertEqual(snapshot.to_dict(), data) - - # Verify the mock call. - parent_path, _ = parent._parent_info() - calls = firestore_api.run_query.call_args_list - - if expect_retry: - self.assertEqual(len(calls), 2) - else: - self.assertEqual(len(calls), 1) - - if transaction is not None: - expected_transaction_id = transaction.id - else: - expected_transaction_id = None - - self.assertEqual( - calls[0], - mock.call( - request={ - "parent": parent_path, - "structured_query": query._to_protobuf(), - "transaction": expected_transaction_id, - }, - metadata=client._rpc_metadata, - **kwargs, - ), - ) - if expect_retry: - new_query = query.start_after(snapshot) - self.assertEqual( - calls[1], - mock.call( - request={ - "parent": parent_path, - "structured_query": new_query._to_protobuf(), - "transaction": None, - }, - metadata=client._rpc_metadata, - **kwargs, - ), - ) - - def test_stream_w_retriable_exc_w_defaults(self): - self._stream_w_retriable_exc_helper() - - def test_stream_w_retriable_exc_w_retry(self): - retry = mock.Mock(spec=["_predicate"]) - retry._predicate = lambda exc: False - self._stream_w_retriable_exc_helper(retry=retry, expect_retry=False) - - def test_stream_w_retriable_exc_w_transaction(self): - from google.cloud.firestore_v1 import transaction - - txn = transaction.Transaction(client=mock.Mock(spec=[])) - txn._id = b"DEADBEEF" - self._stream_w_retriable_exc_helper(transaction=txn) - - @mock.patch("google.cloud.firestore_v1.query.Watch", autospec=True) - def test_on_snapshot(self, watch): - query = self._make_one(mock.sentinel.parent) - query.on_snapshot(None) - watch.for_query.assert_called_once() - - -class TestCollectionGroup(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.query import CollectionGroup - - return CollectionGroup - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - query = self._make_one(mock.sentinel.parent) - self.assertIs(query._parent, mock.sentinel.parent) - self.assertIsNone(query._projection) - self.assertEqual(query._field_filters, ()) - self.assertEqual(query._orders, ()) - self.assertIsNone(query._limit) - self.assertIsNone(query._offset) - self.assertIsNone(query._start_at) - self.assertIsNone(query._end_at) - self.assertTrue(query._all_descendants) - - def test_constructor_all_descendents_is_false(self): - with pytest.raises(ValueError): - self._make_one(mock.sentinel.parent, all_descendants=False) - - def _get_partitions_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - # Create a minimal fake GAPIC. - firestore_api = mock.Mock(spec=["partition_query"]) - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Make a **real** collection reference as parent. - parent = client.collection("charles") - - # Make two **real** document references to use as cursors - document1 = parent.document("one") - document2 = parent.document("two") - - # Add cursor pb's to the minimal fake GAPIC. - cursor_pb1 = _make_cursor_pb(([document1], False)) - cursor_pb2 = _make_cursor_pb(([document2], False)) - firestore_api.partition_query.return_value = iter([cursor_pb1, cursor_pb2]) - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - # Execute the query and check the response. - query = self._make_one(parent) - - get_response = query.get_partitions(2, **kwargs) - - self.assertIsInstance(get_response, types.GeneratorType) - returned = list(get_response) - self.assertEqual(len(returned), 3) +def test_collection_group_get_partitions_w_projection(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") - # Verify the mock call. - parent_path, _ = parent._parent_info() - partition_query = self._make_one( - parent, orders=(query._make_order("__name__", query.ASCENDING),), - ) - firestore_api.partition_query.assert_called_once_with( - request={ - "parent": parent_path, - "structured_query": partition_query._to_protobuf(), - "partition_count": 2, - }, - metadata=client._rpc_metadata, - **kwargs, - ) + # Make a query that fails to partition + query = _make_collection_group(parent).select("foo") + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + +def test_collection_group_get_partitions_w_limit(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") + + # Make a query that fails to partition + query = _make_collection_group(parent).limit(10) + with pytest.raises(ValueError): + list(query.get_partitions(2)) + + +def test_collection_group_get_partitions_w_offset(): + # Make a **real** collection reference as parent. + client = _make_client() + parent = client.collection("charles") - def test_get_partitions(self): - self._get_partitions_helper() - - def test_get_partitions_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_partitions_helper(retry=retry, timeout=timeout) - - def test_get_partitions_w_filter(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).where("foo", "==", "bar") - with pytest.raises(ValueError): - list(query.get_partitions(2)) - - def test_get_partitions_w_projection(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).select("foo") - with pytest.raises(ValueError): - list(query.get_partitions(2)) - - def test_get_partitions_w_limit(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).limit(10) - with pytest.raises(ValueError): - list(query.get_partitions(2)) - - def test_get_partitions_w_offset(self): - # Make a **real** collection reference as parent. - client = _make_client() - parent = client.collection("charles") - - # Make a query that fails to partition - query = self._make_one(parent).offset(10) - with pytest.raises(ValueError): - list(query.get_partitions(2)) + # Make a query that fails to partition + query = _make_collection_group(parent).offset(10) + with pytest.raises(ValueError): + list(query.get_partitions(2)) def _make_client(project="project-project"): diff --git a/tests/unit/v1/test_rate_limiter.py b/tests/unit/v1/test_rate_limiter.py index ea41905e49f97..e5068b3590300 100644 --- a/tests/unit/v1/test_rate_limiter.py +++ b/tests/unit/v1/test_rate_limiter.py @@ -13,12 +13,8 @@ # limitations under the License. import datetime -import unittest -from typing import Optional import mock -import google -from google.cloud.firestore_v1 import rate_limiter # Pick a point in time as the center of our universe for this test run. @@ -26,175 +22,185 @@ fake_now = datetime.datetime.utcnow() -def now_plus_n( - seconds: Optional[int] = 0, microseconds: Optional[int] = 0, -) -> datetime.timedelta: +def now_plus_n(seconds: int = 0, microseconds: int = 0) -> datetime.timedelta: return fake_now + datetime.timedelta(seconds=seconds, microseconds=microseconds,) -class TestRateLimiter(unittest.TestCase): - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_rate_limiter_basic(self, mocked_now): - """Verifies that if the clock does not advance, the RateLimiter allows 500 - writes before crashing out. - """ - mocked_now.return_value = fake_now - # This RateLimiter will never advance. Poor fella. - ramp = rate_limiter.RateLimiter() - for _ in range(rate_limiter.default_initial_tokens): - self.assertEqual(ramp.take_tokens(), 1) - self.assertEqual(ramp.take_tokens(), 0) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_rate_limiter_with_refill(self, mocked_now): - """Verifies that if the clock advances, the RateLimiter allows appropriate - additional writes. - """ - mocked_now.return_value = fake_now - ramp = rate_limiter.RateLimiter() - ramp._available_tokens = 0 - self.assertEqual(ramp.take_tokens(), 0) - # Advance the clock 0.1 seconds - mocked_now.return_value = now_plus_n(microseconds=100000) - for _ in range(round(rate_limiter.default_initial_tokens / 10)): - self.assertEqual(ramp.take_tokens(), 1) - self.assertEqual(ramp.take_tokens(), 0) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_rate_limiter_phase_length(self, mocked_now): - """Verifies that if the clock advances, the RateLimiter allows appropriate - additional writes. - """ - mocked_now.return_value = fake_now - ramp = rate_limiter.RateLimiter() - self.assertEqual(ramp.take_tokens(), 1) - ramp._available_tokens = 0 - self.assertEqual(ramp.take_tokens(), 0) - # Advance the clock 1 phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length, microseconds=1, - ) - for _ in range(round(rate_limiter.default_initial_tokens * 3 / 2)): - self.assertTrue( - ramp.take_tokens(), msg=f"token {_} should have been allowed" - ) - self.assertEqual(ramp.take_tokens(), 0) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_rate_limiter_idle_phase_length(self, mocked_now): - """Verifies that if the clock advances but nothing happens, the RateLimiter - doesn't ramp up. - """ - mocked_now.return_value = fake_now - ramp = rate_limiter.RateLimiter() - ramp._available_tokens = 0 - self.assertEqual(ramp.take_tokens(), 0) - # Advance the clock 1 phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length, microseconds=1, - ) - for _ in range(round(rate_limiter.default_initial_tokens)): - self.assertEqual( - ramp.take_tokens(), 1, msg=f"token {_} should have been allowed" - ) - self.assertEqual(ramp._maximum_tokens, 500) - self.assertEqual(ramp.take_tokens(), 0) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_take_batch_size(self, mocked_now): - """Verifies that if the clock advances but nothing happens, the RateLimiter - doesn't ramp up. - """ - page_size: int = 20 - mocked_now.return_value = fake_now - ramp = rate_limiter.RateLimiter() - ramp._available_tokens = 15 - self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 15) - # Advance the clock 1 phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length, microseconds=1, - ) - ramp._check_phase() - self.assertEqual(ramp._maximum_tokens, 750) - - for _ in range(740 // page_size): - self.assertEqual( - ramp.take_tokens(page_size), - page_size, - msg=f"page {_} should have been allowed", - ) - self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 10) - self.assertEqual(ramp.take_tokens(page_size, allow_less=True), 0) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_phase_progress(self, mocked_now): - mocked_now.return_value = fake_now - - ramp = rate_limiter.RateLimiter() - self.assertEqual(ramp._phase, 0) - self.assertEqual(ramp._maximum_tokens, 500) - ramp.take_tokens() - - # Advance the clock 1 phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length, microseconds=1, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 1) - self.assertEqual(ramp._maximum_tokens, 750) - - # Advance the clock another phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length * 2, microseconds=1, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 2) - self.assertEqual(ramp._maximum_tokens, 1125) - - # Advance the clock another ms and the phase should not advance - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length * 2, microseconds=2, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 2) - self.assertEqual(ramp._maximum_tokens, 1125) - - @mock.patch.object(google.cloud.firestore_v1.rate_limiter, "utcnow") - def test_global_max_tokens(self, mocked_now): - mocked_now.return_value = fake_now - - ramp = rate_limiter.RateLimiter(global_max_tokens=499,) - self.assertEqual(ramp._phase, 0) - self.assertEqual(ramp._maximum_tokens, 499) - ramp.take_tokens() - - # Advance the clock 1 phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length, microseconds=1, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 1) - self.assertEqual(ramp._maximum_tokens, 499) - - # Advance the clock another phase - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length * 2, microseconds=1, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 2) - self.assertEqual(ramp._maximum_tokens, 499) - - # Advance the clock another ms and the phase should not advance - mocked_now.return_value = now_plus_n( - seconds=rate_limiter.default_phase_length * 2, microseconds=2, - ) - ramp.take_tokens() - self.assertEqual(ramp._phase, 2) - self.assertEqual(ramp._maximum_tokens, 499) - - def test_utcnow(self): - self.assertTrue( - isinstance( - google.cloud.firestore_v1.rate_limiter.utcnow(), datetime.datetime, - ) - ) +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_rate_limiter_basic(mocked_now): + """Verifies that if the clock does not advance, the RateLimiter allows 500 + writes before crashing out. + """ + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + # This RateLimiter will never advance. Poor fella. + ramp = rate_limiter.RateLimiter() + for _ in range(rate_limiter.default_initial_tokens): + assert ramp.take_tokens() == 1 + assert ramp.take_tokens() == 0 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_rate_limiter_with_refill(mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + assert ramp.take_tokens() == 0 + # Advance the clock 0.1 seconds + mocked_now.return_value = now_plus_n(microseconds=100000) + for _ in range(round(rate_limiter.default_initial_tokens / 10)): + assert ramp.take_tokens() == 1 + assert ramp.take_tokens() == 0 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_rate_limiter_phase_length(mocked_now): + """Verifies that if the clock advances, the RateLimiter allows appropriate + additional writes. + """ + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + assert ramp.take_tokens() == 1 + ramp._available_tokens = 0 + assert ramp.take_tokens() == 0 + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens * 3 / 2)): + assert ramp.take_tokens() + + assert ramp.take_tokens() == 0 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_rate_limiter_idle_phase_length(mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 0 + assert ramp.take_tokens() == 0 + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + for _ in range(round(rate_limiter.default_initial_tokens)): + assert ramp.take_tokens() == 1 + assert ramp._maximum_tokens == 500 + assert ramp.take_tokens() == 0 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_take_batch_size(mocked_now): + """Verifies that if the clock advances but nothing happens, the RateLimiter + doesn't ramp up. + """ + from google.cloud.firestore_v1 import rate_limiter + + page_size: int = 20 + mocked_now.return_value = fake_now + ramp = rate_limiter.RateLimiter() + ramp._available_tokens = 15 + assert ramp.take_tokens(page_size, allow_less=True) == 15 + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp._check_phase() + assert ramp._maximum_tokens == 750 + + for _ in range(740 // page_size): + assert ramp.take_tokens(page_size) == page_size + assert ramp.take_tokens(page_size, allow_less=True) == 10 + assert ramp.take_tokens(page_size, allow_less=True) == 0 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_phase_progress(mocked_now): + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter() + assert ramp._phase == 0 + assert ramp._maximum_tokens == 500 + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + assert ramp._phase == 1 + assert ramp._maximum_tokens == 750 + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + assert ramp._phase == 2 + assert ramp._maximum_tokens == 1125 + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + assert ramp._phase == 2 + assert ramp._maximum_tokens == 1125 + + +@mock.patch("google.cloud.firestore_v1.rate_limiter.utcnow") +def test_global_max_tokens(mocked_now): + from google.cloud.firestore_v1 import rate_limiter + + mocked_now.return_value = fake_now + + ramp = rate_limiter.RateLimiter(global_max_tokens=499,) + assert ramp._phase == 0 + assert ramp._maximum_tokens == 499 + ramp.take_tokens() + + # Advance the clock 1 phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length, microseconds=1, + ) + ramp.take_tokens() + assert ramp._phase == 1 + assert ramp._maximum_tokens == 499 + + # Advance the clock another phase + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=1, + ) + ramp.take_tokens() + assert ramp._phase == 2 + assert ramp._maximum_tokens == 499 + + # Advance the clock another ms and the phase should not advance + mocked_now.return_value = now_plus_n( + seconds=rate_limiter.default_phase_length * 2, microseconds=2, + ) + ramp.take_tokens() + assert ramp._phase == 2 + assert ramp._maximum_tokens == 499 + + +def test_utcnow(): + from google.cloud.firestore_v1 import rate_limiter + + now = rate_limiter.utcnow() + assert isinstance(now, datetime.datetime) diff --git a/tests/unit/v1/test_transaction.py b/tests/unit/v1/test_transaction.py index 3a093a335d4d5..baad17c9e38f6 100644 --- a/tests/unit/v1/test_transaction.py +++ b/tests/unit/v1/test_transaction.py @@ -12,1010 +12,994 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest import mock +import pytest -class TestTransaction(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.transaction import Transaction +def _make_transaction(*args, **kwargs): + from google.cloud.firestore_v1.transaction import Transaction - return Transaction + return Transaction(*args, **kwargs) - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - def test_constructor_defaults(self): - from google.cloud.firestore_v1.transaction import MAX_ATTEMPTS +def test_transaction_constructor_defaults(): + from google.cloud.firestore_v1.transaction import MAX_ATTEMPTS - transaction = self._make_one(mock.sentinel.client) - self.assertIs(transaction._client, mock.sentinel.client) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, MAX_ATTEMPTS) - self.assertFalse(transaction._read_only) - self.assertIsNone(transaction._id) + transaction = _make_transaction(mock.sentinel.client) + assert transaction._client is mock.sentinel.client + assert transaction._write_pbs == [] + assert transaction._max_attempts == MAX_ATTEMPTS + assert not transaction._read_only + assert transaction._id is None - def test_constructor_explicit(self): - transaction = self._make_one( - mock.sentinel.client, max_attempts=10, read_only=True - ) - self.assertIs(transaction._client, mock.sentinel.client) - self.assertEqual(transaction._write_pbs, []) - self.assertEqual(transaction._max_attempts, 10) - self.assertTrue(transaction._read_only) - self.assertIsNone(transaction._id) - def test__add_write_pbs_failure(self): - from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY +def test_transaction_constructor_explicit(): + transaction = _make_transaction( + mock.sentinel.client, max_attempts=10, read_only=True + ) + assert transaction._client is mock.sentinel.client + assert transaction._write_pbs == [] + assert transaction._max_attempts == 10 + assert transaction._read_only + assert transaction._id is None - batch = self._make_one(mock.sentinel.client, read_only=True) - self.assertEqual(batch._write_pbs, []) - with self.assertRaises(ValueError) as exc_info: - batch._add_write_pbs([mock.sentinel.write]) - self.assertEqual(exc_info.exception.args, (_WRITE_READ_ONLY,)) - self.assertEqual(batch._write_pbs, []) +def test_transaction__add_write_pbs_failure(): + from google.cloud.firestore_v1.base_transaction import _WRITE_READ_ONLY - def test__add_write_pbs(self): - batch = self._make_one(mock.sentinel.client) - self.assertEqual(batch._write_pbs, []) + batch = _make_transaction(mock.sentinel.client, read_only=True) + assert batch._write_pbs == [] + with pytest.raises(ValueError) as exc_info: batch._add_write_pbs([mock.sentinel.write]) - self.assertEqual(batch._write_pbs, [mock.sentinel.write]) - - def test__clean_up(self): - transaction = self._make_one(mock.sentinel.client) - transaction._write_pbs.extend([mock.sentinel.write_pb1, mock.sentinel.write]) - transaction._id = b"not-this-time-my-friend" - - ret_val = transaction._clean_up() - self.assertIsNone(ret_val) - - self.assertEqual(transaction._write_pbs, []) - self.assertIsNone(transaction._id) - - def test__begin(self): - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - from google.cloud.firestore_v1.types import firestore - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - txn_id = b"to-begin" - response = firestore.BeginTransactionResponse(transaction=txn_id) - firestore_api.begin_transaction.return_value = response - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and ``begin()`` it. - transaction = self._make_one(client) - self.assertIsNone(transaction._id) - - ret_val = transaction._begin() - self.assertIsNone(ret_val) - self.assertEqual(transaction._id, txn_id) - - # Verify the called mock. - firestore_api.begin_transaction.assert_called_once_with( - request={"database": client._database_string, "options": None}, - metadata=client._rpc_metadata, - ) - - def test__begin_failure(self): - from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN - - client = _make_client() - transaction = self._make_one(client) - transaction._id = b"not-none" - - with self.assertRaises(ValueError) as exc_info: - transaction._begin() - - err_msg = _CANT_BEGIN.format(transaction._id) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - def test__rollback(self): - from google.protobuf import empty_pb2 - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - firestore_api.rollback.return_value = empty_pb2.Empty() - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and roll it back. - transaction = self._make_one(client) - txn_id = b"to-be-r\x00lled" - transaction._id = txn_id - ret_val = transaction._rollback() - self.assertIsNone(ret_val) - self.assertIsNone(transaction._id) - - # Verify the called mock. - firestore_api.rollback.assert_called_once_with( - request={"database": client._database_string, "transaction": txn_id}, - metadata=client._rpc_metadata, - ) - - def test__rollback_not_allowed(self): - from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK - - client = _make_client() - transaction = self._make_one(client) - self.assertIsNone(transaction._id) - - with self.assertRaises(ValueError) as exc_info: - transaction._rollback() - - self.assertEqual(exc_info.exception.args, (_CANT_ROLLBACK,)) - - def test__rollback_failure(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy failure. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - exc = exceptions.InternalServerError("Fire during rollback.") - firestore_api.rollback.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction and roll it back. - transaction = self._make_one(client) - txn_id = b"roll-bad-server" - transaction._id = txn_id - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - transaction._rollback() - - self.assertIs(exc_info.exception, exc) - self.assertIsNone(transaction._id) - self.assertEqual(transaction._write_pbs, []) - - # Verify the called mock. - firestore_api.rollback.assert_called_once_with( - request={"database": client._database_string, "transaction": txn_id}, - metadata=client._rpc_metadata, - ) - - def test__commit(self): - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) - firestore_api.commit.return_value = commit_response - - # Attach the fake GAPIC to a real client. - client = _make_client("phone-joe") - client._firestore_api_internal = firestore_api - - # Actually make a transaction with some mutations and call _commit(). - transaction = self._make_one(client) - txn_id = b"under-over-thru-woods" - transaction._id = txn_id - document = client.document("zap", "galaxy", "ship", "space") - transaction.set(document, {"apple": 4.5}) - write_pbs = transaction._write_pbs[::] - - write_results = transaction._commit() - self.assertEqual(write_results, list(commit_response.write_results)) - # Make sure transaction has no more "changes". - self.assertIsNone(transaction._id) - self.assertEqual(transaction._write_pbs, []) - - # Verify the mocks. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - def test__commit_not_allowed(self): - from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT - - transaction = self._make_one(mock.sentinel.client) - self.assertIsNone(transaction._id) - with self.assertRaises(ValueError) as exc_info: - transaction._commit() - - self.assertEqual(exc_info.exception.args, (_CANT_COMMIT,)) - - def test__commit_failure(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy failure. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - exc = exceptions.InternalServerError("Fire during commit.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client() - client._firestore_api_internal = firestore_api - - # Actually make a transaction with some mutations and call _commit(). - transaction = self._make_one(client) - txn_id = b"beep-fail-commit" - transaction._id = txn_id - transaction.create(client.document("up", "down"), {"water": 1.0}) - transaction.delete(client.document("up", "left")) - write_pbs = transaction._write_pbs[::] - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - transaction._commit() - - self.assertIs(exc_info.exception, exc) - self.assertEqual(transaction._id, txn_id) - self.assertEqual(transaction._write_pbs, write_pbs) - - # Verify the called mock. - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - def _get_all_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - - client = mock.Mock(spec=["get_all"]) - transaction = self._make_one(client) - ref1, ref2 = mock.Mock(), mock.Mock() - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = transaction.get_all([ref1, ref2], **kwargs) - - client.get_all.assert_called_once_with( - [ref1, ref2], transaction=transaction, **kwargs, - ) - self.assertIs(result, client.get_all.return_value) - - def test_get_all(self): - self._get_all_helper() - - def test_get_all_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_all_helper(retry=retry, timeout=timeout) - - def _get_w_document_ref_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1.document import DocumentReference - from google.cloud.firestore_v1 import _helpers - - client = mock.Mock(spec=["get_all"]) - transaction = self._make_one(client) - ref = DocumentReference("documents", "doc-id") - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = transaction.get(ref, **kwargs) - - self.assertIs(result, client.get_all.return_value) - client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) - - def test_get_w_document_ref(self): - self._get_w_document_ref_helper() - - def test_get_w_document_ref_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_w_document_ref_helper(retry=retry, timeout=timeout) - - def _get_w_query_helper(self, retry=None, timeout=None): - from google.cloud.firestore_v1 import _helpers - from google.cloud.firestore_v1.query import Query - - client = mock.Mock(spec=[]) - transaction = self._make_one(client) - query = Query(parent=mock.Mock(spec=[])) - query.stream = mock.MagicMock() - kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) - - result = transaction.get(query, **kwargs) - - self.assertIs(result, query.stream.return_value) - query.stream.assert_called_once_with(transaction=transaction, **kwargs) - - def test_get_w_query(self): - self._get_w_query_helper() - - def test_get_w_query_w_retry_timeout(self): - from google.api_core.retry import Retry - - retry = Retry(predicate=object()) - timeout = 123.0 - self._get_w_query_helper(retry=retry, timeout=timeout) - - def test_get_failure(self): - client = _make_client() - transaction = self._make_one(client) - ref_or_query = object() - with self.assertRaises(ValueError): - transaction.get(ref_or_query) - -class Test_Transactional(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.transaction import _Transactional - - return _Transactional - - def _make_one(self, *args, **kwargs): - klass = self._get_target_class() - return klass(*args, **kwargs) - - def test_constructor(self): - wrapped = self._make_one(mock.sentinel.callable_) - self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) - self.assertIsNone(wrapped.current_id) - self.assertIsNone(wrapped.retry_id) + assert exc_info.value.args == (_WRITE_READ_ONLY,) + assert batch._write_pbs == [] - def test__pre_commit_success(self): - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - txn_id = b"totes-began" - transaction = _make_transaction(txn_id) - result = wrapped._pre_commit(transaction, "pos", key="word") - self.assertIs(result, mock.sentinel.result) +def test_transaction__add_write_pbs(): + batch = _make_transaction(mock.sentinel.client) + assert batch._write_pbs == [] + batch._add_write_pbs([mock.sentinel.write]) + assert batch._write_pbs == [mock.sentinel.write] - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "pos", key="word") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_not_called() - - def test__pre_commit_retry_id_already_set_success(self): - from google.cloud.firestore_v1.types import common - - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - txn_id1 = b"already-set" - wrapped.retry_id = txn_id1 - - txn_id2 = b"ok-here-too" - transaction = _make_transaction(txn_id2) - result = wrapped._pre_commit(transaction) - self.assertIs(result, mock.sentinel.result) - - self.assertEqual(transaction._id, txn_id2) - self.assertEqual(wrapped.current_id, txn_id2) - self.assertEqual(wrapped.retry_id, txn_id1) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction) - firestore_api = transaction._client._firestore_api - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) - ) - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": options_, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_not_called() - - def test__pre_commit_failure(self): - exc = RuntimeError("Nope not today.") - to_wrap = mock.Mock(side_effect=exc, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"gotta-fail" - transaction = _make_transaction(txn_id) - with self.assertRaises(RuntimeError) as exc_info: - wrapped._pre_commit(transaction, 10, 20) - self.assertIs(exc_info.exception, exc) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, 10, 20) - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - def test__pre_commit_failure_with_rollback_failure(self): - from google.api_core import exceptions - - exc1 = ValueError("I will not be only failure.") - to_wrap = mock.Mock(side_effect=exc1, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"both-will-fail" - transaction = _make_transaction(txn_id) - # Actually force the ``rollback`` to fail as well. - exc2 = exceptions.InternalServerError("Rollback blues.") - firestore_api = transaction._client._firestore_api - firestore_api.rollback.side_effect = exc2 - - # Try to ``_pre_commit`` - with self.assertRaises(exceptions.InternalServerError) as exc_info: - wrapped._pre_commit(transaction, a="b", c="zebra") - self.assertIs(exc_info.exception, exc2) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, a="b", c="zebra") - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_not_called() - - def test__maybe_commit_success(self): - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"nyet" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - succeeded = wrapped._maybe_commit(transaction) - self.assertTrue(succeeded) - - # On success, _id is reset. - self.assertIsNone(transaction._id) - - # Verify mocks. - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - def test__maybe_commit_failure_read_only(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed" - transaction = _make_transaction(txn_id, read_only=True) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail (use ABORTED, but cannot - # retry since read-only). - exc = exceptions.Aborted("Read-only did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - with self.assertRaises(exceptions.Aborted) as exc_info: - wrapped._maybe_commit(transaction) - self.assertIs(exc_info.exception, exc) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - def test__maybe_commit_failure_can_retry(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed-but-retry" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Read-write did a bad.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - succeeded = wrapped._maybe_commit(transaction) - self.assertFalse(succeeded) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - def test__maybe_commit_failure_cannot_retry(self): - from google.api_core import exceptions - - wrapped = self._make_one(mock.sentinel.callable_) - - txn_id = b"failed-but-not-retryable" - transaction = _make_transaction(txn_id) - transaction._id = txn_id # We won't call ``begin()``. - wrapped.current_id = txn_id # We won't call ``_pre_commit()``. - wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. - - # Actually force the ``commit`` to fail. - exc = exceptions.InternalServerError("Real bad thing") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - with self.assertRaises(exceptions.InternalServerError) as exc_info: - wrapped._maybe_commit(transaction) - self.assertIs(exc_info.exception, exc) - - self.assertEqual(transaction._id, txn_id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - firestore_api.begin_transaction.assert_not_called() - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - def test___call__success_first_attempt(self): - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) - result = wrapped(transaction, "a", b="c") - self.assertIs(result, mock.sentinel.result) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "a", b="c") - firestore_api = transaction._client._firestore_api - firestore_api.begin_transaction.assert_called_once_with( - request={"database": transaction._client._database_string, "options": None}, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, - metadata=transaction._client._rpc_metadata, - ) - - def test___call__success_second_attempt(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.types import common - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.types import write - - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"whole-enchilada" - transaction = _make_transaction(txn_id) - - # Actually force the ``commit`` to fail on first / succeed on second. - exc = exceptions.Aborted("Contention junction.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = [ - exc, - firestore.CommitResponse(write_results=[write.WriteResult()]), - ] - - # Call the __call__-able ``wrapped``. - result = wrapped(transaction, "a", b="c") - self.assertIs(result, mock.sentinel.result) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - wrapped_call = mock.call(transaction, "a", b="c") - self.assertEqual(to_wrap.mock_calls, [wrapped_call, wrapped_call]) - firestore_api = transaction._client._firestore_api - db_str = transaction._client._database_string - options_ = common.TransactionOptions( - read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) - ) - self.assertEqual( - firestore_api.begin_transaction.mock_calls, - [ - mock.call( - request={"database": db_str, "options": None}, - metadata=transaction._client._rpc_metadata, - ), - mock.call( - request={"database": db_str, "options": options_}, - metadata=transaction._client._rpc_metadata, - ), - ], - ) - firestore_api.rollback.assert_not_called() - commit_call = mock.call( - request={"database": db_str, "writes": [], "transaction": txn_id}, - metadata=transaction._client._rpc_metadata, - ) - self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) - - def test___call__failure(self): - from google.api_core import exceptions - from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE - - to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) - wrapped = self._make_one(to_wrap) - - txn_id = b"only-one-shot" - transaction = _make_transaction(txn_id, max_attempts=1) - - # Actually force the ``commit`` to fail. - exc = exceptions.Aborted("Contention just once.") - firestore_api = transaction._client._firestore_api - firestore_api.commit.side_effect = exc - - # Call the __call__-able ``wrapped``. - with self.assertRaises(ValueError) as exc_info: - wrapped(transaction, "here", there=1.5) - - err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) - self.assertEqual(exc_info.exception.args, (err_msg,)) - - self.assertIsNone(transaction._id) - self.assertEqual(wrapped.current_id, txn_id) - self.assertEqual(wrapped.retry_id, txn_id) - - # Verify mocks. - to_wrap.assert_called_once_with(transaction, "here", there=1.5) - firestore_api.begin_transaction.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "options": None, - }, - metadata=transaction._client._rpc_metadata, - ) - firestore_api.rollback.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "transaction": txn_id, - }, +def test_transaction__clean_up(): + transaction = _make_transaction(mock.sentinel.client) + transaction._write_pbs.extend([mock.sentinel.write_pb1, mock.sentinel.write]) + transaction._id = b"not-this-time-my-friend" + + ret_val = transaction._clean_up() + assert ret_val is None + + assert transaction._write_pbs == [] + assert transaction._id is None + + +def test_transaction__begin(): + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.types import firestore + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + txn_id = b"to-begin" + response = firestore.BeginTransactionResponse(transaction=txn_id) + firestore_api.begin_transaction.return_value = response + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and ``begin()`` it. + transaction = _make_transaction(client) + assert transaction._id is None + + ret_val = transaction._begin() + assert ret_val is None + assert transaction._id == txn_id + + # Verify the called mock. + firestore_api.begin_transaction.assert_called_once_with( + request={"database": client._database_string, "options": None}, + metadata=client._rpc_metadata, + ) + + +def test_transaction__begin_failure(): + from google.cloud.firestore_v1.base_transaction import _CANT_BEGIN + + client = _make_client() + transaction = _make_transaction(client) + transaction._id = b"not-none" + + with pytest.raises(ValueError) as exc_info: + transaction._begin() + + err_msg = _CANT_BEGIN.format(transaction._id) + assert exc_info.value.args == (err_msg,) + + +def test_transaction__rollback(): + from google.protobuf import empty_pb2 + from google.cloud.firestore_v1.services.firestore import client as firestore_client + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + firestore_api.rollback.return_value = empty_pb2.Empty() + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = _make_transaction(client) + txn_id = b"to-be-r\x00lled" + transaction._id = txn_id + ret_val = transaction._rollback() + assert ret_val is None + assert transaction._id is None + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + +def test_transaction__rollback_not_allowed(): + from google.cloud.firestore_v1.base_transaction import _CANT_ROLLBACK + + client = _make_client() + transaction = _make_transaction(client) + assert transaction._id is None + + with pytest.raises(ValueError) as exc_info: + transaction._rollback() + + assert exc_info.value.args == (_CANT_ROLLBACK,) + + +def test_transaction__rollback_failure(): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during rollback.") + firestore_api.rollback.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction and roll it back. + transaction = _make_transaction(client) + txn_id = b"roll-bad-server" + transaction._id = txn_id + + with pytest.raises(exceptions.InternalServerError) as exc_info: + transaction._rollback() + + assert exc_info.value is exc + assert transaction._id is None + assert transaction._write_pbs == [] + + # Verify the called mock. + firestore_api.rollback.assert_called_once_with( + request={"database": client._database_string, "transaction": txn_id}, + metadata=client._rpc_metadata, + ) + + +def test_transaction__commit(): + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + commit_response = firestore.CommitResponse(write_results=[write.WriteResult()]) + firestore_api.commit.return_value = commit_response + + # Attach the fake GAPIC to a real client. + client = _make_client("phone-joe") + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = _make_transaction(client) + txn_id = b"under-over-thru-woods" + transaction._id = txn_id + document = client.document("zap", "galaxy", "ship", "space") + transaction.set(document, {"apple": 4.5}) + write_pbs = transaction._write_pbs[::] + + write_results = transaction._commit() + assert write_results == list(commit_response.write_results) + # Make sure transaction has no more "changes". + assert transaction._id is None + assert transaction._write_pbs == [] + + # Verify the mocks. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +def test_transaction__commit_not_allowed(): + from google.cloud.firestore_v1.base_transaction import _CANT_COMMIT + + transaction = _make_transaction(mock.sentinel.client) + assert transaction._id is None + with pytest.raises(ValueError) as exc_info: + transaction._commit() + + assert exc_info.value.args == (_CANT_COMMIT,) + + +def test_transaction__commit_failure(): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client + + # Create a minimal fake GAPIC with a dummy failure. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + exc = exceptions.InternalServerError("Fire during commit.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client() + client._firestore_api_internal = firestore_api + + # Actually make a transaction with some mutations and call _commit(). + transaction = _make_transaction(client) + txn_id = b"beep-fail-commit" + transaction._id = txn_id + transaction.create(client.document("up", "down"), {"water": 1.0}) + transaction.delete(client.document("up", "left")) + write_pbs = transaction._write_pbs[::] + + with pytest.raises(exceptions.InternalServerError) as exc_info: + transaction._commit() + + assert exc_info.value is exc + assert transaction._id == txn_id + assert transaction._write_pbs == write_pbs + + # Verify the called mock. + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +def _transaction_get_all_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + + client = mock.Mock(spec=["get_all"]) + transaction = _make_transaction(client) + ref1, ref2 = mock.Mock(), mock.Mock() + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get_all([ref1, ref2], **kwargs) + + client.get_all.assert_called_once_with( + [ref1, ref2], transaction=transaction, **kwargs, + ) + assert result is client.get_all.return_value + + +def test_transaction_get_all(): + _transaction_get_all_helper() + + +def test_transaction_get_all_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _transaction_get_all_helper(retry=retry, timeout=timeout) + + +def _transaction_get_w_document_ref_helper(retry=None, timeout=None): + from google.cloud.firestore_v1.document import DocumentReference + from google.cloud.firestore_v1 import _helpers + + client = mock.Mock(spec=["get_all"]) + transaction = _make_transaction(client) + ref = DocumentReference("documents", "doc-id") + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get(ref, **kwargs) + + assert result is client.get_all.return_value + client.get_all.assert_called_once_with([ref], transaction=transaction, **kwargs) + + +def test_transaction_get_w_document_ref(): + _transaction_get_w_document_ref_helper() + + +def test_transaction_get_w_document_ref_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _transaction_get_w_document_ref_helper(retry=retry, timeout=timeout) + + +def _transaction_get_w_query_helper(retry=None, timeout=None): + from google.cloud.firestore_v1 import _helpers + from google.cloud.firestore_v1.query import Query + + client = mock.Mock(spec=[]) + transaction = _make_transaction(client) + query = Query(parent=mock.Mock(spec=[])) + query.stream = mock.MagicMock() + kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) + + result = transaction.get(query, **kwargs) + + assert result is query.stream.return_value + query.stream.assert_called_once_with(transaction=transaction, **kwargs) + + +def test_transaction_get_w_query(): + _transaction_get_w_query_helper() + + +def test_transaction_get_w_query_w_retry_timeout(): + from google.api_core.retry import Retry + + retry = Retry(predicate=object()) + timeout = 123.0 + _transaction_get_w_query_helper(retry=retry, timeout=timeout) + + +def test_transaction_get_failure(): + client = _make_client() + transaction = _make_transaction(client) + ref_or_query = object() + with pytest.raises(ValueError): + transaction.get(ref_or_query) + + +def _make__transactional(*args, **kwargs): + from google.cloud.firestore_v1.transaction import _Transactional + + return _Transactional(*args, **kwargs) + + +def test__transactional_constructor(): + wrapped = _make__transactional(mock.sentinel.callable_) + assert wrapped.to_wrap is mock.sentinel.callable_ + assert wrapped.current_id is None + assert wrapped.retry_id is None + + +def test__transactional__pre_commit_success(): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"totes-began" + transaction = _make_transaction_pb(txn_id) + result = wrapped._pre_commit(transaction, "pos", key="word") + assert result is mock.sentinel.result + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "pos", key="word") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + +def test__transactional__pre_commit_retry_id_already_set_success(): + from google.cloud.firestore_v1.types import common + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) + txn_id1 = b"already-set" + wrapped.retry_id = txn_id1 + + txn_id2 = b"ok-here-too" + transaction = _make_transaction_pb(txn_id2) + result = wrapped._pre_commit(transaction) + assert result is mock.sentinel.result + + assert transaction._id == txn_id2 + assert wrapped.current_id == txn_id2 + assert wrapped.retry_id == txn_id1 + + # Verify mocks. + to_wrap.assert_called_once_with(transaction) + firestore_api = transaction._client._firestore_api + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id1) + ) + firestore_api.begin_transaction.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "options": options_, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_not_called() + + +def test__transactional__pre_commit_failure(): + exc = RuntimeError("Nope not today.") + to_wrap = mock.Mock(side_effect=exc, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"gotta-fail" + transaction = _make_transaction_pb(txn_id) + with pytest.raises(RuntimeError) as exc_info: + wrapped._pre_commit(transaction, 10, 20) + assert exc_info.value is exc + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, 10, 20) + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + +def test__transactional__pre_commit_failure_with_rollback_failure(): + from google.api_core import exceptions + + exc1 = ValueError("I will not be only failure.") + to_wrap = mock.Mock(side_effect=exc1, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"both-will-fail" + transaction = _make_transaction_pb(txn_id) + # Actually force the ``rollback`` to fail as well. + exc2 = exceptions.InternalServerError("Rollback blues.") + firestore_api = transaction._client._firestore_api + firestore_api.rollback.side_effect = exc2 + + # Try to ``_pre_commit`` + with pytest.raises(exceptions.InternalServerError) as exc_info: + wrapped._pre_commit(transaction, a="b", c="zebra") + assert exc_info.value is exc2 + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, a="b", c="zebra") + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_not_called() + + +def test__transactional__maybe_commit_success(): + wrapped = _make__transactional(mock.sentinel.callable_) + + txn_id = b"nyet" + transaction = _make_transaction_pb(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + succeeded = wrapped._maybe_commit(transaction) + assert succeeded + + # On success, _id is reset. + assert transaction._id is None + + # Verify mocks. + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test__transactional__maybe_commit_failure_read_only(): + from google.api_core import exceptions + + wrapped = _make__transactional(mock.sentinel.callable_) + + txn_id = b"failed" + transaction = _make_transaction_pb(txn_id, read_only=True) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail (use ABORTED, but cannot + # retry since read-only). + exc = exceptions.Aborted("Read-only did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with pytest.raises(exceptions.Aborted) as exc_info: + wrapped._maybe_commit(transaction) + assert exc_info.value is exc + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test__transactional__maybe_commit_failure_can_retry(): + from google.api_core import exceptions + + wrapped = _make__transactional(mock.sentinel.callable_) + + txn_id = b"failed-but-retry" + transaction = _make_transaction_pb(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Read-write did a bad.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + succeeded = wrapped._maybe_commit(transaction) + assert not succeeded + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test__transactional__maybe_commit_failure_cannot_retry(): + from google.api_core import exceptions + + wrapped = _make__transactional(mock.sentinel.callable_) + + txn_id = b"failed-but-not-retryable" + transaction = _make_transaction_pb(txn_id) + transaction._id = txn_id # We won't call ``begin()``. + wrapped.current_id = txn_id # We won't call ``_pre_commit()``. + wrapped.retry_id = txn_id # We won't call ``_pre_commit()``. + + # Actually force the ``commit`` to fail. + exc = exceptions.InternalServerError("Real bad thing") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + with pytest.raises(exceptions.InternalServerError) as exc_info: + wrapped._maybe_commit(transaction) + assert exc_info.value is exc + + assert transaction._id == txn_id + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + firestore_api.begin_transaction.assert_not_called() + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test__transactional___call__success_first_attempt(): + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction_pb(txn_id) + result = wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "a", b="c") + firestore_api = transaction._client._firestore_api + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test__transactional___call__success_second_attempt(): + from google.api_core import exceptions + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.types import write + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"whole-enchilada" + transaction = _make_transaction_pb(txn_id) + + # Actually force the ``commit`` to fail on first / succeed on second. + exc = exceptions.Aborted("Contention junction.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = [ + exc, + firestore.CommitResponse(write_results=[write.WriteResult()]), + ] + + # Call the __call__-able ``wrapped``. + result = wrapped(transaction, "a", b="c") + assert result is mock.sentinel.result + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + wrapped_call = mock.call(transaction, "a", b="c") + assert to_wrap.mock_calls, [wrapped_call == wrapped_call] + firestore_api = transaction._client._firestore_api + db_str = transaction._client._database_string + options_ = common.TransactionOptions( + read_write=common.TransactionOptions.ReadWrite(retry_transaction=txn_id) + ) + expected_calls = [ + mock.call( + request={"database": db_str, "options": None}, metadata=transaction._client._rpc_metadata, - ) - firestore_api.commit.assert_called_once_with( - request={ - "database": transaction._client._database_string, - "writes": [], - "transaction": txn_id, - }, + ), + mock.call( + request={"database": db_str, "options": options_}, metadata=transaction._client._rpc_metadata, - ) - - -class Test_transactional(unittest.TestCase): - @staticmethod - def _call_fut(to_wrap): - from google.cloud.firestore_v1.transaction import transactional - - return transactional(to_wrap) - - def test_it(self): - from google.cloud.firestore_v1.transaction import _Transactional - - wrapped = self._call_fut(mock.sentinel.callable_) - self.assertIsInstance(wrapped, _Transactional) - self.assertIs(wrapped.to_wrap, mock.sentinel.callable_) - - -class Test__commit_with_retry(unittest.TestCase): - @staticmethod - def _call_fut(client, write_pbs, transaction_id): - from google.cloud.firestore_v1.transaction import _commit_with_retry - - return _commit_with_retry(client, write_pbs, transaction_id) - - @mock.patch("google.cloud.firestore_v1.transaction._sleep") - def test_success_first_attempt(self, _sleep): - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - - # Attach the fake GAPIC to a real client. - client = _make_client("summer") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"cheeeeeez" - commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) - self.assertIs(commit_response, firestore_api.commit.return_value) - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - @mock.patch("google.cloud.firestore_v1.transaction._sleep", side_effect=[2.0, 4.0]) - def test_success_third_attempt(self, _sleep): - from google.api_core import exceptions - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first two requests fail and the third succeeds. - firestore_api.commit.side_effect = [ - exceptions.ServiceUnavailable("Server sleepy."), - exceptions.ServiceUnavailable("Server groggy."), - mock.sentinel.commit_response, - ] - - # Attach the fake GAPIC to a real client. - client = _make_client("outside") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-world\x00" - commit_response = self._call_fut(client, mock.sentinel.write_pbs, txn_id) - self.assertIs(commit_response, mock.sentinel.commit_response) - - # Verify mocks used. - # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds - self.assertEqual(_sleep.call_count, 2) - _sleep.assert_any_call(1.0) - _sleep.assert_any_call(2.0) - # commit() called same way 3 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - self.assertEqual( - firestore_api.commit.mock_calls, [commit_call, commit_call, commit_call] - ) - - @mock.patch("google.cloud.firestore_v1.transaction._sleep") - def test_failure_first_attempt(self, _sleep): - from google.api_core import exceptions - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first request fails with an un-retryable error. - exc = exceptions.ResourceExhausted("We ran out of fries.") - firestore_api.commit.side_effect = exc - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" - with self.assertRaises(exceptions.ResourceExhausted) as exc_info: - self._call_fut(client, mock.sentinel.write_pbs, txn_id) - - self.assertIs(exc_info.exception, exc) - - # Verify mocks used. - _sleep.assert_not_called() - firestore_api.commit.assert_called_once_with( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - - @mock.patch("google.cloud.firestore_v1.transaction._sleep", return_value=2.0) - def test_failure_second_attempt(self, _sleep): - from google.api_core import exceptions - from google.cloud.firestore_v1.services.firestore import ( - client as firestore_client, - ) - - # Create a minimal fake GAPIC with a dummy result. - firestore_api = mock.create_autospec( - firestore_client.FirestoreClient, instance=True - ) - # Make sure the first request fails retry-able and second - # fails non-retryable. - exc1 = exceptions.ServiceUnavailable("Come back next time.") - exc2 = exceptions.InternalServerError("Server on fritz.") - firestore_api.commit.side_effect = [exc1, exc2] - - # Attach the fake GAPIC to a real client. - client = _make_client("peanut-butter") - client._firestore_api_internal = firestore_api - - # Call function and check result. - txn_id = b"the-journey-when-and-where-well-go" - with self.assertRaises(exceptions.InternalServerError) as exc_info: - self._call_fut(client, mock.sentinel.write_pbs, txn_id) - - self.assertIs(exc_info.exception, exc2) - - # Verify mocks used. - _sleep.assert_called_once_with(1.0) - # commit() called same way 2 times. - commit_call = mock.call( - request={ - "database": client._database_string, - "writes": mock.sentinel.write_pbs, - "transaction": txn_id, - }, - metadata=client._rpc_metadata, - ) - self.assertEqual(firestore_api.commit.mock_calls, [commit_call, commit_call]) - - -class Test__sleep(unittest.TestCase): - @staticmethod - def _call_fut(current_sleep, **kwargs): - from google.cloud.firestore_v1.transaction import _sleep - - return _sleep(current_sleep, **kwargs) - - @mock.patch("random.uniform", return_value=5.5) - @mock.patch("time.sleep", return_value=None) - def test_defaults(self, sleep, uniform): - curr_sleep = 10.0 - self.assertLessEqual(uniform.return_value, curr_sleep) - - new_sleep = self._call_fut(curr_sleep) - self.assertEqual(new_sleep, 2.0 * curr_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - @mock.patch("random.uniform", return_value=10.5) - @mock.patch("time.sleep", return_value=None) - def test_explicit(self, sleep, uniform): - curr_sleep = 12.25 - self.assertLessEqual(uniform.return_value, curr_sleep) - - multiplier = 1.5 - new_sleep = self._call_fut(curr_sleep, max_sleep=100.0, multiplier=multiplier) - self.assertEqual(new_sleep, multiplier * curr_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) - - @mock.patch("random.uniform", return_value=6.75) - @mock.patch("time.sleep", return_value=None) - def test_exceeds_max(self, sleep, uniform): - curr_sleep = 20.0 - self.assertLessEqual(uniform.return_value, curr_sleep) - - max_sleep = 38.5 - new_sleep = self._call_fut(curr_sleep, max_sleep=max_sleep, multiplier=2.0) - self.assertEqual(new_sleep, max_sleep) - - uniform.assert_called_once_with(0.0, curr_sleep) - sleep.assert_called_once_with(uniform.return_value) + ), + ] + assert firestore_api.begin_transaction.mock_calls == expected_calls + firestore_api.rollback.assert_not_called() + commit_call = mock.call( + request={"database": db_str, "writes": [], "transaction": txn_id}, + metadata=transaction._client._rpc_metadata, + ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] + + +def test__transactional___call__failure(): + from google.api_core import exceptions + from google.cloud.firestore_v1.base_transaction import _EXCEED_ATTEMPTS_TEMPLATE + + to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[]) + wrapped = _make__transactional(to_wrap) + + txn_id = b"only-one-shot" + transaction = _make_transaction_pb(txn_id, max_attempts=1) + + # Actually force the ``commit`` to fail. + exc = exceptions.Aborted("Contention just once.") + firestore_api = transaction._client._firestore_api + firestore_api.commit.side_effect = exc + + # Call the __call__-able ``wrapped``. + with pytest.raises(ValueError) as exc_info: + wrapped(transaction, "here", there=1.5) + + err_msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) + assert exc_info.value.args == (err_msg,) + + assert transaction._id is None + assert wrapped.current_id == txn_id + assert wrapped.retry_id == txn_id + + # Verify mocks. + to_wrap.assert_called_once_with(transaction, "here", there=1.5) + firestore_api.begin_transaction.assert_called_once_with( + request={"database": transaction._client._database_string, "options": None}, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.rollback.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + firestore_api.commit.assert_called_once_with( + request={ + "database": transaction._client._database_string, + "writes": [], + "transaction": txn_id, + }, + metadata=transaction._client._rpc_metadata, + ) + + +def test_transactional_factory(): + from google.cloud.firestore_v1.transaction import _Transactional + from google.cloud.firestore_v1.transaction import transactional + + wrapped = transactional(mock.sentinel.callable_) + assert isinstance(wrapped, _Transactional) + assert wrapped.to_wrap is mock.sentinel.callable_ + + +@mock.patch("google.cloud.firestore_v1.transaction._sleep") +def test__commit_with_retry_success_first_attempt(_sleep): + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + + # Attach the fake GAPIC to a real client. + client = _make_client("summer") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"cheeeeeez" + commit_response = _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + assert commit_response is firestore_api.commit.return_value + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +@mock.patch("google.cloud.firestore_v1.transaction._sleep", side_effect=[2.0, 4.0]) +def test__commit_with_retry_success_third_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first two requests fail and the third succeeds. + firestore_api.commit.side_effect = [ + exceptions.ServiceUnavailable("Server sleepy."), + exceptions.ServiceUnavailable("Server groggy."), + mock.sentinel.commit_response, + ] + + # Attach the fake GAPIC to a real client. + client = _make_client("outside") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-world\x00" + commit_response = _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + assert commit_response is mock.sentinel.commit_response + + # Verify mocks used. + # Ensure _sleep is called after commit failures, with intervals of 1 and 2 seconds + assert _sleep.call_count == 2 + _sleep.assert_any_call(1.0) + _sleep.assert_any_call(2.0) + # commit() called same way 3 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call, commit_call] + + +@mock.patch("google.cloud.firestore_v1.transaction._sleep") +def test__commit_with_retry_failure_first_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails with an un-retryable error. + exc = exceptions.ResourceExhausted("We ran out of fries.") + firestore_api.commit.side_effect = exc + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"\x08\x06\x07\x05\x03\x00\x09-jenny" + with pytest.raises(exceptions.ResourceExhausted) as exc_info: + _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + + assert exc_info.value is exc + + # Verify mocks used. + _sleep.assert_not_called() + firestore_api.commit.assert_called_once_with( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + + +@mock.patch("google.cloud.firestore_v1.transaction._sleep", return_value=2.0) +def test__commit_with_retry_failure_second_attempt(_sleep): + from google.api_core import exceptions + from google.cloud.firestore_v1.services.firestore import client as firestore_client + from google.cloud.firestore_v1.transaction import _commit_with_retry + + # Create a minimal fake GAPIC with a dummy result. + firestore_api = mock.create_autospec( + firestore_client.FirestoreClient, instance=True + ) + # Make sure the first request fails retry-able and second + # fails non-retryable. + exc1 = exceptions.ServiceUnavailable("Come back next time.") + exc2 = exceptions.InternalServerError("Server on fritz.") + firestore_api.commit.side_effect = [exc1, exc2] + + # Attach the fake GAPIC to a real client. + client = _make_client("peanut-butter") + client._firestore_api_internal = firestore_api + + # Call function and check result. + txn_id = b"the-journey-when-and-where-well-go" + with pytest.raises(exceptions.InternalServerError) as exc_info: + _commit_with_retry(client, mock.sentinel.write_pbs, txn_id) + + assert exc_info.value is exc2 + + # Verify mocks used. + _sleep.assert_called_once_with(1.0) + # commit() called same way 2 times. + commit_call = mock.call( + request={ + "database": client._database_string, + "writes": mock.sentinel.write_pbs, + "transaction": txn_id, + }, + metadata=client._rpc_metadata, + ) + assert firestore_api.commit.mock_calls == [commit_call, commit_call] + + +@mock.patch("random.uniform", return_value=5.5) +@mock.patch("time.sleep", return_value=None) +def test_defaults(sleep, uniform): + from google.cloud.firestore_v1.transaction import _sleep + + curr_sleep = 10.0 + assert uniform.return_value <= curr_sleep + + new_sleep = _sleep(curr_sleep) + assert new_sleep == 2.0 * curr_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +@mock.patch("random.uniform", return_value=10.5) +@mock.patch("time.sleep", return_value=None) +def test_explicit(sleep, uniform): + from google.cloud.firestore_v1.transaction import _sleep + + curr_sleep = 12.25 + assert uniform.return_value <= curr_sleep + + multiplier = 1.5 + new_sleep = _sleep(curr_sleep, max_sleep=100.0, multiplier=multiplier) + assert new_sleep == multiplier * curr_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) + + +@mock.patch("random.uniform", return_value=6.75) +@mock.patch("time.sleep", return_value=None) +def test_exceeds_max(sleep, uniform): + from google.cloud.firestore_v1.transaction import _sleep + + curr_sleep = 20.0 + assert uniform.return_value <= curr_sleep + + max_sleep = 38.5 + new_sleep = _sleep(curr_sleep, max_sleep=max_sleep, multiplier=2.0) + assert new_sleep == max_sleep + + uniform.assert_called_once_with(0.0, curr_sleep) + sleep.assert_called_once_with(uniform.return_value) def _make_credentials(): @@ -1031,7 +1015,7 @@ def _make_client(project="feral-tom-cat"): return Client(project=project, credentials=credentials) -def _make_transaction(txn_id, **txn_kwargs): +def _make_transaction_pb(txn_id, **txn_kwargs): from google.protobuf import empty_pb2 from google.cloud.firestore_v1.services.firestore import client as firestore_client from google.cloud.firestore_v1.types import firestore diff --git a/tests/unit/v1/test_transforms.py b/tests/unit/v1/test_transforms.py index 04a6dcdc08993..f5768bac4e0b5 100644 --- a/tests/unit/v1/test_transforms.py +++ b/tests/unit/v1/test_transforms.py @@ -12,102 +12,104 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - - -class Test_ValueList(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.transforms import _ValueList - - return _ValueList - - def _make_one(self, values): - return self._get_target_class()(values) - - def test_ctor_w_non_list_non_tuple(self): - invalid_values = (None, u"phred", b"DEADBEEF", 123, {}, object()) - for invalid_value in invalid_values: - with self.assertRaises(ValueError): - self._make_one(invalid_value) - - def test_ctor_w_empty(self): - with self.assertRaises(ValueError): - self._make_one([]) - - def test_ctor_w_non_empty_list(self): - values = ["phred", "bharney"] - inst = self._make_one(values) - self.assertEqual(inst.values, values) - - def test_ctor_w_non_empty_tuple(self): - values = ("phred", "bharney") - inst = self._make_one(values) - self.assertEqual(inst.values, list(values)) - - def test___eq___other_type(self): - values = ("phred", "bharney") - inst = self._make_one(values) - other = object() - self.assertFalse(inst == other) - - def test___eq___different_values(self): - values = ("phred", "bharney") - other_values = ("wylma", "bhetty") - inst = self._make_one(values) - other = self._make_one(other_values) - self.assertFalse(inst == other) - - def test___eq___same_values(self): - values = ("phred", "bharney") - inst = self._make_one(values) - other = self._make_one(values) - self.assertTrue(inst == other) - - -class Test_NumericValue(unittest.TestCase): - @staticmethod - def _get_target_class(): - from google.cloud.firestore_v1.transforms import _NumericValue - - return _NumericValue - - def _make_one(self, values): - return self._get_target_class()(values) - - def test_ctor_w_invalid_types(self): - invalid_values = (None, u"phred", b"DEADBEEF", [], {}, object()) - for invalid_value in invalid_values: - with self.assertRaises(ValueError): - self._make_one(invalid_value) - - def test_ctor_w_int(self): - values = (-10, -1, 0, 1, 10) - for value in values: - inst = self._make_one(value) - self.assertEqual(inst.value, value) - - def test_ctor_w_float(self): - values = (-10.0, -1.0, 0.0, 1.0, 10.0) - for value in values: - inst = self._make_one(value) - self.assertEqual(inst.value, value) - - def test___eq___other_type(self): - value = 3.1415926 - inst = self._make_one(value) - other = object() - self.assertFalse(inst == other) - - def test___eq___different_value(self): - value = 3.1415926 - other_value = 2.71828 - inst = self._make_one(value) - other = self._make_one(other_value) - self.assertFalse(inst == other) - - def test___eq___same_value(self): - value = 3.1415926 - inst = self._make_one(value) - other = self._make_one(value) - self.assertTrue(inst == other) +import pytest + + +def _make_value_list(*args, **kwargs): + from google.cloud.firestore_v1.transforms import _ValueList + + return _ValueList(*args, **kwargs) + + +def test__valuelist_ctor_w_non_list_non_tuple(): + invalid_values = (None, u"phred", b"DEADBEEF", 123, {}, object()) + for invalid_value in invalid_values: + with pytest.raises(ValueError): + _make_value_list(invalid_value) + + +def test__valuelist_ctor_w_empty(): + with pytest.raises(ValueError): + _make_value_list([]) + + +def test__valuelist_ctor_w_non_empty_list(): + values = ["phred", "bharney"] + inst = _make_value_list(values) + assert inst.values == values + + +def test__valuelist_ctor_w_non_empty_tuple(): + values = ("phred", "bharney") + inst = _make_value_list(values) + assert inst.values == list(values) + + +def test__valuelist___eq___other_type(): + values = ("phred", "bharney") + inst = _make_value_list(values) + other = object() + assert not (inst == other) + + +def test__valuelist___eq___different_values(): + values = ("phred", "bharney") + other_values = ("wylma", "bhetty") + inst = _make_value_list(values) + other = _make_value_list(other_values) + assert not (inst == other) + + +def test__valuelist___eq___same_values(): + values = ("phred", "bharney") + inst = _make_value_list(values) + other = _make_value_list(values) + assert inst == other + + +def _make_numeric_value(*args, **kwargs): + from google.cloud.firestore_v1.transforms import _NumericValue + + return _NumericValue(*args, **kwargs) + + +@pytest.mark.parametrize( + "invalid_value", [(None, u"phred", b"DEADBEEF", [], {}, object())], +) +def test__numericvalue_ctor_w_invalid_types(invalid_value): + with pytest.raises(ValueError): + _make_numeric_value(invalid_value) + + +@pytest.mark.parametrize("value", [-10, -1, 0, 1, 10]) +def test__numericvalue_ctor_w_int(value): + inst = _make_numeric_value(value) + assert inst.value == value + + +@pytest.mark.parametrize("value", [-10.0, -1.0, 0.0, 1.0, 10.0]) +def test__numericvalue_ctor_w_float(value): + inst = _make_numeric_value(value) + assert inst.value == value + + +def test__numericvalue___eq___other_type(): + value = 3.1415926 + inst = _make_numeric_value(value) + other = object() + assert not (inst == other) + + +def test__numericvalue___eq___different_value(): + value = 3.1415926 + other_value = 2.71828 + inst = _make_numeric_value(value) + other = _make_numeric_value(other_value) + assert not (inst == other) + + +def test__numericvalue___eq___same_value(): + value = 3.1415926 + inst = _make_numeric_value(value) + other = _make_numeric_value(value) + assert inst == other diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index c5b758459fcc1..2a49b5b08da1d 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -13,794 +13,823 @@ # limitations under the License. import datetime -import unittest + import mock -from google.cloud.firestore_v1.types import firestore +import pytest -class TestWatchDocTree(unittest.TestCase): - def _makeOne(self): - from google.cloud.firestore_v1.watch import WatchDocTree +def _make_watch_doc_tree(*args, **kwargs): + from google.cloud.firestore_v1.watch import WatchDocTree - return WatchDocTree() + return WatchDocTree(*args, **kwargs) - def test_insert_and_keys(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - inst = inst.insert("a", 2) - self.assertEqual(sorted(inst.keys()), ["a", "b"]) - def test_remove_and_keys(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - inst = inst.insert("a", 2) - inst = inst.remove("a") - self.assertEqual(sorted(inst.keys()), ["b"]) +def test_watchdoctree_insert_and_keys(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + inst = inst.insert("a", 2) + assert sorted(inst.keys()) == ["a", "b"] - def test_insert_and_find(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - inst = inst.insert("a", 2) - val = inst.find("a") - self.assertEqual(val.value, 2) - def test___len__(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - inst = inst.insert("a", 2) - self.assertEqual(len(inst), 2) +def test_watchdoctree_remove_and_keys(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + inst = inst.insert("a", 2) + inst = inst.remove("a") + assert sorted(inst.keys()) == ["b"] - def test___iter__(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - inst = inst.insert("a", 2) - self.assertEqual(sorted(list(inst)), ["a", "b"]) - def test___contains__(self): - inst = self._makeOne() - inst = inst.insert("b", 1) - self.assertTrue("b" in inst) - self.assertFalse("a" in inst) +def test_watchdoctree_insert_and_find(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + inst = inst.insert("a", 2) + val = inst.find("a") + assert val.value == 2 -class TestDocumentChange(unittest.TestCase): - def _makeOne(self, type, document, old_index, new_index): - from google.cloud.firestore_v1.watch import DocumentChange +def test_watchdoctree___len__(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + inst = inst.insert("a", 2) + assert len(inst) == 2 - return DocumentChange(type, document, old_index, new_index) - def test_ctor(self): - inst = self._makeOne("type", "document", "old_index", "new_index") - self.assertEqual(inst.type, "type") - self.assertEqual(inst.document, "document") - self.assertEqual(inst.old_index, "old_index") - self.assertEqual(inst.new_index, "new_index") +def test_watchdoctree___iter__(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + inst = inst.insert("a", 2) + assert sorted(list(inst)) == ["a", "b"] -class TestWatchResult(unittest.TestCase): - def _makeOne(self, snapshot, name, change_type): - from google.cloud.firestore_v1.watch import WatchResult +def test_watchdoctree___contains__(): + inst = _make_watch_doc_tree() + inst = inst.insert("b", 1) + assert "b" in inst + assert "a" not in inst - return WatchResult(snapshot, name, change_type) - def test_ctor(self): - inst = self._makeOne("snapshot", "name", "change_type") - self.assertEqual(inst.snapshot, "snapshot") - self.assertEqual(inst.name, "name") - self.assertEqual(inst.change_type, "change_type") +def test_documentchange_ctor(): + from google.cloud.firestore_v1.watch import DocumentChange + inst = DocumentChange("type", "document", "old_index", "new_index") + assert inst.type == "type" + assert inst.document == "document" + assert inst.old_index == "old_index" + assert inst.new_index == "new_index" -class Test_maybe_wrap_exception(unittest.TestCase): - def _callFUT(self, exc): - from google.cloud.firestore_v1.watch import _maybe_wrap_exception - return _maybe_wrap_exception(exc) +def test_watchresult_ctor(): + from google.cloud.firestore_v1.watch import WatchResult - def test_is_grpc_error(self): - import grpc - from google.api_core.exceptions import GoogleAPICallError + inst = WatchResult("snapshot", "name", "change_type") + assert inst.snapshot == "snapshot" + assert inst.name == "name" + assert inst.change_type == "change_type" - exc = grpc.RpcError() - result = self._callFUT(exc) - self.assertEqual(result.__class__, GoogleAPICallError) - def test_is_not_grpc_error(self): - exc = ValueError() - result = self._callFUT(exc) - self.assertEqual(result.__class__, ValueError) +def test__maybe_wrap_exception_w_grpc_error(): + import grpc + from google.api_core.exceptions import GoogleAPICallError + from google.cloud.firestore_v1.watch import _maybe_wrap_exception + exc = grpc.RpcError() + result = _maybe_wrap_exception(exc) + assert result.__class__ == GoogleAPICallError -class Test_document_watch_comparator(unittest.TestCase): - def _callFUT(self, doc1, doc2): - from google.cloud.firestore_v1.watch import document_watch_comparator - return document_watch_comparator(doc1, doc2) +def test__maybe_wrap_exception_w_non_grpc_error(): + from google.cloud.firestore_v1.watch import _maybe_wrap_exception - def test_same_doc(self): - result = self._callFUT(1, 1) - self.assertEqual(result, 0) + exc = ValueError() + result = _maybe_wrap_exception(exc) + assert result.__class__ == ValueError - def test_diff_doc(self): - self.assertRaises(AssertionError, self._callFUT, 1, 2) +def test_document_watch_comparator_wsame_doc(): + from google.cloud.firestore_v1.watch import document_watch_comparator -class Test_should_recover(unittest.TestCase): - def _callFUT(self, exception): - from google.cloud.firestore_v1.watch import _should_recover + result = document_watch_comparator(1, 1) + assert result == 0 - return _should_recover(exception) - def test_w_unavailable(self): - from google.api_core.exceptions import ServiceUnavailable +def test_document_watch_comparator_wdiff_doc(): + from google.cloud.firestore_v1.watch import document_watch_comparator - exception = ServiceUnavailable("testing") + with pytest.raises(AssertionError): + document_watch_comparator(1, 2) - self.assertTrue(self._callFUT(exception)) - def test_w_non_recoverable(self): - exception = ValueError("testing") +def test__should_recover_w_unavailable(): + from google.api_core.exceptions import ServiceUnavailable + from google.cloud.firestore_v1.watch import _should_recover - self.assertFalse(self._callFUT(exception)) + exception = ServiceUnavailable("testing") + assert _should_recover(exception) -class Test_should_terminate(unittest.TestCase): - def _callFUT(self, exception): - from google.cloud.firestore_v1.watch import _should_terminate - return _should_terminate(exception) +def test__should_recover_w_non_recoverable(): + from google.cloud.firestore_v1.watch import _should_recover - def test_w_unavailable(self): - from google.api_core.exceptions import Cancelled + exception = ValueError("testing") - exception = Cancelled("testing") + assert not _should_recover(exception) - self.assertTrue(self._callFUT(exception)) - def test_w_non_recoverable(self): - exception = ValueError("testing") +def test__should_terminate_w_unavailable(): + from google.api_core.exceptions import Cancelled + from google.cloud.firestore_v1.watch import _should_terminate - self.assertFalse(self._callFUT(exception)) + exception = Cancelled("testing") + assert _should_terminate(exception) -class TestWatch(unittest.TestCase): - def _makeOne( - self, - document_reference=None, - firestore=None, - target=None, - comparator=None, - snapshot_callback=None, - snapshot_class=None, - reference_class=None, - ): # pragma: NO COVER - from google.cloud.firestore_v1.watch import Watch - - if document_reference is None: - document_reference = DummyDocumentReference() - if firestore is None: - firestore = DummyFirestore() - if target is None: - WATCH_TARGET_ID = 0x5079 # "Py" - target = {"documents": {"documents": ["/"]}, "target_id": WATCH_TARGET_ID} - if comparator is None: - comparator = self._document_watch_comparator - if snapshot_callback is None: - snapshot_callback = self._snapshot_callback - if snapshot_class is None: - snapshot_class = DummyDocumentSnapshot - if reference_class is None: - reference_class = DummyDocumentReference - inst = Watch( - document_reference, - firestore, - target, - comparator, - snapshot_callback, - snapshot_class, - reference_class, - BackgroundConsumer=DummyBackgroundConsumer, - ResumableBidiRpc=DummyRpc, - ) - return inst - def setUp(self): - self.snapshotted = None +def test__should_terminate_w_non_recoverable(): + from google.cloud.firestore_v1.watch import _should_terminate - def _document_watch_comparator(self, doc1, doc2): # pragma: NO COVER - return 0 + exception = ValueError("testing") - def _snapshot_callback(self, docs, changes, read_time): - self.snapshotted = (docs, changes, read_time) + assert not _should_terminate(exception) - def test_ctor(self): - from google.cloud.firestore_v1.types import firestore - from google.cloud.firestore_v1.watch import _should_recover - from google.cloud.firestore_v1.watch import _should_terminate - - inst = self._makeOne() - self.assertTrue(inst._consumer.started) - self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done]) - self.assertIs(inst._rpc.start_rpc, inst._api._transport.listen) - self.assertIs(inst._rpc.should_recover, _should_recover) - self.assertIs(inst._rpc.should_terminate, _should_terminate) - self.assertIsInstance(inst._rpc.initial_request, firestore.ListenRequest) - self.assertEqual(inst._rpc.metadata, DummyFirestore._rpc_metadata) - - def test__on_rpc_done(self): - from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME - - inst = self._makeOne() - threading = DummyThreading() - with mock.patch("google.cloud.firestore_v1.watch.threading", threading): - inst._on_rpc_done(True) - self.assertTrue(threading.threads[_RPC_ERROR_THREAD_NAME].started) - - def test_close(self): - inst = self._makeOne() - inst.close() - self.assertEqual(inst._consumer, None) - self.assertEqual(inst._rpc, None) - self.assertTrue(inst._closed) - - def test_close_already_closed(self): - inst = self._makeOne() - inst._closed = True - old_consumer = inst._consumer - inst.close() - self.assertEqual(inst._consumer, old_consumer) - - def test_close_inactive(self): - inst = self._makeOne() - old_consumer = inst._consumer - old_consumer.is_active = False - inst.close() - self.assertEqual(old_consumer.stopped, False) - - def test_unsubscribe(self): - inst = self._makeOne() - inst.unsubscribe() - self.assertTrue(inst._rpc is None) - - def test_for_document(self): - from google.cloud.firestore_v1.watch import Watch - - docref = DummyDocumentReference() - snapshot_callback = self._snapshot_callback - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference - modulename = "google.cloud.firestore_v1.watch" + +@pytest.fixture(scope="function") +def snapshots(): + yield [] + + +def _document_watch_comparator(doc1, doc2): # pragma: NO COVER + return 0 + + +def _make_watch( + snapshots=None, comparator=_document_watch_comparator, +): + from google.cloud.firestore_v1.watch import Watch + + WATCH_TARGET_ID = 0x5079 # "Py" + target = {"documents": {"documents": ["/"]}, "target_id": WATCH_TARGET_ID} + + if snapshots is None: + snapshots = [] + + def snapshot_callback(*args): + snapshots.append(args) + + return Watch( + document_reference=DummyDocumentReference(), + firestore=DummyFirestore(), + target=target, + comparator=comparator, + snapshot_callback=snapshot_callback, + document_snapshot_cls=DummyDocumentSnapshot, + document_reference_cls=DummyDocumentReference, + BackgroundConsumer=DummyBackgroundConsumer, + ResumableBidiRpc=DummyRpc, + ) + + +def test_watch_ctor(): + from google.cloud.firestore_v1.types import firestore + from google.cloud.firestore_v1.watch import _should_recover + from google.cloud.firestore_v1.watch import _should_terminate + + inst = _make_watch() + assert inst._consumer.started + assert inst._rpc.callbacks, [inst._on_rpc_done] + assert inst._rpc.start_rpc is inst._api._transport.listen + assert inst._rpc.should_recover is _should_recover + assert inst._rpc.should_terminate is _should_terminate + assert isinstance(inst._rpc.initial_request, firestore.ListenRequest) + assert inst._rpc.metadata == DummyFirestore._rpc_metadata + + +def test_watch__on_rpc_done(): + from google.cloud.firestore_v1.watch import _RPC_ERROR_THREAD_NAME + + inst = _make_watch() + threading = DummyThreading() + + with mock.patch("google.cloud.firestore_v1.watch.threading", threading): + inst._on_rpc_done(True) + + assert threading.threads[_RPC_ERROR_THREAD_NAME].started + + +def test_watch_close(): + inst = _make_watch() + inst.close() + assert inst._consumer is None + assert inst._rpc is None + assert inst._closed + + +def test_watch_close_already_closed(): + inst = _make_watch() + inst._closed = True + old_consumer = inst._consumer + inst.close() + assert inst._consumer == old_consumer + + +def test_watch_close_inactive(): + inst = _make_watch() + old_consumer = inst._consumer + old_consumer.is_active = False + inst.close() + assert not old_consumer.stopped + + +def test_watch_unsubscribe(): + inst = _make_watch() + inst.unsubscribe() + assert inst._rpc is None + + +def test_watch_for_document(snapshots): + from google.cloud.firestore_v1.watch import Watch + + def snapshot_callback(*args): # pragma: NO COVER + snapshots.append(args) + + docref = DummyDocumentReference() + snapshot_class_instance = DummyDocumentSnapshot + document_reference_class_instance = DummyDocumentReference + modulename = "google.cloud.firestore_v1.watch" + + with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): + with mock.patch( + "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer + ): + inst = Watch.for_document( + docref, + snapshot_callback, + snapshot_class_instance, + document_reference_class_instance, + ) + assert inst._consumer.started + assert inst._rpc.callbacks == [inst._on_rpc_done] + + +def test_watch_for_query(snapshots): + from google.cloud.firestore_v1.watch import Watch + + def snapshot_callback(*args): # pragma: NO COVER + snapshots.append(args) + + snapshot_class_instance = DummyDocumentSnapshot + document_reference_class_instance = DummyDocumentReference + client = DummyFirestore() + parent = DummyCollection(client) + modulename = "google.cloud.firestore_v1.watch" + pb2 = DummyPb2() + with mock.patch("%s.firestore" % modulename, pb2): with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): with mock.patch( "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer ): - inst = Watch.for_document( - docref, + query = DummyQuery(parent=parent) + inst = Watch.for_query( + query, snapshot_callback, snapshot_class_instance, document_reference_class_instance, ) - self.assertTrue(inst._consumer.started) - self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done]) - - def test_for_query(self): - from google.cloud.firestore_v1.watch import Watch - - snapshot_callback = self._snapshot_callback - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference - client = DummyFirestore() - parent = DummyCollection(client) - modulename = "google.cloud.firestore_v1.watch" - pb2 = DummyPb2() - with mock.patch("%s.firestore" % modulename, pb2): - with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): - with mock.patch( - "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer - ): - query = DummyQuery(parent=parent) - inst = Watch.for_query( - query, - snapshot_callback, - snapshot_class_instance, - document_reference_class_instance, - ) - self.assertTrue(inst._consumer.started) - self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done]) - self.assertEqual(inst._targets["query"], "dummy query target") - - def test_for_query_nested(self): - from google.cloud.firestore_v1.watch import Watch - - snapshot_callback = self._snapshot_callback - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference - client = DummyFirestore() - root = DummyCollection(client) - grandparent = DummyDocument("document", parent=root) - parent = DummyCollection(client, parent=grandparent) - modulename = "google.cloud.firestore_v1.watch" - pb2 = DummyPb2() - with mock.patch("%s.firestore" % modulename, pb2): - with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): - with mock.patch( - "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer - ): - query = DummyQuery(parent=parent) - inst = Watch.for_query( - query, - snapshot_callback, - snapshot_class_instance, - document_reference_class_instance, - ) - self.assertTrue(inst._consumer.started) - self.assertTrue(inst._rpc.callbacks, [inst._on_rpc_done]) - self.assertEqual(inst._targets["query"], "dummy query target") - - def test_on_snapshot_target_w_none(self): - inst = self._makeOne() - proto = None - inst.on_snapshot(proto) # nothing to assert, no mutations, no rtnval - self.assertTrue(inst._consumer is None) - self.assertTrue(inst._rpc is None) - - def test_on_snapshot_target_no_change_no_target_ids_not_current(self): - inst = self._makeOne() - proto = DummyProto() - inst.on_snapshot(proto) # nothing to assert, no mutations, no rtnval - - def test_on_snapshot_target_no_change_no_target_ids_current(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change.read_time = 1 - inst.current = True - - def push(read_time, next_resume_token): - inst._read_time = read_time - inst._next_resume_token = next_resume_token - - inst.push = push - inst.on_snapshot(proto) - self.assertEqual(inst._read_time, 1) - self.assertEqual(inst._next_resume_token, None) - - def test_on_snapshot_target_add(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change.target_change_type = ( - firestore.TargetChange.TargetChangeType.ADD - ) - proto.target_change.target_ids = [1] # not "Py" - with self.assertRaises(Exception) as exc: - inst.on_snapshot(proto) - self.assertEqual(str(exc.exception), "Unexpected target ID 1 sent by server") - - def test_on_snapshot_target_remove(self): - inst = self._makeOne() - proto = DummyProto() - target_change = proto.target_change - target_change.target_change_type = ( - firestore.TargetChange.TargetChangeType.REMOVE - ) - with self.assertRaises(Exception) as exc: - inst.on_snapshot(proto) - self.assertEqual(str(exc.exception), "Error 1: hi") - - def test_on_snapshot_target_remove_nocause(self): - inst = self._makeOne() - proto = DummyProto() - target_change = proto.target_change - target_change.cause = None - target_change.target_change_type = ( - firestore.TargetChange.TargetChangeType.REMOVE - ) - with self.assertRaises(Exception) as exc: - inst.on_snapshot(proto) - self.assertEqual(str(exc.exception), "Error 13: internal error") + assert inst._consumer.started + assert inst._rpc.callbacks == [inst._on_rpc_done] + assert inst._targets["query"] == "dummy query target" + + +def test_watch_for_query_nested(snapshots): + from google.cloud.firestore_v1.watch import Watch + + def snapshot_callback(*args): # pragma: NO COVER + snapshots.append(args) + + snapshot_class_instance = DummyDocumentSnapshot + document_reference_class_instance = DummyDocumentReference + client = DummyFirestore() + root = DummyCollection(client) + grandparent = DummyDocument("document", parent=root) + parent = DummyCollection(client, parent=grandparent) + modulename = "google.cloud.firestore_v1.watch" + pb2 = DummyPb2() + with mock.patch("%s.firestore" % modulename, pb2): + with mock.patch("%s.Watch.ResumableBidiRpc" % modulename, DummyRpc): + with mock.patch( + "%s.Watch.BackgroundConsumer" % modulename, DummyBackgroundConsumer + ): + query = DummyQuery(parent=parent) + inst = Watch.for_query( + query, + snapshot_callback, + snapshot_class_instance, + document_reference_class_instance, + ) + assert inst._consumer.started + assert inst._rpc.callbacks == [inst._on_rpc_done] + assert inst._targets["query"] == "dummy query target" - def test_on_snapshot_target_reset(self): - inst = self._makeOne() - def reset(): - inst._docs_reset = True +def test_watch_on_snapshot_target_w_none(): + inst = _make_watch() + proto = None + inst.on_snapshot(proto) # nothing to assert, no mutations, no rtnval + assert inst._consumer is None + assert inst._rpc is None - inst._reset_docs = reset - proto = DummyProto() - target_change = proto.target_change - target_change.target_change_type = firestore.TargetChange.TargetChangeType.RESET - inst.on_snapshot(proto) - self.assertTrue(inst._docs_reset) - - def test_on_snapshot_target_current(self): - inst = self._makeOne() - inst.current = False - proto = DummyProto() - target_change = proto.target_change - target_change.target_change_type = ( - firestore.TargetChange.TargetChangeType.CURRENT - ) - inst.on_snapshot(proto) - self.assertTrue(inst.current) - - def test_on_snapshot_target_unknown(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change.target_change_type = "unknown" - with self.assertRaises(Exception) as exc: - inst.on_snapshot(proto) - self.assertTrue(inst._consumer is None) - self.assertTrue(inst._rpc is None) - self.assertEqual(str(exc.exception), "Unknown target change type: unknown ") - - def test_on_snapshot_document_change_removed(self): - from google.cloud.firestore_v1.watch import WATCH_TARGET_ID, ChangeType - - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change.removed_target_ids = [WATCH_TARGET_ID] - - class DummyDocument: - name = "fred" - - proto.document_change.document = DummyDocument() - inst.on_snapshot(proto) - self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED) - def test_on_snapshot_document_change_changed(self): - from google.cloud.firestore_v1.watch import WATCH_TARGET_ID +def test_watch_on_snapshot_target_no_change_no_target_ids_not_current(): + inst = _make_watch() + proto = DummyProto() + inst.on_snapshot(proto) # nothing to assert, no mutations, no rtnval + + +def test_watch_on_snapshot_target_no_change_no_target_ids_current(): + inst = _make_watch() + proto = DummyProto() + proto.target_change.read_time = 1 + inst.current = True + + def push(read_time, next_resume_token): + inst._read_time = read_time + inst._next_resume_token = next_resume_token - inst = self._makeOne() + inst.push = push + inst.on_snapshot(proto) + assert inst._read_time == 1 + assert inst._next_resume_token is None - proto = DummyProto() - proto.target_change = "" - proto.document_change.target_ids = [WATCH_TARGET_ID] - class DummyDocument: - name = "fred" - fields = {} - create_time = None - update_time = None +def test_watch_on_snapshot_target_add(): + from google.cloud.firestore_v1.types import firestore - proto.document_change.document = DummyDocument() + inst = _make_watch() + proto = DummyProto() + proto.target_change.target_change_type = firestore.TargetChange.TargetChangeType.ADD + proto.target_change.target_ids = [1] # not "Py" + + with pytest.raises(Exception) as exc: inst.on_snapshot(proto) - self.assertEqual(inst.change_map["fred"].data, {}) - def test_on_snapshot_document_change_changed_docname_db_prefix(self): - # TODO: Verify the current behavior. The change map currently contains - # the db-prefixed document name and not the bare document name. - from google.cloud.firestore_v1.watch import WATCH_TARGET_ID + assert str(exc.value) == "Unexpected target ID 1 sent by server" - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change.target_ids = [WATCH_TARGET_ID] +def test_watch_on_snapshot_target_remove(): + from google.cloud.firestore_v1.types import firestore - class DummyDocument: - name = "abc://foo/documents/fred" - fields = {} - create_time = None - update_time = None + inst = _make_watch() + proto = DummyProto() + target_change = proto.target_change + target_change.target_change_type = firestore.TargetChange.TargetChangeType.REMOVE - proto.document_change.document = DummyDocument() - inst._firestore._database_string = "abc://foo" + with pytest.raises(Exception) as exc: inst.on_snapshot(proto) - self.assertEqual(inst.change_map["abc://foo/documents/fred"].data, {}) - def test_on_snapshot_document_change_neither_changed_nor_removed(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change.target_ids = [] + assert str(exc.value) == "Error 1: hi" + + +def test_watch_on_snapshot_target_remove_nocause(): + from google.cloud.firestore_v1.types import firestore + inst = _make_watch() + proto = DummyProto() + target_change = proto.target_change + target_change.cause = None + target_change.target_change_type = firestore.TargetChange.TargetChangeType.REMOVE + + with pytest.raises(Exception) as exc: inst.on_snapshot(proto) - self.assertTrue(not inst.change_map) - def test_on_snapshot_document_removed(self): - from google.cloud.firestore_v1.watch import ChangeType + assert str(exc.value) == "Error 13: internal error" - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change = "" - class DummyRemove(object): - document = "fred" +def test_watch_on_snapshot_target_reset(): + from google.cloud.firestore_v1.types import firestore - remove = DummyRemove() - proto.document_remove = remove - proto.document_delete = "" - inst.on_snapshot(proto) - self.assertTrue(inst.change_map["fred"] is ChangeType.REMOVED) + inst = _make_watch() - def test_on_snapshot_filter_update(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change = "" - proto.document_remove = "" - proto.document_delete = "" + def reset(): + inst._docs_reset = True - class DummyFilter(object): - count = 999 + inst._reset_docs = reset + proto = DummyProto() + target_change = proto.target_change + target_change.target_change_type = firestore.TargetChange.TargetChangeType.RESET + inst.on_snapshot(proto) + assert inst._docs_reset - proto.filter = DummyFilter() - def reset(): - inst._docs_reset = True +def test_watch_on_snapshot_target_current(): + from google.cloud.firestore_v1.types import firestore - inst._reset_docs = reset + inst = _make_watch() + inst.current = False + proto = DummyProto() + target_change = proto.target_change + target_change.target_change_type = firestore.TargetChange.TargetChangeType.CURRENT + inst.on_snapshot(proto) + assert inst.current + + +def test_watch_on_snapshot_target_unknown(): + inst = _make_watch() + proto = DummyProto() + proto.target_change.target_change_type = "unknown" + + with pytest.raises(Exception) as exc: inst.on_snapshot(proto) - self.assertTrue(inst._docs_reset) - def test_on_snapshot_filter_update_no_size_change(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change = "" - proto.document_remove = "" - proto.document_delete = "" + assert inst._consumer is None + assert inst._rpc is None + assert str(exc.value) == "Unknown target change type: unknown " + + +def test_watch_on_snapshot_document_change_removed(): + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID, ChangeType + + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change.removed_target_ids = [WATCH_TARGET_ID] + + class DummyDocument: + name = "fred" + + proto.document_change.document = DummyDocument() + inst.on_snapshot(proto) + assert inst.change_map["fred"] is ChangeType.REMOVED + + +def test_watch_on_snapshot_document_change_changed(): + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID + + inst = _make_watch() + + proto = DummyProto() + proto.target_change = "" + proto.document_change.target_ids = [WATCH_TARGET_ID] + + class DummyDocument: + name = "fred" + fields = {} + create_time = None + update_time = None - class DummyFilter(object): - count = 0 + proto.document_change.document = DummyDocument() + inst.on_snapshot(proto) + assert inst.change_map["fred"].data == {} - proto.filter = DummyFilter() - inst._docs_reset = False +def test_watch_on_snapshot_document_change_changed_docname_db_prefix(): + # TODO: Verify the current behavior. The change map currently contains + # the db-prefixed document name and not the bare document name. + from google.cloud.firestore_v1.watch import WATCH_TARGET_ID + + inst = _make_watch() + + proto = DummyProto() + proto.target_change = "" + proto.document_change.target_ids = [WATCH_TARGET_ID] + + class DummyDocument: + name = "abc://foo/documents/fred" + fields = {} + create_time = None + update_time = None + + proto.document_change.document = DummyDocument() + inst._firestore._database_string = "abc://foo" + inst.on_snapshot(proto) + assert inst.change_map["abc://foo/documents/fred"].data == {} + + +def test_watch_on_snapshot_document_change_neither_changed_nor_removed(): + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change.target_ids = [] + + inst.on_snapshot(proto) + assert not inst.change_map + + +def test_watch_on_snapshot_document_removed(): + from google.cloud.firestore_v1.watch import ChangeType + + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + + class DummyRemove(object): + document = "fred" + + remove = DummyRemove() + proto.document_remove = remove + proto.document_delete = "" + inst.on_snapshot(proto) + assert inst.change_map["fred"] is ChangeType.REMOVED + + +def test_watch_on_snapshot_filter_update(): + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + proto.document_remove = "" + proto.document_delete = "" + + class DummyFilter(object): + count = 999 + + proto.filter = DummyFilter() + + def reset(): + inst._docs_reset = True + + inst._reset_docs = reset + inst.on_snapshot(proto) + assert inst._docs_reset + + +def test_watch_on_snapshot_filter_update_no_size_change(): + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + proto.document_remove = "" + proto.document_delete = "" + + class DummyFilter(object): + count = 0 + + proto.filter = DummyFilter() + inst._docs_reset = False + + inst.on_snapshot(proto) + assert not inst._docs_reset + + +def test_watch_on_snapshot_unknown_listen_type(): + inst = _make_watch() + proto = DummyProto() + proto.target_change = "" + proto.document_change = "" + proto.document_remove = "" + proto.document_delete = "" + proto.filter = "" + + with pytest.raises(Exception) as exc: inst.on_snapshot(proto) - self.assertFalse(inst._docs_reset) - - def test_on_snapshot_unknown_listen_type(self): - inst = self._makeOne() - proto = DummyProto() - proto.target_change = "" - proto.document_change = "" - proto.document_remove = "" - proto.document_delete = "" - proto.filter = "" - with self.assertRaises(Exception) as exc: - inst.on_snapshot(proto) - self.assertTrue( - str(exc.exception).startswith("Unknown listen response type"), - str(exc.exception), - ) - def test_push_callback_called_no_changes(self): - dummy_time = ( - datetime.datetime.fromtimestamp(1534858278, datetime.timezone.utc), - ) + assert str(exc.value).startswith("Unknown listen response type") - inst = self._makeOne() - inst.push(dummy_time, "token") - self.assertEqual( - self.snapshotted, ([], [], dummy_time), - ) - self.assertTrue(inst.has_pushed) - self.assertEqual(inst.resume_token, "token") - - def test_push_already_pushed(self): - class DummyReadTime(object): - seconds = 1534858278 - - inst = self._makeOne() - inst.has_pushed = True - inst.push(DummyReadTime, "token") - self.assertEqual(self.snapshotted, None) - self.assertTrue(inst.has_pushed) - self.assertEqual(inst.resume_token, "token") - - def test__current_size_empty(self): - inst = self._makeOne() - result = inst._current_size() - self.assertEqual(result, 0) - - def test__current_size_docmap_has_one(self): - inst = self._makeOne() - inst.doc_map["a"] = 1 - result = inst._current_size() - self.assertEqual(result, 1) - - def test__affects_target_target_id_None(self): - inst = self._makeOne() - self.assertTrue(inst._affects_target(None, [])) - - def test__affects_target_current_id_in_target_ids(self): - inst = self._makeOne() - self.assertTrue(inst._affects_target([1], 1)) - - def test__affects_target_current_id_not_in_target_ids(self): - inst = self._makeOne() - self.assertFalse(inst._affects_target([1], 2)) - - def test__extract_changes_doc_removed(self): - from google.cloud.firestore_v1.watch import ChangeType - - inst = self._makeOne() - changes = {"name": ChangeType.REMOVED} - doc_map = {"name": True} - results = inst._extract_changes(doc_map, changes, None) - self.assertEqual(results, (["name"], [], [])) - - def test__extract_changes_doc_removed_docname_not_in_docmap(self): - from google.cloud.firestore_v1.watch import ChangeType - - inst = self._makeOne() - changes = {"name": ChangeType.REMOVED} - doc_map = {} - results = inst._extract_changes(doc_map, changes, None) - self.assertEqual(results, ([], [], [])) - - def test__extract_changes_doc_updated(self): - inst = self._makeOne() - - class Dummy(object): - pass - - doc = Dummy() - snapshot = Dummy() - changes = {"name": snapshot} - doc_map = {"name": doc} - results = inst._extract_changes(doc_map, changes, 1) - self.assertEqual(results, ([], [], [snapshot])) - self.assertEqual(snapshot.read_time, 1) - - def test__extract_changes_doc_updated_read_time_is_None(self): - inst = self._makeOne() - - class Dummy(object): - pass - - doc = Dummy() - snapshot = Dummy() - snapshot.read_time = None - changes = {"name": snapshot} - doc_map = {"name": doc} - results = inst._extract_changes(doc_map, changes, None) - self.assertEqual(results, ([], [], [snapshot])) - self.assertEqual(snapshot.read_time, None) - - def test__extract_changes_doc_added(self): - inst = self._makeOne() - - class Dummy(object): - pass - - snapshot = Dummy() - changes = {"name": snapshot} - doc_map = {} - results = inst._extract_changes(doc_map, changes, 1) - self.assertEqual(results, ([], [snapshot], [])) - self.assertEqual(snapshot.read_time, 1) - - def test__extract_changes_doc_added_read_time_is_None(self): - inst = self._makeOne() - - class Dummy(object): - pass - - snapshot = Dummy() - snapshot.read_time = None - changes = {"name": snapshot} - doc_map = {} - results = inst._extract_changes(doc_map, changes, None) - self.assertEqual(results, ([], [snapshot], [])) - self.assertEqual(snapshot.read_time, None) - - def test__compute_snapshot_doctree_and_docmap_disagree_about_length(self): - inst = self._makeOne() - doc_tree = {} - doc_map = {None: None} - self.assertRaises( - AssertionError, inst._compute_snapshot, doc_tree, doc_map, None, None, None - ) - def test__compute_snapshot_operation_relative_ordering(self): - from google.cloud.firestore_v1.watch import WatchDocTree - - doc_tree = WatchDocTree() - - class DummyDoc(object): - update_time = mock.sentinel - - deleted_doc = DummyDoc() - added_doc = DummyDoc() - added_doc._document_path = "/added" - updated_doc = DummyDoc() - updated_doc._document_path = "/updated" - doc_tree = doc_tree.insert(deleted_doc, None) - doc_tree = doc_tree.insert(updated_doc, None) - doc_map = {"/deleted": deleted_doc, "/updated": updated_doc} - added_snapshot = DummyDocumentSnapshot(added_doc, None, True, None, None, None) - added_snapshot.reference = added_doc - updated_snapshot = DummyDocumentSnapshot( - updated_doc, None, True, None, None, None - ) - updated_snapshot.reference = updated_doc - delete_changes = ["/deleted"] - add_changes = [added_snapshot] - update_changes = [updated_snapshot] - inst = self._makeOne() - updated_tree, updated_map, applied_changes = inst._compute_snapshot( - doc_tree, doc_map, delete_changes, add_changes, update_changes - ) - # TODO: Verify that the assertion here is correct. - self.assertEqual( - updated_map, {"/updated": updated_snapshot, "/added": added_snapshot} - ) +def test_watch_push_callback_called_no_changes(snapshots): + dummy_time = (datetime.datetime.fromtimestamp(1534858278, datetime.timezone.utc),) - def test__compute_snapshot_modify_docs_updated_doc_no_timechange(self): - from google.cloud.firestore_v1.watch import WatchDocTree + inst = _make_watch(snapshots=snapshots) + inst.push(dummy_time, "token") + assert snapshots == [([], [], dummy_time)] + assert inst.has_pushed + assert inst.resume_token == "token" - doc_tree = WatchDocTree() - class DummyDoc(object): - pass +def test_watch_push_already_pushed(snapshots): + class DummyReadTime(object): + seconds = 1534858278 - updated_doc_v1 = DummyDoc() - updated_doc_v1.update_time = 1 - updated_doc_v1._document_path = "/updated" - updated_doc_v2 = DummyDoc() - updated_doc_v2.update_time = 1 - updated_doc_v2._document_path = "/updated" - doc_tree = doc_tree.insert("/updated", updated_doc_v1) - doc_map = {"/updated": updated_doc_v1} - updated_snapshot = DummyDocumentSnapshot( - updated_doc_v2, None, True, None, None, 1 - ) - delete_changes = [] - add_changes = [] - update_changes = [updated_snapshot] - inst = self._makeOne() - updated_tree, updated_map, applied_changes = inst._compute_snapshot( - doc_tree, doc_map, delete_changes, add_changes, update_changes - ) - self.assertEqual(updated_map, doc_map) # no change - - def test__compute_snapshot_deletes_w_real_comparator(self): - from google.cloud.firestore_v1.watch import WatchDocTree - - doc_tree = WatchDocTree() - - class DummyDoc(object): - update_time = mock.sentinel - - deleted_doc_1 = DummyDoc() - deleted_doc_2 = DummyDoc() - doc_tree = doc_tree.insert(deleted_doc_1, None) - doc_tree = doc_tree.insert(deleted_doc_2, None) - doc_map = {"/deleted_1": deleted_doc_1, "/deleted_2": deleted_doc_2} - delete_changes = ["/deleted_1", "/deleted_2"] - add_changes = [] - update_changes = [] - inst = self._makeOne(comparator=object()) - updated_tree, updated_map, applied_changes = inst._compute_snapshot( - doc_tree, doc_map, delete_changes, add_changes, update_changes - ) - self.assertEqual(updated_map, {}) + inst = _make_watch(snapshots=snapshots) + inst.has_pushed = True + inst.push(DummyReadTime, "token") + assert snapshots == [] + assert inst.has_pushed + assert inst.resume_token == "token" - def test__reset_docs(self): - from google.cloud.firestore_v1.watch import ChangeType - inst = self._makeOne() - inst.change_map = {None: None} - from google.cloud.firestore_v1.watch import WatchDocTree +def test_watch__current_size_empty(): + inst = _make_watch() + result = inst._current_size() + assert result == 0 - doc = DummyDocumentReference("doc") - doc_tree = WatchDocTree() - snapshot = DummyDocumentSnapshot(doc, None, True, None, None, None) - snapshot.reference = doc - doc_tree = doc_tree.insert(snapshot, None) - inst.doc_tree = doc_tree - inst._reset_docs() - self.assertEqual(inst.change_map, {"/doc": ChangeType.REMOVED}) - self.assertEqual(inst.resume_token, None) - self.assertFalse(inst.current) - def test_resume_token_sent_on_recovery(self): - inst = self._makeOne() - inst.resume_token = b"ABCD0123" - request = inst._get_rpc_request() - self.assertEqual(request.add_target.resume_token, b"ABCD0123") +def test_watch__current_size_docmap_has_one(): + inst = _make_watch() + inst.doc_map["a"] = 1 + result = inst._current_size() + assert result == 1 + + +def test_watch__affects_target_target_id_None(): + inst = _make_watch() + assert inst._affects_target(None, []) + + +def test_watch__affects_target_current_id_in_target_ids(): + inst = _make_watch() + assert inst._affects_target([1], 1) + + +def test_watch__affects_target_current_id_not_in_target_ids(): + inst = _make_watch() + assert not inst._affects_target([1], 2) + + +def test_watch__extract_changes_doc_removed(): + from google.cloud.firestore_v1.watch import ChangeType + + inst = _make_watch() + changes = {"name": ChangeType.REMOVED} + doc_map = {"name": True} + results = inst._extract_changes(doc_map, changes, None) + assert results == (["name"], [], []) + + +def test_watch__extract_changes_doc_removed_docname_not_in_docmap(): + from google.cloud.firestore_v1.watch import ChangeType + + inst = _make_watch() + changes = {"name": ChangeType.REMOVED} + doc_map = {} + results = inst._extract_changes(doc_map, changes, None) + assert results == ([], [], []) + + +def test_watch__extract_changes_doc_updated(): + inst = _make_watch() + + class Dummy(object): + pass + + doc = Dummy() + snapshot = Dummy() + changes = {"name": snapshot} + doc_map = {"name": doc} + results = inst._extract_changes(doc_map, changes, 1) + assert results == ([], [], [snapshot]) + assert snapshot.read_time == 1 + + +def test_watch__extract_changes_doc_updated_read_time_is_None(): + inst = _make_watch() + + class Dummy(object): + pass + + doc = Dummy() + snapshot = Dummy() + snapshot.read_time = None + changes = {"name": snapshot} + doc_map = {"name": doc} + results = inst._extract_changes(doc_map, changes, None) + assert results == ([], [], [snapshot]) + assert snapshot.read_time is None + + +def test_watch__extract_changes_doc_added(): + inst = _make_watch() + + class Dummy(object): + pass + + snapshot = Dummy() + changes = {"name": snapshot} + doc_map = {} + results = inst._extract_changes(doc_map, changes, 1) + assert results == ([], [snapshot], []) + assert snapshot.read_time == 1 + + +def test_watch__extract_changes_doc_added_read_time_is_None(): + inst = _make_watch() + + class Dummy(object): + pass + + snapshot = Dummy() + snapshot.read_time = None + changes = {"name": snapshot} + doc_map = {} + results = inst._extract_changes(doc_map, changes, None) + assert results == ([], [snapshot], []) + assert snapshot.read_time is None + + +def test_watch__compute_snapshot_doctree_and_docmap_disagree_about_length(): + inst = _make_watch() + doc_tree = {} + doc_map = {None: None} + + with pytest.raises(AssertionError): + inst._compute_snapshot(doc_tree, doc_map, None, None, None) + + +def test_watch__compute_snapshot_operation_relative_ordering(): + from google.cloud.firestore_v1.watch import WatchDocTree + + doc_tree = WatchDocTree() + + class DummyDoc(object): + update_time = mock.sentinel + + deleted_doc = DummyDoc() + added_doc = DummyDoc() + added_doc._document_path = "/added" + updated_doc = DummyDoc() + updated_doc._document_path = "/updated" + doc_tree = doc_tree.insert(deleted_doc, None) + doc_tree = doc_tree.insert(updated_doc, None) + doc_map = {"/deleted": deleted_doc, "/updated": updated_doc} + added_snapshot = DummyDocumentSnapshot(added_doc, None, True, None, None, None) + added_snapshot.reference = added_doc + updated_snapshot = DummyDocumentSnapshot(updated_doc, None, True, None, None, None) + updated_snapshot.reference = updated_doc + delete_changes = ["/deleted"] + add_changes = [added_snapshot] + update_changes = [updated_snapshot] + inst = _make_watch() + updated_tree, updated_map, applied_changes = inst._compute_snapshot( + doc_tree, doc_map, delete_changes, add_changes, update_changes + ) + # TODO: Verify that the assertion here is correct. + assert updated_map == {"/updated": updated_snapshot, "/added": added_snapshot} + + +def test_watch__compute_snapshot_modify_docs_updated_doc_no_timechange(): + from google.cloud.firestore_v1.watch import WatchDocTree + + doc_tree = WatchDocTree() + + class DummyDoc(object): + pass + + updated_doc_v1 = DummyDoc() + updated_doc_v1.update_time = 1 + updated_doc_v1._document_path = "/updated" + updated_doc_v2 = DummyDoc() + updated_doc_v2.update_time = 1 + updated_doc_v2._document_path = "/updated" + doc_tree = doc_tree.insert("/updated", updated_doc_v1) + doc_map = {"/updated": updated_doc_v1} + updated_snapshot = DummyDocumentSnapshot(updated_doc_v2, None, True, None, None, 1) + delete_changes = [] + add_changes = [] + update_changes = [updated_snapshot] + inst = _make_watch() + updated_tree, updated_map, applied_changes = inst._compute_snapshot( + doc_tree, doc_map, delete_changes, add_changes, update_changes + ) + assert updated_map == doc_map # no change + + +def test_watch__compute_snapshot_deletes_w_real_comparator(): + from google.cloud.firestore_v1.watch import WatchDocTree + + doc_tree = WatchDocTree() + + class DummyDoc(object): + update_time = mock.sentinel + + deleted_doc_1 = DummyDoc() + deleted_doc_2 = DummyDoc() + doc_tree = doc_tree.insert(deleted_doc_1, None) + doc_tree = doc_tree.insert(deleted_doc_2, None) + doc_map = {"/deleted_1": deleted_doc_1, "/deleted_2": deleted_doc_2} + delete_changes = ["/deleted_1", "/deleted_2"] + add_changes = [] + update_changes = [] + inst = _make_watch(comparator=object()) + updated_tree, updated_map, applied_changes = inst._compute_snapshot( + doc_tree, doc_map, delete_changes, add_changes, update_changes + ) + assert updated_map == {} + + +def test_watch__reset_docs(): + from google.cloud.firestore_v1.watch import ChangeType + + inst = _make_watch() + inst.change_map = {None: None} + from google.cloud.firestore_v1.watch import WatchDocTree + + doc = DummyDocumentReference("doc") + doc_tree = WatchDocTree() + snapshot = DummyDocumentSnapshot(doc, None, True, None, None, None) + snapshot.reference = doc + doc_tree = doc_tree.insert(snapshot, None) + inst.doc_tree = doc_tree + inst._reset_docs() + assert inst.change_map == {"/doc": ChangeType.REMOVED} + assert inst.resume_token is None + assert not inst.current + + +def test_watch_resume_token_sent_on_recovery(): + inst = _make_watch() + inst.resume_token = b"ABCD0123" + request = inst._get_rpc_request() + assert request.add_target.resume_token == b"ABCD0123" class DummyFirestoreStub(object): @@ -970,6 +999,8 @@ class DummyCause(object): class DummyChange(object): def __init__(self): + from google.cloud.firestore_v1.types import firestore + self.target_ids = [] self.removed_target_ids = [] self.read_time = 0