Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kotlinx and custom serializer #1226

Closed
ValdasK opened this issue Jul 28, 2021 · 10 comments
Closed

Kotlinx and custom serializer #1226

ValdasK opened this issue Jul 28, 2021 · 10 comments
Labels
question Further information is requested

Comments

@ValdasK
Copy link

ValdasK commented Jul 28, 2021

Describe the bug
It seems that kotlinx custom serializers are ignored; even if objects are marked as @kotlinx.serialization.Serializable

To Reproduce

package com.example.demokotlin

import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RestController


@Serializable(with = LocationAsStringSerialiser::class)
data class Location(val id: String)

@Serializable
data class LocationResponse(val location: Location)

object LocationAsStringSerialiser : KSerializer<Location> {
	override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("Location", PrimitiveKind.STRING)

	override fun serialize(encoder: Encoder, value: Location) = encoder.encodeString(value.id)
	override fun deserialize(decoder: Decoder) = Location(decoder.decodeString())
}

@RestController
@RequestMapping("/test")
class HelloController {
	@GetMapping
	suspend fun wrong(): LocationResponse {
		return LocationResponse(Location(id = "test"))
	}

}

Steps to reproduce the behavior:

  • What version of spring-boot you are using?
    <artifactId>spring-boot-starter-web</artifactId>
    <version>2.5.1</version>

     <artifactId>springdoc-openapi-ui</artifactId>
     <version>1.5.10</version>

     <artifactId>springdoc-openapi-kotlin</artifactId>
     <version>1.5.10</version>

Expected behavior

Objects to be documented as

   "location": "string"

instead of

    "location": {
        "id": "string"
      },
@bnasslahsen
Copy link
Contributor

@ValdasK,

Please make sure you read the section Using GitHub Issues.

For any future issue, make sure you:
Provide a Minimal, Reproducible Example - with HelloController that reproduces the problem

This issue is not reproducible. You can see the attachment for more details.

@ValdasK
Copy link
Author

ValdasK commented Jul 28, 2021

Thank you for taking time to create attachment, I'v modified it to include a bit more code (updated controller to)

@Serializable(with = LocationAsStringSerialiser::class)
data class Location(val id: String)

@Serializable
data class LocationResponse(val location: Location)

object LocationAsStringSerialiser : KSerializer<Location> {
	override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("Location", PrimitiveKind.STRING)

	override fun serialize(encoder: Encoder, value: Location) = encoder.encodeString(value.id)
	override fun deserialize(decoder: Decoder) = Location(decoder.decodeString())
}

@RestController
@RequestMapping("/test")
class HelloController {
	@GetMapping
	suspend fun wrong(): LocationResponse {
		return LocationResponse(Location(id = "test"))
	}

}

In UI we can see that actual server response and doc generated response does not match:

image

@bnasslahsen
Copy link
Contributor

Not reproducible.
As mentionned earlier, make sure you: provide a Minimal, Reproducible Example - with HelloController that reproduces the problem.

@ValdasK
Copy link
Author

ValdasK commented Jul 29, 2021

Not sure which part is not reproducible; updated original message to latest code and uploading an attachment of modified code too:
demo-kotlin-updated.zip

@ValdasK
Copy link
Author

ValdasK commented Jul 30, 2021

@bnasslahsen would it be possible to re-open issue if you have all the data necessary to re-produce it now?

@bnasslahsen
Copy link
Contributor

@ValdasK,

Not reproducible.

This is the result of your attached sample:

image

@ValdasK
Copy link
Author

ValdasK commented Aug 2, 2021

Alright, it turns out that kotlinx-serialization plugin was missing in the demo, therefore all kotlinx serialization was not applied at all.

Uploading a updated version (same code; but with plugin enabled):
demo-kotlin-with-plugin.zip

@bnasslahsen bnasslahsen reopened this Aug 2, 2021
@bnasslahsen
Copy link
Contributor

@ValdasK,

springdoc-openapi is based on swagger-core, which relies on jackson for types introspection.
You can see the following class for more details.

As workaround, you have to provide your own ModelConverter as a spring bean:

kotlinx-serialization doesn't seem to provide enhanced capabilities for type introspection as available in jackson.

I am sharing with you a sample code, that shows how kotlinx-serialization could be supported for the simple case of your sample.

