Skip to content

Commit

Permalink
type annotate twisted.trial.test.test_deferred
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Sep 16, 2023
1 parent 46d67ce commit 4bd66af
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 38 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,6 @@ module = [
'twisted.trial.test.suppression',
'twisted.trial.test.test_assertions',
'twisted.trial.test.test_asyncassertions',
'twisted.trial.test.test_deferred',
'twisted.trial.test.test_log',
'twisted.trial.test.test_plugins',
'twisted.trial.test.test_reporter',
Expand Down
5 changes: 3 additions & 2 deletions src/twisted/trial/test/detests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"""
Tests for Deferred handling by L{twisted.trial.unittest.TestCase}.
"""

from __future__ import annotations

from twisted.internet import defer, reactor, threads
from twisted.python.failure import Failure
from twisted.python.util import runWithWarningsSuppressed
from twisted.trial import unittest
from twisted.trial.util import suppress as SUPPRESS
Expand Down Expand Up @@ -160,7 +161,7 @@ def test_expectedFailure(self):


class TimeoutTests(unittest.TestCase):
timedOut = None
timedOut: Failure | None = None

def test_pass(self):
d = defer.Deferred()
Expand Down
82 changes: 47 additions & 35 deletions src/twisted/trial/test/test_deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,32 @@
"""
Tests for returning Deferreds from a TestCase.
"""

from __future__ import annotations

import unittest as pyunit

from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.trial import reporter, unittest, util
from twisted.trial.test import detests


class SetUpTests(unittest.TestCase):
def _loadSuite(self, klass):
def _loadSuite(
self, klass: type[pyunit.TestCase]
) -> tuple[reporter.TestResult, pyunit.TestSuite]:
loader = pyunit.TestLoader()
r = reporter.TestResult()
s = loader.loadTestsFromTestCase(klass)
return r, s

def test_success(self):
def test_success(self) -> None:
result, suite = self._loadSuite(detests.DeferredSetUpOK)
suite(result)
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)

def test_fail(self):
def test_fail(self) -> None:
self.assertFalse(detests.DeferredSetUpFail.testCalled)
result, suite = self._loadSuite(detests.DeferredSetUpFail)
suite(result)
Expand All @@ -36,7 +39,7 @@ def test_fail(self):
self.assertEqual(len(result.errors), 1)
self.assertFalse(detests.DeferredSetUpFail.testCalled)

def test_callbackFail(self):
def test_callbackFail(self) -> None:
self.assertFalse(detests.DeferredSetUpCallbackFail.testCalled)
result, suite = self._loadSuite(detests.DeferredSetUpCallbackFail)
suite(result)
Expand All @@ -46,7 +49,7 @@ def test_callbackFail(self):
self.assertEqual(len(result.errors), 1)
self.assertFalse(detests.DeferredSetUpCallbackFail.testCalled)

def test_error(self):
def test_error(self) -> None:
self.assertFalse(detests.DeferredSetUpError.testCalled)
result, suite = self._loadSuite(detests.DeferredSetUpError)
suite(result)
Expand All @@ -56,7 +59,7 @@ def test_error(self):
self.assertEqual(len(result.errors), 1)
self.assertFalse(detests.DeferredSetUpError.testCalled)

def test_skip(self):
def test_skip(self) -> None:
self.assertFalse(detests.DeferredSetUpSkip.testCalled)
result, suite = self._loadSuite(detests.DeferredSetUpSkip)
suite(result)
Expand All @@ -69,20 +72,22 @@ def test_skip(self):


class NeverFireTests(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
self._oldTimeout = util.DEFAULT_TIMEOUT_DURATION
util.DEFAULT_TIMEOUT_DURATION = 0.1

def tearDown(self):
def tearDown(self) -> None:
util.DEFAULT_TIMEOUT_DURATION = self._oldTimeout

def _loadSuite(self, klass):
def _loadSuite(
self, klass: type[pyunit.TestCase]
) -> tuple[reporter.TestResult, pyunit.TestSuite]:
loader = pyunit.TestLoader()
r = reporter.TestResult()
s = loader.loadTestsFromTestCase(klass)
return r, s

def test_setUp(self):
def test_setUp(self) -> None:
self.assertFalse(detests.DeferredSetUpNeverFire.testCalled)
result, suite = self._loadSuite(detests.DeferredSetUpNeverFire)
suite(result)
Expand All @@ -91,29 +96,30 @@ def test_setUp(self):
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.errors), 1)
self.assertFalse(detests.DeferredSetUpNeverFire.testCalled)
assert isinstance(result.errors[0][1], Failure)
self.assertTrue(result.errors[0][1].check(defer.TimeoutError))


class TestTester(unittest.TestCase):
def getTest(self, name):
def getTest(self, name: str) -> pyunit.TestCase:
raise NotImplementedError("must override me")

