Skip to content

Commit

Permalink
support hosts
Browse files Browse the repository at this point in the history
  • Loading branch information
ayanamist committed Mar 15, 2019
1 parent 9a08abc commit f32e85b
Showing 1 changed file with 62 additions and 9 deletions.
71 changes: 62 additions & 9 deletions core/src/main/java/com/github/shadowsocks/net/LocalDnsServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ package com.github.shadowsocks.net

import android.util.Log
import com.crashlytics.android.Crashlytics
import com.github.shadowsocks.Core
import com.github.shadowsocks.utils.parseNumericAddress
import com.github.shadowsocks.utils.printLog
import kotlinx.coroutines.*
import org.xbill.DNS.*
import java.io.EOFException
import java.io.File
import java.io.IOException
import java.net.*
import java.nio.ByteBuffer
Expand Down Expand Up @@ -55,6 +58,8 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
var remoteDomainMatcher: Regex? = null
var localIpMatcher: List<Subnet> = emptyList()

private val hostsMap: Map<String, List<InetAddress>> = readHosts()

companion object {
private const val TAG = "LocalDnsServer"
private const val TIMEOUT = 10_000L
Expand All @@ -65,12 +70,25 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
private const val TTL = 120L
private const val UDP_PACKET_SIZE = 512

private val hostsDelimiter = "[ \t]".toRegex()

private fun prepareDnsResponse(request: Message) = Message(request.header.id).apply {
header.setFlag(Flags.QR.toInt()) // this is a response
if (request.header.getFlag(Flags.RD.toInt())) header.setFlag(Flags.RD.toInt())
request.question?.also { addRecord(it, Section.QUESTION) }
}

private fun prepareDnsResponseWithResults(request: Message, results: Array<InetAddress>): ByteBuffer =
ByteBuffer.wrap(prepareDnsResponse(request).apply {
header.setFlag(Flags.RA.toInt()) // recursion available
for (address in results) addRecord(when (address) {
is Inet4Address -> ARecord(question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
}.toWire())
}

private val monitor = ChannelMonitor()

private val job = SupervisorJob()
Expand Down Expand Up @@ -102,10 +120,16 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
return supervisorScope {
val remote = async { withTimeout(TIMEOUT) { forward(packet) } }
try {
if (forwardOnly || request.header.opcode != Opcode.QUERY) return@supervisorScope remote.await()
if (request.header.opcode != Opcode.QUERY) return@supervisorScope remote.await()
val question = request.question
if (question?.type != Type.A) return@supervisorScope remote.await()
val host = question.name.toString(true)
val hostsResults = hostsResolver(host)
if (hostsResults.isNotEmpty()) {
remote.cancel()
return@supervisorScope prepareDnsResponseWithResults(request, hostsResults)
}
if (forwardOnly) return@supervisorScope remote.await()
if (remoteDomainMatcher?.containsMatchIn(host) == true) return@supervisorScope remote.await()
val localResults = try {
withTimeout(TIMEOUT) { GlobalScope.async(Dispatchers.IO) { localResolver(host) }.await() }
Expand All @@ -118,14 +142,7 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
if (localResults.isEmpty()) return@supervisorScope remote.await()
if (localIpMatcher.isEmpty() || localIpMatcher.any { subnet -> localResults.any(subnet::matches) }) {
remote.cancel()
ByteBuffer.wrap(prepareDnsResponse(request).apply {
header.setFlag(Flags.RA.toInt()) // recursion available
for (address in localResults) addRecord(when (address) {
is Inet4Address -> ARecord(question.name, DClass.IN, TTL, address)
is Inet6Address -> AAAARecord(question.name, DClass.IN, TTL, address)
else -> throw IllegalStateException("Unsupported address $address")
}, Section.ANSWER)
}.toWire())
prepareDnsResponseWithResults(request, localResults)
} else remote.await()
} catch (e: Exception) {
remote.cancel()
Expand All @@ -142,6 +159,42 @@ class LocalDnsServer(private val localResolver: suspend (String) -> Array<InetAd
}
}

private fun hostsResolver(h: String): Array<InetAddress> =
this.hostsMap[h]?.toTypedArray() ?: emptyArray()

private fun readHosts(): Map<String, List<InetAddress>> {
val hostsMap: MutableMap<String, MutableSet<InetAddress>> = HashMap()
try {
val hostsFile = File(Core.deviceStorage.getExternalFilesDir(null), "hosts")
hostsFile.createNewFile()
hostsFile.forEachLine {
val line = it.substringBefore('#')
if (line.isEmpty()) {
return@forEachLine
}

val splitted = line.split(hostsDelimiter)
if (splitted.size < 2) {
return@forEachLine
}
val addr = splitted[0].parseNumericAddress() ?: return@forEachLine
for (j in 1 until splitted.size) {
val d = splitted[j]
var el = hostsMap[d]
if (el == null) {
el = HashSet(1)
hostsMap[d] = el
}
el.add(addr)
}
}
return hostsMap.mapValues { it.value.toList() }
} catch (e: IOException) {
printLog(e)
}
return emptyMap()
}

private suspend fun forward(packet: ByteBuffer): ByteBuffer {
packet.position(0) // the packet might have been parsed, reset to beginning
return if (tcp) SocketChannel.open().use { channel ->
Expand Down

0 comments on commit f32e85b

Please sign in to comment.