Skip to content

Commit

Permalink
Implement another solution to fix issue #27
Browse files Browse the repository at this point in the history
  • Loading branch information
romis2012 committed Oct 3, 2023
1 parent 73865ba commit adb9bb0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion aiohttp_socks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__title__ = 'aiohttp-socks'
__version__ = '0.8.3'
__version__ = '0.8.4'

from python_socks import (
ProxyError,
Expand Down
35 changes: 18 additions & 17 deletions aiohttp_socks/connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import socket
import typing
from asyncio import BaseTransport, StreamWriter
Expand Down Expand Up @@ -27,19 +28,17 @@ async def close(self):
pass # pragma: no cover


class RepairedStreamWriter(StreamWriter):
def __del__(self):
pass


def patch_stream(stream):
class _ResponseHandler(ResponseHandler):
"""
Fix issue https://github.com/romis2012/aiohttp-socks/issues/27
To fix issue https://github.com/romis2012/aiohttp-socks/issues/27
In Python>=3.11.5 we need to keep a reference to the StreamWriter
so that the underlying transport is not closed during garbage collection.
See StreamWriter.__del__ method (was added in Python 3.11.5)
"""
stream.writer.__class__ = RepairedStreamWriter
while hasattr(stream, '_inner'): # pragma: no cover
stream = stream._inner # noqa
stream.writer.__class__ = RepairedStreamWriter

def __init__(self, loop: asyncio.AbstractEventLoop, writer: StreamWriter):
super().__init__(loop)
self._writer = writer


class ProxyConnector(TCPConnector):
Expand Down Expand Up @@ -91,13 +90,14 @@ async def _wrap_create_connection(self, protocol_factory, host, port, *, ssl, **
)

transport: BaseTransport = stream.writer.transport
protocol: ResponseHandler = protocol_factory()
protocol: ResponseHandler = _ResponseHandler(
loop=self._loop,
writer=stream.writer,
)

transport.set_protocol(protocol)
protocol.connection_made(transport)

patch_stream(stream)

return transport, protocol

@classmethod
Expand Down Expand Up @@ -159,13 +159,14 @@ async def _wrap_create_connection(self, protocol_factory, host, port, *, ssl, **
)

transport: BaseTransport = stream.writer.transport
protocol: ResponseHandler = protocol_factory()
protocol: ResponseHandler = _ResponseHandler(
loop=self._loop,
writer=stream.writer,
)

transport.set_protocol(protocol)
protocol.connection_made(transport)

patch_stream(stream)

return transport, protocol

@classmethod
Expand Down

0 comments on commit adb9bb0

Please sign in to comment.