@@ -5,13 +5,12 @@ namespace CodeGeneration.Roslyn
5
5
{
6
6
using System ;
7
7
using System . Collections . Generic ;
8
- using System . IO ;
8
+ using System . Collections . Immutable ;
9
9
using System . Linq ;
10
10
using System . Reflection ;
11
11
using System . Text ;
12
12
using System . Threading ;
13
13
using System . Threading . Tasks ;
14
- using Microsoft ;
15
14
using Microsoft . CodeAnalysis ;
16
15
using Microsoft . CodeAnalysis . CSharp ;
17
16
using Microsoft . CodeAnalysis . CSharp . Syntax ;
@@ -40,7 +39,7 @@ public static class DocumentTransform
40
39
/// <param name="assemblyLoader">A function that can load an assembly with the given name.</param>
41
40
/// <param name="progress">Reports warnings and errors in code generation.</param>
42
41
/// <returns>A task whose result is the generated document.</returns>
43
- public static async Task < SyntaxTree > TransformAsync ( CSharpCompilation compilation , SyntaxTree inputDocument , Func < AssemblyName , Assembly > assemblyLoader , IProgress < Diagnostic > progress )
42
+ public static async Task < SyntaxTree > TransformAsync ( CSharpCompilation compilation , SyntaxTree inputDocument , string projectDirectory , Func < AssemblyName , Assembly > assemblyLoader , IProgress < Diagnostic > progress )
44
43
{
45
44
Requires . NotNull ( compilation , nameof ( compilation ) ) ;
46
45
Requires . NotNull ( inputDocument , nameof ( inputDocument ) ) ;
@@ -51,16 +50,16 @@ public static async Task<SyntaxTree> TransformAsync(CSharpCompilation compilatio
51
50
52
51
var inputFileLevelUsingDirectives = inputSyntaxTree . GetRoot ( ) . ChildNodes ( ) . OfType < UsingDirectiveSyntax > ( ) ;
53
52
54
- var memberNodes = from syntax in inputSyntaxTree . GetRoot ( ) . DescendantNodes ( n => n is CompilationUnitSyntax || n is NamespaceDeclarationSyntax || n is TypeDeclarationSyntax ) . OfType < MemberDeclarationSyntax > ( )
55
- select syntax ;
53
+ var memberNodes = inputSyntaxTree . GetRoot ( ) . DescendantNodesAndSelf ( n => n is CompilationUnitSyntax || n is NamespaceDeclarationSyntax || n is TypeDeclarationSyntax ) . OfType < CSharpSyntaxNode > ( ) ;
56
54
57
55
var emittedMembers = SyntaxFactory . List < MemberDeclarationSyntax > ( ) ;
58
56
foreach ( var memberNode in memberNodes )
59
57
{
60
- var generators = FindCodeGenerators ( inputSemanticModel , memberNode , assemblyLoader ) ;
58
+ var attributeData = GetAttributeData ( compilation , inputSemanticModel , memberNode ) ;
59
+ var generators = FindCodeGenerators ( attributeData , assemblyLoader ) ;
61
60
foreach ( var generator in generators )
62
61
{
63
- var context = new TransformationContext ( memberNode , inputSemanticModel , compilation ) ;
62
+ var context = new TransformationContext ( memberNode , inputSemanticModel , compilation , projectDirectory ) ;
64
63
var generatedTypes = await generator . GenerateAsync ( context , progress , CancellationToken . None ) ;
65
64
66
65
// Figure out ancestry for the generated type, including nesting types and namespaces.
@@ -112,22 +111,29 @@ public static async Task<SyntaxTree> TransformAsync(CSharpCompilation compilatio
112
111
return compilationUnit . SyntaxTree ;
113
112
}
114
113
115
- private static IEnumerable < ICodeGenerator > FindCodeGenerators ( SemanticModel document , SyntaxNode nodeWithAttributesApplied , Func < AssemblyName , Assembly > assemblyLoader )
114
+ private static ImmutableArray < AttributeData > GetAttributeData ( Compilation compilation , SemanticModel document , SyntaxNode syntaxNode )
116
115
{
117
- Requires . NotNull ( document , " document" ) ;
118
- Requires . NotNull ( nodeWithAttributesApplied , "nodeWithAttributesApplied" ) ;
116
+ Requires . NotNull ( document , nameof ( document ) ) ;
117
+ Requires . NotNull ( syntaxNode , nameof ( syntaxNode ) ) ;
119
118
120
- var symbol = document . GetDeclaredSymbol ( nodeWithAttributesApplied ) ;
121
- if ( symbol != null )
119
+ switch ( syntaxNode )
122
120
{
123
- foreach ( var attributeData in symbol . GetAttributes ( ) )
121
+ case CompilationUnitSyntax syntax :
122
+ return compilation . Assembly . GetAttributes ( ) . Where ( x => x . ApplicationSyntaxReference . SyntaxTree == syntax . SyntaxTree ) . ToImmutableArray ( ) ;
123
+ default :
124
+ return document . GetDeclaredSymbol ( syntaxNode ) ? . GetAttributes ( ) ?? ImmutableArray < AttributeData > . Empty ;
125
+ }
126
+ }
127
+
128
+ private static IEnumerable < ICodeGenerator > FindCodeGenerators ( ImmutableArray < AttributeData > nodeAttributes , Func < AssemblyName , Assembly > assemblyLoader )
129
+ {
130
+ foreach ( var attributeData in nodeAttributes )
131
+ {
132
+ Type generatorType = GetCodeGeneratorTypeForAttribute ( attributeData . AttributeClass , assemblyLoader ) ;
133
+ if ( generatorType != null )
124
134
{
125
- Type generatorType = GetCodeGeneratorTypeForAttribute ( attributeData . AttributeClass , assemblyLoader ) ;
126
- if ( generatorType != null )
127
- {
128
- ICodeGenerator generator = ( ICodeGenerator ) Activator . CreateInstance ( generatorType , attributeData ) ;
129
- yield return generator ;
130
- }
135
+ ICodeGenerator generator = ( ICodeGenerator ) Activator . CreateInstance ( generatorType , attributeData ) ;
136
+ yield return generator ;
131
137
}
132
138
}
133
139
}
0 commit comments