Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read kotlin metadata to resolve nullability of suspending functions #3544

Open
wants to merge 17 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 53 additions & 5 deletions retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ import okhttp3.mockwebserver.MockWebServer
import okhttp3.mockwebserver.SocketPolicy.DISCONNECT_AFTER_REQUEST
import okhttp3.mockwebserver.SocketPolicy.NO_RESPONSE
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assert.assertTrue
import org.junit.Assert.fail
import org.junit.Ignore
import org.junit.Assert.*
import org.junit.Rule
import org.junit.Test
import retrofit2.helpers.ToStringConverterFactory
import retrofit2.http.GET
import retrofit2.http.HEAD
import retrofit2.http.Path
import retrofit2.http.Query
import java.io.IOException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
Expand All @@ -43,9 +42,23 @@ import kotlin.coroutines.CoroutineContext
class KotlinSuspendTest {
@get:Rule val server = MockWebServer()

interface Service {
interface SuperService {
@GET("/") suspend fun noBody(@Query("x") arg: Long)
}

interface Service : SuperService {
@GET("/") suspend fun body(): String
@GET("/") suspend fun bodyNullable(): String?
@GET("/") suspend fun noBody()
@GET("/") suspend fun noBody(@Query("x") arg: String)
@GET("/") suspend fun noBody(@Query("x") arg: Int)
@GET("/") suspend fun noBody(@Query("x") arg: Array<Int>)
@GET("/") suspend fun noBody(@Query("x") arg: Array<String>)
@GET("/") suspend fun noBody(@Query("x") arg: IntArray)

@UseExperimental(ExperimentalUnsignedTypes::class)
@GET("/") suspend fun noBody(@Query("x") arg: UInt)

@GET("/") suspend fun response(): Response<String>
@GET("/") suspend fun unit()
@HEAD("/") suspend fun headUnit()
Expand Down Expand Up @@ -124,7 +137,6 @@ class KotlinSuspendTest {
}
}

@Ignore("Not working yet")
@Test fun bodyNullable() {
val retrofit = Retrofit.Builder()
.baseUrl(server.url("/"))
Expand All @@ -138,6 +150,42 @@ class KotlinSuspendTest {
assertThat(body).isNull()
}

@Test fun noBody() {
val retrofit = Retrofit.Builder()
.baseUrl(server.url("/"))
.addConverterFactory(ToStringConverterFactory())
.build()
val example = retrofit.create(Service::class.java)

server.enqueue(MockResponse().setResponseCode(204))

val body = runBlocking { example.noBody(intArrayOf(1)) }
assertThat(body).isEqualTo(Unit)
}

@Test fun signatureMatch() {
val retrofit = Retrofit.Builder()
.baseUrl(server.url("/"))
.addConverterFactory(ToStringConverterFactory())
.build()
val example = retrofit.create(Service::class.java)

repeat(8) {
server.enqueue(MockResponse())
}

runBlocking {
example.noBody()
example.noBody("")
example.noBody(1)
example.noBody(arrayOf(1))
example.noBody(intArrayOf(1))
example.noBody(arrayOf(""))
example.noBody(1u)
example.noBody(1L)
}
}

@Test fun response() {
val retrofit = Retrofit.Builder()
.baseUrl(server.url("/"))
Expand Down
5 changes: 1 addition & 4 deletions retrofit/src/main/java/retrofit2/HttpServiceMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ static <ResponseT, ReturnT> HttpServiceMethod<ResponseT, ReturnT> parseAnnotatio
continuationWantsResponse = true;
} else {
continuationIsUnit = Utils.isUnit(responseType);
// TODO figure out if type is nullable or not
// Metadata metadata = method.getDeclaringClass().getAnnotation(Metadata.class)
// Find the entry for method
// Determine if return type is nullable or not
continuationBodyNullable = KotlinMetadata.isReturnTypeNullable(method);
}

adapterType = new Utils.ParameterizedTypeImpl(null, Call.class, responseType);
Expand Down
144 changes: 144 additions & 0 deletions retrofit/src/main/java/retrofit2/KotlinMetadata.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright (C) 2021 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package retrofit2

import retrofit2.kotlin.metadata.deserialization.BitEncoding
import retrofit2.kotlin.metadata.deserialization.ByteArrayInput
import retrofit2.kotlin.metadata.deserialization.JvmMetadataVersion
import retrofit2.kotlin.metadata.deserialization.MetadataParser
import retrofit2.kotlin.metadata.deserialization.ProtobufReader
import java.lang.reflect.Method
import java.util.concurrent.ConcurrentHashMap
import kotlin.coroutines.Continuation

object KotlinMetadata {

data class Function(val signature: String, val returnType: ReturnType)
data class ReturnType(val isNullable: Boolean, val isUnit: Boolean)

private val kotlinFunctionsMap = ConcurrentHashMap<Class<*>, List<Function>>()

/**
* This helps to parse kotlin metadata of a compiled class to find out the nullability of a suspending method return
* type.
*
* For example a suspending method with following declaration:
*
* ```
* @GET("/") suspend fun foo(@Query("x") arg: IntArray): String?
* ```
*
* Will be compiled as a method returning [Object] and with injected [Continuation] argument with following java
* method:
*
* ```
* public Object foo(int[], Continuation)
* ```
*
* The information about the return type and its nullability is stored in a [Metadata] annotation of the containing
* class. We process the metadata of a class the first time [isReturnTypeNullable] is called on one of its methods.
* We extract necessary information about all of its methods and store this info in cache, so each class is
* processed only once. Then we try to match the currently inspected [Method] to one extracted from metadata by
* comparing their signatures.
*
* We use the method signature because:
* - it uniquely identifies a method
* - it requires just comparing 2 strings
* - it is already stored in kotlin metadata
* - it is trivial to create one from java reflection's [Method] instance
*
* For example the previous method's signature would be:
*
* ```
* foo([ILkotlin/coroutines/Continuation;)Ljava/lang/Object;
* ```
*/
@JvmStatic fun isReturnTypeNullable(method: Method): Boolean {
if (method.declaringClass.getAnnotation(Metadata::class.java) == null) return false

val javaMethodSignature = method.createSignature()
val kotlinFunctions = loadKotlinFunctions(method.declaringClass)
val candidates = kotlinFunctions.filter { it.signature == javaMethodSignature }

require(candidates.isNotEmpty()) { "No match found in metadata for '${method}'" }
require(candidates.size == 1) { "Multiple function matches found in metadata for '${method}'" }
val match = candidates.first()

return match.returnType.isNullable || match.returnType.isUnit
}

private fun Method.createSignature() = buildString {
append(name)
append('(')

parameterTypes.forEach {
append(it.typeToSignature())
}

append(')')

append(returnType.typeToSignature())
}

private fun loadKotlinFunctions(clazz: Class<*>): List<Function> {
var result = kotlinFunctionsMap[clazz]
if (result != null) return result

synchronized(kotlinFunctionsMap) {
result = kotlinFunctionsMap[clazz]
if (result == null) {
result = readFunctionsFromMetadata(clazz)
}
}

return result!!
}

private fun readFunctionsFromMetadata(clazz: Class<*>): List<Function> {
val metadataAnnotation = clazz.getAnnotation(Metadata::class.java)

val isStrictSemantics = (metadataAnnotation.extraInt and (1 shl 3)) != 0
val isCompatible = JvmMetadataVersion(metadataAnnotation.metadataVersion, isStrictSemantics).isCompatible()

require(isCompatible) { "Metadata version not compatible" }
require(metadataAnnotation.kind == 1) { "Metadata of wrong kind: ${metadataAnnotation.kind}" }
require(metadataAnnotation.data1.isNotEmpty()) { "data1 must not be empty" }

val bytes: ByteArray = BitEncoding.decodeBytes(metadataAnnotation.data1)
val reader = ProtobufReader(ByteArrayInput(bytes))
val parser = MetadataParser(reader, metadataAnnotation.data2)

return parser.parse()
}

private fun Class<*>.typeToSignature() = when {
isPrimitive -> javaTypesMap[name]
isArray -> name.replace('.', '/')
else -> "L${name.replace('.', '/')};"
}

private val javaTypesMap = mapOf(
"int" to "I",
"long" to "J",
"boolean" to "Z",
"byte" to "B",
"char" to "C",
"float" to "F",
"double" to "D",
"short" to "S",
"void" to "V"
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (C) 2021 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package retrofit2.kotlin.metadata.deserialization

/**
* This file was adapted from https://github.com/JetBrains/kotlin/blob/af18b10da9d1e20b1b35831a3fb5e508048a2576/core/metadata/src/org/jetbrains/kotlin/metadata/deserialization/BinaryVersion.kt
* by removing unused parts.
*/

/**
* Subclasses of this class are used to identify different versions of the binary output of the compiler and their compatibility guarantees.
* - Major version should be increased only when the new binary format is neither forward- nor backward compatible.
* This shouldn't really ever happen at all.
* - Minor version should be increased when the new format is backward compatible,
* i.e. the new compiler can process old data, but the old compiler will not be able to process new data.
* - Patch version can be increased freely and is only supposed to be used for debugging. Increase the patch version when you
* make a change to binaries which is both forward- and backward compatible.
*/
abstract class BinaryVersion(private vararg val numbers: Int) {
val major: Int = numbers.getOrNull(0) ?: UNKNOWN
val minor: Int = numbers.getOrNull(1) ?: UNKNOWN
val patch: Int = numbers.getOrNull(2) ?: UNKNOWN
val rest: List<Int> = if (numbers.size > 3) {
if (numbers.size > MAX_LENGTH)
throw IllegalArgumentException("BinaryVersion with length more than $MAX_LENGTH are not supported. Provided length ${numbers.size}.")
else
numbers.asList().subList(3, numbers.size).toList()
} else emptyList()

abstract fun isCompatible(): Boolean

fun toArray(): IntArray = numbers

/**
* Returns true if this version of some format loaded from some binaries is compatible
* to the expected version of that format in the current compiler.
*
* @param ourVersion the version of this format in the current compiler
*/
protected fun isCompatibleTo(ourVersion: BinaryVersion): Boolean {
return if (major == 0) ourVersion.major == 0 && minor == ourVersion.minor
else major == ourVersion.major && minor <= ourVersion.minor
}

override fun toString(): String {
val versions = toArray().takeWhile { it != UNKNOWN }
return if (versions.isEmpty()) "unknown" else versions.joinToString(".")
}

override fun equals(other: Any?) =
other != null &&
this::class.java == other::class.java &&
major == (other as BinaryVersion).major && minor == other.minor && patch == other.patch && rest == other.rest

override fun hashCode(): Int {
var result = major
result += 31 * result + minor
result += 31 * result + patch
result += 31 * result + rest.hashCode()
return result
}

companion object {
const val MAX_LENGTH = 1024
private const val UNKNOWN = -1
}
}