Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,16 @@ import com.squareup.kotlinpoet.buildCodeBlock
import com.squareup.kotlinpoet.joinToCode
import com.squareup.kotlinpoet.jvm.jvmField
import com.squareup.kotlinpoet.jvm.jvmStatic
import com.squareup.wire.internal.DoubleArrayList
import com.squareup.wire.EnumAdapter
import com.squareup.wire.FieldEncoding
import com.squareup.wire.internal.FloatArrayList
import com.squareup.wire.GrpcCall
import com.squareup.wire.GrpcClient
import com.squareup.wire.GrpcMethod
import com.squareup.wire.GrpcStreamingCall
import com.squareup.wire.internal.IntArrayList
import com.squareup.wire.internal.LongArrayList
import com.squareup.wire.Message
import com.squareup.wire.MessageSink
import com.squareup.wire.MessageSource
Expand Down Expand Up @@ -156,6 +160,47 @@ class KotlinGenerator private constructor(
get() = type.typeName
private val Service.serviceName
get() = type.typeName
private val Field.primitiveArrayClassForType
get() = when (type!!.typeName) {
LONG -> LongArray::class
INT -> IntArray::class
FLOAT -> FloatArray::class
DOUBLE -> DoubleArray::class
else -> throw IllegalArgumentException("No Array type for $type")
}
private val Field.emptyPrimitiveArrayForType
get() = when (type!!.typeName) {
LONG -> CodeBlock.of("longArrayOf()")
INT -> CodeBlock.of("intArrayOf()")
FLOAT -> CodeBlock.of("floatArrayOf()")
DOUBLE -> CodeBlock.of("doubleArrayOf()")
else -> throw IllegalArgumentException("No Array type for $type")
}
private val Field.arrayListClassForType
get() = when (type!!.typeName) {
LONG -> LongArrayList::class
INT -> IntArrayList::class
FLOAT -> FloatArrayList::class
DOUBLE -> DoubleArrayList::class
else -> throw IllegalArgumentException("No ArrayList type for $type")
}

private val Field.arrayAdapterForType
get() = when (type!!) {
ProtoType.INT32 -> CodeBlock.of("%T.INT32_ARRAY", ProtoAdapter::class)
ProtoType.UINT32 -> CodeBlock.of("%T.UINT32_ARRAY", ProtoAdapter::class)
ProtoType.SINT32 -> CodeBlock.of("%T.SINT32_ARRAY", ProtoAdapter::class)
ProtoType.FIXED32 -> CodeBlock.of("%T.FIXED32_ARRAY", ProtoAdapter::class)
ProtoType.SFIXED32 -> CodeBlock.of("%T.SFIXED32_ARRAY", ProtoAdapter::class)
ProtoType.INT64 -> CodeBlock.of("%T.INT64_ARRAY", ProtoAdapter::class)
ProtoType.UINT64 -> CodeBlock.of("%T.UINT64_ARRAY", ProtoAdapter::class)
ProtoType.SINT64 -> CodeBlock.of("%T.SINT64_ARRAY", ProtoAdapter::class)
ProtoType.FIXED64 -> CodeBlock.of("%T.FIXED64_ARRAY", ProtoAdapter::class)
ProtoType.SFIXED64 -> CodeBlock.of("%T.SFIXED64_ARRAY", ProtoAdapter::class)
ProtoType.FLOAT -> CodeBlock.of("%T.FLOAT_ARRAY", ProtoAdapter::class)
ProtoType.DOUBLE -> CodeBlock.of("%T.DOUBLE_ARRAY", ProtoAdapter::class)
else -> throw IllegalArgumentException("No Array adapter for $type")
}

/** Returns the full name of the class generated for [type]. */
fun generatedTypeName(type: Type) = type.typeName as ClassName
Expand Down Expand Up @@ -659,7 +704,11 @@ class KotlinGenerator private constructor(
when (fieldOrOneOf) {
is Field -> {
val fieldName = localNameAllocator[fieldOrOneOf]
addStatement("if (%1L != %2N.%1L) return·false", fieldName, otherName)
if (fieldOrOneOf.useArray) {
addStatement("if (!%1L.contentEquals(%2N.%1L)) return·false", fieldName, otherName)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

} else {
addStatement("if (%1L != %2N.%1L) return·false", fieldName, otherName)
}
}
is OneOf -> {
val fieldName = localNameAllocator[fieldOrOneOf]
Expand Down Expand Up @@ -713,7 +762,9 @@ class KotlinGenerator private constructor(
is Field -> {
val fieldName = localNameAllocator[fieldOrOneOf]
add("%1N = %1N * 37 + ", resultName)
if (fieldOrOneOf.isRepeated || fieldOrOneOf.isRequired || fieldOrOneOf.isMap || !fieldOrOneOf.acceptsNull) {
if (fieldOrOneOf.useArray) {
addStatement("%L.contentHashCode()", fieldName)
} else if (fieldOrOneOf.isRepeated || fieldOrOneOf.isRequired || fieldOrOneOf.isMap || !fieldOrOneOf.acceptsNull) {
addStatement("%L.hashCode()", fieldName)
} else {
addStatement("(%L?.hashCode() ?: 0)", fieldName)
Expand Down Expand Up @@ -940,7 +991,7 @@ class KotlinGenerator private constructor(
.build()
)
}
if (field.isRepeated) {
if (field.isRepeated && !field.useArray) {
val checkElementsNotNull = MemberName("com.squareup.wire.internal", "checkElementsNotNull")
funBuilder.addStatement("%M(%L)", checkElementsNotNull, fieldName)
}
Expand Down Expand Up @@ -1075,6 +1126,9 @@ class KotlinGenerator private constructor(
fieldName
)
}
field.isPacked && field.isScalar && field.useArray -> {
CodeBlock.of(fieldName)
}
field.isRepeated || field.isMap -> {
CodeBlock.of(
"%M(%S, %N)",
Expand Down Expand Up @@ -1216,10 +1270,13 @@ class KotlinGenerator private constructor(
when (fieldOrOneOf) {
is Field -> {
val fieldName = localNameAllocator[fieldOrOneOf]
if (fieldOrOneOf.isRepeated || fieldOrOneOf.isMap) {
add("if (%N.isNotEmpty()) ", fieldName)
} else if (fieldOrOneOf.acceptsNull) {
add("if (%N != null) ", fieldName)
when {
fieldOrOneOf.isRepeated || fieldOrOneOf.isMap -> {
add("if (%N.isNotEmpty()) ", fieldName)
}
fieldOrOneOf.acceptsNull -> {
add("if (%N != null) ", fieldName)
}
}
addStatement(
"%N += %P", resultName,
Expand All @@ -1230,6 +1287,11 @@ class KotlinGenerator private constructor(
} else {
if (fieldOrOneOf.type == ProtoType.STRING) {
add("=\${%M($fieldName)}", sanitizeMember)
} else if (fieldOrOneOf.useArray) {
add("=\${")
add(fieldName)
add(".contentToString()")
add("}")
} else {
add("=\$")
add(fieldName)
Expand Down Expand Up @@ -1517,11 +1579,15 @@ class KotlinGenerator private constructor(
}

private fun adapterFor(field: Field) = buildCodeBlock {
add("%L", field.getAdapterName())
if (field.isPacked) {
add(".asPacked()")
} else if (field.isRepeated) {
add(".asRepeated()")
if (field.useArray) {
add(field.arrayAdapterForType)
} else {
add("%L", field.getAdapterName())
if (field.isPacked) {
add(".asPacked()")
} else if (field.isRepeated) {
add(".asRepeated()")
}
}
}

Expand All @@ -1536,12 +1602,23 @@ class KotlinGenerator private constructor(
if (field.encodeMode == EncodeMode.OMIT_IDENTITY) {
add(fieldEqualsIdentityBlock(field, fieldName))
}
addStatement(
"%L.encodeWithTag(writer, %L, value.%L)",
adapterFor(field),
field.tag,
fieldName
)
if (field.useArray && reverse) {
val encodeArray = MemberName("com.squareup.wire.internal", "encodeArray")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

addStatement(
"%M(value.%L, %L, writer, %L)",
encodeArray,
fieldName,
field.getAdapterName(),
field.tag,
)
} else {
addStatement(
"%L.encodeWithTag(writer, %L, value.%L)",
adapterFor(field),
field.tag,
fieldName
)
}
}
}
for (boxOneOf in message.boxOneOfs()) {
Expand Down Expand Up @@ -1620,7 +1697,11 @@ class KotlinGenerator private constructor(
}

if (fieldOrOneOf.isPacked && fieldOrOneOf.isScalar) {
addStatement("%1L = %1L ?: listOf(),", fieldName)
if (fieldOrOneOf.useArray) {
addStatement("%1L = %1L?.toArray() ?: %2L,", fieldName, fieldOrOneOf.emptyPrimitiveArrayForType)
} else {
addStatement("%1L = %1L ?: listOf(),", fieldName)
}
} else {
addStatement("%1L = %1L%2L,", fieldName, throwExceptionBlock)
}
Expand Down Expand Up @@ -1705,20 +1786,38 @@ class KotlinGenerator private constructor(
}

private fun decodeAndAssign(field: Field, fieldName: String, adapterName: CodeBlock): CodeBlock {
val decode = CodeBlock.of("%L.decode(reader)", adapterName)
val decode = CodeBlock.of(
"%L.%L(reader)",
adapterName,
if (field.useArray) "decodePrimitive" else "decode",
)
return CodeBlock.of(
when {
field.isPacked && field.isScalar ->
field.useArray -> {
buildCodeBlock {
beginControlFlow("if (%L == null)", fieldName)
addStatement(
"%L = %L.forDecoding(reader.nextFieldMinLengthInBytes(), %L)",
fieldName,
field.arrayListClassForType.simpleName,
field.getMinimumByteSize(),
)
endControlFlow()
addStatement("%1L!!.add(%2L)", field, decode)
}.toString()
}
field.isPacked && field.isScalar -> {
buildCodeBlock {
beginControlFlow("if (%L == null)", fieldName)
addStatement("val minimumByteSize = ${field.getMinimumByteSize()}")
addStatement("val initialCapacity = (reader.nextFieldMinLengthInBytes() / minimumByteSize)")
addStatement("⇥.coerceAtMost(Int.MAX_VALUE.toLong())")
addStatement(".toInt()")
addStatement("⇤%L = ArrayList(initialCapacity)", fieldName)
addStatement("⇤%L = %L(initialCapacity)", fieldName, ArrayList::class.simpleName)
endControlFlow()
addStatement("%1L!!.add(%2L)", field, decode)
}.toString()
}
field.isRepeated -> "%L.add(%L)"
field.isMap -> "%L.putAll(%L)"
else -> "%L·= %L"
Expand Down Expand Up @@ -1796,6 +1895,7 @@ class KotlinGenerator private constructor(
private fun Field.redact(fieldName: String): CodeBlock? {
if (isRedacted) {
return when {
useArray -> emptyPrimitiveArrayForType
isRepeated -> CodeBlock.of("emptyList()")
isMap -> CodeBlock.of("emptyMap()")
encodeMode!! == EncodeMode.NULL_IF_ABSENT -> CodeBlock.of("null")
Expand Down Expand Up @@ -2186,6 +2286,7 @@ class KotlinGenerator private constructor(
}

private fun Field.getDeclaration(allocatedName: String) = when {
useArray -> CodeBlock.of("var %N: %T? = null", allocatedName, arrayListClassForType)
isPacked && isScalar -> CodeBlock.of("var %N: MutableList<%T>? = null", allocatedName, type!!.typeName)
isRepeated -> CodeBlock.of("val $allocatedName = mutableListOf<%T>()", type!!.typeName)
isMap -> CodeBlock.of(
Expand Down Expand Up @@ -2216,8 +2317,14 @@ class KotlinGenerator private constructor(
val type = type!!
val baseClass = type.typeName
return when (encodeMode!!) {
EncodeMode.REPEATED,
EncodeMode.PACKED -> List::class.asClassName().parameterizedBy(baseClass)
EncodeMode.REPEATED -> List::class.asClassName().parameterizedBy(baseClass)
EncodeMode.PACKED -> {
if (useArray) {
primitiveArrayClassForType.asTypeName()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

} else {
List::class.asTypeName().parameterizedBy(baseClass)
}
}
EncodeMode.MAP -> baseClass.copy(nullable = false)
EncodeMode.NULL_IF_ABSENT -> baseClass.copy(nullable = true)
else -> {
Expand All @@ -2235,8 +2342,13 @@ class KotlinGenerator private constructor(
return when (encodeMode!!) {
EncodeMode.MAP ->
Map::class.asTypeName().parameterizedBy(keyType.typeName, valueType.typeName)
EncodeMode.REPEATED,
EncodeMode.PACKED -> List::class.asClassName().parameterizedBy(type.typeName)
EncodeMode.PACKED -> {
when {
useArray -> primitiveArrayClassForType.asTypeName()
Comment thread
JGulbronson marked this conversation as resolved.
else -> List::class.asTypeName().parameterizedBy(type.typeName)
}
}
EncodeMode.REPEATED -> List::class.asClassName().parameterizedBy(type.typeName)
EncodeMode.NULL_IF_ABSENT -> type.typeName.copy(nullable = true)
EncodeMode.REQUIRED -> type.typeName
EncodeMode.OMIT_IDENTITY -> {
Expand All @@ -2254,8 +2366,14 @@ class KotlinGenerator private constructor(
get() {
return when (encodeMode!!) {
EncodeMode.MAP -> CodeBlock.of("emptyMap()")
EncodeMode.REPEATED,
EncodeMode.PACKED -> CodeBlock.of("emptyList()")
EncodeMode.REPEATED -> CodeBlock.of("emptyList()")
EncodeMode.PACKED -> {
if (useArray) {
emptyPrimitiveArrayForType
} else {
CodeBlock.of("emptyList()")
}
}
EncodeMode.NULL_IF_ABSENT -> CodeBlock.of("null")
EncodeMode.OMIT_IDENTITY -> {
val protoType = type!!
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,25 @@ class KotlinGeneratorTest {
assertTrue(code.contains("import wire_package.Person"))
}

@Test fun useArrayUsesTheCorrectType() {
val schema = buildSchema {
add(
"proto_package/person.proto".toPath(),
"""
|package proto_package;
|import "wire/extensions.proto";
|
|message Person {
| repeated float info = 1 [packed = true, wire.use_array = true];
|}
|""".trimMargin()
)
}
val code = KotlinWithProfilesGenerator(schema).generateKotlin("proto_package.Person")
assertContains(code, "public val info: FloatArray = floatArrayOf()")
assertContains(code, "ProtoAdapter.FLOAT_ARRAY.encodeWithTag(writer, 1, value.info)")
}

@Test fun documentationEscapesBrackets() {
val schema = buildSchema {
add(
Expand Down
Loading