diff --git a/lib/bin/console.rb b/lib/bin/console.rb index 484e32950c..18f7a31c13 100644 --- a/lib/bin/console.rb +++ b/lib/bin/console.rb @@ -13,7 +13,9 @@ def initialize(port, pid, path) end def connect(notice=true) - @agent = Rubinius::Agent.connect "localhost", @port + @agent = Rubinius::Agent.connect "localhost", @port do + Readline.readline("password> ") + end if notice puts "Connected to localhost:#{@port}, host type: #{@agent.handshake[1]}" diff --git a/lib/rubinius/agent.rb b/lib/rubinius/agent.rb index 1b25588455..1af3f449e2 100644 --- a/lib/rubinius/agent.rb +++ b/lib/rubinius/agent.rb @@ -2,9 +2,9 @@ module Rubinius class Agent - def self.connect(host, port) + def self.connect(host, port, &b) i = new TCPSocket.new(host, port) - i.handshake! + i.handshake!(&b) return i end @@ -13,10 +13,11 @@ def self.loopback new Rubinius.agent_io end - def initialize(io) + def initialize(io, password=nil) @io = io @decoder = BERT::Decode.new(@io) @encoder = BERT::Encode.new(@io) + @password = password end def handshake! @@ -28,6 +29,17 @@ def handshake! @encoder.write_any :ok + @handshake = @decoder.read_any + elsif @handshake[0] == :password_auth + if @password + @encoder.write_any t[:password, @password] + elsif block_given? + password = yield + @encoder.write_any t[:password, password.to_s] + else + raise "Password required, none available" + end + @handshake = @decoder.read_any end diff --git a/lib/rubinius/configuration.rb b/lib/rubinius/configuration.rb index 302e4d5552..c46f165a05 100644 --- a/lib/rubinius/configuration.rb +++ b/lib/rubinius/configuration.rb @@ -105,6 +105,8 @@ s.vm_variable "tmpdir", :string, "Where to store files used to discover running query agents" + s.vm_variable "password", :string, + "The password required to connect to the agent" end c.vm_variable "tool", :string, diff --git a/vm/agent.cpp b/vm/agent.cpp index ffb19bbc85..bd684ed65f 100644 --- a/vm/agent.cpp +++ b/vm/agent.cpp @@ -89,6 +89,12 @@ namespace rubinius { shared_.globals.rubinius.get()->set_const(state_, "FROM_AGENT", from); shared_.globals.rubinius.get()->set_const(state_, "TO_AGENT", to); + + if(shared_.config.agent_password.set_p()) { + local_only_ = false; + use_password_ = true; + password_ = std::string(shared_.config.agent_password); + } } QueryAgent::~QueryAgent() { @@ -236,42 +242,99 @@ namespace rubinius { encoder.write_binary(name.c_str()); } - } + void request_password(int client) { + bert::IOWriter writer(client); + bert::Encoder encoder(writer); + encoder.write_version(); + encoder.write_tuple(1); + encoder.write_atom("password_auth"); + } + } - bool QueryAgent::process_commands(Client& client) { + bool QueryAgent::check_password(Client& client) { bert::IOReader reader(client.socket); bert::Decoder decoder(reader); bert::IOWriter writer(client.socket); bert::Encoder encoder(writer); + bert::Value* val = 0; - if(client.needs_auth_p()) { - std::stringstream name; - std::ifstream file; - bert::Value* val = 0; + int ver = decoder.read_version(); + if(ver != 131) return false; - name << "/tmp/agent-auth." << getuid() << "-" << getpid() << "." << client.auth_key; + val = decoder.next_value(); + if(!val) return false; - int ver = decoder.read_version(); - if(ver != 131) goto auth_error; + if(reader.eof_p()) { + delete val; + return false; + } - val = decoder.next_value(); - if(!val) goto auth_error; + if(val->type() == bert::Tuple) { + bert::Value* cmd = val->get_element(0); + if(cmd->equal_atom("password")) { + bert::Value* pass = val->get_element(1); - if(reader.eof_p()) { - delete val; - goto auth_error; + if(pass->type() == bert::Binary) { + if(password_ == std::string(pass->string())) return true; + } } + } + + return false; + } + + bool QueryAgent::check_file_auth(Client& client) { + bert::IOReader reader(client.socket); + bert::Decoder decoder(reader); + + bert::IOWriter writer(client.socket); + bert::Encoder encoder(writer); + + std::stringstream name; + std::ifstream file; + bert::Value* val = 0; + + name << "/tmp/agent-auth." << getuid() << "-" << getpid() << "." << client.auth_key; + + int ver = decoder.read_version(); + if(ver != 131) goto auth_error; + + val = decoder.next_value(); + if(!val) goto auth_error; + + if(reader.eof_p()) { + delete val; + goto auth_error; + } + + if(!val->equal_atom("ok")) goto auth_error; + + file.open(name.str().c_str()); - if(!val->equal_atom("ok")) goto auth_error; + char key[PATH_MAX]; + file.getline(key, PATH_MAX); - file.open(name.str().c_str()); + if(strcmp(key, "agent start") != 0) goto auth_error; - char key[PATH_MAX]; - file.getline(key, PATH_MAX); + unlink(name.str().c_str()); + return true; + +auth_error: + unlink(name.str().c_str()); + return false; + } - if(strcmp(key, "agent start") != 0) goto auth_error; + bool QueryAgent::process_commands(Client& client) { + if(client.needs_auth_p()) { + if(local_only_) { + if(!check_file_auth(client)) return false; + } else if(use_password_) { + if(!check_password(client)) return false; + } else { + return false; + } if(verbose_) { struct sockaddr_in sin; @@ -279,20 +342,20 @@ namespace rubinius { getpeername(client.socket, (struct sockaddr*)&sin, &len); std::cerr << "[QA: Authenticated " << inet_ntoa(sin.sin_addr) - << ":" << ntohs(sin.sin_port) << "]\n"; + << ":" << ntohs(sin.sin_port) << "]\n"; } client.set_running(); write_welcome(client.socket); - - unlink(name.str().c_str()); return true; - -auth_error: - unlink(name.str().c_str()); - return false; } + bert::IOReader reader(client.socket); + bert::Decoder decoder(reader); + + bert::IOWriter writer(client.socket); + bert::Encoder encoder(writer); + int ver = decoder.read_version(); if(ver != 131) return false; @@ -458,6 +521,22 @@ namespace rubinius { sockets_.push_back(cl); + continue; + } else if(use_password_) { + Client cl(client); + cl.begin_auth(0); + + if(verbose_) { + std::cerr << "[QA: Requesting password auth from " << inet_ntoa(sin.sin_addr) + << ":" << ntohs(sin.sin_port) << "]\n"; + } + + request_password(client); + + add_fd(client); + + sockets_.push_back(cl); + continue; } diff --git a/vm/agent.hpp b/vm/agent.hpp index a680ffb6f1..10587296d7 100644 --- a/vm/agent.hpp +++ b/vm/agent.hpp @@ -70,6 +70,8 @@ namespace rubinius { agent::VariableAccess* vars_; bool local_only_; + bool use_password_; + std::string password_; uint32_t tmp_key_; const static int cBackLog = 10; @@ -135,6 +137,8 @@ namespace rubinius { void make_discoverable(); virtual void perform(); + bool check_password(Client& client); + bool check_file_auth(Client& client); bool process_commands(Client& client); void on_fork();