Skip to content

Commit

Permalink
Support SOCKS proxies.
Browse files Browse the repository at this point in the history
The trickiest part of this change is the SOCKS 5 proxy implemented
to make testing possible. Fortunately the protocol is very easy, and
shows off Okio.

Closes #1009
  • Loading branch information
swankjesse committed Dec 28, 2014
1 parent 29ec4f3 commit 51846fe
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 2 deletions.
232 changes: 232 additions & 0 deletions okhttp-tests/src/test/java/com/squareup/okhttp/SocksProxy.java
@@ -0,0 +1,232 @@
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.okhttp;

import com.squareup.okhttp.internal.NamedRunnable;
import com.squareup.okhttp.internal.Util;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ProtocolException;
import java.net.Proxy;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Level;
import java.util.logging.Logger;
import okio.Buffer;
import okio.BufferedSink;
import okio.BufferedSource;
import okio.Okio;

/**
* A limited implementation of SOCKS Protocol Version 5, intended to be similar to MockWebServer.
* See <a href="https://www.ietf.org/rfc/rfc1928.tx">RFC 1928</a>.
*/
public final class SocksProxy {
private static final int VERSION_5 = 5;
private static final int METHOD_NONE = 0xff;
private static final int METHOD_NO_AUTHENTICATION_REQUIRED = 0;
private static final int ADDRESS_TYPE_IPV4 = 1;
private static final int ADDRESS_TYPE_DOMAIN_NAME = 3;
private static final int COMMAND_CONNECT = 1;
private static final int REPLY_SUCCEEDED = 0;

private static final Logger logger = Logger.getLogger(SocksProxy.class.getName());

private final ExecutorService executor = Executors.newCachedThreadPool(
Util.threadFactory("SocksProxy", false));

private ServerSocket serverSocket;
private AtomicInteger connectionCount = new AtomicInteger();

public void play() throws IOException {
serverSocket = new ServerSocket(0);
executor.execute(new NamedRunnable("SocksProxy %s", serverSocket.getLocalPort()) {
@Override protected void execute() {
try {
while (true) {
Socket socket = serverSocket.accept();
connectionCount.incrementAndGet();
service(socket);
}
} catch (SocketException e) {
logger.info(name + " done accepting connections: " + e.getMessage());
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed unexpectedly", e);
}
}
});
}

public Proxy proxy() {
return new Proxy(Proxy.Type.SOCKS, InetSocketAddress.createUnresolved(
"localhost", serverSocket.getLocalPort()));
}

public int connectionCount() {
return connectionCount.get();
}

public void shutdown() throws Exception {
serverSocket.close();
executor.shutdown();
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
throw new IOException("Gave up waiting for executor to shut down");
}
}

private void service(final Socket client) {
executor.execute(new NamedRunnable("SocksProxy %s", client.getRemoteSocketAddress()) {
@Override protected void execute() {
try {
BufferedSource clientSource = Okio.buffer(Okio.source(client));
BufferedSink clientSink = Okio.buffer(Okio.sink(client));
hello(clientSource, clientSink);
acceptCommand(client.getInetAddress(), clientSource, clientSink);
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
Util.closeQuietly(client);
}
}
});
}

private void hello(BufferedSource clientSource, BufferedSink clientSink) throws IOException {
int version = clientSource.readByte() & 0xff;
int methodCount = clientSource.readByte() & 0xff;
int selectedMethod = METHOD_NONE;

if (version != VERSION_5) {
throw new ProtocolException("unsupported version: " + version);
}

for (int i = 0; i < methodCount; i++) {
int candidateMethod = clientSource.readByte() & 0xff;
if (candidateMethod == METHOD_NO_AUTHENTICATION_REQUIRED) {
selectedMethod = candidateMethod;
}
}

switch (selectedMethod) {
case METHOD_NO_AUTHENTICATION_REQUIRED:
clientSink.writeByte(VERSION_5);
clientSink.writeByte(selectedMethod);
clientSink.emit();
break;

default:
throw new ProtocolException("unsupported method: " + selectedMethod);
}
}

