Skip to content

Commit

Permalink
Property test for AggregateSmallWrites.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Sep 18, 2023
1 parent ef3e07f commit 8ffe969
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/twisted/protocols/test/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
Tests for L{twisted.protocols.tls}.
"""

from __future__ import annotations

import gc
from typing import Union

from hypothesis import given, strategies as st

from zope.interface import Interface, directlyProvides, implementer
from zope.interface.verify import verifyObject

from twisted.internet.task import Clock
from twisted.python.compat import iterbytes

try:
Expand All @@ -28,6 +33,7 @@
)

from twisted.protocols.tls import (
AggregateSmallWrites,
TLSMemoryBIOFactory,
TLSMemoryBIOProtocol,
_ProducerMembrane,
Expand Down Expand Up @@ -1819,3 +1825,50 @@ def acceptableProtocols(self):
@rtype: L{list} of L{bytes}
"""
return self._acceptableProtocols


class AggregateSmallWritesTests(SynchronousTestCase):
"""Tests for ``AggregateSmallWrites``."""

@given(
st.lists(
st.one_of(
st.none(),
st.integers(min_value=1, max_value=100_000).map(
lambda length: (b"0123456789ABCDEFGHIJ" * ((length // 20) + 1))[
:length
]
),
),
max_size=1_000,
)
)
def test_writes_get_aggregated(self, writes: list[Union[bytes, None]]):
"""
If multiple writes happen in between reactor iterations, they get
written in a batch at the start of the next reactor iteration.
"""
result = []
lengths = []
clock = Clock()
aggregate = AggregateSmallWrites(result.append, clock)
length_so_far = 0
for value in writes:
if value is None:
if length_so_far != 0:
lengths.append(length_so_far)
length_so_far = 0
clock.advance(0.0001)
else:
length_so_far += len(value)
aggregate.write(value)
aggregate.flush()
if length_so_far != 0:
lengths.append(length_so_far)

self.assertEqual(len(result), len(lengths))
self.assertEqual(
b"".join(result), b"".join(value for value in writes if value is not None)
)
for (combined, expected_length) in zip(result, lengths):
self.assertEqual(len(combined), expected_length)

0 comments on commit 8ffe969

Please sign in to comment.