diff --git a/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt b/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt index cc3872e5e194..726643b872be 100644 --- a/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt +++ b/okhttp-dnsoverhttps/src/main/kotlin/okhttp3/dnsoverhttps/DnsOverHttps.kt @@ -16,11 +16,9 @@ package okhttp3.dnsoverhttps import java.io.IOException -import java.net.HttpURLConnection import java.net.InetAddress import java.net.UnknownHostException import java.util.concurrent.CountDownLatch -import okhttp3.CacheControl import okhttp3.Call import okhttp3.Callback import okhttp3.Dns @@ -72,16 +70,17 @@ class DnsOverHttps internal constructor( @Throws(UnknownHostException::class) private fun lookupHttps(hostname: String): List { - val networkRequests = ArrayList(2) - val failures = ArrayList(2) - val results = ArrayList(5) + val networkRequests = + buildList { + add(client.newCall(buildRequest(hostname, DnsRecordCodec.TYPE_A))) - buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_A) - - if (includeIPv6) { - buildRequest(hostname, networkRequests, results, failures, DnsRecordCodec.TYPE_AAAA) - } + if (includeIPv6) { + add(client.newCall(buildRequest(hostname, DnsRecordCodec.TYPE_AAAA))) + } + } + val failures = ArrayList(2) + val results = ArrayList(5) executeRequests(hostname, networkRequests, results, failures) return results.ifEmpty { @@ -89,21 +88,6 @@ class DnsOverHttps internal constructor( } } - private fun buildRequest( - hostname: String, - networkRequests: MutableList, - results: MutableList, - failures: MutableList, - type: Int, - ) { - val request = buildRequest(hostname, type) - val response = getCacheOnlyResponse(request) - - response?.let { processResponse(it, hostname, results, failures) } ?: networkRequests.add( - client.newCall(request), - ) - } - private fun executeRequests( hostname: String, networkRequests: List, @@ -186,38 +170,6 @@ class DnsOverHttps internal constructor( throw unknownHostException } - private fun getCacheOnlyResponse(request: Request): Response? { - if (client.cache != null) { - try { - // Use the cache without hitting the network first - // 504 code indicates that the Cache is stale - val onlyIfCached = - CacheControl.Builder() - .onlyIfCached() - .build() - - var cacheUrl = request.url - - val cacheRequest = - request.newBuilder() - .cacheControl(onlyIfCached) - .cacheUrlOverride(cacheUrl) - .build() - - val cacheResponse = client.newCall(cacheRequest).execute() - - if (cacheResponse.code != HttpURLConnection.HTTP_GATEWAY_TIMEOUT) { - return cacheResponse - } - } catch (ioe: IOException) { - // Failures are ignored as we can fallback to the network - // and hopefully repopulate the cache. - } - } - - return null - } - @Throws(Exception::class) private fun readResponse( hostname: String, diff --git a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt index bfe84cb34800..34d602b36fe9 100644 --- a/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt +++ b/okhttp-dnsoverhttps/src/test/java/okhttp3/dnsoverhttps/DnsOverHttpsTest.kt @@ -35,6 +35,7 @@ import okhttp3.Cache import okhttp3.Dns import okhttp3.OkHttpClient import okhttp3.Protocol +import okhttp3.RecordingEventListener import okhttp3.testing.PlatformRule import okio.Buffer import okio.ByteString.Companion.decodeHex @@ -53,9 +54,11 @@ class DnsOverHttpsTest { private lateinit var server: MockWebServer private lateinit var dns: Dns private val cacheFs = FakeFileSystem() + private val eventListener = RecordingEventListener() private val bootstrapClient = OkHttpClient.Builder() .protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)) + .eventListener(eventListener) .build() @BeforeEach @@ -194,16 +197,22 @@ class DnsOverHttpsTest { assertThat(recordedRequest.path) .isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ") + assertThat(cacheEvents()).containsExactly("CacheMiss") + result = cachedDns.lookup("google.com") assertThat(server.takeRequest(1, TimeUnit.MILLISECONDS)).isNull() assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) + assertThat(cacheEvents()).containsExactly("CacheHit") + result = cachedDns.lookup("www.google.com") assertThat(result).containsExactly(address("157.240.1.18")) recordedRequest = server.takeRequest() assertThat(recordedRequest.method).isEqualTo("GET") assertThat(recordedRequest.path) .isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAAA3d3dwZnb29nbGUDY29tAAABAAE") + + assertThat(cacheEvents()).containsExactly("CacheMiss") } @Test @@ -231,16 +240,22 @@ class DnsOverHttpsTest { assertThat(recordedRequest.path) .isEqualTo("/lookup?ct") + assertThat(cacheEvents()).containsExactly("CacheMiss") + result = cachedDns.lookup("google.com") assertThat(server.takeRequest(0, TimeUnit.MILLISECONDS)).isNull() assertThat(result).isEqualTo(listOf(address("157.240.1.18"))) + assertThat(cacheEvents()).containsExactly("CacheHit") + result = cachedDns.lookup("www.google.com") assertThat(result).containsExactly(address("157.240.1.18")) recordedRequest = server.takeRequest(0, TimeUnit.MILLISECONDS)!! assertThat(recordedRequest.method).isEqualTo("POST") assertThat(recordedRequest.path) .isEqualTo("/lookup?ct") + + assertThat(cacheEvents()).containsExactly("CacheMiss") } @Test @@ -265,6 +280,9 @@ class DnsOverHttpsTest { assertThat(recordedRequest.path).isEqualTo( "/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ", ) + + assertThat(cacheEvents()).containsExactly("CacheMiss") + Thread.sleep(2000) server.enqueue( dnsResponse( @@ -282,6 +300,14 @@ class DnsOverHttpsTest { assertThat(recordedRequest!!.method).isEqualTo("GET") assertThat(recordedRequest.path) .isEqualTo("/lookup?ct&dns=AAABAAABAAAAAAAABmdvb2dsZQNjb20AAAEAAQ") + + assertThat(cacheEvents()).containsExactly("CacheMiss") + } + + private fun cacheEvents(): List { + return eventListener.recordedEventTypes().filter { it.contains("Cache") }.also { + eventListener.clearAllEvents() + } } private fun dnsResponse(s: String): MockResponse {