Skip to content

Commit

Permalink
Close requests.Socket in RemoteScheduler before exiting (#3173) (#3175)
Browse files Browse the repository at this point in the history
Co-authored-by: Dillon Stadther <dlstadther@gmail.com>
  • Loading branch information
starhel and dlstadther committed Jun 23, 2022
1 parent a295041 commit c135664
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
2 changes: 2 additions & 0 deletions luigi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No
success &= worker.run()
luigi_run_result = LuigiRunResult(worker, success)
logger.info(luigi_run_result.summary_text)
if hasattr(sch, 'close'):
sch.close()
return luigi_run_result


Expand Down
30 changes: 25 additions & 5 deletions luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
rpc.py implements the client side of it, server.py implements the server side.
See :doc:`/central_scheduler` for more info.
"""
import abc
import os
import json
import logging
Expand Down Expand Up @@ -69,7 +70,17 @@ def __init__(self, message, sub_exception=None):
self.sub_exception = sub_exception


class URLLibFetcher:
class _FetcherInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def fetch(self, full_url, body, timeout):
pass

@abc.abstractmethod
def close(self):
pass


class URLLibFetcher(_FetcherInterface):
raises = (URLError, socket.timeout)

def _create_request(self, full_url, body=None):
Expand All @@ -96,12 +107,15 @@ def fetch(self, full_url, body, timeout):
req = self._create_request(full_url, body=body)
return urlopen(req, timeout=timeout).read().decode('utf-8')

def close(self):
pass

class RequestsFetcher:
def __init__(self, session):

class RequestsFetcher(_FetcherInterface):
def __init__(self):
from requests import exceptions as requests_exceptions
self.raises = requests_exceptions.RequestException
self.session = session
self.session = requests.Session()
self.process_id = os.getpid()

def check_pid(self):
Expand All @@ -117,6 +131,9 @@ def fetch(self, full_url, body, timeout):
resp.raise_for_status()
return resp.text

def close(self):
self.session.close()


class RemoteScheduler:
"""
Expand All @@ -140,10 +157,13 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None):
self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True)

if HAS_REQUESTS:
self._fetcher = RequestsFetcher(requests.Session())
self._fetcher = RequestsFetcher()
else:
self._fetcher = URLLibFetcher()

def close(self):
self._fetcher.close()

def _get_retryer(self):
def retry_logging(retry_state):
if self._rpc_log_retries:
Expand Down
5 changes: 2 additions & 3 deletions test/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from server_test import ServerTestBase
import socket
from multiprocessing import Process, Queue
import requests


class RemoteSchedulerTest(unittest.TestCase):
Expand Down Expand Up @@ -147,8 +146,8 @@ def test_get_work_speed(self):

class RequestsFetcherTest(ServerTestBase):
def test_fork_changes_session(self):
session = requests.Session()
fetcher = luigi.rpc.RequestsFetcher(session)
fetcher = luigi.rpc.RequestsFetcher()
session = fetcher.session

q = Queue()

Expand Down

0 comments on commit c135664

Please sign in to comment.