diff --git a/Makefile b/Makefile index 63c768d..bebbe16 100755 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ NAME := epgsql -VERSION := 1.0 +VERSION := 1.1 ERL := erl ERLC := erlc diff --git a/README b/README index 10924ca..c7b637b 100644 --- a/README +++ b/README @@ -8,6 +8,7 @@ Erlang PostgreSQL Database Client - database - port + - ssl (true | false | required) ok = pgsql:close(C). diff --git a/src/epgsql.app b/src/epgsql.app index 1d2bc62..b6c1f9e 100644 --- a/src/epgsql.app +++ b/src/epgsql.app @@ -1,7 +1,8 @@ {application, epgsql, [{description, "PostgreSQL Client"}, - {vsn, "1.0"}, - {modules, [pgsql, pgsql_binary, pgsql_connection, pgsql_datetime, pgsql_types]}, + {vsn, "1.1"}, + {modules, [pgsql, pgsql_binary, pgsql_connection, pgsql_fdatetime, + pgsql_idatetime, pgsql_sock, pgsql_types]}, {registered, []}, - {applications, [kernel, stdlib]}, + {applications, [kernel, stdlib, crypto, ssl]}, {included_applications, []}]}. diff --git a/src/pgsql_connection.erl b/src/pgsql_connection.erl index 8db0488..7723391 100644 --- a/src/pgsql_connection.erl +++ b/src/pgsql_connection.erl @@ -11,7 +11,6 @@ -export([init/1, handle_event/3, handle_sync_event/4]). -export([handle_info/3, terminate/3, code_change/4]). --export([read/3]). -export([startup/3, auth/2, initializing/2, ready/2, ready/3]). -export([querying/2, parsing/2, binding/2, describing/2]). @@ -93,7 +92,7 @@ handle_event(Event, _State_Name, State) -> handle_sync_event(Event, _From, _State_Name, State) -> {stop, {unsupported_sync_event, Event}, State}. -handle_info({'EXIT', Pid, Reason}, _State_Name, State = #state{reader = Pid}) -> +handle_info({'EXIT', Pid, Reason}, _State_Name, State = #state{sock = Pid}) -> {stop, Reason, State}; handle_info(Info, _State_Name, State) -> @@ -101,8 +100,7 @@ handle_info(Info, _State_Name, State) -> terminate(_Reason, _State_Name, State = #state{sock = Sock}) when Sock =/= undefined -> - send(State, $X, []), - gen_tcp:close(Sock); + send(State, $X, []); terminate(_Reason, _State_Name, _State) -> ok. @@ -113,25 +111,11 @@ code_change(_Old_Vsn, State_Name, State, _Extra) -> %% -- states -- startup({connect, Host, Username, Password, Opts}, From, State) -> - Port = proplists:get_value(port, Opts, 5432), - Sock_Opts = [{active, false}, {packet, raw}, binary], - case gen_tcp:connect(Host, Port, Sock_Opts) of + case pgsql_sock:start_link(self(), Host, Username, Opts) of {ok, Sock} -> - Reader = spawn_link(?MODULE, read, [self(), Sock, <<>>]), - - Opts2 = ["user", 0, Username, 0], - case proplists:get_value(database, Opts, undefined) of - undefined -> Opts3 = Opts2; - Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]] - end, - put(username, Username), put(password, Password), - State2 = State#state{reader = Reader, - sock = Sock, - reply_to = From}, - send(State2, [<<196608:32>>, Opts3, 0]), - + State2 = State#state{sock = Sock, reply_to = From}, {next_state, auth, State2}; Error -> {stop, normal, Error, State} @@ -167,9 +151,8 @@ auth({$R, <>}, State) -> {stop, normal, State}; %% ErrorResponse -auth({$E, Bin}, State) -> - Error = decode_error(Bin), - case Error#error.code of +auth({error, E}, State) -> + case E#error.code of <<"28000">> -> Why = invalid_authorization_specification; Any -> Why = Any end, @@ -182,9 +165,8 @@ initializing({$K, <>}, State) -> {next_state, initializing, State2}; %% ErrorResponse -initializing({$E, Bin}, State) -> - Error = decode_error(Bin), - case Error#error.code of +initializing({error, E}, State) -> + case E#error.code of <<"28000">> -> Why = invalid_authorization_specification; Any -> Why = Any end, @@ -311,9 +293,8 @@ querying({$I, _Bin}, State) -> {next_state, querying, State}; %% ErrorResponse -querying({$E, Bin}, State) -> - Error = decode_error(Bin), - notify(State, {error, Error}), +querying({error, E}, State) -> + notify(State, {error, E}), {next_state, querying, State}; %% ReadyForQuery @@ -326,8 +307,8 @@ parsing({$1, <<>>}, State) -> {next_state, describing, State}; %% ErrorResponse -parsing({$E, Bin}, State) -> - Reply = {error, decode_error(Bin)}, +parsing({error, E}, State) -> + Reply = {error, E}, send(State, $S, []), {next_state, parsing, State#state{reply = Reply}}; @@ -343,8 +324,8 @@ binding({$2, <<>>}, State) -> {next_state, ready, State}; %% ErrorResponse -binding({$E, Bin}, State) -> - Reply = {error, decode_error(Bin)}, +binding({error, E}, State) -> + Reply = {error, E}, send(State, $S, []), {next_state, binding, State#state{reply = Reply}}; @@ -375,8 +356,8 @@ describing({$n, <<>>}, State) -> {next_state, ready, State}; %% ErrorResponse -describing({$E, Bin}, State) -> - Reply = {error, decode_error(Bin)}, +describing({error, E}, State) -> + Reply = {error, E}, send(State, $S, []), {next_state, describing, State#state{reply = Reply}}; @@ -409,8 +390,8 @@ executing({$I, _Bin}, State) -> {next_state, ready, State}; %% ErrorResponse -executing({$E, Bin}, State) -> - notify(State, {error, decode_error(Bin)}), +executing({error, E}, State) -> + notify(State, {error, E}), {next_state, executing, State}. %% CloseComplete @@ -419,14 +400,14 @@ closing({$3, <<>>}, State) -> {next_state, ready, State}; %% ErrorResponse -closing({$E, Bin}, State) -> - Error = {error, decode_error(Bin)}, +closing({error, E}, State) -> + Error = {error, E}, gen_fsm:reply(State#state.reply_to, Error), {next_state, ready, State}. %% ErrorResponse -synchronizing({$E, Bin}, State) -> - Reply = {error, decode_error(Bin)}, +synchronizing({error, E}, State) -> + Reply = {error, E}, {next_state, synchronizing, State#state{reply = Reply}}; %% ReadyForQuery @@ -437,35 +418,6 @@ synchronizing({$Z, <>}, State) -> %% -- internal functions -- -%% decode a single null-terminated string -decode_string(Bin) -> - decode_string(Bin, <<>>). - -decode_string(<<0, Rest/binary>>, Str) -> - {Str, Rest}; -decode_string(<>, Str) -> - decode_string(Rest, <>). - -%% decode multiple null-terminated string -decode_strings(Bin) -> - decode_strings(Bin, []). - -decode_strings(<<>>, Acc) -> - lists:reverse(Acc); -decode_strings(Bin, Acc) -> - {Str, Rest} = decode_string(Bin), - decode_strings(Rest, [Str | Acc]). - -%% decode field -decode_fields(Bin) -> - decode_fields(Bin, []). - -decode_fields(<<0>>, Acc) -> - Acc; -decode_fields(<>, Acc) -> - {Str, Rest2} = decode_string(Rest), - decode_fields(Rest2, [{Type, Str} | Acc]). - %% decode data decode_data(Columns, Bin) -> decode_data(Columns, Bin, []). @@ -488,7 +440,7 @@ decode_columns(Count, Bin) -> decode_columns(0, _Bin, Acc) -> lists:reverse(Acc); decode_columns(N, Bin, Acc) -> - {Name, Rest} = decode_string(Bin), + {Name, Rest} = pgsql_sock:decode_string(Bin), <<_Table_Oid:?int32, _Attrib_Num:?int16, Type_Oid:?int32, Size:?int16, Modifier:?int32, Format:?int16, Rest2/binary>> = Rest, Desc = #column{ @@ -504,36 +456,14 @@ decode_complete(<<"SELECT", 0>>) -> select; decode_complete(<<"BEGIN", 0>>) -> 'begin'; decode_complete(<<"ROLLBACK", 0>>) -> rollback; decode_complete(Bin) -> - {Str, _} = decode_string(Bin), + {Str, _} = pgsql_sock:decode_string(Bin), case string:tokens(binary_to_list(Str), " ") of ["INSERT", _Oid, Rows] -> {insert, list_to_integer(Rows)}; ["UPDATE", Rows] -> {update, list_to_integer(Rows)}; ["DELETE", Rows] -> {delete, list_to_integer(Rows)}; ["MOVE", Rows] -> {move, list_to_integer(Rows)}; ["FETCH", Rows] -> {fetch, list_to_integer(Rows)}; - [Type | _Rest] -> lower_atom(Type) - end. - -%% decode ErrorResponse -decode_error(Bin) -> - Fields = decode_fields(Bin), - Error = #error{ - severity = lower_atom(proplists:get_value($S, Fields)), - code = proplists:get_value($C, Fields), - message = proplists:get_value($M, Fields), - extra = decode_error_extra(Fields)}, - Error. - -decode_error_extra(Fields) -> - Types = [{$D, detail}, {$H, hint}, {$P, position}], - decode_error_extra(Types, Fields, []). - -decode_error_extra([], _Fields, Extra) -> - Extra; -decode_error_extra([{Type, Name} | T], Fields, Extra) -> - case proplists:get_value(Type, Fields) of - undefined -> decode_error_extra(T, Fields, Extra); - Value -> decode_error_extra(T, Fields, [{Name, Value} | Extra]) + [Type | _Rest] -> pgsql_sock:lower_atom(Type) end. %% encode types @@ -599,11 +529,6 @@ encode_list(L) -> notify(#state{reply_to = {Pid, _Tag}}, Msg) -> Pid ! {pgsql, self(), Msg}. -lower_atom(Str) when is_binary(Str) -> - lower_atom(binary_to_list(Str)); -lower_atom(Str) when is_list(Str) -> - list_to_atom(string:to_lower(Str)). - to_binary(B) when is_binary(B) -> B; to_binary(L) when is_list(L) -> list_to_binary(L). @@ -616,36 +541,4 @@ hex(Bin) -> %% send data to server send(#state{sock = Sock}, Type, Data) -> - Bin = iolist_to_binary(Data), - gen_tcp:send(Sock, <>). - -send(#state{sock = Sock}, Data) -> - Bin = iolist_to_binary(Data), - gen_tcp:send(Sock, <<(byte_size(Bin) + 4):?int32, Bin/binary>>). - -%% -- socket read loop -- - -read(Fsm, Sock, Tail) -> - case gen_tcp:recv(Sock, 0) of - {ok, Bin} -> decode(Fsm, Sock, <>); - Error -> exit(Error) - end. - -decode(Fsm, Sock, <> = Bin) -> - Len2 = Len - 4, - case Rest of - <> when Type == $N -> - gen_fsm:send_all_state_event(Fsm, {notice, decode_error(Data)}), - decode(Fsm, Sock, Tail); - <> when Type == $S -> - [Name, Value] = decode_strings(Data), - gen_fsm:send_all_state_event(Fsm, {parameter_status, Name, Value}), - decode(Fsm, Sock, Tail); - <> -> - gen_fsm:send_event(Fsm, {Type, Data}), - decode(Fsm, Sock, Tail); - _Other -> - ?MODULE:read(Fsm, Sock, Bin) - end; -decode(Fsm, Sock, Bin) -> - ?MODULE:read(Fsm, Sock, Bin). + pgsql_sock:send(Sock, Type, Data). diff --git a/src/pgsql_sock.erl b/src/pgsql_sock.erl new file mode 100644 index 0000000..6cce681 --- /dev/null +++ b/src/pgsql_sock.erl @@ -0,0 +1,201 @@ +%%% Copyright (C) 2009 - Will Glozer. All rights reserved. + +-module(pgsql_sock). + +-behavior(gen_server). + +-export([start_link/4, send/2, send/3]). +-export([decode_string/1, lower_atom/1]). + +-export([handle_call/3, handle_cast/2, handle_info/2]). +-export([init/1, code_change/3, terminate/2]). + +-include("pgsql.hrl"). + +-record(state, {c, mod, sock, tail}). + +-define(int16, 1/big-signed-unit:16). +-define(int32, 1/big-signed-unit:32). + +%% -- client interface -- + +start_link(C, Host, Username, Opts) -> + gen_server:start_link(?MODULE, [C, Host, Username, Opts], []). + +send(S, Type, Data) -> + Bin = iolist_to_binary(Data), + Msg = <>, + gen_server:cast(S, {send, Msg}). + +send(S, Data) -> + Bin = iolist_to_binary(Data), + Msg = <<(byte_size(Bin) + 4):?int32, Bin/binary>>, + gen_server:cast(S, {send, Msg}). + +%% -- gen_server implementation -- + +init([C, Host, Username, Opts]) -> + process_flag(trap_exit, true), + + Opts2 = ["user", 0, Username, 0], + case proplists:get_value(database, Opts, undefined) of + undefined -> Opts3 = Opts2; + Database -> Opts3 = [Opts2 | ["database", 0, Database, 0]] + end, + + Port = proplists:get_value(port, Opts, 5432), + SockOpts = [{active, false}, {packet, raw}, binary], + {ok, S} = gen_tcp:connect(Host, Port, SockOpts), + + State = #state{ + c = C, + mod = gen_tcp, + sock = S, + tail = <<>>}, + + case proplists:get_value(ssl, Opts) of + T when T == true; T == required -> + ok = gen_tcp:send(S, <<8:?int32, 80877103:?int32>>), + {ok, <>} = gen_tcp:recv(S, 1), + State2 = start_ssl(Code, T, Opts, State); + _ -> + State2 = State + end, + + setopts(State2, [{active, true}]), + send(self(), [<<196608:32>>, Opts3, 0]), + {ok, State2}. + +handle_call(Call, _From, State) -> + {stop, {unsupported_call, Call}, State}. + +handle_cast({send, Data}, State) -> + #state{mod = Mod, sock = Sock} = State, + ok = Mod:send(Sock, Data), + {noreply, State}; + +handle_cast(Cast, State) -> + {stop, {unsupported_cast, Cast}, State}. + +handle_info({_, _Sock, Data}, #state{tail = Tail} = State) -> + State2 = decode(<>, State), + {noreply, State2}; + +handle_info({Closed, _Sock}, State) + when Closed == tcp_closed; Closed == ssl_closed -> + {stop, sock_closed, State}; + +handle_info({Error, _Sock, Reason}, State) + when Error == tcp_error; Error == ssl_error -> + {stop, {sock_error, Reason}, State}; + +handle_info({'EXIT', _Pid, Reason}, State) -> + {stop, Reason, State}; + +handle_info(Info, State) -> + {stop, {unsupported_info, Info}, State}. + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%% -- internal functions -- + +start_ssl($S, _Flag, Opts, State) -> + #state{sock = S1} = State, + case ssl:connect(S1, Opts) of + {ok, S2} -> State#state{mod = ssl, sock = S2}; + {error, Reason} -> exit({ssl_negotiation_failed, Reason}) + end; + +start_ssl($N, Flag, _Opts, State) -> + case Flag of + true -> State; + required -> exit(ssl_not_available) + end. + +setopts(#state{mod = Mod, sock = Sock}, Opts) -> + case Mod of + gen_tcp -> inet:setopts(Sock, Opts); + ssl -> ssl:setopts(Sock, Opts) + end. + +decode(<> = Bin, #state{c = C} = State) -> + Len2 = Len - 4, + case Rest of + <> when Type == $N -> + gen_fsm:send_all_state_event(C, {notice, decode_error(Data)}), + decode(Tail, State); + <> when Type == $S -> + [Name, Value] = decode_strings(Data), + gen_fsm:send_all_state_event(C, {parameter_status, Name, Value}), + decode(Tail, State); + <> when Type == $E -> + gen_fsm:send_event(C, {error, decode_error(Data)}), + decode(Tail, State); + <> -> + gen_fsm:send_event(C, {Type, Data}), + decode(Tail, State); + _Other -> + State#state{tail = Bin} + end; +decode(Bin, State) -> + State#state{tail = Bin}. + +%% decode a single null-terminated string +decode_string(Bin) -> + decode_string(Bin, <<>>). + +decode_string(<<0, Rest/binary>>, Str) -> + {Str, Rest}; +decode_string(<>, Str) -> + decode_string(Rest, <>). + +%% decode multiple null-terminated string +decode_strings(Bin) -> + decode_strings(Bin, []). + +decode_strings(<<>>, Acc) -> + lists:reverse(Acc); +decode_strings(Bin, Acc) -> + {Str, Rest} = decode_string(Bin), + decode_strings(Rest, [Str | Acc]). + +%% decode field +decode_fields(Bin) -> + decode_fields(Bin, []). + +decode_fields(<<0>>, Acc) -> + Acc; +decode_fields(<>, Acc) -> + {Str, Rest2} = decode_string(Rest), + decode_fields(Rest2, [{Type, Str} | Acc]). + +%% decode ErrorResponse +decode_error(Bin) -> + Fields = decode_fields(Bin), + Error = #error{ + severity = lower_atom(proplists:get_value($S, Fields)), + code = proplists:get_value($C, Fields), + message = proplists:get_value($M, Fields), + extra = decode_error_extra(Fields)}, + Error. + +decode_error_extra(Fields) -> + Types = [{$D, detail}, {$H, hint}, {$P, position}], + decode_error_extra(Types, Fields, []). + +decode_error_extra([], _Fields, Extra) -> + Extra; +decode_error_extra([{Type, Name} | T], Fields, Extra) -> + case proplists:get_value(Type, Fields) of + undefined -> decode_error_extra(T, Fields, Extra); + Value -> decode_error_extra(T, Fields, [{Name, Value} | Extra]) + end. + +lower_atom(Str) when is_binary(Str) -> + lower_atom(binary_to_list(Str)); +lower_atom(Str) when is_list(Str) -> + list_to_atom(string:to_lower(Str)). diff --git a/test_src/pgsql_tests.erl b/test_src/pgsql_tests.erl index 326d526..0d0bb4e 100644 --- a/test_src/pgsql_tests.erl +++ b/test_src/pgsql_tests.erl @@ -34,6 +34,14 @@ connect_with_invalid_password_test() -> "epgsql_test_sha1", [{port, ?port}, {database, "epgsql_test_db1"}]). +connect_with_ssl_test() -> + lists:foreach(fun application:start/1, [crypto, ssl]), + with_connection( + fun(C) -> + {ok, _Cols, [{true}]} = pgsql:equery(C, "select ssl_is_used()") + end, + [{ssl, true}]). + select_test() -> with_connection( fun(C) -> @@ -394,8 +402,11 @@ connect_only(Args) -> flush(). with_connection(F) -> - Args = [{port, ?port}, {database, "epgsql_test_db1"}], - {ok, C} = pgsql:connect(?host, "epgsql_test", Args), + with_connection(F, []). + +with_connection(F, Args) -> + Args2 = [{port, ?port}, {database, "epgsql_test_db1"} | Args], + {ok, C} = pgsql:connect(?host, "epgsql_test", Args2), try F(C) after diff --git a/test_src/test_schema.sql b/test_src/test_schema.sql index 723ecfd..d1e30cb 100644 --- a/test_src/test_schema.sql +++ b/test_src/test_schema.sql @@ -12,6 +12,9 @@ -- -- any 'trust all' must be commented out for the invalid password test -- to succeed. +-- +-- ssl support must be configured, and the sslinfo contrib module +-- loaded for the ssl tests to succeed. CREATE USER epgsql_test;