Skip to content

Commit

Permalink
Add support for getaddrinfo (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Mar 27, 2024
1 parent d40f913 commit c77e97a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Futu
cb = functools.partial(self._callback, fut)
self._channel.gethostbyname(host, family, cb)
return fut

def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
cb = functools.partial(self._callback, fut)
self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags)
return fut

def gethostbyaddr(self, name: str) -> asyncio.Future:
fut = asyncio.Future(loop=self.loop) # type: asyncio.Future
Expand Down
18 changes: 18 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def test_gethostbyname(self):
result = self.loop.run_until_complete(f)
self.assertTrue(result)

def test_getaddrinfo_address_family_0(self):
f = self.resolver.getaddrinfo('google.com')
result = self.loop.run_until_complete(f)
self.assertTrue(result)
self.assertTrue(len(result.nodes) > 1)

def test_getaddrinfo_address_family_af_inet(self):
f = self.resolver.getaddrinfo('google.com', socket.AF_INET)
result = self.loop.run_until_complete(f)
self.assertTrue(result)
self.assertTrue(all(node.family == socket.AF_INET for node in result.nodes))

def test_getaddrinfo_address_family_af_inet6(self):
f = self.resolver.getaddrinfo('google.com', socket.AF_INET6)
result = self.loop.run_until_complete(f)
self.assertTrue(result)
self.assertTrue(all(node.family == socket.AF_INET6 for node in result.nodes))

@unittest.skipIf(sys.platform == 'win32', 'skipped on Windows')
def test_gethostbyaddr(self):
f = self.resolver.gethostbyaddr('127.0.0.1')
Expand Down

0 comments on commit c77e97a

Please sign in to comment.