Skip to content

Commit

Permalink
use aiohttp dns resolver - fixes #126
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Nov 19, 2020
1 parent fdd09ee commit 4c9a6e7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
3 changes: 2 additions & 1 deletion molotov/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ def main(args=None):
if args.workers == 1:
args.workers = 500

return run(args)
run(args)
return 0


_SIZING = """\
Expand Down
6 changes: 2 additions & 4 deletions molotov/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,11 @@ def __init__(self, loop, console, verbose=0, statsd=None, resolve_dns=True, **kw
async def send_event(self, event, **options):
await self.eventer.send_event(event, session=self, **options)

def _dns_lookup(self, url):
return resolve(url)[0]

async def _request(self, *args, **kw):
args = list(args)
if self._resolve_dns:
args[1] = self._dns_lookup(args[1])
resolved = await resolve(args[1])
args[1] = resolved[0]
args = tuple(args)
req = super(LoggedClientSession, self)._request

Expand Down
6 changes: 4 additions & 2 deletions molotov/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest
import os
from molotov.util import resolve, expand_options, OptionError, set_var, get_var, _VARS
from molotov.tests.support import async_test

_HERE = os.path.dirname(__file__)
config = os.path.join(_HERE, "..", "..", "molotov.json")
Expand All @@ -18,7 +19,8 @@ def setUp(self):
super(TestUtil, self).setUp()
_VARS.clear()

def test_resolve(self):
@async_test
async def test_resolve(self, loop, console, results):

urls = [
("http://localhost:80/blah", "http://127.0.0.1:80/blah"),
Expand All @@ -33,7 +35,7 @@ def test_resolve(self):
]

for url, wanted in urls:
changed, original, resolved = resolve(url)
changed, original, resolved = await resolve(url, loop=loop)
self.assertEqual(changed, wanted, "%s vs %s" % (original, resolved))

def test_config(self):
Expand Down
19 changes: 16 additions & 3 deletions molotov/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import platform
from urllib.parse import urlparse, urlunparse
from socket import gethostbyname

from aiohttp import ClientSession, __version__
from aiohttp.resolver import DefaultResolver

# this lib works for CPython 3.7+
if platform.python_implementation() == "PyPy" or sys.version_info.minor < 7:
Expand Down Expand Up @@ -54,7 +56,15 @@ def is_stopped():
return _STOP


def resolve(url):
_RESOLVERS = {}


async def resolve(url, loop=None):
if loop in _RESOLVERS:
resolver = _RESOLVERS[loop]
else:
resolver = _RESOLVERS[loop] = DefaultResolver(loop=loop)

parts = urlparse(url)

if "@" in parts.netloc:
Expand Down Expand Up @@ -84,10 +94,13 @@ def resolve(url):
resolved = _DNS_CACHE[host]
else:
try:
resolved = gethostbyname(host)
_DNS_CACHE[host] = resolved
hosts = await resolver.resolve(host, port)
except socket.gaierror:
hosts = []
if len(hosts) == 0:
return url, original, host
resolved = hosts[0]["host"]
_DNS_CACHE[host] = resolved

# Don't use a resolved hostname for SSL requests otherwise the
# certificate will not match the IP address (resolved)
Expand Down

0 comments on commit 4c9a6e7

Please sign in to comment.