From fa3941e6da60be5086f213408eb93b050a5e0aae Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Jun 2023 16:19:18 +0100 Subject: [PATCH 1/6] Suspend support in interceptors --- build.gradle.kts | 2 +- .../kotlin/okhttp3/SuspendingInterceptor.kt | 11 + .../jvmTest/kotlin/okhttp3/InterceptorTest.kt | 214 ++++++++++++++++++ .../src/jvmMain/kotlin/okhttp3/Interceptor.kt | 2 + .../internal/http/InterceptorCallFactory.kt | 25 ++ .../internal/http/RealInterceptorChain.kt | 3 + .../java/okhttp3/KotlinSourceModernTest.kt | 2 + 7 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt create mode 100644 okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt create mode 100644 okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt diff --git a/build.gradle.kts b/build.gradle.kts index 18e27d042731..278f0b25409a 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -139,7 +139,7 @@ subprojects { } } - val platform = System.getProperty("okhttp.platform", "jdk9") + val platform = System.getProperty("okhttp.platform", "bouncycastle") val testJavaVersion = System.getProperty("test.java.version", "11").toInt() val testRuntimeOnly: Configuration by configurations.getting diff --git a/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt b/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt new file mode 100644 index 000000000000..831862e05da0 --- /dev/null +++ b/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt @@ -0,0 +1,11 @@ +package okhttp3; + +import kotlinx.coroutines.runBlocking + +abstract class SuspendingInterceptor: Interceptor { + override fun intercept(chain: Interceptor.Chain): Response = runBlocking { + interceptAsync(chain) + } + + abstract suspend fun interceptAsync(chain: Interceptor.Chain): Response +} diff --git a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt new file mode 100644 index 000000000000..2f598e6565c6 --- /dev/null +++ b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2022 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +@file:OptIn(ExperimentalCoroutinesApi::class) + +package okhttp3 + +import assertk.assertThat +import assertk.assertions.isEqualTo +import java.util.concurrent.Executors +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.junit5.internal.MockWebServerExtension +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.RegisterExtension + +@ExtendWith(MockWebServerExtension::class) +class InterceptorTest { + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() + + private lateinit var server: MockWebServer + + val request by lazy { Request(server.url("/")) } + + @BeforeEach + fun setup(server: MockWebServer) { + this.server = server + } + + @Test + fun asyncCallTest() { + runTest { + server.enqueue(MockResponse(body = "failed", code = 401)) + server.enqueue(MockResponse(body = "token")) + server.enqueue(MockResponse(body = "abc")) + + val interceptor = Interceptor { + val response = it.proceed(it.request()) + + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() + + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) + val token = runBlocking { + val tokenResponse = call.executeAsync() + withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + } + + check(token == "token") + + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + + secondResponse + } else { + response + } + } + + + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() + + val call = client.newCall(request) + + val tokenResponse = call.executeAsync() + val body = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + + assertThat(body).isEqualTo("abc") + } + } + + @Test + fun syncCallTest() { + server.enqueue(MockResponse(body = "failed", code = 401)) + server.enqueue(MockResponse(body = "token")) + server.enqueue(MockResponse(body = "abc")) + + val interceptor = Interceptor { + val response = it.proceed(it.request()) + + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() + + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) + val token = if (false) + runBlocking { + val tokenResponse = call.executeAsync() + withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + } + else + call.execute().body.string() + + check(token == "token") + + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + + secondResponse + } else { + response + } + } + + + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() + + val call = client.newCall(request) + + val body = call.execute().body.string() + + assertThat(body).isEqualTo("abc") + + } + + @Test + fun asyncInterceptorCallTest() { + runTest { + server.enqueue(MockResponse(body = "failed", code = 401)) + server.enqueue(MockResponse(body = "token")) + server.enqueue(MockResponse(body = "abc")) + + val interceptor = object : SuspendingInterceptor() { + override suspend fun interceptAsync(it: Interceptor.Chain): Response { + val response = it.proceed(it.request()) + + return if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() + + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) + + val tokenResponse = call.executeAsync() + val token = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + + check(token == "token") + + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + + secondResponse + } else { + response + } + } + } + + + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() + + val call = client.newCall(request) + + val tokenResponse = call.executeAsync() + val body = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + + assertThat(body).isEqualTo("abc") + } + } + +} diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt b/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt index 9c8814d61460..8f3d2613a7c7 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/Interceptor.kt @@ -89,6 +89,8 @@ fun interface Interceptor { fun call(): Call + val callFactory: Call.Factory + fun connectTimeoutMillis(): Int fun withConnectTimeout(timeout: Int, unit: TimeUnit): Chain diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt new file mode 100644 index 000000000000..e04b0a074893 --- /dev/null +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt @@ -0,0 +1,25 @@ +package okhttp3.internal.http + +import okhttp3.Call +import okhttp3.Callback +import okhttp3.Request +import okio.IOException + +class InterceptorCallFactory(val delegate: Call.Factory) : Call.Factory { + override fun newCall(request: Request): Call { + return InterceptorCall(delegate.newCall(request)) + } +} + +class InterceptorCall(val delegate: Call): Call by delegate { + + override fun enqueue(responseCallback: Callback) { + try { + responseCallback.onResponse(this, delegate.execute()) + } catch (ioe: IOException) { + responseCallback.onFailure(this, ioe) + } + } + + override fun clone(): Call = InterceptorCall(delegate.clone()) +} diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt index d7fe4e744cd8..2a3388af82e5 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/RealInterceptorChain.kt @@ -58,6 +58,9 @@ class RealInterceptorChain( override fun connection(): Connection? = exchange?.connection + override val callFactory: Call.Factory + get() = InterceptorCallFactory(call.client) + override fun connectTimeoutMillis(): Int = connectTimeoutMillis override fun withConnectTimeout(timeout: Int, unit: TimeUnit): Interceptor.Chain { diff --git a/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt b/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt index 6e020d10a847..a87c76c36bcb 100644 --- a/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt +++ b/okhttp/src/jvmTest/java/okhttp3/KotlinSourceModernTest.kt @@ -1175,6 +1175,8 @@ class KotlinSourceModernTest { override fun proceed(request: Request): Response = TODO() override fun connection(): Connection? = TODO() override fun call(): Call = TODO() + override val callFactory: Call.Factory + get() = TODO() override fun connectTimeoutMillis(): Int = TODO() override fun withConnectTimeout(timeout: Int, unit: TimeUnit): Interceptor.Chain = TODO() override fun readTimeoutMillis(): Int = TODO() From 307e7290e8b1feb80f69960b837b9857cef097b8 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Jun 2023 16:20:13 +0100 Subject: [PATCH 2/6] Suspend support in interceptors --- build.gradle.kts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build.gradle.kts b/build.gradle.kts index 278f0b25409a..18e27d042731 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -139,7 +139,7 @@ subprojects { } } - val platform = System.getProperty("okhttp.platform", "bouncycastle") + val platform = System.getProperty("okhttp.platform", "jdk9") val testJavaVersion = System.getProperty("test.java.version", "11").toInt() val testRuntimeOnly: Configuration by configurations.getting From 6cf69501ad088585a70836962627cf24427e46ee Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Jun 2023 19:10:21 +0100 Subject: [PATCH 3/6] Use fun interface --- .../kotlin/okhttp3/SuspendingInterceptor.kt | 4 +- .../jvmTest/kotlin/okhttp3/InterceptorTest.kt | 42 +++++++++---------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt b/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt index 831862e05da0..5b81b3abd773 100644 --- a/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt +++ b/okhttp-coroutines/src/jvmMain/kotlin/okhttp3/SuspendingInterceptor.kt @@ -2,10 +2,10 @@ package okhttp3; import kotlinx.coroutines.runBlocking -abstract class SuspendingInterceptor: Interceptor { +fun interface SuspendingInterceptor: Interceptor { override fun intercept(chain: Interceptor.Chain): Response = runBlocking { interceptAsync(chain) } - abstract suspend fun interceptAsync(chain: Interceptor.Chain): Response + suspend fun interceptAsync(chain: Interceptor.Chain): Response } diff --git a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt index 2f598e6565c6..501d26c6f7f5 100644 --- a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt +++ b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt @@ -163,34 +163,32 @@ class InterceptorTest { server.enqueue(MockResponse(body = "token")) server.enqueue(MockResponse(body = "abc")) - val interceptor = object : SuspendingInterceptor() { - override suspend fun interceptAsync(it: Interceptor.Chain): Response { - val response = it.proceed(it.request()) + val interceptor = SuspendingInterceptor { + val response = it.proceed(it.request()) - return if (response.code == 401 && it.request().url.encodedPath != "/token") { - check(response.body.string() == "failed") - response.close() + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() - val tokenRequest = Request(server.url("/token")) - val call = it.callFactory.newCall(tokenRequest) + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) - val tokenResponse = call.executeAsync() - val token = withContext(Dispatchers.IO) { - tokenResponse.body.string() - } + val tokenResponse = call.executeAsync() + val token = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - check(token == "token") + check(token == "token") - val secondResponse = it.proceed( - it.request().newBuilder() - .header("Authorization", token) - .build() - ) + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) - secondResponse - } else { - response - } + secondResponse + } else { + response } } From e01437edfef53042f874d10a4bffb38b243244c9 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Jun 2023 19:45:14 +0100 Subject: [PATCH 4/6] Use fun interface --- .../jvmTest/kotlin/okhttp3/InterceptorTest.kt | 300 +++--- okhttp/api/okhttp.api | 1 + .../okhttp3/internal/connection/RealCall.kt | 916 +++++++++--------- .../internal/http/InterceptorCallFactory.kt | 4 +- 4 files changed, 643 insertions(+), 578 deletions(-) diff --git a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt index 501d26c6f7f5..07dd9a5b2271 100644 --- a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt +++ b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt @@ -21,7 +21,6 @@ package okhttp3 import assertk.assertThat import assertk.assertions.isEqualTo -import java.util.concurrent.Executors import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.runBlocking @@ -34,179 +33,220 @@ import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.extension.RegisterExtension +import java.util.concurrent.Executors @ExtendWith(MockWebServerExtension::class) class InterceptorTest { - @RegisterExtension - val clientTestRule = OkHttpClientTestRule() + @RegisterExtension + val clientTestRule = OkHttpClientTestRule() - private lateinit var server: MockWebServer + private lateinit var server: MockWebServer - val request by lazy { Request(server.url("/")) } + val request by lazy { Request(server.url("/")) } - @BeforeEach - fun setup(server: MockWebServer) { - this.server = server - } + @BeforeEach + fun setup(server: MockWebServer) { + this.server = server - @Test - fun asyncCallTest() { - runTest { - server.enqueue(MockResponse(body = "failed", code = 401)) - server.enqueue(MockResponse(body = "token")) - server.enqueue(MockResponse(body = "abc")) + server.enqueue(MockResponse(body = "failed", code = 401)) + server.enqueue(MockResponse(body = "token")) + server.enqueue(MockResponse(body = "abc")) + } - val interceptor = Interceptor { - val response = it.proceed(it.request()) + @Test + fun asyncCallTest() { + runTest { + val interceptor = Interceptor { + val response = it.proceed(it.request()) + + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() + + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) + val token = runBlocking { + val tokenResponse = call.executeAsync() + withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + } + + check(token == "token") + + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + + secondResponse + } else { + response + } + } - if (response.code == 401 && it.request().url.encodedPath != "/token") { - check(response.body.string() == "failed") - response.close() - val tokenRequest = Request(server.url("/token")) - val call = it.callFactory.newCall(tokenRequest) - val token = runBlocking { - val tokenResponse = call.executeAsync() - withContext(Dispatchers.IO) { - tokenResponse.body.string() - } - } + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() - check(token == "token") + val call = client.newCall(request) - val secondResponse = it.proceed( - it.request().newBuilder() - .header("Authorization", token) - .build() - ) + val tokenResponse = call.executeAsync() + val body = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - secondResponse - } else { - response + assertThat(body).isEqualTo("abc") } - } + } + @Test + fun syncCallTest() { + val interceptor = Interceptor { + val response = it.proceed(it.request()) + + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() + + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) + val token = if (false) + runBlocking { + val tokenResponse = call.executeAsync() + withContext(Dispatchers.IO) { + tokenResponse.body.string() + } + } + else + call.execute().body.string() + + check(token == "token") + + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + + secondResponse + } else { + response + } + } - val client = clientTestRule.newClientBuilder() - .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) - .addInterceptor(interceptor) - .build() + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() - val call = client.newCall(request) + val call = client.newCall(request) - val tokenResponse = call.executeAsync() - val body = withContext(Dispatchers.IO) { - tokenResponse.body.string() - } + val body = call.execute().body.string() - assertThat(body).isEqualTo("abc") + assertThat(body).isEqualTo("abc") } - } - - @Test - fun syncCallTest() { - server.enqueue(MockResponse(body = "failed", code = 401)) - server.enqueue(MockResponse(body = "token")) - server.enqueue(MockResponse(body = "abc")) - - val interceptor = Interceptor { - val response = it.proceed(it.request()) - if (response.code == 401 && it.request().url.encodedPath != "/token") { - check(response.body.string() == "failed") - response.close() + @Test + fun asyncInterceptorCallTest() { + runTest { + val interceptor = SuspendingInterceptor { + val response = it.proceed(it.request()) - val tokenRequest = Request(server.url("/token")) - val call = it.callFactory.newCall(tokenRequest) - val token = if (false) - runBlocking { - val tokenResponse = call.executeAsync() - withContext(Dispatchers.IO) { - tokenResponse.body.string() - } - } - else - call.execute().body.string() + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() - check(token == "token") + val tokenRequest = Request(server.url("/token")) + val call = it.callFactory.newCall(tokenRequest) - val secondResponse = it.proceed( - it.request().newBuilder() - .header("Authorization", token) - .build() - ) + val tokenResponse = call.executeAsync() + val token = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - secondResponse - } else { - response - } - } + check(token == "token") + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) - val client = clientTestRule.newClientBuilder() - .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) - .addInterceptor(interceptor) - .build() + secondResponse + } else { + response + } + } - val call = client.newCall(request) - val body = call.execute().body.string() + val client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() - assertThat(body).isEqualTo("abc") + val call = client.newCall(request) - } + val tokenResponse = call.executeAsync() + val body = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - @Test - fun asyncInterceptorCallTest() { - runTest { - server.enqueue(MockResponse(body = "failed", code = 401)) - server.enqueue(MockResponse(body = "token")) - server.enqueue(MockResponse(body = "abc")) + assertThat(body).isEqualTo("abc") + } + } - val interceptor = SuspendingInterceptor { - val response = it.proceed(it.request()) + @Test + fun asyncInterceptorCallByThreadTest() { + lateinit var client: OkHttpClient - if (response.code == 401 && it.request().url.encodedPath != "/token") { - check(response.body.string() == "failed") - response.close() + runTest { + val interceptor = SuspendingInterceptor { + val response = it.proceed(it.request()) - val tokenRequest = Request(server.url("/token")) - val call = it.callFactory.newCall(tokenRequest) + if (response.code == 401 && it.request().url.encodedPath != "/token") { + check(response.body.string() == "failed") + response.close() - val tokenResponse = call.executeAsync() - val token = withContext(Dispatchers.IO) { - tokenResponse.body.string() - } + val tokenRequest = Request(server.url("/token")) + val call = client.newCall(tokenRequest) - check(token == "token") + val tokenResponse = call.executeAsync() + val token = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - val secondResponse = it.proceed( - it.request().newBuilder() - .header("Authorization", token) - .build() - ) + check(token == "token") - secondResponse - } else { - response - } - } + val secondResponse = it.proceed( + it.request().newBuilder() + .header("Authorization", token) + .build() + ) + secondResponse + } else { + response + } + } - val client = clientTestRule.newClientBuilder() - .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) - .addInterceptor(interceptor) - .build() + client = clientTestRule.newClientBuilder() + .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) + .addInterceptor(interceptor) + .build() - val call = client.newCall(request) + val call = client.newCall(request) - val tokenResponse = call.executeAsync() - val body = withContext(Dispatchers.IO) { - tokenResponse.body.string() - } + val tokenResponse = call.executeAsync() + val body = withContext(Dispatchers.IO) { + tokenResponse.body.string() + } - assertThat(body).isEqualTo("abc") + assertThat(body).isEqualTo("abc") + } } - } } diff --git a/okhttp/api/okhttp.api b/okhttp/api/okhttp.api index 0b2d964c7b5f..7bd540076c4d 100644 --- a/okhttp/api/okhttp.api +++ b/okhttp/api/okhttp.api @@ -762,6 +762,7 @@ public abstract interface class okhttp3/Interceptor$Chain { public abstract fun call ()Lokhttp3/Call; public abstract fun connectTimeoutMillis ()I public abstract fun connection ()Lokhttp3/Connection; + public abstract fun getCallFactory ()Lokhttp3/Call$Factory; public abstract fun proceed (Lokhttp3/Request;)Lokhttp3/Response; public abstract fun readTimeoutMillis ()I public abstract fun request ()Lokhttp3/Request; diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt index 033ee5122348..721b7fe7d644 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/connection/RealCall.kt @@ -43,6 +43,7 @@ import okhttp3.internal.cache.CacheInterceptor import okhttp3.internal.closeQuietly import okhttp3.internal.http.BridgeInterceptor import okhttp3.internal.http.CallServerInterceptor +import okhttp3.internal.http.InterceptorCall import okhttp3.internal.http.RealInterceptorChain import okhttp3.internal.http.RetryAndFollowUpInterceptor import okhttp3.internal.platform.Platform @@ -60,517 +61,540 @@ import okio.Timeout * canceling may break the entire connection. */ class RealCall( - val client: OkHttpClient, - /** The application's original request unadulterated by redirects or auth headers. */ - val originalRequest: Request, - val forWebSocket: Boolean + val client: OkHttpClient, + /** The application's original request unadulterated by redirects or auth headers. */ + val originalRequest: Request, + val forWebSocket: Boolean ) : Call, Cloneable { - private val connectionPool: RealConnectionPool = client.connectionPool.delegate + private val connectionPool: RealConnectionPool = client.connectionPool.delegate - internal val eventListener: EventListener = client.eventListenerFactory.create(this) + internal val eventListener: EventListener = client.eventListenerFactory.create(this) - private val timeout = object : AsyncTimeout() { - override fun timedOut() { - cancel() + private val timeout = object : AsyncTimeout() { + override fun timedOut() { + cancel() + } + }.apply { + timeout(client.callTimeoutMillis.toLong(), MILLISECONDS) } - }.apply { - timeout(client.callTimeoutMillis.toLong(), MILLISECONDS) - } - private val executed = AtomicBoolean() + private val executed = AtomicBoolean() - // These properties are only accessed by the thread executing the call. + // These properties are only accessed by the thread executing the call. - /** Initialized in [callStart]. */ - private var callStackTrace: Any? = null + /** Initialized in [callStart]. */ + private var callStackTrace: Any? = null - /** Finds an exchange to send the next request and receive the next response. */ - private var exchangeFinder: ExchangeFinder? = null + /** Finds an exchange to send the next request and receive the next response. */ + private var exchangeFinder: ExchangeFinder? = null - var connection: RealConnection? = null - private set - private var timeoutEarlyExit = false + var connection: RealConnection? = null + private set + private var timeoutEarlyExit = false - /** - * This is the same value as [exchange], but scoped to the execution of the network interceptors. - * The [exchange] field is assigned to null when its streams end, which may be before or after the - * network interceptors return. - */ - internal var interceptorScopedExchange: Exchange? = null - private set + /** + * This is the same value as [exchange], but scoped to the execution of the network interceptors. + * The [exchange] field is assigned to null when its streams end, which may be before or after the + * network interceptors return. + */ + internal var interceptorScopedExchange: Exchange? = null + private set - // These properties are guarded by this. They are typically only accessed by the thread executing - // the call, but they may be accessed by other threads for duplex requests. + // These properties are guarded by this. They are typically only accessed by the thread executing + // the call, but they may be accessed by other threads for duplex requests. - /** True if this call still has a request body open. */ - private var requestBodyOpen = false + /** True if this call still has a request body open. */ + private var requestBodyOpen = false - /** True if this call still has a response body open. */ - private var responseBodyOpen = false + /** True if this call still has a response body open. */ + private var responseBodyOpen = false - /** True if there are more exchanges expected for this call. */ - private var expectMoreExchanges = true + /** True if there are more exchanges expected for this call. */ + private var expectMoreExchanges = true - // These properties are accessed by canceling threads. Any thread can cancel a call, and once it's - // canceled it's canceled forever. + // These properties are accessed by canceling threads. Any thread can cancel a call, and once it's + // canceled it's canceled forever. - @Volatile private var canceled = false - @Volatile private var exchange: Exchange? = null - internal val plansToCancel = CopyOnWriteArrayList() + @Volatile + private var canceled = false + @Volatile + private var exchange: Exchange? = null + internal val plansToCancel = CopyOnWriteArrayList() - override fun timeout(): Timeout = timeout + override fun timeout(): Timeout = timeout - @SuppressWarnings("CloneDoesntCallSuperClone") // We are a final type & this saves clearing state. - override fun clone(): Call = RealCall(client, originalRequest, forWebSocket) + @SuppressWarnings("CloneDoesntCallSuperClone") // We are a final type & this saves clearing state. + override fun clone(): Call = RealCall(client, originalRequest, forWebSocket) - override fun request(): Request = originalRequest + override fun request(): Request = originalRequest - /** - * Immediately closes the socket connection if it's currently held. Use this to interrupt an - * in-flight request from any thread. It's the caller's responsibility to close the request body - * and response body streams; otherwise resources may be leaked. - * - * This method is safe to be called concurrently, but provides limited guarantees. If a transport - * layer connection has been established (such as a HTTP/2 stream) that is terminated. Otherwise - * if a socket connection is being established, that is terminated. - */ - override fun cancel() { - if (canceled) return // Already canceled. + /** + * Immediately closes the socket connection if it's currently held. Use this to interrupt an + * in-flight request from any thread. It's the caller's responsibility to close the request body + * and response body streams; otherwise resources may be leaked. + * + * This method is safe to be called concurrently, but provides limited guarantees. If a transport + * layer connection has been established (such as a HTTP/2 stream) that is terminated. Otherwise + * if a socket connection is being established, that is terminated. + */ + override fun cancel() { + if (canceled) return // Already canceled. - canceled = true - exchange?.cancel() - for (plan in plansToCancel) { - plan.cancel() - } + canceled = true + exchange?.cancel() + for (plan in plansToCancel) { + plan.cancel() + } - eventListener.canceled(this) - } + eventListener.canceled(this) + } - override fun isCanceled(): Boolean = canceled + override fun isCanceled(): Boolean = canceled - override fun execute(): Response { - check(executed.compareAndSet(false, true)) { "Already Executed" } + override fun execute(): Response { + check(executed.compareAndSet(false, true)) { "Already Executed" } - timeout.enter() - callStart() - try { - client.dispatcher.executed(this) - return getResponseWithInterceptorChain() - } finally { - client.dispatcher.finished(this) - } - } - - override fun enqueue(responseCallback: Callback) { - check(executed.compareAndSet(false, true)) { "Already Executed" } - - callStart() - client.dispatcher.enqueue(AsyncCall(responseCallback)) - } - - override fun isExecuted(): Boolean = executed.get() - - private fun callStart() { - this.callStackTrace = Platform.get().getStackTraceForCloseable("response.body().close()") - eventListener.callStart(this) - } - - @Throws(IOException::class) - internal fun getResponseWithInterceptorChain(): Response { - // Build a full stack of interceptors. - val interceptors = mutableListOf() - interceptors += client.interceptors - interceptors += RetryAndFollowUpInterceptor(client) - interceptors += BridgeInterceptor(client.cookieJar) - interceptors += CacheInterceptor(client.cache) - interceptors += ConnectInterceptor - if (!forWebSocket) { - interceptors += client.networkInterceptors - } - interceptors += CallServerInterceptor(forWebSocket) - - val chain = RealInterceptorChain( - call = this, - interceptors = interceptors, - index = 0, - exchange = null, - request = originalRequest, - connectTimeoutMillis = client.connectTimeoutMillis, - readTimeoutMillis = client.readTimeoutMillis, - writeTimeoutMillis = client.writeTimeoutMillis - ) - - var calledNoMoreExchanges = false - try { - val response = chain.proceed(originalRequest) - if (isCanceled()) { - response.closeQuietly() - throw IOException("Canceled") - } - return response - } catch (e: IOException) { - calledNoMoreExchanges = true - throw noMoreExchanges(e) as Throwable - } finally { - if (!calledNoMoreExchanges) { - noMoreExchanges(null) - } + timeout.enter() + callStart() + try { + client.dispatcher.executed(this) + return getResponseWithInterceptorChain() + } finally { + client.dispatcher.finished(this) + } } - } - - /** - * Prepare for a potential trip through all of this call's network interceptors. This prepares to - * find an exchange to carry the request. - * - * Note that an exchange will not be needed if the request is satisfied by the cache. - * - * @param newRoutePlanner true if this is not a retry and new routing can be performed. - */ - fun enterNetworkInterceptorExchange( - request: Request, - newRoutePlanner: Boolean, - chain: RealInterceptorChain, - ) { - check(interceptorScopedExchange == null) - - synchronized(this) { - check(!responseBodyOpen) { - "cannot make a new request because the previous response is still open: " + - "please call response.close()" - } - check(!requestBodyOpen) + + override fun enqueue(responseCallback: Callback) { + if (isOnDispatcherThread()) { + InterceptorCall(this).enqueue(responseCallback) + } else { + + check(executed.compareAndSet(false, true)) { "Already Executed" } + + callStart() + client.dispatcher.enqueue(AsyncCall(responseCallback)) + } } - if (newRoutePlanner) { - val routePlanner = RealRoutePlanner( - client, - createAddress(request.url), - this, - chain, - connectionListener = connectionPool.connectionListener - ) - this.exchangeFinder = when { - client.fastFallback -> FastFallbackExchangeFinder(routePlanner, client.taskRunner) - else -> SequentialExchangeFinder(routePlanner) - } + private fun isOnDispatcherThread(): Boolean = onCallThread.get() + + override fun isExecuted(): Boolean = executed.get() + + private fun callStart() { + this.callStackTrace = Platform.get().getStackTraceForCloseable("response.body().close()") + eventListener.callStart(this) } - } - - /** Finds a new or pooled connection to carry a forthcoming request and response. */ - internal fun initExchange(chain: RealInterceptorChain): Exchange { - synchronized(this) { - check(expectMoreExchanges) { "released" } - check(!responseBodyOpen) - check(!requestBodyOpen) + + @Throws(IOException::class) + internal fun getResponseWithInterceptorChain(): Response { + // Build a full stack of interceptors. + val interceptors = mutableListOf() + interceptors += client.interceptors + interceptors += RetryAndFollowUpInterceptor(client) + interceptors += BridgeInterceptor(client.cookieJar) + interceptors += CacheInterceptor(client.cache) + interceptors += ConnectInterceptor + if (!forWebSocket) { + interceptors += client.networkInterceptors + } + interceptors += CallServerInterceptor(forWebSocket) + + val chain = RealInterceptorChain( + call = this, + interceptors = interceptors, + index = 0, + exchange = null, + request = originalRequest, + connectTimeoutMillis = client.connectTimeoutMillis, + readTimeoutMillis = client.readTimeoutMillis, + writeTimeoutMillis = client.writeTimeoutMillis + ) + + var calledNoMoreExchanges = false + try { + val response = chain.proceed(originalRequest) + if (isCanceled()) { + response.closeQuietly() + throw IOException("Canceled") + } + return response + } catch (e: IOException) { + calledNoMoreExchanges = true + throw noMoreExchanges(e) as Throwable + } finally { + if (!calledNoMoreExchanges) { + noMoreExchanges(null) + } + } } - val exchangeFinder = this.exchangeFinder!! - val connection = exchangeFinder.find() - val codec = connection.newCodec(client, chain) - val result = Exchange(this, eventListener, exchangeFinder, codec) - this.interceptorScopedExchange = result - this.exchange = result - synchronized(this) { - this.requestBodyOpen = true - this.responseBodyOpen = true + /** + * Prepare for a potential trip through all of this call's network interceptors. This prepares to + * find an exchange to carry the request. + * + * Note that an exchange will not be needed if the request is satisfied by the cache. + * + * @param newRoutePlanner true if this is not a retry and new routing can be performed. + */ + fun enterNetworkInterceptorExchange( + request: Request, + newRoutePlanner: Boolean, + chain: RealInterceptorChain, + ) { + check(interceptorScopedExchange == null) + + synchronized(this) { + check(!responseBodyOpen) { + "cannot make a new request because the previous response is still open: " + + "please call response.close()" + } + check(!requestBodyOpen) + } + + if (newRoutePlanner) { + val routePlanner = RealRoutePlanner( + client, + createAddress(request.url), + this, + chain, + connectionListener = connectionPool.connectionListener + ) + this.exchangeFinder = when { + client.fastFallback -> FastFallbackExchangeFinder(routePlanner, client.taskRunner) + else -> SequentialExchangeFinder(routePlanner) + } + } } - if (canceled) throw IOException("Canceled") - return result - } - - fun acquireConnectionNoEvents(connection: RealConnection) { - connection.assertThreadHoldsLock() - - check(this.connection == null) - this.connection = connection - connection.calls.add(CallReference(this, callStackTrace)) - } - - /** - * Releases resources held with the request or response of [exchange]. This should be called when - * the request completes normally or when it fails due to an exception, in which case [e] should - * be non-null. - * - * If the exchange was canceled or timed out, this will wrap [e] in an exception that provides - * that additional context. Otherwise [e] is returned as-is. - */ - internal fun messageDone( - exchange: Exchange, - requestDone: Boolean, - responseDone: Boolean, - e: E - ): E { - if (exchange != this.exchange) return e // This exchange was detached violently! - - var bothStreamsDone = false - var callDone = false - synchronized(this) { - if (requestDone && requestBodyOpen || responseDone && responseBodyOpen) { - if (requestDone) requestBodyOpen = false - if (responseDone) responseBodyOpen = false - bothStreamsDone = !requestBodyOpen && !responseBodyOpen - callDone = !requestBodyOpen && !responseBodyOpen && !expectMoreExchanges - } + /** Finds a new or pooled connection to carry a forthcoming request and response. */ + internal fun initExchange(chain: RealInterceptorChain): Exchange { + synchronized(this) { + check(expectMoreExchanges) { "released" } + check(!responseBodyOpen) + check(!requestBodyOpen) + } + + val exchangeFinder = this.exchangeFinder!! + val connection = exchangeFinder.find() + val codec = connection.newCodec(client, chain) + val result = Exchange(this, eventListener, exchangeFinder, codec) + this.interceptorScopedExchange = result + this.exchange = result + synchronized(this) { + this.requestBodyOpen = true + this.responseBodyOpen = true + } + + if (canceled) throw IOException("Canceled") + return result } - if (bothStreamsDone) { - this.exchange = null - this.connection?.incrementSuccessCount() + fun acquireConnectionNoEvents(connection: RealConnection) { + connection.assertThreadHoldsLock() + + check(this.connection == null) + this.connection = connection + connection.calls.add(CallReference(this, callStackTrace)) } - if (callDone) { - return callDone(e) + /** + * Releases resources held with the request or response of [exchange]. This should be called when + * the request completes normally or when it fails due to an exception, in which case [e] should + * be non-null. + * + * If the exchange was canceled or timed out, this will wrap [e] in an exception that provides + * that additional context. Otherwise [e] is returned as-is. + */ + internal fun messageDone( + exchange: Exchange, + requestDone: Boolean, + responseDone: Boolean, + e: E + ): E { + if (exchange != this.exchange) return e // This exchange was detached violently! + + var bothStreamsDone = false + var callDone = false + synchronized(this) { + if (requestDone && requestBodyOpen || responseDone && responseBodyOpen) { + if (requestDone) requestBodyOpen = false + if (responseDone) responseBodyOpen = false + bothStreamsDone = !requestBodyOpen && !responseBodyOpen + callDone = !requestBodyOpen && !responseBodyOpen && !expectMoreExchanges + } + } + + if (bothStreamsDone) { + this.exchange = null + this.connection?.incrementSuccessCount() + } + + if (callDone) { + return callDone(e) + } + + return e } - return e - } + internal fun noMoreExchanges(e: IOException?): IOException? { + var callDone = false + synchronized(this) { + if (expectMoreExchanges) { + expectMoreExchanges = false + callDone = !requestBodyOpen && !responseBodyOpen + } + } + + if (callDone) { + return callDone(e) + } - internal fun noMoreExchanges(e: IOException?): IOException? { - var callDone = false - synchronized(this) { - if (expectMoreExchanges) { - expectMoreExchanges = false - callDone = !requestBodyOpen && !responseBodyOpen - } + return e } - if (callDone) { - return callDone(e) + /** + * Complete this call. This should be called once these properties are all false: + * [requestBodyOpen], [responseBodyOpen], and [expectMoreExchanges]. + * + * This will release the connection if it is still held. + * + * It will also notify the listener that the call completed; either successfully or + * unsuccessfully. + * + * If the call was canceled or timed out, this will wrap [e] in an exception that provides that + * additional context. Otherwise [e] is returned as-is. + */ + private fun callDone(e: E): E { + assertThreadDoesntHoldLock() + + val connection = this.connection + if (connection != null) { + connection.assertThreadDoesntHoldLock() + val toClose: Socket? = synchronized(connection) { + // Sets this.connection to null. + releaseConnectionNoEvents() + } + if (this.connection == null) { + toClose?.closeQuietly() + eventListener.connectionReleased(this, connection) + connection.connectionListener.connectionReleased(connection, this) + if (toClose != null) { + connection.connectionListener.connectionClosed(connection) + } + } else { + check(toClose == null) // If we still have a connection we shouldn't be closing any sockets. + } + } + + val result = timeoutExit(e) + if (e != null) { + eventListener.callFailed(this, result!!) + } else { + eventListener.callEnd(this) + } + return result } - return e - } - - /** - * Complete this call. This should be called once these properties are all false: - * [requestBodyOpen], [responseBodyOpen], and [expectMoreExchanges]. - * - * This will release the connection if it is still held. - * - * It will also notify the listener that the call completed; either successfully or - * unsuccessfully. - * - * If the call was canceled or timed out, this will wrap [e] in an exception that provides that - * additional context. Otherwise [e] is returned as-is. - */ - private fun callDone(e: E): E { - assertThreadDoesntHoldLock() - - val connection = this.connection - if (connection != null) { - connection.assertThreadDoesntHoldLock() - val toClose: Socket? = synchronized(connection) { - // Sets this.connection to null. - releaseConnectionNoEvents() - } - if (this.connection == null) { - toClose?.closeQuietly() - eventListener.connectionReleased(this, connection) - connection.connectionListener.connectionReleased(connection, this) - if (toClose != null) { - connection.connectionListener.connectionClosed(connection) + /** + * Remove this call from the connection's list of allocations. Returns a socket that the caller + * should close. + */ + internal fun releaseConnectionNoEvents(): Socket? { + val connection = this.connection!! + connection.assertThreadHoldsLock() + + val calls = connection.calls + val index = calls.indexOfFirst { it.get() == this@RealCall } + check(index != -1) + + calls.removeAt(index) + this.connection = null + + if (calls.isEmpty()) { + connection.idleAtNs = System.nanoTime() + if (connectionPool.connectionBecameIdle(connection)) { + return connection.socket() + } } - } else { - check(toClose == null) // If we still have a connection we shouldn't be closing any sockets. - } + + return null } - val result = timeoutExit(e) - if (e != null) { - eventListener.callFailed(this, result!!) - } else { - eventListener.callEnd(this) + private fun timeoutExit(cause: E): E { + if (timeoutEarlyExit) return cause + if (!timeout.exit()) return cause + + val e = InterruptedIOException("timeout") + if (cause != null) e.initCause(cause) + @Suppress("UNCHECKED_CAST") // E is either IOException or IOException? + return e as E } - return result - } - - /** - * Remove this call from the connection's list of allocations. Returns a socket that the caller - * should close. - */ - internal fun releaseConnectionNoEvents(): Socket? { - val connection = this.connection!! - connection.assertThreadHoldsLock() - - val calls = connection.calls - val index = calls.indexOfFirst { it.get() == this@RealCall } - check(index != -1) - - calls.removeAt(index) - this.connection = null - - if (calls.isEmpty()) { - connection.idleAtNs = System.nanoTime() - if (connectionPool.connectionBecameIdle(connection)) { - return connection.socket() - } + + /** + * Stops applying the timeout before the call is entirely complete. This is used for WebSockets + * and duplex calls where the timeout only applies to the initial setup. + */ + fun timeoutEarlyExit() { + check(!timeoutEarlyExit) + timeoutEarlyExit = true + timeout.exit() } - return null - } - - private fun timeoutExit(cause: E): E { - if (timeoutEarlyExit) return cause - if (!timeout.exit()) return cause - - val e = InterruptedIOException("timeout") - if (cause != null) e.initCause(cause) - @Suppress("UNCHECKED_CAST") // E is either IOException or IOException? - return e as E - } - - /** - * Stops applying the timeout before the call is entirely complete. This is used for WebSockets - * and duplex calls where the timeout only applies to the initial setup. - */ - fun timeoutEarlyExit() { - check(!timeoutEarlyExit) - timeoutEarlyExit = true - timeout.exit() - } - - /** - * @param closeExchange true if the current exchange should be closed because it will not be used. - * This is usually due to either an exception or a retry. - */ - internal fun exitNetworkInterceptorExchange(closeExchange: Boolean) { - synchronized(this) { - check(expectMoreExchanges) { "released" } + /** + * @param closeExchange true if the current exchange should be closed because it will not be used. + * This is usually due to either an exception or a retry. + */ + internal fun exitNetworkInterceptorExchange(closeExchange: Boolean) { + synchronized(this) { + check(expectMoreExchanges) { "released" } + } + + if (closeExchange) { + exchange?.detachWithViolence() + } + + interceptorScopedExchange = null } - if (closeExchange) { - exchange?.detachWithViolence() + private fun createAddress(url: HttpUrl): Address { + var sslSocketFactory: SSLSocketFactory? = null + var hostnameVerifier: HostnameVerifier? = null + var certificatePinner: CertificatePinner? = null + if (url.isHttps) { + sslSocketFactory = client.sslSocketFactory + hostnameVerifier = client.hostnameVerifier + certificatePinner = client.certificatePinner + } + + return Address( + uriHost = url.host, + uriPort = url.port, + dns = client.dns, + socketFactory = client.socketFactory, + sslSocketFactory = sslSocketFactory, + hostnameVerifier = hostnameVerifier, + certificatePinner = certificatePinner, + proxyAuthenticator = client.proxyAuthenticator, + proxy = client.proxy, + protocols = client.protocols, + connectionSpecs = client.connectionSpecs, + proxySelector = client.proxySelector + ) } - interceptorScopedExchange = null - } - - private fun createAddress(url: HttpUrl): Address { - var sslSocketFactory: SSLSocketFactory? = null - var hostnameVerifier: HostnameVerifier? = null - var certificatePinner: CertificatePinner? = null - if (url.isHttps) { - sslSocketFactory = client.sslSocketFactory - hostnameVerifier = client.hostnameVerifier - certificatePinner = client.certificatePinner + fun retryAfterFailure(): Boolean { + return exchange?.hasFailure == true && + exchangeFinder!!.routePlanner.hasNext(exchange?.connection) } - return Address( - uriHost = url.host, - uriPort = url.port, - dns = client.dns, - socketFactory = client.socketFactory, - sslSocketFactory = sslSocketFactory, - hostnameVerifier = hostnameVerifier, - certificatePinner = certificatePinner, - proxyAuthenticator = client.proxyAuthenticator, - proxy = client.proxy, - protocols = client.protocols, - connectionSpecs = client.connectionSpecs, - proxySelector = client.proxySelector - ) - } - - fun retryAfterFailure(): Boolean { - return exchange?.hasFailure == true && - exchangeFinder!!.routePlanner.hasNext(exchange?.connection) - } - - /** - * Returns a string that describes this call. Doesn't include a full URL as that might contain - * sensitive information. - */ - private fun toLoggableString(): String { - return ((if (isCanceled()) "canceled " else "") + - (if (forWebSocket) "web socket" else "call") + - " to " + redactedUrl()) - } - - internal fun redactedUrl(): String = originalRequest.url.redact() - - inner class AsyncCall( - private val responseCallback: Callback - ) : Runnable { - @Volatile var callsPerHost = AtomicInteger(0) - private set - - fun reuseCallsPerHostFrom(other: AsyncCall) { - this.callsPerHost = other.callsPerHost + /** + * Returns a string that describes this call. Doesn't include a full URL as that might contain + * sensitive information. + */ + private fun toLoggableString(): String { + return ((if (isCanceled()) "canceled " else "") + + (if (forWebSocket) "web socket" else "call") + + " to " + redactedUrl()) } - val host: String - get() = originalRequest.url.host + internal fun redactedUrl(): String = originalRequest.url.redact() - val request: Request - get() = originalRequest + inner class AsyncCall( + private val responseCallback: Callback + ) : Runnable { + @Volatile + var callsPerHost = AtomicInteger(0) + private set - val call: RealCall - get() = this@RealCall + fun reuseCallsPerHostFrom(other: AsyncCall) { + this.callsPerHost = other.callsPerHost + } - /** - * Attempt to enqueue this async call on [executorService]. This will attempt to clean up - * if the executor has been shut down by reporting the call as failed. - */ - fun executeOn(executorService: ExecutorService) { - client.dispatcher.assertThreadDoesntHoldLock() - - var success = false - try { - executorService.execute(this) - success = true - } catch (e: RejectedExecutionException) { - failRejected(e) - } finally { - if (!success) { - client.dispatcher.finished(this) // This call is no longer running! + val host: String + get() = originalRequest.url.host + + val request: Request + get() = originalRequest + + val call: RealCall + get() = this@RealCall + + /** + * Attempt to enqueue this async call on [executorService]. This will attempt to clean up + * if the executor has been shut down by reporting the call as failed. + */ + fun executeOn(executorService: ExecutorService) { + client.dispatcher.assertThreadDoesntHoldLock() + + var success = false + try { + executorService.execute { + onCallThread.set(true) + try { + this.run() + } finally { + onCallThread.set(false) + } + } + success = true + } catch (e: RejectedExecutionException) { + failRejected(e) + } finally { + if (!success) { + client.dispatcher.finished(this) // This call is no longer running! + } + } } - } - } - internal fun failRejected(e: RejectedExecutionException? = null) { - val ioException = InterruptedIOException("executor rejected") - ioException.initCause(e) - noMoreExchanges(ioException) - responseCallback.onFailure(this@RealCall, ioException) - } + internal fun failRejected(e: RejectedExecutionException? = null) { + val ioException = InterruptedIOException("executor rejected") + ioException.initCause(e) + noMoreExchanges(ioException) + responseCallback.onFailure(this@RealCall, ioException) + } - override fun run() { - threadName("OkHttp ${redactedUrl()}") { - var signalledCallback = false - timeout.enter() - try { - val response = getResponseWithInterceptorChain() - signalledCallback = true - responseCallback.onResponse(this@RealCall, response) - } catch (e: IOException) { - if (signalledCallback) { - // Do not signal the callback twice! - Platform.get().log("Callback failure for ${toLoggableString()}", Platform.INFO, e) - } else { - responseCallback.onFailure(this@RealCall, e) - } - } catch (t: Throwable) { - cancel() - if (!signalledCallback) { - val canceledException = IOException("canceled due to $t") - canceledException.addSuppressed(t) - responseCallback.onFailure(this@RealCall, canceledException) - } - throw t - } finally { - client.dispatcher.finished(this) + override fun run() { + threadName("OkHttp ${redactedUrl()}") { + var signalledCallback = false + timeout.enter() + try { + val response = getResponseWithInterceptorChain() + signalledCallback = true + responseCallback.onResponse(this@RealCall, response) + } catch (e: IOException) { + if (signalledCallback) { + // Do not signal the callback twice! + Platform.get().log("Callback failure for ${toLoggableString()}", Platform.INFO, e) + } else { + responseCallback.onFailure(this@RealCall, e) + } + } catch (t: Throwable) { + cancel() + if (!signalledCallback) { + val canceledException = IOException("canceled due to $t") + canceledException.addSuppressed(t) + responseCallback.onFailure(this@RealCall, canceledException) + } + throw t + } finally { + client.dispatcher.finished(this) + } + } } - } } - } - internal class CallReference( - referent: RealCall, - /** - * Captures the stack trace at the time the Call is executed or enqueued. This is helpful for - * identifying the origin of connection leaks. - */ - val callStackTrace: Any? - ) : WeakReference(referent) + internal class CallReference( + referent: RealCall, + /** + * Captures the stack trace at the time the Call is executed or enqueued. This is helpful for + * identifying the origin of connection leaks. + */ + val callStackTrace: Any? + ) : WeakReference(referent) + + companion object { + val onCallThread = object : ThreadLocal() { + override fun initialValue(): Boolean = false + } + } } diff --git a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt index e04b0a074893..6c1ca85bf479 100644 --- a/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt +++ b/okhttp/src/jvmMain/kotlin/okhttp3/internal/http/InterceptorCallFactory.kt @@ -15,9 +15,9 @@ class InterceptorCall(val delegate: Call): Call by delegate { override fun enqueue(responseCallback: Callback) { try { - responseCallback.onResponse(this, delegate.execute()) + responseCallback.onResponse(delegate, delegate.execute()) } catch (ioe: IOException) { - responseCallback.onFailure(this, ioe) + responseCallback.onFailure(delegate, ioe) } } From db9ab939a69811280569664f9ef19cbaa717e3b8 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Sat, 24 Jun 2023 19:55:35 +0100 Subject: [PATCH 5/6] Use fun interface --- okhttp-coroutines/api/okhttp-coroutines.api | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/okhttp-coroutines/api/okhttp-coroutines.api b/okhttp-coroutines/api/okhttp-coroutines.api index 4261ec3bb4b7..8eba75877f10 100644 --- a/okhttp-coroutines/api/okhttp-coroutines.api +++ b/okhttp-coroutines/api/okhttp-coroutines.api @@ -2,3 +2,8 @@ public final class okhttp3/JvmCallExtensionsKt { public static final fun executeAsync (Lokhttp3/Call;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } +public abstract interface class okhttp3/SuspendingInterceptor : okhttp3/Interceptor { + public fun intercept (Lokhttp3/Interceptor$Chain;)Lokhttp3/Response; + public abstract fun interceptAsync (Lokhttp3/Interceptor$Chain;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + From 87cfbca008466c47d87c5e0235a1ff4163db85d5 Mon Sep 17 00:00:00 2001 From: Yuri Schimke Date: Fri, 30 Jun 2023 19:00:41 +0100 Subject: [PATCH 6/6] Enable tls --- .../jvmTest/kotlin/okhttp3/InterceptorTest.kt | 42 +++++++++++++++++-- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt index 07dd9a5b2271..f6caf4675e91 100644 --- a/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt +++ b/okhttp-coroutines/src/jvmTest/kotlin/okhttp3/InterceptorTest.kt @@ -29,10 +29,13 @@ import kotlinx.coroutines.withContext import mockwebserver3.MockResponse import mockwebserver3.MockWebServer import mockwebserver3.junit5.internal.MockWebServerExtension +import okhttp3.tls.internal.TlsUtil import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.extension.RegisterExtension +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import java.util.concurrent.Executors @ExtendWith(MockWebServerExtension::class) @@ -42,12 +45,33 @@ class InterceptorTest { private lateinit var server: MockWebServer + private val clientBuilder: OkHttpClient.Builder = clientTestRule.newClientBuilder() + val request by lazy { Request(server.url("/")) } + // TODO parameterize + val tls: Boolean = true + @BeforeEach fun setup(server: MockWebServer) { this.server = server + clientBuilder.addInterceptor { chain -> + chain.proceed(chain.request()).also { + if (tls) { + check(it.protocol == Protocol.HTTP_2) + check(it.handshake != null) + } else { + check(it.protocol == Protocol.HTTP_1_1) + check(it.handshake == null) + } + } + } + + if (tls) { + enableTls() + } + server.enqueue(MockResponse(body = "failed", code = 401)) server.enqueue(MockResponse(body = "token")) server.enqueue(MockResponse(body = "abc")) @@ -87,7 +111,7 @@ class InterceptorTest { } - val client = clientTestRule.newClientBuilder() + val client = clientBuilder .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) .addInterceptor(interceptor) .build() @@ -138,7 +162,7 @@ class InterceptorTest { } } - val client = clientTestRule.newClientBuilder() + val client = clientBuilder .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) .addInterceptor(interceptor) .build() @@ -183,7 +207,7 @@ class InterceptorTest { } - val client = clientTestRule.newClientBuilder() + val client = clientBuilder .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) .addInterceptor(interceptor) .build() @@ -233,7 +257,7 @@ class InterceptorTest { } } - client = clientTestRule.newClientBuilder() + client = clientBuilder .dispatcher(Dispatcher(Executors.newSingleThreadExecutor())) .addInterceptor(interceptor) .build() @@ -249,4 +273,14 @@ class InterceptorTest { } } + private fun enableTls() { + val certs = TlsUtil.localhost() + + clientBuilder + .sslSocketFactory( + certs.sslSocketFactory(), certs.trustManager + ) + + server.useHttps(certs.sslSocketFactory()) + } }