Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use private static fields to store constant typeParameters where possible #1606

Merged
merged 1 commit into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 91 additions & 18 deletions InterfaceStubGenerator.Shared/InterfaceStubGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public class InterfaceStubGeneratorV2 : IIncrementalGenerator
public class InterfaceStubGenerator : ISourceGenerator
#endif
{
private const string TypeParameterVariableName = "_typeParameters";

#pragma warning disable RS2008 // Enable analyzer release tracking
static readonly DiagnosticDescriptor InvalidRefitMember =
new(
Expand Down Expand Up @@ -396,15 +398,17 @@ partial class {ns}{classDeclaration}
.Cast<IMethodSymbol>()
.ToList();

var memberNames = new HashSet<string>(interfaceSymbol.GetMembers().Select(x => x.Name));

// Handle Refit Methods
foreach (var method in refitMethods)
{
ProcessRefitMethod(source, method, true);
ProcessRefitMethod(source, method, true, memberNames);
}

foreach (var method in refitMethods.Concat(derivedRefitMethods))
{
ProcessRefitMethod(source, method, false);
ProcessRefitMethod(source, method, false, memberNames);
}

// Handle non-refit Methods that aren't static or properties or have a method body
Expand Down Expand Up @@ -445,8 +449,20 @@ partial class {ns}{classDeclaration}
/// <param name="source"></param>
/// <param name="methodSymbol"></param>
/// <param name="isTopLevel">True if directly from the type we're generating for, false for methods found on base interfaces</param>
void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool isTopLevel)
/// <param name="memberNames">Contains the unique member names in the interface scope.</param>
void ProcessRefitMethod(
StringBuilder source,
IMethodSymbol methodSymbol,
bool isTopLevel,
HashSet<string> memberNames
)
{
var parameterTypesExpression = GenerateTypeParameterExpression(
source,
methodSymbol,
memberNames
);

var returnType = methodSymbol.ReturnType.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat
);
Expand All @@ -466,15 +482,6 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
argList.Add($"@{param.MetadataName}");
}

// List of types.
var typeList = new List<string>();
foreach (var param in methodSymbol.Parameters)
{
typeList.Add(
$"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
);
}

// List of generic arguments
var genericList = new List<string>();
foreach (var typeParam in methodSymbol.TypeParameters)
Expand All @@ -489,11 +496,6 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
? "global::System.Array.Empty<object>()"
: $"new object[] {{ {string.Join(", ", argList)} }}";

var parameterTypesArrayString =
typeList.Count == 0
? "global::System.Array.Empty<global::System.Type>()"
: $"new global::System.Type[] {{ {string.Join(", ", typeList)} }}";

var genericString =
genericList.Count > 0
? $", new global::System.Type[] {{ {string.Join(", ", genericList)} }}"
Expand All @@ -502,7 +504,7 @@ void ProcessRefitMethod(StringBuilder source, IMethodSymbol methodSymbol, bool i
source.Append(
@$"
var ______arguments = {argumentsArrayString};
var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesArrayString}{genericString} );
var ______func = requestBuilder.BuildRestResultFuncForMethod(""{methodSymbol.Name}"", {parameterTypesExpression}{genericString} );
try
{{
{@return}({returnType})______func(this.Client, ______arguments){configureAwait};
Expand Down Expand Up @@ -628,6 +630,63 @@ IMethodSymbol methodSymbol
}
}

static string GenerateTypeParameterExpression(
StringBuilder source,
IMethodSymbol methodSymbol,
HashSet<string> memberNames
)
{
// use Array.Empty if method has no parameters.
if (methodSymbol.Parameters.Length == 0)
return "global::System.Array.Empty<global::System.Type>()";

// if one of the parameters is/contains a type parameter then it cannot be cached as it will change type between calls.
if (methodSymbol.Parameters.Any(x => ContainsTypeParameter(x.Type)))
{
var typeEnumerable = methodSymbol.Parameters.Select(
param =>
$"typeof({param.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
);
return $"new global::System.Type[] {{ {string.Join(", ", typeEnumerable)} }}";
}

// find a name and generate field declaration.
var typeParameterFieldName = UniqueName(TypeParameterVariableName, memberNames);
var types = string.Join(
", ",
methodSymbol.Parameters.Select(
x =>
$"typeof({x.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)})"
)
);
source.Append(
$$"""


private static readonly global::System.Type[] {{typeParameterFieldName}} = new global::System.Type[] {{{types}} };
"""
);

return typeParameterFieldName;

static bool ContainsTypeParameter(ITypeSymbol symbol)
{
if (symbol is ITypeParameterSymbol)
return true;

if (symbol is not INamedTypeSymbol { TypeParameters.Length: > 0 } namedType)
return false;

foreach (var typeArg in namedType.TypeArguments)
{
if (ContainsTypeParameter(typeArg))
return true;
}

return false;
}
}

void WriteMethodOpening(
StringBuilder source,
IMethodSymbol methodSymbol,
Expand Down Expand Up @@ -680,6 +739,20 @@ IMethodSymbol methodSymbol

void WriteMethodClosing(StringBuilder source) => source.Append(@" }");

static string UniqueName(string name, HashSet<string> methodNames)
{
var candidateName = name;
var counter = 0;
while (methodNames.Contains(candidateName))
{
candidateName = $"{name}{counter}";
counter++;
}

methodNames.Add(candidateName);
return candidateName;
}

bool IsRefitMethod(IMethodSymbol? methodSymbol, INamedTypeSymbol httpMethodAttibute)
{
return methodSymbol
Expand Down
Loading
Loading