private void acceptCommand(InetAddress fromAddress, BufferedSource fromSource,
BufferedSink fromSink) throws IOException {
// Read the command.
int version = fromSource.readByte() & 0xff;
if (version != VERSION_5) throw new ProtocolException("unexpected version: " + version);
int command = fromSource.readByte() & 0xff;
int reserved = fromSource.readByte() & 0xff;
if (reserved != 0) throw new ProtocolException("unexpected reserved: " + reserved);

int addressType = fromSource.readByte() & 0xff;
InetAddress toAddress;
switch (addressType) {
case ADDRESS_TYPE_IPV4:
toAddress = InetAddress.getByAddress(fromSource.readByteArray(4L));
break;

case ADDRESS_TYPE_DOMAIN_NAME:
int domainNameLength = fromSource.readByte() & 0xff;
String domainName = fromSource.readUtf8(domainNameLength);
toAddress = InetAddress.getByName(domainName);
break;

default:
throw new ProtocolException("unsupported address type: " + 4);
}

int port = fromSource.readShort() & 0xffff;

switch (command) {
case COMMAND_CONNECT:
Socket toSocket = new Socket(toAddress, port);
fromSink.writeByte(VERSION_5);
fromSink.writeByte(REPLY_SUCCEEDED);
fromSink.writeByte(0);

byte[] localAddress = toSocket.getLocalAddress().getAddress();
if (localAddress.length != 4) {
throw new ProtocolException("unexpected address: " + toSocket.getLocalAddress());
}

// Write the reply.
fromSink.writeByte(ADDRESS_TYPE_IPV4);
fromSink.write(localAddress);
fromSink.writeShort(toSocket.getLocalPort());
fromSink.emit();

logger.log(Level.INFO, "SocksProxy connected " + fromAddress + " to " + toAddress);

// Copy sources to sinks in both directions.
BufferedSource toSource = Okio.buffer(Okio.source(toSocket));
BufferedSink toSink = Okio.buffer(Okio.sink(toSocket));
transfer(fromAddress, toAddress, fromSource, toSink);
transfer(fromAddress, toAddress, toSource, fromSink);
break;

default:
throw new ProtocolException("unexpected command: " + command);
}
}

private void transfer(final InetAddress fromAddress, final InetAddress toAddress,
final BufferedSource source, final BufferedSink sink) {
executor.execute(new NamedRunnable("SocksProxy %s to %s", fromAddress, toAddress) {
@Override protected void execute() {
Buffer buffer = new Buffer();
try {
while (true) {
long byteCount = source.read(buffer, 2048L);
if (byteCount == -1L) break;
sink.write(buffer, byteCount);
sink.emit();
}
} catch (SocketException e) {
logger.info(name + " done: " + e.getMessage());
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}

try {
source.close();
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}

try {
sink.close();
} catch (IOException e) {
logger.log(Level.WARNING, name + " failed", e);
}
}
});
}
}
88 changes: 88 additions & 0 deletions okhttp-tests/src/test/java/com/squareup/okhttp/SocksProxyTest.java
@@ -0,0 +1,88 @@
/*
* Copyright (C) 2014 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.squareup.okhttp;

import com.squareup.okhttp.mockwebserver.MockResponse;
import com.squareup.okhttp.mockwebserver.MockWebServer;
import java.io.IOException;
import java.net.Proxy;
import java.net.ProxySelector;
import java.net.SocketAddress;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import static org.junit.Assert.assertEquals;

public final class SocksProxyTest {
private final SocksProxy socksProxy = new SocksProxy();
private final MockWebServer server = new MockWebServer();

@Before public void setUp() throws Exception {
server.play();
socksProxy.play();
}

@After public void tearDown() throws Exception {
server.shutdown();
socksProxy.shutdown();
}

@Test public void proxy() throws Exception {
server.enqueue(new MockResponse().setBody("abc"));
server.enqueue(new MockResponse().setBody("def"));

OkHttpClient client = new OkHttpClient()
.setProxy(socksProxy.proxy());

Request request1 = new Request.Builder().url(server.getUrl("/")).build();
Response response1 = client.newCall(request1).execute();
assertEquals("abc", response1.body().string());

Request request2 = new Request.Builder().url(server.getUrl("/")).build();
Response response2 = client.newCall(request2).execute();
assertEquals("def", response2.body().string());

// The HTTP calls should share a single connection.
assertEquals(1, socksProxy.connectionCount());
}

@Test public void proxySelector() throws Exception {
server.enqueue(new MockResponse().setBody("abc"));

ProxySelector proxySelector = new ProxySelector() {
@Override public List<Proxy> select(URI uri) {
return Collections.singletonList(socksProxy.proxy());
}

@Override public void connectFailed(URI uri, SocketAddress socketAddress, IOException e) {
throw new AssertionError();
}
};

OkHttpClient client = new OkHttpClient()
.setProxySelector(proxySelector);

Request request = new Request.Builder().url(server.getUrl("/")).build();
Response response = client.newCall(request).execute();
assertEquals("abc", response.body().string());

assertEquals(1, socksProxy.connectionCount());
}
}
Expand Up @@ -20,7 +20,7 @@
* Runnable implementation which always sets its thread name.
*/
public abstract class NamedRunnable implements Runnable {
private final String name;
protected final String name;

public NamedRunnable(String format, Object... args) {
this.name = String.format(format, args);
Expand Down
Expand Up @@ -248,7 +248,7 @@ private void resetNextInetSocketAddress(Proxy proxy) throws UnknownHostException

String socketHost;
int socketPort;
if (proxy.type() == Proxy.Type.DIRECT) {
if (proxy.type() == Proxy.Type.DIRECT || proxy.type() == Proxy.Type.SOCKS) {
socketHost = address.getUriHost();
socketPort = getEffectivePort(uri);
} else {
Expand Down

0 comments on commit 51846fe

Please sign in to comment.