diff --git a/nebula3/gclient/net/ConnectionPool.py b/nebula3/gclient/net/ConnectionPool.py index d9ad4099..29d3a99c 100644 --- a/nebula3/gclient/net/ConnectionPool.py +++ b/nebula3/gclient/net/ConnectionPool.py @@ -63,7 +63,7 @@ def init(self, addresses, configs, ssl_conf=None): self._addresses.append(ip_port) self._addresses_status[ip_port] = self.S_BAD self._connections[ip_port] = deque() - + self._ssl_configs = ssl_conf self.update_servers_status() # detect the services @@ -78,20 +78,13 @@ def init(self, addresses, configs, ssl_conf=None): conns_per_address = int(self._configs.min_connection_pool_size / ok_num) - if self._ssl_configs is None: - for addr in self._addresses: - for i in range(0, conns_per_address): - connection = Connection() - connection.open(addr[0], addr[1], self._configs.timeout) - self._connections[addr].append(connection) - else: - for addr in self._addresses: - for i in range(0, conns_per_address): - connection = Connection() - connection.open_SSL( - addr[0], addr[1], self._configs.timeout, self._ssl_configs - ) - self._connections[addr].append(connection) + for addr in self._addresses: + for i in range(0, conns_per_address): + connection = Connection() + connection.open_SSL( + addr[0], addr[1], self._configs.timeout, self._ssl_configs + ) + self._connections[addr].append(connection) return True def get_session(self, user_name, password, retry_connect=True): @@ -148,6 +141,7 @@ def get_connection(self): try: ok_num = self.get_ok_servers_num() if ok_num == 0: + logging.error('No available server') return None max_con_per_address = int( self._configs.max_connection_pool_size / ok_num @@ -157,24 +151,37 @@ def get_connection(self): self._pos = (self._pos + 1) % len(self._addresses) addr = self._addresses[self._pos] if self._addresses_status[addr] == self.S_OK: + invalid_connections = list() + + # iterate all connections to find an available connection for connection in self._connections[addr]: if not connection.is_used: + # ping to check the connection is valid if connection.ping(): connection.is_used = True logger.info('Get connection to {}'.format(addr)) return connection + else: + invalid_connections.append(connection) + + # remove invalid connections + for connection in invalid_connections: + self._connections[addr].remove(connection) + # check if the server is still alive + if not self.ping(addr): + self._addresses_status[addr] = self.S_BAD + continue + + # create new connection if the number of connections is less than max_con_per_address if len(self._connections[addr]) < max_con_per_address: connection = Connection() - if self._ssl_configs is None: - connection.open(addr[0], addr[1], self._configs.timeout) - else: - connection.open_SSL( - addr[0], - addr[1], - self._configs.timeout, - self._ssl_configs, - ) + connection.open_SSL( + addr[0], + addr[1], + self._configs.timeout, + self._ssl_configs, + ) connection.is_used = True self._connections[addr].append(connection) logger.info('Get connection to {}'.format(addr)) @@ -184,6 +191,8 @@ def get_connection(self): if not connection.is_used: self._connections[addr].remove(connection) try_count = try_count + 1 + + logging.error('No available connection') return None except Exception as ex: logger.error('Get connection failed: {}'.format(ex)) @@ -197,10 +206,7 @@ def ping(self, address): """ try: conn = Connection() - if self._ssl_configs is None: - conn.open(address[0], address[1], 1000) - else: - conn.open_SSL(address[0], address[1], 1000, self._ssl_configs) + conn.open_SSL(address[0], address[1], 1000, self._ssl_configs) conn.close() return True except Exception as ex: @@ -218,9 +224,7 @@ def close(self): for addr in self._connections.keys(): for connection in self._connections[addr]: if connection.is_used: - logger.error( - 'The connection using by someone, but now want to close it' - ) + logger.warning('Closing a connection that is in use') connection.close() self._close = True @@ -286,7 +290,7 @@ def _remove_idle_unusable_connection(self): if not connection.is_used: if not connection.ping(): logger.debug( - 'Remove the not unusable connection to {}'.format( + 'Remove the unusable connection to {}'.format( connection.get_address() ) ) diff --git a/tests/test_pool.py b/tests/test_pool.py index 16fb72d7..84b0fbf7 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -170,14 +170,15 @@ def test_multi_thread(): # Test multi thread addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)] configs = Config() - configs.max_connection_pool_size = 4 + thread_num = 50 + configs.max_connection_pool_size = thread_num pool = ConnectionPool() assert pool.init(addresses, configs) global success_flag success_flag = True - def main_test(): + def pool_multi_thread_test(): session = None global success_flag try: @@ -187,7 +188,7 @@ def main_test(): return space_name = 'space_' + threading.current_thread().getName() - session.execute('DROP SPACE %s' % space_name) + session.execute('DROP SPACE IF EXISTS %s' % space_name) resp = session.execute( 'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' % space_name ) @@ -207,20 +208,108 @@ def main_test(): if session is not None: session.release() - thread1 = threading.Thread(target=main_test, name='thread1') - thread2 = threading.Thread(target=main_test, name='thread2') - thread3 = threading.Thread(target=main_test, name='thread3') - thread4 = threading.Thread(target=main_test, name='thread4') + threads = [] + for num in range(0, thread_num): + thread = threading.Thread( + target=pool_multi_thread_test, name='test_pool_thread' + str(num) + ) + thread.start() + threads.append(thread) - thread1.start() - thread2.start() - thread3.start() - thread4.start() - - thread1.join() - thread2.join() - thread3.join() - thread4.join() + for t in threads: + t.join() + assert success_flag pool.close() + + +def test_session_context_multi_thread(): + # Test multi thread + addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670)] + configs = Config() + thread_num = 50 + configs.max_connection_pool_size = thread_num + pool = ConnectionPool() + assert pool.init(addresses, configs) + + global success_flag + success_flag = True + + def pool_session_context_multi_thread_test(): + session = None + global success_flag + try: + with pool.session_context('root', 'nebula') as session: + if session is None: + success_flag = False + return + space_name = 'space_' + threading.current_thread().getName() + + session.execute('DROP SPACE IF EXISTS %s' % space_name) + resp = session.execute( + 'CREATE SPACE IF NOT EXISTS %s(vid_type=FIXED_STRING(8))' + % space_name + ) + if not resp.is_succeeded(): + raise RuntimeError( + 'CREATE SPACE failed: {}'.format(resp.error_msg()) + ) + + time.sleep(3) + resp = session.execute('USE %s' % space_name) + if not resp.is_succeeded(): + raise RuntimeError('USE SPACE failed:{}'.format(resp.error_msg())) + + except Exception as x: + print(x) + success_flag = False + return + + threads = [] + for num in range(0, thread_num): + thread = threading.Thread( + target=pool_session_context_multi_thread_test, + name='test_session_context_thread' + str(num), + ) + thread.start() + threads.append(thread) + + for t in threads: + t.join() assert success_flag + + pool.close() + + +def test_remove_invalid_connection(): + addresses = [('127.0.0.1', 9669), ('127.0.0.1', 9670), ('127.0.0.1', 9671)] + configs = Config() + configs.min_connection_pool_size = 30 + configs.max_connection_pool_size = 45 + pool = ConnectionPool() + + try: + assert pool.init(addresses, configs) + + # turn down one server('127.0.0.1', 9669) so the connection to it is invalid + os.system('docker stop tests_graphd0_1') + time.sleep(3) + + # get connection from the pool, we should be able to still get 30 connections even though one server is down + for i in range(0, 30): + conn = pool.get_connection() + assert conn is not None + + # total connection should still be 30 + assert pool.connects() == 30 + + # the number of connections to the down server should be 0 + assert len(pool._connections[addresses[0]]) == 0 + + # the number of connections to the 2nd('127.0.0.1', 9670) and 3rd server('127.0.0.1', 9671) should be 15 + assert len(pool._connections[addresses[1]]) == 15 + assert len(pool._connections[addresses[2]]) == 15 + + finally: + os.system('docker start tests_graphd0_1') + time.sleep(3)