def runTest(self, name):
def runTest(self, name: str) -> reporter.TestResult: # type: ignore[override]
result = reporter.TestResult()
self.getTest(name).run(result)
return result


class DeferredTests(TestTester):
def getTest(self, name):
def getTest(self, name: str) -> detests.DeferredTests:
return detests.DeferredTests(name)

def test_pass(self):
def test_pass(self) -> None:
result = self.runTest("test_pass")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)

def test_passGenerated(self):
def test_passGenerated(self) -> None:
result = self.runTest("test_passGenerated")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
Expand All @@ -123,7 +129,7 @@ def test_passGenerated(self):
util.suppress(message="twisted.internet.defer.deferredGenerator is deprecated")
]

def test_passInlineCallbacks(self):
def test_passInlineCallbacks(self) -> None:
"""
The body of a L{defer.inlineCallbacks} decorated test gets run.
"""
Expand All @@ -132,113 +138,119 @@ def test_passInlineCallbacks(self):
self.assertEqual(result.testsRun, 1)
self.assertTrue(detests.DeferredTests.touched)

def test_fail(self):
def test_fail(self) -> None:
result = self.runTest("test_fail")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.failures), 1)

def test_failureInCallback(self):
def test_failureInCallback(self) -> None:
result = self.runTest("test_failureInCallback")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.failures), 1)

def test_errorInCallback(self):
def test_errorInCallback(self) -> None:
result = self.runTest("test_errorInCallback")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.errors), 1)

def test_skip(self):
def test_skip(self) -> None:
result = self.runTest("test_skip")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.skips), 1)
self.assertFalse(detests.DeferredTests.touched)

def test_todo(self):
def test_todo(self) -> None:
result = self.runTest("test_expectedFailure")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.failures), 0)
self.assertEqual(len(result.expectedFailures), 1)

def test_thread(self):
def test_thread(self) -> None:
result = self.runTest("test_thread")
self.assertEqual(result.testsRun, 1)
self.assertTrue(result.wasSuccessful(), result.errors)


class TimeoutTests(TestTester):
def getTest(self, name):
def getTest(self, name: str) -> detests.TimeoutTests:
return detests.TimeoutTests(name)

def _wasTimeout(self, error):
def _wasTimeout(self, error: Failure) -> None:
self.assertEqual(error.check(defer.TimeoutError), defer.TimeoutError)

def test_pass(self):
def test_pass(self) -> None:
result = self.runTest("test_pass")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)

def test_passDefault(self):
def test_passDefault(self) -> None:
result = self.runTest("test_passDefault")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)

def test_timeout(self):
def test_timeout(self) -> None:
result = self.runTest("test_timeout")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.errors), 1)
assert isinstance(result.errors[0][1], Failure)
self._wasTimeout(result.errors[0][1])

def test_timeoutZero(self):
def test_timeoutZero(self) -> None:
result = self.runTest("test_timeoutZero")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.errors), 1)
assert isinstance(result.errors[0][1], Failure)
self._wasTimeout(result.errors[0][1])

def test_skip(self):
def test_skip(self) -> None:
result = self.runTest("test_skip")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.skips), 1)

def test_todo(self):
def test_todo(self) -> None:
result = self.runTest("test_expectedFailure")
self.assertTrue(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
self.assertEqual(len(result.expectedFailures), 1)
assert isinstance(result.expectedFailures[0][1], Failure)
self._wasTimeout(result.expectedFailures[0][1])

def test_errorPropagation(self):
def test_errorPropagation(self) -> None:
result = self.runTest("test_errorPropagation")
self.assertFalse(result.wasSuccessful())
self.assertEqual(result.testsRun, 1)
assert detests.TimeoutTests.timedOut is not None
self._wasTimeout(detests.TimeoutTests.timedOut)

def test_classTimeout(self):
def test_classTimeout(self) -> None:
loader = pyunit.TestLoader()
suite = loader.loadTestsFromTestCase(detests.TestClassTimeoutAttribute)
result = reporter.TestResult()
suite.run(result)
self.assertEqual(len(result.errors), 1)
assert isinstance(result.errors[0][1], Failure)
self._wasTimeout(result.errors[0][1])

def test_callbackReturnsNonCallingDeferred(self):
def test_callbackReturnsNonCallingDeferred(self) -> None:
# hacky timeout
# raises KeyboardInterrupt because Trial sucks
from twisted.internet import reactor

call = reactor.callLater(2, reactor.crash)
call = reactor.callLater(2, reactor.crash) # type: ignore[attr-defined]
result = self.runTest("test_calledButNeverCallback")
if call.active():
call.cancel()
self.assertFalse(result.wasSuccessful())
assert isinstance(result.errors[0][1], Failure)
self._wasTimeout(result.errors[0][1])


Expand Down

0 comments on commit 4bd66af

Please sign in to comment.