Skip to content

Commit

Permalink
REmove contravariance from ISerialize<T> (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
agocke committed Jan 27, 2024
1 parent 4c335b7 commit 4dbb767
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 46 deletions.
95 changes: 52 additions & 43 deletions src/generator/Generator.Serialize.Generic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.InteropServices.ComTypes;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -164,7 +165,7 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
Identifier("name"),
argumentList: null,
EqualsValueClause(SwitchExpression(receiverExpr, SeparatedList(cases)))) }))));
var wrapper = TryGetPrimitiveWrapper(enumType.EnumUnderlyingType!, SerdeUsage.Serialize)!;
var wrapper = TryGetPrimitiveWrapper(enumType.EnumUnderlyingType!, SerdeUsage.Serialize).Unwrap().Wrapper;
statements.Add(ExpressionStatement(InvocationExpression(
QualifiedName(IdentifierName("serializer"), IdentifierName("SerializeEnumValue")),
ArgumentList(SeparatedList(new[] {
Expand Down Expand Up @@ -205,8 +206,8 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
{
// Generate statements of the form `type.SerializeField<FieldType, Serialize>("FieldName", receiver.FieldValue)`
var memberExpr = MakeMemberAccessExpr(m, receiverExpr);
var serializeImpl = MakeSerializeType(m, context, memberExpr, inProgress);
if (serializeImpl is null)
var typeAndWrapperOpt = MakeSerializeType(m, context, memberExpr, inProgress);
if (typeAndWrapperOpt is not {} typeAndWrapper)
{
// No built-in handling and doesn't implement ISerialize, error
context.ReportDiagnostic(CreateDiagnostic(
Expand All @@ -218,7 +219,7 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
}
else
{
statements.Add(MakeSerializeFieldStmt(m, memberExpr, serializeImpl, receiverExpr));
statements.Add(MakeSerializeFieldStmt(m, memberExpr, typeAndWrapper, receiverExpr));
}
}

Expand Down Expand Up @@ -260,7 +261,7 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
static ExpressionStatementSyntax MakeSerializeFieldStmt(
DataMemberSymbol member,
ExpressionSyntax value,
TypeSyntax serializeType,
TypeAndWrapper typeAndWrapper,
ExpressionSyntax receiver)
{
var arguments = new List<ExpressionSyntax>() {
Expand All @@ -270,8 +271,8 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
value,
};
var typeArgs = new List<TypeSyntax>() {
member.Type.ToFqnSyntax(),
serializeType
typeAndWrapper.Type,
typeAndWrapper.Wrapper
};

string methodName;
Expand Down Expand Up @@ -341,7 +342,7 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
/// implements ISerialize. SerdeDn provides wrappers for primitives and common types in the
/// framework. If found, we generate and initialize the wrapper.
/// </summary>
private static TypeSyntax? MakeSerializeType(
private static TypeAndWrapper? MakeSerializeType(
DataMemberSymbol member,
GeneratorExecutionContext context,
ExpressionSyntax memberExpr,
Expand All @@ -350,23 +351,18 @@ partial record struct {{wrapperName}} : Serde.ISerializeWrap<{{wrappedName}}, {{
// 1. Check for an explicit wrapper
if (TryGetExplicitWrapper(member, context, SerdeUsage.Serialize, inProgress) is {} wrapper)
{
return wrapper;
return new(member.Type.ToFqnSyntax(), wrapper);
}

// 2. Check for a direct implementation of ISerialize
if (ImplementsSerde(member.Type, context, SerdeUsage.Serialize))
{
return GenericName(Identifier("IdWrap"), TypeArgumentList(SeparatedList(new[] { member.Type.ToFqnSyntax() })));
return new(member.Type.ToFqnSyntax(),
GenericName(Identifier("IdWrap"), TypeArgumentList(SeparatedList(new[] { member.Type.ToFqnSyntax() }))));
}

// 3. A wrapper that implements ISerialize
var wrapperType = TryGetAnyWrapper(member.Type, context, SerdeUsage.Serialize, inProgress);
if (wrapperType is not null)
{
return wrapperType;
}

return null;
return TryGetAnyWrapper(member.Type, context, SerdeUsage.Serialize, inProgress);
}

/// <summary>
Expand Down Expand Up @@ -553,9 +549,9 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo

// Otherwise we'll need to wrap the element type as well e.g.,
// ArrayWrap<`elemType`, `elemTypeWrapper`>
var wrapper = TryGetAnyWrapper(elemType, context, usage, inProgress);
var typeAndWrapper = TryGetAnyWrapper(elemType, context, usage, inProgress);

if (wrapper is null)
if (typeAndWrapper is not (_, var wrapper))
{
// Could not find a wrapper
return null;
Expand All @@ -575,7 +571,7 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo
return wrapperSyntax;
}

private static TypeSyntax? TryGetAnyWrapper(
private static TypeAndWrapper? TryGetAnyWrapper(
ITypeSymbol elemType,
GeneratorExecutionContext context,
SerdeUsage usage,
Expand All @@ -592,16 +588,16 @@ private static bool ImplementsSerde(ITypeSymbol memberType, GeneratorExecutionCo
allTypes = parent.Name + allTypes;
}
var wrapperName = $"{allTypes}Wrap";
return IdentifierName(wrapperName);
return new(elemType.ToFqnSyntax(), IdentifierName(wrapperName));
}
var nameSyntax = TryGetPrimitiveWrapper(elemType, usage)
var typeAndWrapper = TryGetPrimitiveWrapper(elemType, usage)
?? TryGetEnumWrapper(elemType, usage)
?? TryGetCompoundWrapper(elemType, context, usage, inProgress);
if (nameSyntax is null)
if (typeAndWrapper is null)
{
return null;
}
return nameSyntax;
return typeAndWrapper;
}


Expand Down Expand Up @@ -654,7 +650,7 @@ namespace Serde
}

// If the target is a core type, we can wrap it
private static TypeSyntax? TryGetPrimitiveWrapper(ITypeSymbol type, SerdeUsage usage)
private static TypeAndWrapper? TryGetPrimitiveWrapper(ITypeSymbol type, SerdeUsage usage)
{
if (type.NullableAnnotation == NullableAnnotation.Annotated)
{
Expand All @@ -678,10 +674,10 @@ namespace Serde
SpecialType.System_Decimal => "DecimalWrap",
_ => null
};
return name is null ? null : IdentifierName(name);
return name is null ? null : new(type.ToFqnSyntax(), IdentifierName(name));
}

private static TypeSyntax? TryGetEnumWrapper(ITypeSymbol type, SerdeUsage usage)
private static TypeAndWrapper? TryGetEnumWrapper(ITypeSymbol type, SerdeUsage usage)
{
if (type.TypeKind is not TypeKind.Enum)
{
Expand All @@ -704,7 +700,7 @@ namespace Serde
? containing + "." + wrapperName
: "global::" + wrapperName;

return SyntaxFactory.ParseTypeName(wrapperFqn);
return new(type.ToFqnSyntax(), ParseTypeName(wrapperFqn));
}

private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usage)
Expand Down Expand Up @@ -735,17 +731,18 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
return false;
}

private static TypeSyntax? TryGetCompoundWrapper(ITypeSymbol type, GeneratorExecutionContext context, SerdeUsage usage, ImmutableList<ITypeSymbol> inProgress)
private static TypeAndWrapper? TryGetCompoundWrapper(ITypeSymbol type, GeneratorExecutionContext context, SerdeUsage usage, ImmutableList<ITypeSymbol> inProgress)
{
return type switch
(TypeSyntax?, TypeSyntax?)? valueTypeAndWrapper = type switch
{
{ OriginalDefinition.SpecialType: SpecialType.System_Nullable_T } =>
MakeWrappedExpression(
(null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.NullableWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(((INamedTypeSymbol)type).TypeArguments[0]),
context,
usage,
inProgress),
inProgress)),

// This is rather subtle. One might think that we would want to use a
// NullableRefWrapper for any reference type that could contain null. In fact, we
Expand All @@ -759,27 +756,36 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
// ISerialize, and therefore the substitution to provide the appropriate nullable
// wrapper.
{ IsReferenceType: true, NullableAnnotation: NullableAnnotation.Annotated} =>
MakeWrappedExpression(
(null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.NullableRefWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(type.WithNullableAnnotation(NullableAnnotation.NotAnnotated)),
context,
usage,
inProgress),
inProgress)),

IArrayTypeSymbol and { IsSZArray: true, Rank: 1, ElementType: { } elemType }
=> MakeWrappedExpression(
=> (null,
MakeWrappedExpression(
context.Compilation.GetTypeByMetadataName("Serde.ArrayWrap+" + GetImplName(usage) + "`2")!,
ImmutableArray.Create(elemType),
context,
usage,
inProgress),
inProgress)),

INamedTypeSymbol t when TryGetWrapperName(t, context, usage) is { } tuple
=> MakeWrappedExpression(
tuple.WrapperType, tuple.Args, context, usage, inProgress),
INamedTypeSymbol t when TryGetWrapperName(t, context, usage) is (var ValueType, (var WrapperType, var Args))
=> (ValueType,
MakeWrappedExpression(
WrapperType, Args, context, usage, inProgress)),

_ => null,
};
return valueTypeAndWrapper switch {
null => null,
(null, {} wrapper) => new(type.ToFqnSyntax(), wrapper),
({ } value, { } wrapper) => new(value, wrapper),
(_, null) => throw ExceptionUtilities.Unreachable
};
}

private static string GetImplName(SerdeUsage usage) => usage switch
Expand All @@ -789,19 +795,20 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
_ => throw ExceptionUtilities.Unreachable
};

private static (INamedTypeSymbol WrapperType, ImmutableArray<ITypeSymbol> Args)? TryGetWrapperName(
private static (TypeSyntax MemberType, (INamedTypeSymbol WrapperType, ImmutableArray<ITypeSymbol> Args))? TryGetWrapperName(
ITypeSymbol typeSymbol,
GeneratorExecutionContext context,
SerdeUsage usage)
{
if (typeSymbol.NullableAnnotation == NullableAnnotation.Annotated)
{
var nullableRefWrap = context.Compilation.GetTypeByMetadataName("Serde.NullableRefWrap+" + GetImplName(usage) + "`1")!;
return (nullableRefWrap, ImmutableArray.Create(typeSymbol.WithNullableAnnotation(NullableAnnotation.NotAnnotated)));
return (typeSymbol.ToFqnSyntax(),
(nullableRefWrap, ImmutableArray.Create(typeSymbol.WithNullableAnnotation(NullableAnnotation.NotAnnotated))));
}
if (typeSymbol is INamedTypeSymbol named && TryGetWellKnownType(named, context) is {} wk)
{
return (ToWrapper(wk, context.Compilation, usage), named.TypeArguments);
return (typeSymbol.ToFqnSyntax(), (ToWrapper(wk, context.Compilation, usage), named.TypeArguments));
}

// Check if it implements well-known interfaces
Expand All @@ -813,13 +820,15 @@ private static bool HasGenerateAttribute(ITypeSymbol memberType, SerdeUsage usag
if (impl.OriginalDefinition.Equals(iface, SymbolEqualityComparer.Default) &&
ToWrapper(TryGetWellKnownType(iface, context), context.Compilation, usage) is { } wrap)
{
return (wrap, impl.TypeArguments);
return (impl.ToFqnSyntax(), (wrap, impl.TypeArguments));
}
}
}
return null;
}

private readonly record struct TypeAndWrapper(TypeSyntax Type, TypeSyntax Wrapper);

[return: NotNullIfNotNull(nameof(wk))]
internal static INamedTypeSymbol? ToWrapper(WellKnownType? wk, Compilation comp, SerdeUsage usage)
{
Expand Down
3 changes: 3 additions & 0 deletions src/generator/Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Serde
{
Expand Down Expand Up @@ -32,6 +33,8 @@ public static bool IsSorted<T>(this ReadOnlySpan<T> span, IComparer<T> comparer)

internal static class Utilities
{
public static T Unwrap<T>(this T? value) where T : struct => value!.Value;

public static string Concat(this string recv, string other)
{
return recv + other;
Expand Down
2 changes: 1 addition & 1 deletion src/serde/ISerialize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public interface ISerialize
void Serialize(ISerializer serializer);
}

public interface ISerialize<in T> : ISerialize
public interface ISerialize<T> : ISerialize
{
void Serialize(T value, ISerializer serializer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ partial class C : Serde.ISerialize<C>
void ISerialize<C>.Serialize(C value, ISerializer serializer)
{
var type = serializer.SerializeType("C", 1);
type.SerializeField<R, Serde.IDictWrap.SerializeImpl<string, StringWrap, int, Int32Wrap>>("rDictionary", value.RDictionary);
type.SerializeField<System.Collections.Generic.IDictionary<string, int>, Serde.IDictWrap.SerializeImpl<string, StringWrap, int, Int32Wrap>>("rDictionary", value.RDictionary);
type.End();
}
}
2 changes: 1 addition & 1 deletion test/Serde.Test/JsonSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public void NullableString()
string? s = null;
var js = Serde.Json.JsonSerializer.Serialize<string?, NullableRefWrap.SerializeImpl<string, StringWrap>>(s);
Assert.Equal("null", js);
js = Serde.Json.JsonSerializer.Serialize(JsonValue.Null.Instance);
js = Serde.Json.JsonSerializer.Serialize<JsonValue>(JsonValue.Null.Instance);
Assert.Equal("null", js);
}

Expand Down

0 comments on commit 4dbb767

Please sign in to comment.