Please note this only adapted for simple Pojos with string properties. You will have to adapt it, for more complex types:

import io.swagger.v3.core.converter.AnnotatedType
import io.swagger.v3.core.converter.ModelConverter
import io.swagger.v3.core.converter.ModelConverterContext
import io.swagger.v3.core.util.Json
import io.swagger.v3.oas.models.media.ObjectSchema
import io.swagger.v3.oas.models.media.Schema
import io.swagger.v3.oas.models.media.StringSchema
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerializersKt
import kotlinx.serialization.descriptors.PrimitiveKind.STRING
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind.CLASS
import org.apache.commons.lang3.reflect.FieldUtils
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component


/**
 * The converter to support Kotlinx Serialization
 * @author bnasslahsen
 */
@Component
class KotlinxSerializationTypeConverter : ModelConverter {
	override fun resolve(
		type: AnnotatedType,
		context: ModelConverterContext,
		chain: Iterator<ModelConverter>
	): Schema<*>? {
		val javaType = Json.mapper().constructType(type.type)
		val cls = javaType.rawClass
		val serializer: KSerializer<Any> = SerializersKt.serializerOrNull(cls)
		if (serializer != null) {
			val schema: Schema<*>
			val serialDescriptor: SerialDescriptor = serializer.getDescriptor()
			if (CLASS.INSTANCE.equals(serialDescriptor.getKind())) {
				schema = ObjectSchema()
				val indicesField = FieldUtils.getDeclaredField(
					serialDescriptor.getClass(),
					"indices",
					true
				)
				if (indicesField != null) {
					try {
						val indices: Map<String?, Int?> =
							indicesField[serialDescriptor] as Map<*, *>
						for ((key, value) in indices) {
							val propsDescriptor: SerialDescriptor =
								serialDescriptor.getElementDescriptor(
									value
								)
							if (STRING.INSTANCE.equals(propsDescriptor.getKind())) {
								schema.addProperties(key, StringSchema())
							}
						}
						return schema
					} catch (e: IllegalAccessException) {
						LOGGER.warn(e.message)
					}
				}
			}
		}
		return if (chain.hasNext()) chain.next().resolve(type, context, chain) else null
	}

	companion object {
		private val LOGGER = LoggerFactory.getLogger(
			KotlinxSerializationTypeConverter::class.java
		)
	}
}

@bnasslahsen bnasslahsen added the question Further information is requested label Jan 26, 2022
@werner77
Copy link

werner77 commented Feb 19, 2022

I think I figured out the complete implementation!

package com.beatgridmedia.measurementkit.swagger;

import io.swagger.v3.core.converter.AnnotatedType;
import io.swagger.v3.core.converter.ModelConverter;
import io.swagger.v3.core.converter.ModelConverterContext;
import io.swagger.v3.oas.models.media.ArraySchema;
import io.swagger.v3.oas.models.media.BooleanSchema;
import io.swagger.v3.oas.models.media.ComposedSchema;
import io.swagger.v3.oas.models.media.Discriminator;
import io.swagger.v3.oas.models.media.IntegerSchema;
import io.swagger.v3.oas.models.media.NumberSchema;
import io.swagger.v3.oas.models.media.ObjectSchema;
import io.swagger.v3.oas.models.media.Schema;
import io.swagger.v3.oas.models.media.StringSchema;
import kotlin.jvm.functions.Function1;
import kotlin.reflect.KClass;
import kotlinx.serialization.DeserializationStrategy;
import kotlinx.serialization.KSerializer;
import kotlinx.serialization.Serializable;
import kotlinx.serialization.SerializersKt;
import kotlinx.serialization.descriptors.PolymorphicKind;
import kotlinx.serialization.descriptors.PrimitiveKind;
import kotlinx.serialization.descriptors.SerialDescriptor;
import kotlinx.serialization.descriptors.SerialKind;
import kotlinx.serialization.descriptors.StructureKind;
import kotlinx.serialization.modules.SerializersModule;
import kotlinx.serialization.modules.SerializersModuleCollector;
import kotlinx.serialization.modules.SerializersModuleKt;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.regex.Pattern;

import static io.swagger.v3.core.util.RefUtils.constructRef;

@Component
public class KotlinxSerializationTypeConverter implements ModelConverter {

