diff --git a/src/mysqlerl.c b/src/mysqlerl.c index c757f22..4932700 100644 --- a/src/mysqlerl.c +++ b/src/mysqlerl.c @@ -12,7 +12,7 @@ #include #include -const char *CONNECT_MSG = "connect"; +const char *CONNECT_MSG = "sql_connect"; const char *QUERY_MSG = "sql_query"; const char *PARAM_QUERY_MSG = "sql_param_query"; const char *SELECT_COUNT_MSG = "sql_select_count"; @@ -22,16 +22,10 @@ const char *LAST_MSG = "sql_last"; const char *NEXT_MSG = "sql_next"; const char *PREV_MSG = "sql_prev"; +MYSQL dbh; MYSQL_RES *results = NULL; my_ulonglong resultoffset = 0, numrows = 0; -void -usage() -{ - fprintf(stderr, "Usage: mysqlerl host port db_name user passwd\n"); - exit(1); -} - void set_mysql_results(MYSQL_RES *res) { @@ -158,7 +152,49 @@ handle_mysql_result() } void -handle_query(MYSQL *dbh, ETERM *cmd) +handle_connect(ETERM *msg) +{ + ETERM *resp, *tmp; + char *host, *db_name, *user, *passwd; + int port; + + tmp = erl_element(2, msg); + host = erl_iolist_to_string(tmp); + erl_free_term(tmp); + + tmp = erl_element(3, msg); + port = ERL_INT_VALUE(tmp); + erl_free_term(tmp); + + tmp = erl_element(4, msg); + db_name = erl_iolist_to_string(tmp); + erl_free_term(tmp); + + tmp = erl_element(5, msg); + user = erl_iolist_to_string(tmp); + erl_free_term(tmp); + + tmp = erl_element(6, msg); + passwd = erl_iolist_to_string(tmp); + erl_free_term(tmp); + + /* TODO: handle options, passed in next. */ + + logmsg("INFO: Connecting to %s on %s:%d as %s", db_name, host, port, user); + if (mysql_real_connect(&dbh, host, user, passwd, + db_name, port, NULL, 0) == NULL) { + logmsg("ERROR: Failed to connect to database %s as %s: %s.", + db_name, user, mysql_error(&dbh)); + exit(2); + } + + resp = erl_format("ok"); + write_msg(resp); + erl_free_term(resp); +} + +void +handle_query(ETERM *cmd) { ETERM *query, *resp; char *q; @@ -168,20 +204,20 @@ handle_query(MYSQL *dbh, ETERM *cmd) erl_free_term(query); logmsg("DEBUG: got query msg: %s.", q); - if (mysql_query(dbh, q)) { + if (mysql_query(&dbh, q)) { resp = erl_format("{error, {mysql_error, ~i, ~s}}", - mysql_errno(dbh), mysql_error(dbh)); + mysql_errno(&dbh), mysql_error(&dbh)); } else { - set_mysql_results(mysql_store_result(dbh)); + set_mysql_results(mysql_store_result(&dbh)); if (results) { resp = handle_mysql_result(results); set_mysql_results(NULL); } else { - if (mysql_field_count(dbh) == 0) - resp = erl_format("{updated, ~i}", mysql_affected_rows(dbh)); + if (mysql_field_count(&dbh) == 0) + resp = erl_format("{updated, ~i}", mysql_affected_rows(&dbh)); else resp = erl_format("{error, {mysql_error, ~i, ~s}}", - mysql_errno(dbh), mysql_error(dbh)); + mysql_errno(&dbh), mysql_error(&dbh)); } } erl_free(q); @@ -207,7 +243,7 @@ handle_query(MYSQL *dbh, ETERM *cmd) * {updated, 7} */ void -handle_param_query(MYSQL *dbh, ETERM *msg) +handle_param_query(ETERM *msg) { ETERM *query, *params; char *q; @@ -225,7 +261,7 @@ handle_param_query(MYSQL *dbh, ETERM *msg) } void -handle_select_count(MYSQL *dbh, ETERM *msg) +handle_select_count(ETERM *msg) { ETERM *query, *resp; char *q; @@ -235,19 +271,19 @@ handle_select_count(MYSQL *dbh, ETERM *msg) erl_free_term(query); logmsg("DEBUG: got select count msg: %s.", q); - if (mysql_query(dbh, q)) { + if (mysql_query(&dbh, q)) { resp = erl_format("{error, {mysql_error, ~i, ~s}}", - mysql_errno(dbh), mysql_error(dbh)); + mysql_errno(&dbh), mysql_error(&dbh)); } else { - set_mysql_results(mysql_store_result(dbh)); + set_mysql_results(mysql_store_result(&dbh)); if (results) { resp = erl_format("{ok, ~i}", numrows); } else { - if (mysql_field_count(dbh) == 0) - resp = erl_format("{ok, ~i}", mysql_affected_rows(dbh)); + if (mysql_field_count(&dbh) == 0) + resp = erl_format("{ok, ~i}", mysql_affected_rows(&dbh)); else resp = erl_format("{error, {mysql_error, ~i, ~s}}", - mysql_errno(dbh), mysql_error(dbh)); + mysql_errno(&dbh), mysql_error(&dbh)); } } erl_free(q); @@ -257,7 +293,7 @@ handle_select_count(MYSQL *dbh, ETERM *msg) } void -handle_select(MYSQL *dbh, ETERM *msg) +handle_select(ETERM *msg) { MYSQL_FIELD *fields; ETERM *epos, *enum_items, *ecols, *erows, *resp; @@ -297,7 +333,7 @@ handle_select(MYSQL *dbh, ETERM *msg) } void -handle_first(MYSQL *dbh, ETERM *msg) +handle_first(ETERM *msg) { MYSQL_FIELD *fields; ETERM *ecols, *erows, *resp; @@ -325,7 +361,7 @@ handle_first(MYSQL *dbh, ETERM *msg) } void -handle_last(MYSQL *dbh, ETERM *msg) +handle_last(ETERM *msg) { MYSQL_FIELD *fields; ETERM *ecols, *erows, *resp; @@ -353,7 +389,7 @@ handle_last(MYSQL *dbh, ETERM *msg) } void -handle_next(MYSQL *dbh, ETERM *msg) +handle_next(ETERM *msg) { MYSQL_FIELD *fields; ETERM *ecols, *erows, *resp; @@ -384,7 +420,7 @@ handle_next(MYSQL *dbh, ETERM *msg) } void -handle_prev(MYSQL *dbh, ETERM *msg) +handle_prev(ETERM *msg) { MYSQL_FIELD *fields; ETERM *ecols, *erows, *resp; @@ -421,29 +457,31 @@ handle_prev(MYSQL *dbh, ETERM *msg) } void -dispatch_db_cmd(MYSQL *dbh, ETERM *msg) +dispatch_db_cmd(ETERM *msg) { ETERM *tag; char *tag_name; tag = erl_element(1, msg); tag_name = (char *)ERL_ATOM_PTR(tag); - if (strncmp(tag_name, QUERY_MSG, strlen(QUERY_MSG)) == 0) - handle_query(dbh, msg); + if (strncmp(tag_name, CONNECT_MSG, strlen(CONNECT_MSG)) == 0) + handle_connect(msg); + else if (strncmp(tag_name, QUERY_MSG, strlen(QUERY_MSG)) == 0) + handle_query(msg); else if (strncmp(tag_name, PARAM_QUERY_MSG, strlen(PARAM_QUERY_MSG)) == 0) - handle_param_query(dbh, msg); + handle_param_query(msg); else if (strncmp(tag_name, SELECT_COUNT_MSG, strlen(SELECT_COUNT_MSG)) == 0) - handle_select_count(dbh, msg); + handle_select_count(msg); else if (strncmp(tag_name, SELECT_MSG, strlen(SELECT_MSG)) == 0) - handle_select(dbh, msg); + handle_select(msg); else if (strncmp(tag_name, FIRST_MSG, strlen(FIRST_MSG)) == 0) - handle_first(dbh, msg); + handle_first(msg); else if (strncmp(tag_name, LAST_MSG, strlen(LAST_MSG)) == 0) - handle_last(dbh, msg); + handle_last(msg); else if (strncmp(tag_name, NEXT_MSG, strlen(NEXT_MSG)) == 0) - handle_next(dbh, msg); + handle_next(msg); else if (strncmp(tag_name, PREV_MSG, strlen(PREV_MSG)) == 0) - handle_prev(dbh, msg); + handle_prev(msg); else { logmsg("WARNING: message type %s unknown.", (char *)ERL_ATOM_PTR(tag)); erl_free_term(tag); @@ -456,37 +494,17 @@ dispatch_db_cmd(MYSQL *dbh, ETERM *msg) int main(int argc, char *argv[]) { - MYSQL dbh; - char *host, *port, *db_name, *user, *passwd; ETERM *msg; openlog(); logmsg("INFO: starting up."); - - if (argc < 6) - usage(); - - host = argv[1]; - port = argv[2]; - db_name = argv[3]; - user = argv[4]; - passwd = argv[5]; - erl_init(NULL, 0); mysql_init(&dbh); - if (mysql_real_connect(&dbh, host, user, passwd, - db_name, atoi(port), NULL, 0) == NULL) { - logmsg("ERROR: Failed to connect to database %s as %s: %s.", - db_name, user, mysql_error(&dbh)); - exit(2); - } - while ((msg = read_msg()) != NULL) { - dispatch_db_cmd(&dbh, msg); + dispatch_db_cmd(msg); erl_free_term(msg); } - mysql_close(&dbh); logmsg("INFO: shutting down."); diff --git a/src/mysqlerl.hrl b/src/mysqlerl.hrl index 4e1230d..99e674f 100644 --- a/src/mysqlerl.hrl +++ b/src/mysqlerl.hrl @@ -1,3 +1,4 @@ +-record(sql_connect, {host, port, database, user, password, options}). -record(sql_query, {q}). -record(sql_param_query, {q, params}). -record(sql_select_count, {q}). diff --git a/src/mysqlerl_connection.erl b/src/mysqlerl_connection.erl index ac0497a..1a3d900 100644 --- a/src/mysqlerl_connection.erl +++ b/src/mysqlerl_connection.erl @@ -23,10 +23,8 @@ stop(Pid) -> init([Owner, Host, Port, Database, User, Password, Options]) -> process_flag(trap_exit, true), link(Owner), - Cmd = lists:flatten(io_lib:format("~s ~s ~w ~s ~s ~s ~s", - [helper(), Host, Port, Database, - User, Password, Options])), - {ok, Sup} = mysqlerl_port_sup:start_link(Cmd), + {ok, Sup} = mysqlerl_port_sup:start_link(helper(), Host, Port, Database, + User, Password, Options), {ok, #state{sup = Sup, owner = Owner}}. terminate(Reason, _State) -> diff --git a/src/mysqlerl_port.erl b/src/mysqlerl_port.erl index 3593119..74fe177 100644 --- a/src/mysqlerl_port.erl +++ b/src/mysqlerl_port.erl @@ -1,23 +1,35 @@ -module(mysqlerl_port). -author('bjc@kublai.com'). +-include("mysqlerl.hrl"). -include("mysqlerl_port.hrl"). -behavior(gen_server). --export([start_link/1]). +-export([start_link/7]). -export([init/1, terminate/2, code_change/3, handle_call/3, handle_cast/2, handle_info/2]). +-define(CONNECT_TIMEOUT, 30000). + -record(state, {ref}). -record(port_closed, {reason}). -start_link(Cmd) -> - gen_server:start_link(?MODULE, [Cmd], []). +start_link(Cmd, Host, Port, Database, User, Password, Options) -> + gen_server:start_link(?MODULE, + [Cmd, Host, Port, Database, User, Password, Options], + []). -init([Cmd]) -> +init([Cmd, Host, Port, Database, User, Password, Options]) -> process_flag(trap_exit, true), Ref = open_port({spawn, Cmd}, [{packet, 4}, binary]), + {data, ok} = send_port_cmd(Ref, #sql_connect{host = Host, + port = Port, + database = Database, + user = User, + password = Password, + options = Options}, + ?CONNECT_TIMEOUT), {ok, #state{ref = Ref}}. terminate(#port_closed{reason = Reason}, #state{ref = Ref}) -> @@ -34,27 +46,38 @@ code_change(_OldVsn, State, _Extra) -> handle_call(#req{request = {Request, Timeout}}, From, #state{ref = Ref} = State) -> - io:format("DEBUG: Sending request: ~p~n", [Request]), - port_command(Ref, term_to_binary(Request)), - receive - {Ref, {data, Res}} -> - {reply, binary_to_term(Res), State}; + case send_port_cmd(Ref, Request, Timeout) of + {data, Res} -> + {reply, Res, State}; {'EXIT', Ref, Reason} -> gen_server:reply(From, {error, connection_closed}), {stop, #port_closed{reason = Reason}, State}; + timeout -> + gen_server:reply(From, timeout), + {stop, timeout, State}; Other -> error_logger:warning_msg("Got unknown query response: ~p~n", [Other]), gen_server:reply(From, {error, connection_closed}), {stop, {unknownreply, Other}, State} - after Timeout -> - gen_server:reply(From, timeout), - {stop, timeout, State} end. + handle_cast(_Request, State) -> {noreply, State}. handle_info({'EXIT', Ref, Reason}, #state{ref = Ref} = State) -> io:format("DEBUG: Port ~p closed on ~p.~n", [Ref, State]), {stop, #port_closed{reason = Reason}, State}. + + +send_port_cmd(Ref, Request, Timeout) -> + io:format("DEBUG: Sending request: ~p~n", [Request]), + port_command(Ref, term_to_binary(Request)), + receive + {Ref, {data, Res}} -> + {data, binary_to_term(Res)}; + Other -> Other + after Timeout -> + timeout + end. diff --git a/src/mysqlerl_port_sup.erl b/src/mysqlerl_port_sup.erl index 820f929..3053cdc 100644 --- a/src/mysqlerl_port_sup.erl +++ b/src/mysqlerl_port_sup.erl @@ -3,14 +3,16 @@ -behavior(supervisor). --export([start_link/1]). +-export([start_link/7]). -export([init/1]). -start_link(Cmd) -> - supervisor:start_link(?MODULE, [Cmd]). +start_link(Cmd, Host, Port, Database, User, Password, Options) -> + supervisor:start_link(?MODULE, [Cmd, Host, Port, Database, + User, Password, Options]). -init([Cmd]) -> - Port = {mysqlerl_port, {mysqlerl_port, start_link, [Cmd]}, +init([Cmd, Host, Port, Database, User, Password, Options]) -> + Ref = {mysqlerl_port, {mysqlerl_port, start_link, + [Cmd, Host, Port, Database, + User, Password, Options]}, transient, 5, worker, [mysqlerl_port]}, - {ok, {{one_for_one, 10, 5}, - [Port]}}. + {ok, {{one_for_one, 10, 5}, [Ref]}}.