diff --git a/src/factory/DnsTaskImpl.cc b/src/factory/DnsTaskImpl.cc index 46de4c89a3..83299a5f18 100644 --- a/src/factory/DnsTaskImpl.cc +++ b/src/factory/DnsTaskImpl.cc @@ -17,9 +17,11 @@ */ #include +#include +#include "DnsMessage.h" #include "WFTaskError.h" #include "WFTaskFactory.h" -#include "DnsMessage.h" +#include "WFServer.h" using namespace protocol; @@ -31,16 +33,13 @@ class ComplexDnsTask : public WFComplexClientTask> { static struct addrinfo hints; + static std::atomic seq; public: ComplexDnsTask(int retry_max, dns_callback_t&& cb): WFComplexClientTask(retry_max, std::move(cb)) { -#ifdef _WIN32 - this->set_transport_type(TT_TCP); -#else this->set_transport_type(TT_UDP); -#endif } protected: @@ -54,24 +53,21 @@ class ComplexDnsTask : public WFComplexClientTask ComplexDnsTask::seq(0); + CommMessageOut *ComplexDnsTask::message_out() { DnsRequest *req = this->get_req(); DnsResponse *resp = this->get_resp(); - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); if (req->get_id() == 0) - req->set_id((this->get_seq() + 1) * 99991 % 65535 + 1); + req->set_id(++ComplexDnsTask::seq * 99991 % 65535 + 1); resp->set_request_id(req->get_id()); resp->set_request_name(req->get_question_name()); req->set_single_packet(type == TT_UDP); @@ -93,21 +89,22 @@ bool ComplexDnsTask::init_success() if (!this->route_result_.request_object) { - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); struct addrinfo *addr; int ret; ret = getaddrinfo(uri_.host, uri_.port, &hints, &addr); if (ret != 0) { - this->state = WFT_STATE_TASK_ERROR; - this->error = WFT_ERR_URI_PARSE_FAILED; + this->state = WFT_STATE_DNS_ERROR; + this->error = ret; return false; } auto *ep = &WFGlobal::get_global_settings()->dns_server_params; ret = WFGlobal::get_route_manager()->get(type, addr, info_, ep, - uri_.host, route_result_); + uri_.host, ssl_ctx_, + route_result_); freeaddrinfo(addr); if (ret < 0) { @@ -146,7 +143,7 @@ bool ComplexDnsTask::finish_once() bool ComplexDnsTask::need_redirect() { DnsResponse *client_resp = this->get_resp(); - TransportType type = this->get_transport_type(); + enum TransportType type = this->get_transport_type(); if (type == TT_UDP && client_resp->get_tc() == 1) { @@ -189,3 +186,42 @@ WFDnsTask *WFTaskFactory::create_dns_task(const ParsedURI& uri, return task; } + +/**********Server**********/ + +class WFDnsServerTask : public WFServerTask +{ +public: + WFDnsServerTask(CommService *service, + std::function& proc) : + WFServerTask(service, WFGlobal::get_scheduler(), proc) + { + // this->type = ((WFServerBase *)service)->get_params()->transport_type; + this->type = TT_TCP; + } + +protected: + virtual CommMessageIn *message_in() + { + this->get_req()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_in(); + } + + virtual CommMessageOut *message_out() + { + this->get_resp()->set_single_packet(this->type == TT_UDP); + return this->WFServerTask::message_out(); + } + +protected: + enum TransportType type; +}; + +/**********Server Factory**********/ + +WFDnsTask *WFServerTaskFactory::create_dns_task(CommService *service, + std::function& proc) +{ + return new WFDnsServerTask(service, proc); +} + diff --git a/src/factory/HttpTaskImpl.cc b/src/factory/HttpTaskImpl.cc index 4e61265919..73c8dba74b 100644 --- a/src/factory/HttpTaskImpl.cc +++ b/src/factory/HttpTaskImpl.cc @@ -484,7 +484,8 @@ class ComplexHttpProxyTask : public ComplexHttpTask int ComplexHttpProxyTask::init_ssl_connection() { - SSL *ssl = __create_ssl(WFGlobal::get_ssl_client_ctx()); + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + SSL *ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); WFConnection *conn; if (!ssl) diff --git a/src/factory/KafkaTaskImpl.cc b/src/factory/KafkaTaskImpl.cc index 30b71badd0..47518caabf 100644 --- a/src/factory/KafkaTaskImpl.cc +++ b/src/factory/KafkaTaskImpl.cc @@ -747,16 +747,28 @@ __WFKafkaTask *__WFKafkaTaskFactory::create_kafka_task(enum TransportType type, { auto *task = new __ComplexKafkaTask(retry_max, std::move(callback)); - std::string url = (type == TT_TCP_SSL ? "kafkas://" : "kafka://"); + ParsedURI uri; + char buf[32]; + + if (type == TT_TCP_SSL) + uri.scheme = strdup("kafkas"); + else + uri.scheme = strdup("kafka"); if (!info.empty()) - url += info + "@"; + uri.userinfo = strdup(info.c_str()); - url += host; - url += ":" + std::to_string(port); + uri.host = strdup(host); + sprintf(buf, "%u", port); + uri.port = strdup(buf); + + if (!uri.scheme || !uri.host || !uri.port || + (!info.empty() && !uri.userinfo)) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } - ParsedURI uri; - URIParser::parse(url, uri); task->init(std::move(uri)); task->set_keep_alive(KAFKA_KEEPALIVE_DEFAULT); return task; diff --git a/src/factory/MySQLTaskImpl.cc b/src/factory/MySQLTaskImpl.cc index c71348741d..2d62bd0e4f 100644 --- a/src/factory/MySQLTaskImpl.cc +++ b/src/factory/MySQLTaskImpl.cc @@ -241,7 +241,7 @@ CommMessageOut *ComplexMySQLTask::message_out() break; case ST_FIRST_USER_REQUEST: - if (this->is_fixed_addr()) + if (this->is_fixed_conn()) { auto *target = (RouteManager::RouteTarget *)this->target; @@ -350,7 +350,21 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) if (is_ssl_) { - if (!(resp->get_capability_flags() & 0x800)) + if (resp->get_capability_flags() & 0x800) + { + static SSL_CTX *ssl_ctx = WFGlobal::get_ssl_client_ctx(); + + ssl = __create_ssl(ssl_ctx_ ? ssl_ctx_ : ssl_ctx); + if (!ssl) + { + state_ = WFT_STATE_SYS_ERROR; + error_ = errno; + return 0; + } + + SSL_set_connect_state(ssl); + } + else { this->resp = std::move(*(MySQLResponse *)resp); state_ = WFT_STATE_TASK_ERROR; @@ -358,15 +372,6 @@ int ComplexMySQLTask::check_handshake(MySQLHandshakeResponse *resp) return 0; } - ssl = __create_ssl(WFGlobal::get_ssl_client_ctx()); - if (!ssl) - { - state_ = WFT_STATE_SYS_ERROR; - error_ = errno; - return 0; - } - - SSL_set_connect_state(ssl); } auto *conn = this->get_connection(); @@ -712,9 +717,9 @@ bool ComplexMySQLTask::init_success() if (!transaction.empty()) { - this->WFComplexClientTask::set_info(std::string("?maxconn=1&") + - info + "|txn:" + transaction); this->set_fixed_addr(true); + this->set_fixed_conn(true); + this->WFComplexClientTask::set_info(info + ("|txn:" + transaction)); } else this->WFComplexClientTask::set_info(info); @@ -741,7 +746,7 @@ bool ComplexMySQLTask::finish_once() return false; } - if (this->is_fixed_addr()) + if (this->is_fixed_conn()) { if (this->state != WFT_STATE_SUCCESS || this->keep_alive_timeo == 0) { @@ -767,7 +772,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const std::string& url, URIParser::parse(url, uri); task->init(std::move(uri)); - if (task->is_fixed_addr()) + if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); @@ -782,7 +787,7 @@ WFMySQLTask *WFTaskFactory::create_mysql_task(const ParsedURI& uri, auto *task = new ComplexMySQLTask(retry_max, std::move(callback)); task->init(uri); - if (task->is_fixed_addr()) + if (task->is_fixed_conn()) task->set_keep_alive(MYSQL_KEEPALIVE_TRANSACTION); else task->set_keep_alive(MYSQL_KEEPALIVE_DEFAULT); diff --git a/src/factory/WFTaskFactory.h b/src/factory/WFTaskFactory.h index 48573c2186..ced15df738 100644 --- a/src/factory/WFTaskFactory.h +++ b/src/factory/WFTaskFactory.h @@ -407,6 +407,13 @@ class WFNetworkTaskFactory int retry_max, std::function callback); + static T *create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + SSL_CTX *ssl_ctx, + int retry_max, + std::function callback); + public: static T *create_server_task(CommService *service, std::function& process); diff --git a/src/factory/WFTaskFactory.inl b/src/factory/WFTaskFactory.inl index 37763b112c..dd90d49ce3 100644 --- a/src/factory/WFTaskFactory.inl +++ b/src/factory/WFTaskFactory.inl @@ -26,6 +26,7 @@ #include #include #include +#include #include "PlatformSocket.h" #include "WFGlobal.h" #include "Workflow.h" @@ -73,7 +74,9 @@ public: WFClientTask(NULL, WFGlobal::get_scheduler(), std::move(cb)) { type_ = TT_TCP; + ssl_ctx_ = NULL; fixed_addr_ = false; + fixed_conn_ = false; retry_max_ = retry_max; retry_times_ = 0; redirect_ = false; @@ -102,17 +105,19 @@ public: init_with_uri(); } - void init(TransportType type, + void init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info); - void set_transport_type(TransportType type) + void set_transport_type(enum TransportType type) { type_ = type; } - TransportType get_transport_type() const { return type_; } + enum TransportType get_transport_type() const { return type_; } + + void set_ssl_ctx(SSL_CTX *ssl_ctx) { ssl_ctx_ = ssl_ctx; } virtual const ParsedURI *get_current_uri() const { return &uri_; } @@ -122,7 +127,7 @@ public: init(uri); } - void set_redirect(TransportType type, const struct sockaddr *addr, + void set_redirect(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) { redirect_ = true; @@ -131,9 +136,13 @@ public: bool is_fixed_addr() const { return this->fixed_addr_; } + bool is_fixed_conn() const { return this->fixed_conn_; } + protected: void set_fixed_addr(int fixed) { this->fixed_addr_ = fixed; } + void set_fixed_conn(int fixed) { this->fixed_conn_ = fixed; } + void set_info(const std::string& info) { info_.assign(info); @@ -163,10 +172,12 @@ protected: } protected: - TransportType type_; + enum TransportType type_; ParsedURI uri_; std::string info_; + SSL_CTX *ssl_ctx_; bool fixed_addr_; + bool fixed_conn_; bool redirect_; CTX ctx_; int retry_max_; @@ -205,7 +216,7 @@ void WFComplexClientTask::clear_prev_state() } template -void WFComplexClientTask::init(TransportType type, +void WFComplexClientTask::init(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, const std::string& info) @@ -216,7 +227,6 @@ void WFComplexClientTask::init(TransportType type, auto params = WFGlobal::get_global_settings()->endpoint_params; struct addrinfo addrinfo = { }; addrinfo.ai_family = addr->sa_family; - addrinfo.ai_socktype = SOCK_STREAM; addrinfo.ai_addr = (struct sockaddr *)addr; addrinfo.ai_addrlen = addrlen; @@ -224,7 +234,7 @@ void WFComplexClientTask::init(TransportType type, info_.assign(info); params.use_tls_sni = false; if (WFGlobal::get_route_manager()->get(type, &addrinfo, info_, ¶ms, - "", route_result_) < 0) + "", ssl_ctx_, route_result_) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -277,10 +287,10 @@ template void WFComplexClientTask::init_with_uri() { if (redirect_) - { - clear_prev_state(); - ns_policy_ = WFGlobal::get_dns_resolver(); - } + { + clear_prev_state(); + ns_policy_ = WFGlobal::get_dns_resolver(); + } if (uri_.state == URI_STATE_SUCCESS) { @@ -311,12 +321,14 @@ WFRouterTask *WFComplexClientTask::route() this, std::placeholders::_1); struct WFNSParams params = { - /*.type =*/ type_, - /*.uri =*/ uri_, - /*.info =*/ info_.c_str(), - /*.fixed_addr =*/ fixed_addr_, - /*.retry_times =*/ retry_times_, - /*.tracing =*/ &tracing_, + /*.type =*/ type_, + /*.uri =*/ uri_, + /*.info =*/ info_.c_str(), + /*.ssl_ctx =*/ ssl_ctx_, + /*.fixed_addr =*/ fixed_addr_, + /*.fixed_conn =*/ fixed_conn_, + /*.retry_times =*/ retry_times_, + /*.tracing =*/ &tracing_, }; if (!ns_policy_) @@ -475,22 +487,26 @@ SubTask *WFComplexClientTask::done() template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& host, unsigned short port, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); - char buf[8]; - std::string url = "scheme://"; ParsedURI uri; + char buf[32]; sprintf(buf, "%u", port); - url += host; - url += ":"; - url += buf; - URIParser::parse(url, uri); + uri.scheme = strdup("scheme"); + uri.host = strdup(host.c_str()); + uri.port = strdup(buf); + if (!uri.scheme || !uri.host || !uri.port) + { + uri.state = URI_STATE_ERROR; + uri.error = errno; + } + task->init(std::move(uri)); task->set_transport_type(type); return task; @@ -498,7 +514,7 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const std::string& url, int retry_max, std::function *)> callback) @@ -514,7 +530,7 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, const ParsedURI& uri, int retry_max, std::function *)> callback) @@ -528,14 +544,30 @@ WFNetworkTaskFactory::create_client_task(TransportType type, template WFNetworkTask * -WFNetworkTaskFactory::create_client_task(TransportType type, +WFNetworkTaskFactory::create_client_task(enum TransportType type, + const struct sockaddr *addr, + socklen_t addrlen, + int retry_max, + std::function *)> callback) +{ + auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + + task->init(type, addr, addrlen, ""); + return task; +} + +template +WFNetworkTask * +WFNetworkTaskFactory::create_client_task(enum TransportType type, const struct sockaddr *addr, socklen_t addrlen, + SSL_CTX *ssl_ctx, int retry_max, std::function *)> callback) { auto *task = new WFComplexClientTask(retry_max, std::move(callback)); + task->set_ssl_ctx(ssl_ctx); task->init(type, addr, addrlen, ""); return task; } @@ -553,6 +585,9 @@ WFNetworkTaskFactory::create_server_task(CommService *service, class WFServerTaskFactory { public: + static WFDnsTask *create_dns_task(CommService *service, + std::function& proc); + static WFHttpTask *create_http_task(CommService *service, std::function& proc) { @@ -670,26 +705,24 @@ void WFTaskFactory::reset_go_task(WFGoTask *task, FUNC&& func, ARGS&&... args) { auto&& tmp = std::bind(std::forward(func), std::forward(args)...); - static_cast<__WFGoTask *>(task)->set_go_func(std::move(tmp)); + ((__WFGoTask *)task)->set_go_func(std::move(tmp)); } /**********Create go task with nullptr func**********/ -template<> -inline WFGoTask *WFTaskFactory::create_go_task - (const std::string& queue_name, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_go_task(const std::string& queue_name, + std::nullptr_t&& func) { return new __WFGoTask(WFGlobal::get_exec_queue(queue_name), WFGlobal::get_compute_executor(), nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_timedgo_task - (time_t seconds, long nanoseconds, - const std::string& queue_name, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + const std::string& queue_name, + std::nullptr_t&& func) { return new __WFTimedGoTask(seconds, nanoseconds, WFGlobal::get_exec_queue(queue_name), @@ -697,28 +730,25 @@ inline WFGoTask *WFTaskFactory::create_timedgo_task nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_go_task - (ExecQueue *queue, Executor *executor, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_go_task(ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) { return new __WFGoTask(queue, executor, nullptr); } -template<> -inline WFGoTask *WFTaskFactory::create_timedgo_task - (time_t seconds, long nanoseconds, - ExecQueue *queue, Executor *executor, - std::nullptr_t&& func) +template<> inline +WFGoTask *WFTaskFactory::create_timedgo_task(time_t seconds, long nanoseconds, + ExecQueue *queue, Executor *executor, + std::nullptr_t&& func) { return new __WFTimedGoTask(seconds, nanoseconds, queue, executor, nullptr); } -template<> -inline void WFTaskFactory::reset_go_task - (WFGoTask *task, std::nullptr_t&& func) +template<> inline +void WFTaskFactory::reset_go_task(WFGoTask *task, std::nullptr_t&& func) { - static_cast<__WFGoTask *>(task)->set_go_func(nullptr); + ((__WFGoTask *)task)->set_go_func(nullptr); } /**********Template Thread Task Factory**********/ diff --git a/src/kernel/Communicator.cc b/src/kernel/Communicator.cc index d81fc5e5db..41c72bb4bd 100644 --- a/src/kernel/Communicator.cc +++ b/src/kernel/Communicator.cc @@ -74,8 +74,8 @@ static inline int __set_fd_nonblock(int fd) return flags; } -static int __bind_and_listen(int sockfd, const struct sockaddr *addr, - socklen_t addrlen) +static int __bind_sockaddr(int sockfd, const struct sockaddr *addr, + socklen_t addrlen) { struct sockaddr_storage ss; socklen_t len; @@ -97,7 +97,7 @@ static int __bind_and_listen(int sockfd, const struct sockaddr *addr, return -1; } - return listen(sockfd, SOMAXCONN < 4096 ? 4096 : SOMAXCONN); + return 0; } static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) @@ -119,6 +119,48 @@ static int __create_ssl(SSL_CTX *ssl_ctx, struct CommConnEntry *entry) return -1; } +static int __send_to_conn(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + const struct sockaddr *addr; + socklen_t addrlen; + int ret; + + if (!entry->ssl) + { + entry->target->get_addr(&addr, &addrlen); + return sendto(entry->sockfd, buf, size, 0, addr, addrlen); + } + + if (size == 0) + return 0; + + ret = SSL_write(entry->ssl, buf, size); + if (ret <= 0) + { + ret = SSL_get_error(entry->ssl, ret); + if (ret != SSL_ERROR_SYSCALL) + errno = -ret; + + ret = -1; + } + + return ret; +} + +static void __release_conn(struct CommConnEntry *entry) +{ + delete entry->conn; + if (!entry->service) + pthread_mutex_destroy(&entry->mutex); + + if (entry->ssl) + SSL_free(entry->ssl); + + close(entry->sockfd); + free(entry); +} + #define SSL_WRITE_BUFSIZE 8192 static int __ssl_writev(SSL *ssl, const struct iovec vectors[], int cnt) @@ -186,26 +228,7 @@ void CommTarget::deinit() int CommMessageIn::feedback(const void *buf, size_t size) { - struct CommConnEntry *entry = this->entry; - int ret; - - if (!entry->ssl) - return write(entry->sockfd, buf, size); - - if (size == 0) - return 0; - - ret = SSL_write(entry->ssl, buf, size); - if (ret <= 0) - { - ret = SSL_get_error(entry->ssl, ret); - if (ret != SSL_ERROR_SYSCALL) - errno = -ret; - - ret = -1; - } - - return ret; + return __send_to_conn(buf, size, this->entry); } void CommMessageIn::renew() @@ -305,6 +328,9 @@ class CommServiceTarget : public CommTarget } } +public: + int shutdown(); + private: int sockfd; int ref; @@ -322,36 +348,50 @@ class CommServiceTarget : public CommTarget friend class Communicator; }; -CommSession::~CommSession() +int CommServiceTarget::shutdown() { struct CommConnEntry *entry; - struct list_head *pos; - CommTarget *target; int errno_bak; + int ret = 0; - if (!this->passive) - return; - - target = this->target; - if (this->passive == 1) + pthread_mutex_lock(&this->mutex); + if (!list_empty(&this->idle_list)) { - pthread_mutex_lock(&target->mutex); - if (!list_empty(&target->idle_list)) - { - pos = target->idle_list.next; - entry = list_entry(pos, struct CommConnEntry, list); - list_del(pos); + entry = list_entry(this->idle_list.next, struct CommConnEntry, list); + list_del(&entry->list); + if (this->service->reliable) + { errno_bak = errno; mpoller_del(entry->sockfd, entry->mpoller); entry->state = CONN_STATE_CLOSING; errno = errno_bak; } + else + { + __release_conn(entry); + this->decref(); + } - pthread_mutex_unlock(&target->mutex); + ret = 1; } - ((CommServiceTarget *)target)->decref(); + pthread_mutex_unlock(&this->mutex); + return ret; +} + +CommSession::~CommSession() +{ + CommServiceTarget *target; + + if (!this->passive) + return; + + target = (CommServiceTarget *)this->target; + if (this->passive == 1) + target->shutdown(); + + target->decref(); } inline int Communicator::first_timeout(CommSession *session) @@ -404,19 +444,6 @@ int Communicator::first_timeout_recv(CommSession *session) return Communicator::first_timeout(session); } -void Communicator::release_conn(struct CommConnEntry *entry) -{ - delete entry->conn; - if (!entry->service) - pthread_mutex_destroy(&entry->mutex); - - if (entry->ssl) - SSL_free(entry->ssl); - - close(entry->sockfd); - free(entry); -} - void Communicator::shutdown_service(CommService *service) { close(service->listen_fd); @@ -670,7 +697,7 @@ void Communicator::handle_incoming_request(struct poller_result *res) if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } } @@ -752,7 +779,7 @@ void Communicator::handle_incoming_reply(struct poller_result *res) } if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); } } @@ -824,7 +851,7 @@ void Communicator::handle_reply_result(struct poller_result *res) session->handle(state, res->error); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); } @@ -877,7 +904,7 @@ void Communicator::handle_request_result(struct poller_result *res) /* do nothing */ pthread_mutex_unlock(&entry->mutex); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) - this->release_conn(entry); + __release_conn(entry); break; } @@ -910,7 +937,7 @@ struct CommConnEntry *Communicator::accept_conn(CommServiceTarget *target, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = service; entry->target = target; entry->ssl = NULL; @@ -996,7 +1023,7 @@ void Communicator::handle_connect_result(struct poller_result *res) target->release(); session->handle(state, res->error); - this->release_conn(entry); + __release_conn(entry); break; } } @@ -1012,9 +1039,10 @@ void Communicator::handle_listen_result(struct poller_result *res) { case PR_ST_SUCCESS: target = (CommServiceTarget *)res->data.result; - entry = this->accept_conn(target, service); + entry = Communicator::accept_conn(target, service); if (entry) { + entry->mpoller = this->mpoller; if (service->ssl_ctx) { if (__create_ssl(service->ssl_ctx, entry) >= 0 && @@ -1045,7 +1073,7 @@ void Communicator::handle_listen_result(struct poller_result *res) } } - this->release_conn(entry); + __release_conn(entry); } else close(target->sockfd); @@ -1064,6 +1092,54 @@ void Communicator::handle_listen_result(struct poller_result *res) } } +void Communicator::handle_recvfrom_result(struct poller_result *res) +{ + CommService *service = (CommService *)res->data.context; + struct CommConnEntry *entry; + CommTarget *target; + int state, error; + + switch (res->state) + { + case PR_ST_SUCCESS: + entry = (struct CommConnEntry *)res->data.result; + target = entry->target; + if (entry->state == CONN_STATE_SUCCESS) + { + state = CS_STATE_TOREPLY; + error = 0; + entry->state = CONN_STATE_IDLE; + list_add(&entry->list, &target->idle_list); + } + else + { + state = CS_STATE_ERROR; + if (entry->state == CONN_STATE_ERROR) + error = entry->error; + else + error = EBADMSG; + } + + entry->session->handle(state, error); + if (state == CS_STATE_ERROR) + { + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + + break; + + case PR_ST_DELETED: + this->shutdown_service(service); + break; + + case PR_ST_ERROR: + case PR_ST_STOPPED: + service->handle_stop(res->error); + break; + } +} + void Communicator::handle_ssl_accept_result(struct poller_result *res) { struct CommConnEntry *entry = (struct CommConnEntry *)res->data.context; @@ -1087,7 +1163,7 @@ void Communicator::handle_ssl_accept_result(struct poller_result *res) case PR_ST_DELETED: case PR_ST_ERROR: case PR_ST_STOPPED: - this->release_conn(entry); + __release_conn(entry); ((CommServiceTarget *)target)->decref(); break; } @@ -1184,6 +1260,9 @@ void Communicator::handler_thread_routine(void *context) case PD_OP_LISTEN: comm->handle_listen_result(res); break; + case PD_OP_RECVFROM: + comm->handle_recvfrom_result(res); + break; case PD_OP_SSL_ACCEPT: comm->handle_ssl_accept_result(res); break; @@ -1343,6 +1422,58 @@ poller_message_t *Communicator::create_reply(void *context) return session->in; } +int Communicator::recv_request(const void *buf, size_t size, + struct CommConnEntry *entry) +{ + CommService *service = entry->service; + CommTarget *target = entry->target; + CommSession *session; + size_t n; + int ret; + + session = service->new_session(entry->seq, entry->conn); + if (!session) + return -1; + + session->passive = 1; + entry->session = session; + session->target = target; + session->conn = entry->conn; + session->seq = entry->seq++; + session->out = NULL; + session->in = NULL; + + entry->state = CONN_STATE_RECEIVING; + + ((CommServiceTarget *)target)->incref(); + + session->in = session->message_in(); + if (session->in) + { + session->in->entry = entry; + do + { + n = size; + ret = session->in->append(buf, &n); + if (ret == 0) + { + size -= n; + buf = (const char *)buf + n; + } + else if (ret < 0) + { + entry->error = errno; + entry->state = CONN_STATE_ERROR; + } + else + entry->state = CONN_STATE_SUCCESS; + + } while (ret == 0 && size > 0); + } + + return 0; +} + int Communicator::partial_written(size_t n, void *context) { struct CommConnEntry *entry = (struct CommConnEntry *)context; @@ -1378,6 +1509,40 @@ void *Communicator::accept(const struct sockaddr *addr, socklen_t addrlen, return NULL; } +void *Communicator::recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context) +{ + CommService *service = (CommService *)context; + struct CommConnEntry *entry; + CommServiceTarget *target; + void *result; + int sockfd; + + sockfd = dup(service->listen_fd); + if (sockfd >= 0) + { + result = Communicator::accept(addr, addrlen, sockfd, context); + if (result) + { + target = (CommServiceTarget *)result; + entry = Communicator::accept_conn(target, service); + if (entry) + { + if (Communicator::recv_request(buf, size, entry) >= 0) + return entry; + + __release_conn(entry); + } + else + close(sockfd); + + target->decref(); + } + } + + return NULL; +} + void Communicator::callback(struct poller_result *res, void *context) { msgqueue_t *msgqueue = (msgqueue_t *)context; @@ -1502,7 +1667,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, int sockfd; int ret; - sockfd = this->nonblock_connect(target); + sockfd = Communicator::nonblock_connect(target); if (sockfd >= 0) { entry = (struct CommConnEntry *)malloc(sizeof (struct CommConnEntry)); @@ -1515,7 +1680,7 @@ struct CommConnEntry *Communicator::launch_conn(CommSession *session, if (entry->conn) { entry->seq = 0; - entry->mpoller = this->mpoller; + entry->mpoller = NULL; entry->service = NULL; entry->target = target; entry->session = session; @@ -1598,9 +1763,10 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) struct poller_data data; int timeout; - entry = this->launch_conn(session, target); + entry = Communicator::launch_conn(session, target); if (entry) { + entry->mpoller = this->mpoller; session->conn = entry->conn; session->seq = entry->seq++; data.operation = PD_OP_CONNECT; @@ -1611,7 +1777,7 @@ int Communicator::request_new_conn(CommSession *session, CommTarget *target) if (mpoller_add(&data, timeout, this->mpoller) >= 0) return 0; - this->release_conn(entry); + __release_conn(entry); } return -1; @@ -1648,15 +1814,21 @@ int Communicator::request(CommSession *session, CommTarget *target) int Communicator::nonblock_listen(CommService *service) { int sockfd = service->create_listen_fd(); + int ret; if (sockfd >= 0) { if (__set_fd_nonblock(sockfd) >= 0) { - if (__bind_and_listen(sockfd, service->bind_addr, - service->addrlen) >= 0) + if (__bind_sockaddr(sockfd, service->bind_addr, + service->addrlen) >= 0) { - return sockfd; + ret = listen(sockfd, SOMAXCONN); + if (ret >= 0 || errno == EOPNOTSUPP) + { + service->reliable = (ret >= 0); + return sockfd; + } } } @@ -1669,6 +1841,7 @@ int Communicator::nonblock_listen(CommService *service) int Communicator::bind(CommService *service) { struct poller_data data; + int errno_bak = errno; int sockfd; sockfd = this->nonblock_listen(service); @@ -1676,13 +1849,25 @@ int Communicator::bind(CommService *service) { service->listen_fd = sockfd; service->ref = 1; - data.operation = PD_OP_LISTEN; data.fd = sockfd; - data.accept = Communicator::accept; data.context = service; data.result = NULL; + if (service->reliable) + { + data.operation = PD_OP_LISTEN; + data.accept = Communicator::accept; + } + else + { + data.operation = PD_OP_RECVFROM; + data.recvfrom = Communicator::recvfrom; + } + if (mpoller_add(&data, service->listen_timeout, this->mpoller) >= 0) + { + errno = errno_bak; return 0; + } close(sockfd); } @@ -1702,7 +1887,7 @@ void Communicator::unbind(CommService *service) } } -int Communicator::reply_idle_conn(CommSession *session, CommTarget *target) +int Communicator::reply_reliable(CommSession *session, CommTarget *target) { struct CommConnEntry *entry; struct list_head *pos; @@ -1734,25 +1919,85 @@ int Communicator::reply_idle_conn(CommSession *session, CommTarget *target) return ret; } +int Communicator::reply_message_unreliable(struct CommConnEntry *entry) +{ + struct iovec vectors[ENCODE_IOV_MAX]; + int cnt; + + cnt = entry->session->out->encode(vectors, ENCODE_IOV_MAX); + if ((unsigned int)cnt > ENCODE_IOV_MAX) + { + if (cnt > ENCODE_IOV_MAX) + errno = EOVERFLOW; + return -1; + } + + if (cnt > 0) + { + struct msghdr message = { + .msg_name = entry->target->addr, + .msg_namelen = entry->target->addrlen, + .msg_iov = vectors, +#ifdef __linux__ + .msg_iovlen = (size_t)cnt, +#else + .msg_iovlen = cnt, +#endif + }; + if (sendmsg(entry->sockfd, &message, 0) < 0) + return -1; + } + + return 0; +} + +int Communicator::reply_unreliable(CommSession *session, CommTarget *target) +{ + struct CommConnEntry *entry; + struct list_head *pos; + + if (!list_empty(&target->idle_list)) + { + pos = target->idle_list.next; + entry = list_entry(pos, struct CommConnEntry, list); + list_del(pos); + + session->out = session->message_out(); + if (session->out) + { + if (this->reply_message_unreliable(entry) >= 0) + return 0; + } + + __release_conn(entry); + ((CommServiceTarget *)target)->decref(); + } + else + errno = ENOENT; + + return -1; +} + int Communicator::reply(CommSession *session) { struct CommConnEntry *entry; - CommTarget *target; + CommServiceTarget *target; int errno_bak; int ret; if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } errno_bak = errno; session->passive = 2; - target = session->target; - ret = this->reply_idle_conn(session, target); - if (ret < 0) - return -1; + target = (CommServiceTarget *)session->target; + if (target->service->reliable) + ret = this->reply_reliable(session, target); + else + ret = this->reply_unreliable(session, target); if (ret == 0) { @@ -1760,10 +2005,12 @@ int Communicator::reply(CommSession *session) session->handle(CS_STATE_SUCCESS, 0); if (__sync_sub_and_fetch(&entry->ref, 1) == 0) { - this->release_conn(entry); - ((CommServiceTarget *)target)->decref(); + __release_conn(entry); + target->decref(); } } + else if (ret < 0) + return -1; errno = errno_bak; return 0; @@ -1777,7 +2024,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } @@ -1785,22 +2032,7 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) if (!list_empty(&target->idle_list)) { entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - if (!entry->ssl) - ret = write(entry->sockfd, buf, size); - else if (size == 0) - ret = 0; - else - { - ret = SSL_write(entry->ssl, buf, size); - if (ret <= 0) - { - ret = SSL_get_error(entry->ssl, ret); - if (ret != SSL_ERROR_SYSCALL) - errno = -ret; - - ret = -1; - } - } + ret = __send_to_conn(buf, size, entry); } else { @@ -1814,33 +2046,23 @@ int Communicator::push(const void *buf, size_t size, CommSession *session) int Communicator::shutdown(CommSession *session) { - CommTarget *target = session->target; - struct CommConnEntry *entry; - int ret; + CommServiceTarget *target; if (session->passive != 1) { - errno = session->passive ? ENOENT : EPERM; + errno = session->passive ? ENOENT : EINVAL; return -1; } session->passive = 2; - pthread_mutex_lock(&target->mutex); - if (!list_empty(&target->idle_list)) - { - entry = list_entry(target->idle_list.next, struct CommConnEntry, list); - list_del(&entry->list); - ret = mpoller_del(entry->sockfd, entry->mpoller); - entry->state = CONN_STATE_CLOSING; - } - else + target = (CommServiceTarget *)session->target; + if (!target->shutdown()) { errno = ENOENT; - ret = -1; + return -1; } - pthread_mutex_unlock(&target->mutex); - return ret; + return 0; } int Communicator::sleep(SleepSession *session) diff --git a/src/kernel/Communicator.h b/src/kernel/Communicator.h index 7552b5393a..dbbc2ba7af 100644 --- a/src/kernel/Communicator.h +++ b/src/kernel/Communicator.h @@ -31,9 +31,8 @@ class CommConnection { -protected: +public: virtual ~CommConnection() { } - friend class Communicator; }; class CommTarget @@ -91,7 +90,7 @@ class CommTarget public: virtual ~CommTarget() { } - friend class CommSession; + friend class CommServiceTarget; friend class Communicator; }; @@ -223,6 +222,7 @@ class CommService void decref(); private: + int reliable; int listen_fd; int ref; @@ -298,16 +298,6 @@ class Communicator int create_handler_threads(size_t handler_threads); - int nonblock_connect(CommTarget *target); - int nonblock_listen(CommService *service); - - struct CommConnEntry *launch_conn(CommSession *session, - CommTarget *target); - struct CommConnEntry *accept_conn(class CommServiceTarget *target, - CommService *service); - - void release_conn(struct CommConnEntry *entry); - void shutdown_service(CommService *service); void shutdown_io_service(IOService *service); @@ -319,10 +309,13 @@ class Communicator int send_message(struct CommConnEntry *entry); + int request_new_conn(CommSession *session, CommTarget *target); int request_idle_conn(CommSession *session, CommTarget *target); - int reply_idle_conn(CommSession *session, CommTarget *target); - int request_new_conn(CommSession *session, CommTarget *target); + int reply_message_unreliable(struct CommConnEntry *entry); + + int reply_reliable(CommSession *session, CommTarget *target); + int reply_unreliable(CommSession *session, CommTarget *target); void handle_incoming_request(struct poller_result *res); void handle_incoming_reply(struct poller_result *res); @@ -336,6 +329,8 @@ class Communicator void handle_connect_result(struct poller_result *res); void handle_listen_result(struct poller_result *res); + void handle_recvfrom_result(struct poller_result *res); + void handle_ssl_accept_result(struct poller_result *res); void handle_sleep_result(struct poller_result *res); @@ -344,6 +339,14 @@ class Communicator static void handler_thread_routine(void *context); + static int nonblock_connect(CommTarget *target); + static int nonblock_listen(CommService *service); + + static struct CommConnEntry *launch_conn(CommSession *session, + CommTarget *target); + static struct CommConnEntry *accept_conn(class CommServiceTarget *target, + CommService *service); + static int first_timeout(CommSession *session); static int next_timeout(CommSession *session); @@ -358,11 +361,17 @@ class Communicator static poller_message_t *create_request(void *context); static poller_message_t *create_reply(void *context); + static int recv_request(const void *buf, size_t size, + struct CommConnEntry *entry); + static int partial_written(size_t n, void *context); static void *accept(const struct sockaddr *addr, socklen_t addrlen, int sockfd, void *context); + static void *recvfrom(const struct sockaddr *addr, socklen_t addrlen, + const void *buf, size_t size, void *context); + static void callback(struct poller_result *res, void *context); public: diff --git a/src/kernel/poller.c b/src/kernel/poller.c index c2881cf146..4da9db9d15 100644 --- a/src/kernel/poller.c +++ b/src/kernel/poller.c @@ -651,6 +651,55 @@ static void __poller_handle_connect(struct __poller_node *node, poller->callback((struct poller_result *)node, poller->context); } +static void __poller_handle_recvfrom(struct __poller_node *node, + poller_t *poller) +{ + struct __poller_node *res = node->res; + struct sockaddr_storage ss; + struct sockaddr *addr = (struct sockaddr *)&ss; + socklen_t addrlen; + void *result; + ssize_t n; + + while (1) + { + addrlen = sizeof (struct sockaddr_storage); + n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, + addr, &addrlen); + if (n < 0) + { + if (errno == EAGAIN) + return; + else + break; + } + + result = node->data.recvfrom(addr, addrlen, poller->buf, n, + node->data.context); + if (!result) + break; + + res->data = node->data; + res->data.result = result; + res->error = 0; + res->state = PR_ST_SUCCESS; + poller->callback((struct poller_result *)res, poller->context); + + res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); + node->res = res; + if (!res) + break; + } + + if (__poller_remove_node(node, poller)) + return; + + node->error = errno; + node->state = PR_ST_ERROR; + free(node->res); + poller->callback((struct poller_result *)node, poller->context); +} + static void __poller_handle_ssl_accept(struct __poller_node *node, poller_t *poller) { @@ -849,55 +898,6 @@ static void __poller_handle_notify(struct __poller_node *node, poller->callback((struct poller_result *)node, poller->context); } -static void __poller_handle_recvfrom(struct __poller_node *node, - poller_t *poller) -{ - struct __poller_node *res = node->res; - struct sockaddr_storage ss; - struct sockaddr *addr = (struct sockaddr *)&ss; - socklen_t addrlen; - void *result; - ssize_t n; - - while (1) - { - addrlen = sizeof (struct sockaddr_storage); - n = recvfrom(node->data.fd, poller->buf, POLLER_BUFSIZE, 0, - addr, &addrlen); - if (n < 0) - { - if (errno == EAGAIN) - return; - else - break; - } - - result = node->data.recvfrom(addr, addrlen, poller->buf, n, - node->data.context); - if (!result) - break; - - res->data = node->data; - res->data.result = result; - res->error = 0; - res->state = PR_ST_SUCCESS; - poller->callback((struct poller_result *)res, poller->context); - - res = (struct __poller_node *)malloc(sizeof (struct __poller_node)); - node->res = res; - if (!res) - break; - } - - if (__poller_remove_node(node, poller)) - return; - - node->error = errno; - node->state = PR_ST_ERROR; - free(node->res); - poller->callback((struct poller_result *)node, poller->context); -} - static int __poller_handle_pipe(poller_t *poller) { struct __poller_node **node = (struct __poller_node **)poller->buf; @@ -1055,6 +1055,9 @@ static void *__poller_thread_routine(void *arg) case PD_OP_CONNECT: __poller_handle_connect(node, poller); break; + case PD_OP_RECVFROM: + __poller_handle_recvfrom(node, poller); + break; case PD_OP_SSL_ACCEPT: __poller_handle_ssl_accept(node, poller); break; @@ -1070,9 +1073,6 @@ static void *__poller_thread_routine(void *arg) case PD_OP_NOTIFY: __poller_handle_notify(node, poller); break; - case PD_OP_RECVFROM: - __poller_handle_recvfrom(node, poller); - break; } } @@ -1282,6 +1282,9 @@ static int __poller_data_get_event(int *event, const struct poller_data *data) case PD_OP_CONNECT: *event = EPOLLOUT | EPOLLET; return 0; + case PD_OP_RECVFROM: + *event = EPOLLIN | EPOLLET; + return 1; case PD_OP_SSL_ACCEPT: *event = EPOLLIN | EPOLLET; return 0; @@ -1297,9 +1300,6 @@ static int __poller_data_get_event(int *event, const struct poller_data *data) case PD_OP_NOTIFY: *event = EPOLLIN | EPOLLET; return 1; - case PD_OP_RECVFROM: - *event = EPOLLIN | EPOLLET; - return 1; default: errno = EINVAL; return -1; diff --git a/src/kernel/poller.h b/src/kernel/poller.h index 89831277e3..71ff70cccf 100644 --- a/src/kernel/poller.h +++ b/src/kernel/poller.h @@ -40,14 +40,14 @@ struct poller_data #define PD_OP_WRITE 2 #define PD_OP_LISTEN 3 #define PD_OP_CONNECT 4 +#define PD_OP_RECVFROM 5 #define PD_OP_SSL_READ PD_OP_READ #define PD_OP_SSL_WRITE PD_OP_WRITE -#define PD_OP_SSL_ACCEPT 5 -#define PD_OP_SSL_CONNECT 6 -#define PD_OP_SSL_SHUTDOWN 7 -#define PD_OP_EVENT 8 -#define PD_OP_NOTIFY 9 -#define PD_OP_RECVFROM 10 +#define PD_OP_SSL_ACCEPT 6 +#define PD_OP_SSL_CONNECT 7 +#define PD_OP_SSL_SHUTDOWN 8 +#define PD_OP_EVENT 9 +#define PD_OP_NOTIFY 10 short operation; unsigned short iovcnt; int fd; @@ -57,10 +57,10 @@ struct poller_data poller_message_t *(*create_message)(void *); int (*partial_written)(size_t, void *); void *(*accept)(const struct sockaddr *, socklen_t, int, void *); - void *(*event)(void *); - void *(*notify)(void *, void *); void *(*recvfrom)(const struct sockaddr *, socklen_t, const void *, size_t, void *); + void *(*event)(void *); + void *(*notify)(void *, void *); }; void *context; union diff --git a/src/manager/DnsCache.cc b/src/manager/DnsCache.cc index fbafd40df7..3c0cee9c7b 100644 --- a/src/manager/DnsCache.cc +++ b/src/manager/DnsCache.cc @@ -23,42 +23,31 @@ #define GET_CURRENT_SECOND std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count() -#define CONFIDENT_INC 10 -#define TTL_INC 10 +#define TTL_INC 5 -const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, int type) +const DnsCache::DnsHandle *DnsCache::get_inner(const HostPort& host_port, + int type) { - int64_t cur_time = GET_CURRENT_SECOND; + int64_t cur = GET_CURRENT_SECOND; std::lock_guard lock(mutex_); const DnsHandle *handle = cache_pool_.get(host_port); - if (handle) + if (handle && ((type == GET_TYPE_TTL && cur > handle->value.expire_time) || + (type == GET_TYPE_CONFIDENT && cur > handle->value.confident_time))) { - switch (type) + if (!handle->value.delayed()) { - case GET_TYPE_TTL: - if (cur_time > handle->value.expire_time) - { - const_cast(handle)->value.expire_time += TTL_INC; - cache_pool_.release(handle); - return NULL; - } - - break; - - case GET_TYPE_CONFIDENT: - if (cur_time > handle->value.confident_time) - { - const_cast(handle)->value.confident_time += CONFIDENT_INC; - cache_pool_.release(handle); - return NULL; - } - - break; - - default: - break; + DnsHandle *h = const_cast(handle); + if (type == GET_TYPE_TTL) + h->value.expire_time += TTL_INC; + else + h->value.confident_time += TTL_INC; + + h->value.addrinfo->ai_flags |= 2; } + + cache_pool_.release(handle); + return NULL; } return handle; @@ -90,3 +79,29 @@ const DnsCache::DnsHandle *DnsCache::put(const HostPort& host_port, return cache_pool_.put(host_port, {addrinfo, confident_time, expire_time}); } +const DnsCache::DnsHandle *DnsCache::get(const DnsCache::HostPort& host_port) +{ + std::lock_guard lock(mutex_); + return cache_pool_.get(host_port); +} + +void DnsCache::release(const DnsCache::DnsHandle *handle) +{ + std::lock_guard lock(mutex_); + cache_pool_.release(handle); +} + +void DnsCache::del(const DnsCache::HostPort& key) +{ + std::lock_guard lock(mutex_); + cache_pool_.del(key); +} + +DnsCache::DnsCache() +{ +} + +DnsCache::~DnsCache() +{ +} + diff --git a/src/manager/DnsCache.h b/src/manager/DnsCache.h index 7e864de21d..a3720fc34d 100644 --- a/src/manager/DnsCache.h +++ b/src/manager/DnsCache.h @@ -35,6 +35,11 @@ struct DnsCacheValue struct addrinfo *addrinfo; int64_t confident_time; int64_t expire_time; + + bool delayed() const + { + return addrinfo->ai_flags & 2; + } }; // RAII: NO. Release handle by user @@ -47,27 +52,10 @@ class DnsCache using DnsHandle = LRUHandle; public: - // release handle by get/put - void release(DnsHandle *handle) - { - std::lock_guard lock(mutex_); - cache_pool_.release(handle); - } - - void release(const DnsHandle *handle) - { - std::lock_guard lock(mutex_); - cache_pool_.release(handle); - } - // get handler // Need call release when handle no longer needed //Handle *get(const KEY &key); - const DnsHandle *get(const HostPort& host_port) - { - std::lock_guard lock(mutex_); - return cache_pool_.get(host_port); - } + const DnsHandle *get(const HostPort& host_port); const DnsHandle *get(const std::string& host, unsigned short port) { @@ -132,12 +120,11 @@ class DnsCache return put(std::string(host), port, addrinfo, dns_ttl_default, dns_ttl_min); } + // release handle by get/put + void release(const DnsHandle *handle); + // delete from cache, deleter delay called when all inuse-handle release. - void del(const HostPort& key) - { - std::lock_guard lock(mutex_); - cache_pool_.del(key); - } + void del(const HostPort& key); void del(const std::string& host, unsigned short port) { @@ -161,14 +148,22 @@ class DnsCache { struct addrinfo *ai = value.addrinfo; - if (ai && (ai->ai_flags & AI_PASSIVE)) - freeaddrinfo(ai); - else - protocol::DnsUtil::freeaddrinfo(ai); + if (ai) + { + if (ai->ai_flags) + freeaddrinfo(ai); + else + protocol::DnsUtil::freeaddrinfo(ai); + } } }; LRUCache cache_pool_; + +public: + // To prevent inline calling LRUCache's constructor and deconstructor. + DnsCache(); + ~DnsCache(); }; #endif diff --git a/src/manager/EndpointParams.h b/src/manager/EndpointParams.h index aa3f892831..cd264c8faa 100644 --- a/src/manager/EndpointParams.h +++ b/src/manager/EndpointParams.h @@ -20,6 +20,7 @@ #define _ENDPOINTPARAMS_H_ #include +#include "PlatformSocket.h" /** * @file EndpointParams.h @@ -37,6 +38,7 @@ enum TransportType struct EndpointParams { + int address_family; size_t max_connections; int connect_timeout; int response_timeout; @@ -46,6 +48,7 @@ struct EndpointParams static constexpr struct EndpointParams ENDPOINT_PARAMS_DEFAULT = { +/* address_family = */ AF_INET, /* .max_connections = */ 200, /* .connect_timeout = */ 10 * 1000, /* .response_timeout = */ 10 * 1000, diff --git a/src/manager/RouteManager.cc b/src/manager/RouteManager.cc index 8a99f33554..526ffea076 100644 --- a/src/manager/RouteManager.cc +++ b/src/manager/RouteManager.cc @@ -76,7 +76,7 @@ class RouteTargetSCTP : public RouteManager::RouteTarget }; /* To support TLS SNI. */ -class RouteTargetSNI : public RouteManager::RouteTarget +class RouteTargetTCPSNI : public RouteTargetTCP { private: virtual int init_ssl(SSL *ssl) @@ -91,7 +91,27 @@ class RouteTargetSNI : public RouteManager::RouteTarget std::string hostname; public: - RouteTargetSNI(const std::string& name) : hostname(name) + RouteTargetTCPSNI(const std::string& name) : hostname(name) + { + } +}; + +class RouteTargetSCTPSNI : public RouteTargetSCTP +{ +private: + virtual int init_ssl(SSL *ssl) + { + if (SSL_set_tlsext_host_name(ssl, this->hostname.c_str()) > 0) + return 0; + else + return -1; + } + +private: + std::string hostname; + +public: + RouteTargetSCTPSNI(const std::string& name) : hostname(name) { } }; @@ -101,11 +121,11 @@ class RouteTargetSNI : public RouteManager::RouteTarget struct RouteParams { - TransportType transport_type; + enum TransportType transport_type; const struct addrinfo *addrinfo; uint64_t key; SSL_CTX *ssl_ctx; - unsigned int max_connections; + size_t max_connections; int connect_timeout; int response_timeout; int ssl_connect_timeout; @@ -120,7 +140,7 @@ class RouteResultEntry CommSchedObject *request_object; CommSchedGroup *group; std::mutex mutex; - std::vector targets; + std::vector targets; struct list_head breaker_list; uint64_t key; int nleft; @@ -139,34 +159,35 @@ class RouteResultEntry int init(const struct RouteParams *params); void deinit(); - void notify_unavailable(CommSchedTarget *target); - void notify_available(CommSchedTarget *target); + void notify_unavailable(RouteManager::RouteTarget *target); + void notify_available(RouteManager::RouteTarget *target); void check_breaker(); private: void free_list(); - CommSchedTarget *create_target(const struct RouteParams *params, - const struct addrinfo *addrinfo); + RouteManager::RouteTarget *create_target(const struct RouteParams *params, + const struct addrinfo *addrinfo); int add_group_targets(const struct RouteParams *params); }; struct __breaker_node { - CommSchedTarget *target; + RouteManager::RouteTarget *target; int64_t timeout; struct list_head breaker_list; }; -CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *params, - const struct addrinfo *addr) +RouteManager::RouteTarget * +RouteResultEntry::create_target(const struct RouteParams *params, + const struct addrinfo *addr) { - CommSchedTarget *target; + RouteManager::RouteTarget *target; switch (params->transport_type) { case TT_TCP_SSL: if (params->use_tls_sni) - target = new RouteTargetSNI(params->hostname); + target = new RouteTargetTCPSNI(params->hostname); else case TT_TCP: target = new RouteTargetTCP(); @@ -174,16 +195,19 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param case TT_UDP: target = new RouteTargetUDP(); break; - case TT_SCTP: case TT_SCTP_SSL: - target = new RouteTargetSCTP(); + if (params->use_tls_sni) + target = new RouteTargetSCTPSNI(params->hostname); + else + case TT_SCTP: + target = new RouteTargetSCTP(); break; default: errno = EINVAL; return NULL; } - if (target->init(addr->ai_addr, (socklen_t)addr->ai_addrlen, params->ssl_ctx, + if (target->init(addr->ai_addr, addr->ai_addrlen, params->ssl_ctx, params->connect_timeout, params->ssl_connect_timeout, params->response_timeout, params->max_connections) < 0) { @@ -197,7 +221,7 @@ CommSchedTarget *RouteResultEntry::create_target(const struct RouteParams *param int RouteResultEntry::init(const struct RouteParams *params) { const struct addrinfo *addr = params->addrinfo; - CommSchedTarget *target; + RouteManager::RouteTarget *target; if (addr == NULL)//0 { @@ -238,8 +262,8 @@ int RouteResultEntry::init(const struct RouteParams *params) int RouteResultEntry::add_group_targets(const struct RouteParams *params) { + RouteManager::RouteTarget *target; const struct addrinfo *addr; - CommSchedTarget *target; for (addr = params->addrinfo; addr; addr = addr->ai_next) { @@ -298,7 +322,7 @@ void RouteResultEntry::deinit() } } -void RouteResultEntry::notify_unavailable(CommSchedTarget *target) +void RouteResultEntry::notify_unavailable(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1) return; @@ -324,7 +348,7 @@ void RouteResultEntry::notify_unavailable(CommSchedTarget *target) this->nleft--; } -void RouteResultEntry::notify_available(CommSchedTarget *target) +void RouteResultEntry::notify_available(RouteManager::RouteTarget *target) { if (this->targets.size() <= 1 || this->nbreak == 0) return; @@ -396,23 +420,26 @@ static uint64_t __fnv_hash(const unsigned char *data, size_t size) return hash; } -static uint64_t __generate_key(TransportType type, +static uint64_t __generate_key(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname) + const std::string& hostname, + SSL_CTX *ssl_ctx) { - std::string buf((const char *)&type, sizeof (TransportType)); - unsigned int max_conn = ep_params->max_connections; + const int params[] = { + ep_params->address_family, (int)ep_params->max_connections, + ep_params->connect_timeout, ep_params->response_timeout + }; + std::string buf((const char *)&type, sizeof (enum TransportType)); if (!other_info.empty()) buf += other_info; - buf.append((const char *)&max_conn, sizeof (unsigned int)); - buf.append((const char *)&ep_params->connect_timeout, sizeof (int)); - buf.append((const char *)&ep_params->response_timeout, sizeof (int)); - if (type == TT_TCP_SSL) + buf.append((const char *)params, sizeof params); + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) { + buf.append((const char *)&ssl_ctx, sizeof (void *)); buf.append((const char *)&ep_params->ssl_connect_timeout, sizeof (int)); if (ep_params->use_tls_sni) { @@ -462,15 +489,24 @@ RouteManager::~RouteManager() } } -int RouteManager::get(TransportType type, +int RouteManager::get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname, + const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result) { - uint64_t key = __generate_key(type, addrinfo, other_info, - ep_params, hostname); + if (type == TT_TCP_SSL || type == TT_SCTP_SSL) + { + static SSL_CTX *global_client_ctx = WFGlobal::get_ssl_client_ctx(); + if (ssl_ctx == NULL) + ssl_ctx = global_client_ctx; + } + else + ssl_ctx = NULL; + + uint64_t key = __generate_key(type, addrinfo, other_info, ep_params, + hostname, ssl_ctx); struct rb_node **p = &cache_.rb_node; struct rb_node *parent = NULL; RouteResultEntry *bound = NULL; @@ -497,37 +533,19 @@ int RouteManager::get(TransportType type, } else { - int ssl_connect_timeout = 0; - SSL_CTX *ssl_ctx = NULL; - - if (type == TT_TCP_SSL || type == TT_SCTP_SSL) - { - static SSL_CTX *client_ssl_ctx = WFGlobal::get_ssl_client_ctx(); - - ssl_ctx = client_ssl_ctx; - ssl_connect_timeout = ep_params->ssl_connect_timeout; - } - struct RouteParams params = { - /* .transport_type = */ type, - /* .addrinfo = */ addrinfo, - /* .key = */ key, - /* .ssl_ctx = */ ssl_ctx, - /* .max_connections = */ (unsigned int)ep_params->max_connections, - /* .connect_timeout = */ ep_params->connect_timeout, - /* .response_timeout = */ ep_params->response_timeout, - /* .ssl_connect_timeout = */ ssl_connect_timeout, - /* .use_tls_sni = */ ep_params->use_tls_sni, - /* .hostname = */ hostname, + /*.transport_type =*/ type, + /*.addrinfo =*/ addrinfo, + /*.key =*/ key, + /*.ssl_ctx =*/ ssl_ctx, + /*.max_connections =*/ ep_params->max_connections, + /*.connect_timeout =*/ ep_params->connect_timeout, + /*.response_timeout =*/ ep_params->response_timeout, + /*.ssl_connect_timeout =*/ ep_params->ssl_connect_timeout, + /*.use_tls_sni =*/ ep_params->use_tls_sni, + /*.hostname =*/ hostname, }; - if (StringUtil::start_with(other_info, "?maxconn=")) - { - int maxconn = atoi(other_info.c_str() + 9); - if (maxconn > 0) - params.max_connections = maxconn; - } - entry = new RouteResultEntry; if (entry->init(¶ms) >= 0) { @@ -549,12 +567,12 @@ int RouteManager::get(TransportType type, void RouteManager::notify_unavailable(void *cookie, CommTarget *target) { if (cookie && target) - ((RouteResultEntry *)cookie)->notify_unavailable((CommSchedTarget *)target); + ((RouteResultEntry *)cookie)->notify_unavailable((RouteTarget *)target); } void RouteManager::notify_available(void *cookie, CommTarget *target) { if (cookie && target) - ((RouteResultEntry *)cookie)->notify_available((CommSchedTarget *)target); + ((RouteResultEntry *)cookie)->notify_available((RouteTarget *)target); } diff --git a/src/manager/RouteManager.h b/src/manager/RouteManager.h index 2271f78724..3265196d54 100644 --- a/src/manager/RouteManager.h +++ b/src/manager/RouteManager.h @@ -43,6 +43,32 @@ class RouteManager class RouteTarget : public CommSchedTarget { +#if OPENSSL_VERSION_NUMBER >= 0x10100000L + public: + int init(const struct sockaddr *addr, socklen_t addrlen, SSL_CTX *ssl_ctx, + int connect_timeout, int ssl_connect_timeout, int response_timeout, + size_t max_connections) + { + int ret = this->CommSchedTarget::init(addr, addrlen, ssl_ctx, + connect_timeout, ssl_connect_timeout, + response_timeout, max_connections); + + if (ret >= 0 && ssl_ctx) + SSL_CTX_up_ref(ssl_ctx); + + return ret; + } + + void deinit() + { + SSL_CTX *ssl_ctx = this->get_ssl_ctx(); + + this->CommSchedTarget::deinit(); + if (ssl_ctx) + SSL_CTX_free(ssl_ctx); + } +#endif + public: int state; @@ -57,11 +83,11 @@ class RouteManager }; public: - int get(TransportType type, + int get(enum TransportType type, const struct addrinfo *addrinfo, const std::string& other_info, const struct EndpointParams *ep_params, - const std::string& hostname, + const std::string& hostname, SSL_CTX *ssl_ctx, RouteResult& result); RouteManager() diff --git a/src/nameservice/WFDnsResolver.cc b/src/nameservice/WFDnsResolver.cc index ea85bdbc08..4279e1922d 100644 --- a/src/nameservice/WFDnsResolver.cc +++ b/src/nameservice/WFDnsResolver.cc @@ -40,35 +40,21 @@ #define HOSTS_LINEBUF_INIT_SIZE 128 #define PORT_STR_MAX 5 -static constexpr struct addrinfo __ai_hints = -{ -#ifdef AI_ADDRCONFIG - /*.ai_flags = */ AI_ADDRCONFIG, -#else - /*.ai_flags = */ 0, -#endif - /*.ai_family = */ AF_UNSPEC, - /*.ai_socktype = */ SOCK_STREAM, - /*.ai_protocol = */ 0, - /*.ai_addrlen = */ 0, - /*.ai_addr = */ NULL, - /*.ai_canonname = */ NULL, - /*.ai_next = */ NULL -}; - class DnsInput { public: DnsInput() : port_(0), - numeric_host_(false) + numeric_host_(false), + family_(AF_UNSPEC) {} DnsInput(const std::string& host, unsigned short port, - bool numeric_host) : + bool numeric_host, int family) : host_(host), port_(port), - numeric_host_(numeric_host) + numeric_host_(numeric_host), + family_(family) {} void reset(const std::string& host, unsigned short port) @@ -76,14 +62,16 @@ class DnsInput host_.assign(host); port_ = port; numeric_host_ = false; + family_ = AF_UNSPEC; } void reset(const std::string& host, unsigned short port, - bool numeric_host) + bool numeric_host, int family) { host_.assign(host); port_ = port; numeric_host_ = numeric_host; + family_ = family; } const std::string& get_host() const { return host_; } @@ -94,6 +82,7 @@ class DnsInput std::string host_; unsigned short port_; bool numeric_host_; + int family_; friend class DnsRoutine; }; @@ -109,7 +98,12 @@ class DnsOutput ~DnsOutput() { if (addrinfo_) - freeaddrinfo(addrinfo_); + { + if (addrinfo_->ai_flags) + freeaddrinfo(addrinfo_); + else + free(addrinfo_); + } } int get_error() const { return error_; } @@ -137,7 +131,12 @@ class DnsRoutine static void create(DnsOutput *out, int error, struct addrinfo *ai) { if (out->addrinfo_) - freeaddrinfo(out->addrinfo_); + { + if (out->addrinfo_->ai_flags) + freeaddrinfo(out->addrinfo_); + else + free(out->addrinfo_); + } out->error_ = error; out->addrinfo_ = ai; @@ -146,21 +145,21 @@ class DnsRoutine void DnsRoutine::run(const DnsInput *in, DnsOutput *out) { - if (!in->host_.empty() && in->host_[0] == '/') - return; - - struct addrinfo hints = __ai_hints; + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV, + .ai_family = in->family_, + .ai_socktype = SOCK_STREAM, + }; char port_str[PORT_STR_MAX + 1]; - hints.ai_flags |= AI_NUMERICSERV; if (in->is_numeric_host()) hints.ai_flags |= AI_NUMERICHOST; snprintf(port_str, PORT_STR_MAX + 1, "%u", in->port_); - out->error_ = getaddrinfo(in->host_.c_str(), - port_str, - &hints, - &out->addrinfo_); + out->error_ = getaddrinfo(in->host_.c_str(), port_str, + &hints, &out->addrinfo_); + if (out->error_ == 0) + out->addrinfo_->ai_flags = 1; } // Dns Thread task. For internal usage only. @@ -178,13 +177,18 @@ struct DnsContext static int __default_family() { + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG, + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + }; struct addrinfo *res; struct addrinfo *cur; int family = AF_UNSPEC; bool v4 = false; bool v6 = false; - if (getaddrinfo(NULL, "1", &__ai_hints, &res) == 0) + if (getaddrinfo(NULL, "1", &hints, &res) == 0) { for (cur = res; cur; cur = cur->ai_next) { @@ -246,45 +250,38 @@ static int __readaddrinfo_line(char *p, const char *name, const char *port, return 1; } -#include static int __readaddrinfo(const char *path, const char *name, unsigned short port, const struct addrinfo *hints, struct addrinfo **res) { char port_str[PORT_STR_MAX + 1]; - std::string line; - // 1024 may be enough for one line - char buffer[1024]; + size_t bufsize = 0; + char *line = NULL; int count = 0; - struct addrinfo h; int errno_bak; + FILE *fp; int ret; - std::ifstream ifs; - ifs.open(path, ifs.in); - if (!ifs.is_open()) - return /*EAI_SYSTEM*/ EAI_FAIL; + fp = fopen(path, "r"); + if (!fp) + return EAI_FAIL; - h = *hints; - h.ai_flags |= AI_NUMERICSERV | AI_NUMERICHOST, snprintf(port_str, PORT_STR_MAX + 1, "%u", port); errno_bak = errno; - while (!(std::getline(ifs, line)).eof() && !line.empty()) + while ((ret = getline(&line, &bufsize, fp)) > 0) { - std::cout << line << std::endl; - line.copy(buffer, sizeof(buffer) / sizeof(buffer[0])); - buffer[line.length()] = '\0'; - if (__readaddrinfo_line(buffer, name, port_str, &h, res) == 0) + if (__readaddrinfo_line(line, name, port_str, hints, res) == 0) { count++; res = &(*res)->ai_next; } } - ret = ifs.bad() ? /*EAI_SYSTEM*/ EAI_FAIL : EAI_NONAME; - ifs.close(); + ret = ferror(fp) ? EAI_FAIL : EAI_NONAME; + free(line); + fclose(fp); if (count != 0) { errno = errno_bak; @@ -294,18 +291,9 @@ static int __readaddrinfo(const char *path, return ret; } -// Add AI_PASSIVE to point that this addrinfo is alloced by getaddrinfo -static void __add_passive_flags(struct addrinfo *ai) -{ - while (ai) - { - ai->ai_flags |= AI_PASSIVE; - ai = ai->ai_next; - } -} - static ThreadDnsTask *__create_thread_dns_task(const std::string& host, unsigned short port, + int family, thread_dns_callback_t callback) { auto *task = WFThreadTaskFactory:: @@ -314,12 +302,46 @@ static ThreadDnsTask *__create_thread_dns_task(const std::string& host, DnsRoutine::run, std::move(callback)); - task->get_input()->reset(host, port); + task->get_input()->reset(host, port, false, family); return task; } +static std::string __get_cache_host(const std::string& hostname, + int family) +{ + char c; + + if (family == AF_UNSPEC) + c = '*'; + else if (family == AF_INET) + c = '4'; + else if (family == AF_INET6) + c = '6'; + else + c = '?'; + + return hostname + c; +} + +static std::string __get_guard_name(const std::string& cache_host, + unsigned short port) +{ + std::string guard_name("INTERNAL-dns:"); + guard_name.append(cache_host).append(":"); + guard_name.append(std::to_string(port)); + return guard_name; +} + void WFResolverTask::dispatch() { + if (this->msg_) + { + this->state = WFT_STATE_DNS_ERROR; + this->error = (intptr_t)msg_; + this->subtask_done(); + return; + } + const ParsedURI& uri = ns_params_.uri; host_ = uri.host ? uri.host : ""; port_ = uri.port ? atoi(uri.port) : 0; @@ -327,11 +349,22 @@ void WFResolverTask::dispatch() DnsCache *dns_cache = WFGlobal::get_dns_cache(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); if (ns_params_.retry_times == 0) - addr_handle = dns_cache->get_ttl(hostname, port_); + addr_handle = dns_cache->get_ttl(cache_host, port_); else - addr_handle = dns_cache->get_confident(hostname, port_); + addr_handle = dns_cache->get_confident(cache_host, port_); + + if (in_guard_ && (addr_handle == NULL || addr_handle->value.delayed())) + { + if (addr_handle) + dns_cache->release(addr_handle); + + this->request_dns(); + return; + } if (addr_handle) { @@ -347,7 +380,8 @@ void WFResolverTask::dispatch() } if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, - &ep_params_, hostname, this->result) < 0) + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -378,11 +412,11 @@ void WFResolverTask::dispatch() if (ret == 1) { - DnsInput dns_in(hostname, port_, true); // 'true' means numeric host + // 'true' means numeric host + DnsInput dns_in(hostname, port_, true, AF_UNSPEC); DnsOutput dns_out; DnsRoutine::run(&dns_in, &dns_out); - __add_passive_flags((struct addrinfo *)dns_out.get_addrinfo()); dns_callback_internal(&dns_out, (unsigned int)-1, (unsigned int)-1); this->subtask_done(); return; @@ -392,32 +426,53 @@ void WFResolverTask::dispatch() const char *hosts = WFGlobal::get_global_settings()->hosts_path; if (hosts) { + struct addrinfo hints = { + .ai_flags = AI_ADDRCONFIG | AI_NUMERICSERV | AI_NUMERICHOST, + .ai_family = ep_params_.address_family, + .ai_socktype = SOCK_STREAM, + }; struct addrinfo *ai; - int ret = __readaddrinfo(hosts, host_, port_, &__ai_hints, &ai); + int ret; + ret = __readaddrinfo(hosts, host_, port_, &hints, &ai); if (ret == 0) { DnsOutput out; DnsRoutine::create(&out, ret, ai); - __add_passive_flags((struct addrinfo *)out.get_addrinfo()); dns_callback_internal(&out, dns_ttl_default_, dns_ttl_min_); this->subtask_done(); return; } } + std::string guard_name = __get_guard_name(cache_host, port_); + WFConditional *guard = WFTaskFactory::create_guard(guard_name, this, &msg_); + + in_guard_ = true; + has_next_ = true; + + series_of(this)->push_front(guard); + this->subtask_done(); +} + +void WFResolverTask::request_dns() +{ WFDnsClient *client = WFGlobal::get_dns_client(); if (client) { - static int family = __default_family(); + static int default_family = __default_family(); WFResourcePool *respool = WFGlobal::get_dns_respool(); + int family = ep_params_.address_family; + if (family == AF_UNSPEC) + family = default_family; + if (family == AF_INET || family == AF_INET6) { auto&& cb = std::bind(&WFResolverTask::dns_single_callback, this, std::placeholders::_1); - WFDnsTask *dns_task = client->create_dns_task(hostname, std::move(cb)); + WFDnsTask *dns_task = client->create_dns_task(host_, std::move(cb)); if (family == AF_INET6) dns_task->get_req()->set_question_type(DNS_TYPE_AAAA); @@ -437,10 +492,10 @@ void WFResolverTask::dispatch() dctx[0].port = port_; dctx[1].port = port_; - task_v4 = client->create_dns_task(hostname, dns_partial_callback); + task_v4 = client->create_dns_task(host_, dns_partial_callback); task_v4->user_data = dctx; - task_v6 = client->create_dns_task(hostname, dns_partial_callback); + task_v6 = client->create_dns_task(host_, dns_partial_callback); task_v6->get_req()->set_question_type(DNS_TYPE_AAAA); task_v6->user_data = dctx + 1; @@ -461,11 +516,13 @@ void WFResolverTask::dispatch() } else { + ThreadDnsTask *dns_task; auto&& cb = std::bind(&WFResolverTask::thread_dns_callback, this, std::placeholders::_1); - ThreadDnsTask *dns_task = __create_thread_dns_task(hostname, port_, - std::move(cb)); + dns_task = __create_thread_dns_task(host_, port_, + ep_params_.address_family, + std::move(cb)); series_of(this)->push_front(dns_task); } @@ -478,12 +535,7 @@ SubTask *WFResolverTask::done() SeriesWork *series = series_of(this); if (!has_next_) - { - if (this->callback) - this->callback(this); - - delete this; - } + task_callback(); else has_next_ = false; @@ -499,7 +551,7 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output, if (dns_error) { - if (dns_error == /*EAI_SYSTEM*/ EAI_FAIL) + if (dns_error == EAI_FAIL) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -517,12 +569,15 @@ void WFResolverTask::dns_callback_internal(void *thrd_dns_output, struct addrinfo *addrinfo = dns_out->move_addrinfo(); const DnsCache::DnsHandle *addr_handle; std::string hostname = host_; + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(hostname, family); - addr_handle = dns_cache->put(hostname, port_, addrinfo, + addr_handle = dns_cache->put(cache_host, port_, addrinfo, (unsigned int)ttl_default, (unsigned int)ttl_min); if (route_manager->get(ns_params_.type, addrinfo, ns_params_.info, - &ep_params_, hostname, this->result) < 0) + &ep_params_, hostname, ns_params_.ssl_ctx, + this->result) < 0) { this->state = WFT_STATE_SYS_ERROR; this->error = errno; @@ -555,10 +610,7 @@ void WFResolverTask::dns_single_callback(void *net_dns_task) this->error = dns_task->get_error(); } - if (this->callback) - this->callback(this); - - delete this; + task_callback(); } void WFResolverTask::dns_partial_callback(void *net_dns_task) @@ -618,10 +670,7 @@ void WFResolverTask::dns_parallel_callback(const void *parallel) delete[] c4; - if (this->callback) - this->callback(this); - - delete this; + task_callback(); } void WFResolverTask::thread_dns_callback(void *thrd_dns_task) @@ -631,7 +680,6 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task) if (dns_task->get_state() == WFT_STATE_SUCCESS) { DnsOutput *out = dns_task->get_output(); - __add_passive_flags((struct addrinfo *)out->get_addrinfo()); dns_callback_internal(out, dns_ttl_default_, dns_ttl_min_); } else @@ -640,6 +688,23 @@ void WFResolverTask::thread_dns_callback(void *thrd_dns_task) this->error = dns_task->get_error(); } + task_callback(); +} + +void WFResolverTask::task_callback() +{ + if (in_guard_) + { + int family = ep_params_.address_family; + std::string cache_host = __get_cache_host(host_, family); + std::string guard_name = __get_guard_name(cache_host, port_); + + if (this->state == WFT_STATE_DNS_ERROR) + msg_ = (void *)(intptr_t)this->error; + + WFTaskFactory::release_guard_safe(guard_name, msg_); + } + if (this->callback) this->callback(this); diff --git a/src/nameservice/WFDnsResolver.h b/src/nameservice/WFDnsResolver.h index d0272141d7..4584d96aca 100644 --- a/src/nameservice/WFDnsResolver.h +++ b/src/nameservice/WFDnsResolver.h @@ -35,9 +35,14 @@ class WFResolverTask : public WFRouterTask ns_params_(*ns_params), ep_params_(*ep_params) { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + dns_ttl_default_ = dns_ttl_default; dns_ttl_min_ = dns_ttl_min; has_next_ = false; + in_guard_ = false; + msg_ = NULL; } WFResolverTask(const struct WFNSParams *ns_params, @@ -45,7 +50,12 @@ class WFResolverTask : public WFRouterTask WFRouterTask(std::move(cb)), ns_params_(*ns_params) { + if (ns_params_.fixed_conn) + ep_params_.max_connections = 1; + has_next_ = false; + in_guard_ = false; + msg_ = NULL; } protected: @@ -62,6 +72,9 @@ class WFResolverTask : public WFRouterTask unsigned int ttl_default, unsigned int ttl_min); + void request_dns(); + void task_callback(); + protected: struct WFNSParams ns_params_; unsigned int dns_ttl_default_; @@ -72,6 +85,8 @@ class WFResolverTask : public WFRouterTask const char *host_; unsigned short port_; bool has_next_; + bool in_guard_; + void *msg_; }; class WFDnsResolver : public WFNSPolicy diff --git a/src/nameservice/WFNameService.h b/src/nameservice/WFNameService.h index a6ef8e3f14..268f46dc5c 100644 --- a/src/nameservice/WFNameService.h +++ b/src/nameservice/WFNameService.h @@ -74,10 +74,12 @@ class WFNSTracing struct WFNSParams { - TransportType type; + enum TransportType type; ParsedURI& uri; const char *info; + SSL_CTX *ssl_ctx; bool fixed_addr; + bool fixed_conn; int retry_times; WFNSTracing *tracing; }; diff --git a/src/protocol/dns_parser.c b/src/protocol/dns_parser.c index 901ae2032f..ab3bdb7c45 100644 --- a/src/protocol/dns_parser.c +++ b/src/protocol/dns_parser.c @@ -25,7 +25,6 @@ #define DNS_LABELS_MAX 63 #define DNS_NAMES_MAX 256 #define DNS_MSGBASE_INIT_SIZE 514 // 512 + 2(leading length) -#define DNS_HEADER_SIZE sizeof (struct dns_header) #define MAX(x, y) ((x) <= (y) ? (y) : (x)) struct __dns_record_entry @@ -102,11 +101,16 @@ static int __dns_parser_parse_host(char *phost, dns_parser_t *parser) else if ((len & 0xC0) == 0xC0) { pointer = __dns_parser_uint16(*cur) & 0x3FFF; - *cur += 2; if (pointer >= parser->msgsize) return -2; + // pointer must point to a prior position + if ((const char *)parser->msgbase + pointer >= *cur) + return -2; + + *cur += 2; + // backup cur only when the first pointer occurs if (curbackup == NULL) curbackup = *cur; @@ -707,7 +711,7 @@ void dns_parser_init(dns_parser_t *parser) parser->bufsize = 0; parser->complete = 0; parser->single_packet = 0; - memset(&parser->header, 0, DNS_HEADER_SIZE); + memset(&parser->header, 0, sizeof (struct dns_header)); memset(&parser->question, 0, sizeof (struct dns_question)); INIT_LIST_HEAD(&parser->answer_list); INIT_LIST_HEAD(&parser->authority_list); @@ -770,16 +774,16 @@ int dns_parser_parse_all(dns_parser_t *parser) parser->cur = (const char *)parser->msgbase; h = &parser->header; - if (parser->msgsize < DNS_HEADER_SIZE) + if (parser->msgsize < sizeof (struct dns_header)) return -2; - memcpy(h, parser->msgbase, DNS_HEADER_SIZE); + memcpy(h, parser->msgbase, sizeof (struct dns_header)); h->id = ntohs(h->id); h->qdcount = ntohs(h->qdcount); h->ancount = ntohs(h->ancount); h->nscount = ntohs(h->nscount); h->arcount = ntohs(h->arcount); - parser->cur += DNS_HEADER_SIZE; + parser->cur += sizeof (struct dns_header); ret = __dns_parser_parse_question(parser); if (ret < 0) @@ -911,6 +915,186 @@ int dns_record_cursor_find_cname(const char *name, return 1; } +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list) +{ + struct __dns_record_entry *entry; + size_t entry_size = sizeof (struct __dns_record_entry) + rlen; + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.name = strdup(name); + if (!entry->record.name) + { + free(entry); + return -1; + } + + entry->record.type = type; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + entry->record.rdlength = rlen; + entry->record.rdata = (void *)(entry + 1); + memcpy(entry->record.rdata, rdata, rlen); + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list) +{ + size_t rlen = strlen(rdata); + // record.rdlength has no meaning for parsed record types, ignore its + // correctness, same for soa/srv/mx record + return dns_add_raw_record(name, type, rclass, ttl, rlen+1, rdata, list); +} + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_soa *soa; + size_t entry_size; + char *pname, *pmname, *prname; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_soa); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + soa = (struct dns_record_soa *)(entry->record.rdata); + + pname = strdup(name); + pmname = strdup(mname); + prname = strdup(rname); + + if (!pname || !pmname || !prname) + { + free(pname); + free(pmname); + free(prname); + free(entry); + return -1; + } + + soa->mname = pmname; + soa->rname = prname; + soa->serial = serial; + soa->refresh = refresh; + soa->retry = retry; + soa->expire = expire; + soa->minimum = minimum; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SOA; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_srv *srv; + size_t entry_size; + char *pname, *ptarget; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_srv); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + srv = (struct dns_record_srv *)(entry->record.rdata); + + pname = strdup(name); + ptarget = strdup(target); + + if (!pname || !ptarget) + { + free(pname); + free(ptarget); + free(entry); + return -1; + } + + srv->priority = priority; + srv->weight = weight; + srv->port = port; + srv->target = ptarget; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_SRV; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list) +{ + struct __dns_record_entry *entry; + struct dns_record_mx *mx; + size_t entry_size; + char *pname, *pexchange; + + entry_size = sizeof (struct __dns_record_entry) + + sizeof (struct dns_record_mx); + + entry = (struct __dns_record_entry *)malloc(entry_size); + if (!entry) + return -1; + + entry->record.rdata = (void *)(entry + 1); + entry->record.rdlength = 0; + mx = (struct dns_record_mx *)(entry->record.rdata); + + pname = strdup(name); + pexchange = strdup(exchange); + + if (!pname || !pexchange) + { + free(pname); + free(pexchange); + free(entry); + return -1; + } + + mx->preference = preference; + mx->exchange = pexchange; + + entry->record.name = pname; + entry->record.type = DNS_TYPE_MX; + entry->record.rclass = rclass; + entry->record.ttl = ttl; + list_add_tail(&entry->entry_list, list); + + return 0; +} + const char *dns_type2str(int type) { switch (type) diff --git a/src/protocol/dns_parser.h b/src/protocol/dns_parser.h index 892af6c7a1..e950767d5b 100644 --- a/src/protocol/dns_parser.h +++ b/src/protocol/dns_parser.h @@ -78,11 +78,19 @@ enum DNS_RCODE_REFUSED }; +enum +{ + DNS_ANSWER_SECTION = 1, + DNS_AUTHORITY_SECTION = 2, + DNS_ADDITIONAL_SECTION = 3, +}; + /** * dns_header_t is a struct to describe the header of a dns * request or response packet, but the byte order is not * transformed. */ +#pragma pack(1) struct dns_header { uint16_t id; @@ -112,6 +120,7 @@ struct dns_header uint16_t nscount; uint16_t arcount; }; +#pragma pack() struct dns_question { @@ -205,6 +214,29 @@ int dns_record_cursor_find_cname(const char *name, const char **cname, dns_record_cursor_t *cursor); +int dns_add_raw_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, uint16_t rlen, const void *rdata, + struct list_head *list); + +int dns_add_str_record(const char *name, uint16_t type, uint16_t rclass, + uint32_t ttl, const char *rdata, + struct list_head *list); + +int dns_add_soa_record(const char *name, uint16_t rclass, uint32_t ttl, + const char *mname, const char *rname, + uint32_t serial, int32_t refresh, + int32_t retry, int32_t expire, uint32_t minimum, + struct list_head *list); + +int dns_add_srv_record(const char *name, uint16_t rclass, uint32_t ttl, + uint16_t priority, uint16_t weight, + uint16_t port, const char *target, + struct list_head *list); + +int dns_add_mx_record(const char *name, uint16_t rclass, uint32_t ttl, + int16_t preference, const char *exchange, + struct list_head *list); + const char *dns_type2str(int type); const char *dns_class2str(int dnsclass); const char *dns_opcode2str(int opcode);