    private static final Pattern polymorphicNamePattern = Pattern.compile("kotlinx\\.serialization\\.(Polymorphic|Sealed)<(.*)>");
    private final SerializersModule module;
    private final Map<String, List<SerialDescriptor>> classHierarchyMap = new HashMap<>();

    public KotlinxSerializationTypeConverter(@Autowired(required = false) SerializersModule module) {
        var effectiveModule = Optional.ofNullable(module).orElse(SerializersModuleKt.getEmptySerializersModule());

        // Inspect the module to collect all the sub classes for polymorphism
        effectiveModule.dumpTo(new SerializersModuleCollector() {
            @Override
            public <T> void contextual(@NotNull KClass<T> kClass, @NotNull KSerializer<T> kSerializer) {
            }

            @Override
            public <T> void contextual(@NotNull KClass<T> kClass, @NotNull Function1<? super List<? extends KSerializer<?>>, ? extends KSerializer<?>> function1) {
            }

            @Override
            public <Base, Sub extends Base> void polymorphic(@NotNull KClass<Base> baseClass, @NotNull KClass<Sub> subClass, @NotNull KSerializer<Sub> subSerializer) {
                var baseSerializer = SerializersKt.serializer(baseClass);
                classHierarchyMap.computeIfAbsent(getName(baseSerializer.getDescriptor()), (k) -> new ArrayList<>()).add(subSerializer.getDescriptor());
            }

            @Override
            public <Base> void polymorphicDefault(@NotNull KClass<Base> kClass, @NotNull Function1<? super String, ? extends DeserializationStrategy<? extends Base>> function1) {

            }
        });

        this.module = effectiveModule;
    }

    @Override
    public Schema<?> resolve(AnnotatedType annotatedType, ModelConverterContext context, Iterator<ModelConverter> iterator) {
        // Check whether the type is annotated with @Serializable
        if (annotatedType.getType() instanceof Class<?> clazz &&
                Arrays.stream(clazz.getAnnotations()).anyMatch(it -> it instanceof Serializable)) {
            KSerializer<?> serializer = SerializersKt.serializer(module, clazz);
            SerialDescriptor serialDescriptor = serializer.getDescriptor();
            return resolveImpl(context, serialDescriptor, null);
        }
        if (iterator.hasNext()) {
            return iterator.next().resolve(annotatedType, context, iterator);
        } else {
            return null;
        }
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    private Schema<?> resolveImpl(ModelConverterContext context, SerialDescriptor serialDescriptor, @Nullable Schema<?> baseSchema) {
        var kind = serialDescriptor.getKind();
        var resolved = resolveRef(context, serialDescriptor);
        if (resolved != null) {
            return resolved;
        } else if (PrimitiveKind.STRING.INSTANCE.equals(kind)) {
            return new StringSchema();
        } else if (PrimitiveKind.BOOLEAN.INSTANCE.equals(kind)) {
            return new BooleanSchema();
        } else if (PrimitiveKind.INT.INSTANCE.equals(kind) ||
                PrimitiveKind.LONG.INSTANCE.equals(kind) ||
                PrimitiveKind.SHORT.INSTANCE.equals(kind) ||
                PrimitiveKind.BYTE.INSTANCE.equals(kind) ||
                PrimitiveKind.CHAR.INSTANCE.equals(kind)) {
            return new IntegerSchema();
        } else if (PrimitiveKind.FLOAT.INSTANCE.equals(kind) ||
                PrimitiveKind.DOUBLE.INSTANCE.equals(kind)) {
            return new NumberSchema();
        } else if (StructureKind.CLASS.INSTANCE.equals(kind) || StructureKind.OBJECT.INSTANCE.equals(kind)) {
            // Find base schema
            var schema = baseSchema == null ? new ObjectSchema() : new ComposedSchema().addAllOfItem(baseSchema);
            for (int i = 0; i < serialDescriptor.getElementsCount(); ++i) {
                var elementDescriptor = serialDescriptor.getElementDescriptor(i);
                var elementName = serialDescriptor.getElementName(i);
                schema.addProperties(
                        elementName,
                        resolveImpl(context, elementDescriptor, null).nullable(elementDescriptor.isNullable())
                );
                if (!serialDescriptor.isElementOptional(i)) {
                    schema.addRequiredItem(elementName);
                }
            }
            return defineRef(context, serialDescriptor, schema);
        } else if (StructureKind.LIST.INSTANCE.equals(kind)) {
            var schema = new ArraySchema();
            var elementDescriptor = serialDescriptor.getElementDescriptor(0);
            schema.setItems(resolveImpl(context, elementDescriptor, null).nullable(elementDescriptor.isNullable()));
            return schema;
        } else if (StructureKind.MAP.INSTANCE.equals(kind)) {
            if (serialDescriptor.getElementsCount() != 2) {
                throw new IllegalStateException("Expected exactly two elements for Map serial descriptor");
            }
            var schema = new ObjectSchema();
            // Key should always be a string
            if (!PrimitiveKind.STRING.INSTANCE.equals(serialDescriptor.getElementDescriptor(0).getKind())) {
                throw new IllegalStateException("Key type should be string for JSON");
            }
            var valueSchema = resolveImpl(context, serialDescriptor.getElementDescriptor(1), null);
            schema.additionalProperties(valueSchema);
            return schema;
        } else if (SerialKind.CONTEXTUAL.INSTANCE.equals(kind)) {
            throw new IllegalStateException("Contextual mappings are only allowed in the context of polymorphism");
        } else if (SerialKind.ENUM.INSTANCE.equals(kind)) {
            Schema schema = baseSchema == null ? new StringSchema() : new ComposedSchema().addAllOfItem(baseSchema);
            for (int i = 0; i < serialDescriptor.getElementsCount(); ++i) {
                schema.addEnumItemObject(serialDescriptor.getElementName(i));
            }
            return defineRef(context, serialDescriptor, schema);
        } else if (PolymorphicKind.SEALED.INSTANCE.equals(kind) || PolymorphicKind.OPEN.INSTANCE.equals(kind)) {
            if (serialDescriptor.getElementsCount() < 2) {
                throw new IllegalStateException("Expected at least two fields for a polymorphic class descriptor");
            }
            var schema = new ComposedSchema();
            if (baseSchema != null) {
                schema.addAllOfItem(baseSchema);
            }
            var discriminator = new Discriminator().propertyName(serialDescriptor.getElementName(0));
            schema.discriminator(discriminator);
            var refSchema = defineRef(context, serialDescriptor, schema);
            for (int i = 0; i < serialDescriptor.getElementsCount(); ++i) {
                var elementName = serialDescriptor.getElementName(i);
                var elementDescriptor = serialDescriptor.getElementDescriptor(i);
                if (elementDescriptor.getKind().equals(SerialKind.CONTEXTUAL.INSTANCE)) {
                    // Value descriptor
                    var allKnownSubDescriptors = Optional.ofNullable(classHierarchyMap.get(getName(elementDescriptor))).orElse(Collections.emptyList());
                    for (var subDescriptor : allKnownSubDescriptors) {
                        Schema<?> subSchema = resolveImpl(context, subDescriptor, refSchema);
                        discriminator.mapping(getName(subDescriptor), subSchema.get$ref());
                        schema.addAnyOfItem(subSchema);
                    }
                } else {
                    schema.addProperties(
                            elementName,
                            resolveImpl(context, elementDescriptor, null).nullable(elementDescriptor.isNullable())
                    );
                    if (!serialDescriptor.isElementOptional(i)) {
                        schema.addRequiredItem(elementName);
                    }
                }
            }
            return refSchema;
        }
        throw new IllegalStateException("Unsupported serializer kind: " + kind);
    }

    private static Schema<?> resolveRef(ModelConverterContext context, SerialDescriptor serialDescriptor) {
        var name = getName(serialDescriptor);
        if (context.getDefinedModels().containsKey(name)) {
            return new Schema<>().$ref(constructRef(name));
        }
        return null;
    }

    private static Schema<?> defineRef(ModelConverterContext context, SerialDescriptor serialDescriptor, Schema<?> schema) {
        // Store off the ref and add the enum as a top-level model
        var name = getName(serialDescriptor);
        context.defineModel(name, schema);
        return new Schema<>().$ref(constructRef(name));
    }

    private static String getName(SerialDescriptor serialDescriptor) {
        var name = serialDescriptor.getSerialName().replace("?", "").trim();
        var matcher = polymorphicNamePattern.matcher(name);
        if (matcher.matches()) {
            name = matcher.group(2);
        }
        return name;
    }
}

@werner77
Copy link

See PR: #1514

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants