diff --git a/Rubberduck.Parsing/Symbols/DeclarationFinder.cs b/Rubberduck.Parsing/Symbols/DeclarationFinder.cs index 4a1aceda89..13c4ecc000 100644 --- a/Rubberduck.Parsing/Symbols/DeclarationFinder.cs +++ b/Rubberduck.Parsing/Symbols/DeclarationFinder.cs @@ -52,8 +52,11 @@ public class DeclarationFinder private readonly Lazy> _handlersByWithEventsField; private readonly Lazy> _membersByImplementsContext; private readonly Lazy> _interfaceMembers; + private Lazy> _nonBaseAsType; + private readonly Lazy> _eventHandlers; + private readonly Lazy> _classes; - private static readonly object ThreadLock = new object(); + private readonly object threadLock = new object(); public DeclarationFinder(IReadOnlyList declarations, IEnumerable annotations, IReadOnlyList unresolvedMemberDeclarations, IHostApplication hostApp = null) { @@ -133,8 +136,16 @@ public DeclarationFinder(IReadOnlyList declarations, IEnumerable>(() => - new ConcurrentDictionary( - implementableMembers.ToDictionary(item => item.Context, item => item.Members)), true); + new ConcurrentDictionary( + implementableMembers.ToDictionary(item => item.Context, item => item.Members)), true); + + _nonBaseAsType = new Lazy>(() => + _declarations.AllValues().Where(d => + !string.IsNullOrWhiteSpace(d.AsTypeName) + && !d.AsTypeIsBaseType + && d.DeclarationType != DeclarationType.Project + && d.DeclarationType != DeclarationType.ProceduralModule).ToList() + ,true); } public IEnumerable FreshUndeclared @@ -152,38 +163,22 @@ public IEnumerable Members(QualifiedModuleName module) return _declarations[module]; } - private IEnumerable _nonBaseAsType; public IEnumerable FindDeclarationsWithNonBaseAsType() { - lock (ThreadLock) - { - return _nonBaseAsType ?? (_nonBaseAsType = _declarations.AllValues().Where(d => - !string.IsNullOrWhiteSpace(d.AsTypeName) - && !d.AsTypeIsBaseType - && d.DeclarationType != DeclarationType.Project - && d.DeclarationType != DeclarationType.ProceduralModule).ToList()); - } - } + return _nonBaseAsType.Value; - private readonly Lazy> _eventHandlers; + } + public IEnumerable FindEventHandlers() { - lock (ThreadLock) - { - return _eventHandlers.Value; - } + return _eventHandlers.Value; } - private readonly Lazy> _classes; - public IEnumerable Classes { get { - lock (ThreadLock) - { - return _classes.Value; - } + return _classes.Value; } } @@ -193,10 +188,7 @@ public IEnumerable Projects { get { - lock (ThreadLock) - { - return _projects.Value; - } + return _projects.Value; } } @@ -214,10 +206,7 @@ public IEnumerable UserDeclarations(DeclarationType type) public IEnumerable FreshUnresolvedMemberDeclarations() { - lock (ThreadLock) - { - return _newUnresolved.ToArray(); - } + return _newUnresolved.ToArray(); //This does not need a lock because enumerators over a ConcurrentBag uses a snapshot. } public IEnumerable UnresolvedMemberDeclarations() @@ -262,19 +251,18 @@ public Declaration FindParameter(Declaration procedure, string parameterName) public IEnumerable FindMemberMatches(Declaration parent, string memberName) { ConcurrentBag children; - if (_declarations.TryGetValue(parent.QualifiedName.QualifiedModuleName, out children)) - { - return children.Where(item => item.DeclarationType.HasFlag(DeclarationType.Member) - && item.IdentifierName == memberName).ToList(); - } - - return Enumerable.Empty(); + return _declarations.TryGetValue(parent.QualifiedName.QualifiedModuleName, out children) + ? children.Where(item => item.DeclarationType.HasFlag(DeclarationType.Member) + && item.IdentifierName == memberName).ToList() + : Enumerable.Empty(); } public IEnumerable FindAnnotations(QualifiedModuleName module) { ConcurrentBag result; - return _annotations.TryGetValue(module, out result) ? result : Enumerable.Empty(); + return _annotations.TryGetValue(module, out result) + ? result + : Enumerable.Empty(); } public bool IsMatch(string declarationName, string potentialMatchName) @@ -329,7 +317,8 @@ public Declaration FindProject(string name, Declaration currentScope = null) Declaration result = null; try { - result = MatchName(name).SingleOrDefault(project => project.DeclarationType.HasFlag(DeclarationType.Project) + result = MatchName(name).SingleOrDefault(project => + project.DeclarationType.HasFlag(DeclarationType.Project) && (currentScope == null || project.ProjectId == currentScope.ProjectId)); } catch (InvalidOperationException exception) @@ -340,7 +329,7 @@ public Declaration FindProject(string name, Declaration currentScope = null) return result; } - public Declaration FindStdModule(string name, Declaration parent = null, bool includeBuiltIn = false) + public Declaration FindStdModule(string name, Declaration parent, bool includeBuiltIn = false) { Debug.Assert(parent != null); Declaration result = null; @@ -359,7 +348,7 @@ public Declaration FindStdModule(string name, Declaration parent = null, bool in return result; } - public Declaration FindClassModule(string name, Declaration parent = null, bool includeBuiltIn = false) + public Declaration FindClassModule(string name, Declaration parent, bool includeBuiltIn = false) { Debug.Assert(parent != null); Declaration result = null; diff --git a/Rubberduck.Parsing/VBA/RubberduckParserState.cs b/Rubberduck.Parsing/VBA/RubberduckParserState.cs index 1122dfa946..cdab989ce4 100644 --- a/Rubberduck.Parsing/VBA/RubberduckParserState.cs +++ b/Rubberduck.Parsing/VBA/RubberduckParserState.cs @@ -1207,12 +1207,13 @@ public HashSet ModulesReferencedBy(QualifiedModuleName refe public HashSet ModulesReferencedBy(IEnumerable referencingModules) { - var referencedModules = new HashSet(); + var toModules = new HashSet(); + foreach (var referencingModule in referencingModules) { - referencedModules.UnionWith(ModulesReferencedBy(referencingModule)); + toModules.UnionWith(ModulesReferencedBy(referencingModule)); } - return referencedModules; + return toModules; } public HashSet ModulesReferencing(QualifiedModuleName referencedModule) diff --git a/RubberduckTests/RubberduckTests.csproj b/RubberduckTests/RubberduckTests.csproj index 8f6c1a0bf7..5f8cda7ae3 100644 --- a/RubberduckTests/RubberduckTests.csproj +++ b/RubberduckTests/RubberduckTests.csproj @@ -203,6 +203,7 @@ + diff --git a/RubberduckTests/Symbols/DeclarationFinderTests.cs b/RubberduckTests/Symbols/DeclarationFinderTests.cs new file mode 100644 index 0000000000..3e58770e48 --- /dev/null +++ b/RubberduckTests/Symbols/DeclarationFinderTests.cs @@ -0,0 +1,71 @@ +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Rubberduck.Parsing.Symbols; +using Rubberduck.Parsing.VBA; +using Rubberduck.Parsing.Grammar; +using Rubberduck.VBEditor; +using System.Collections.Generic; +using System.Linq; +using Antlr4.Runtime; + +namespace RubberduckTests.Symbols +{ + [TestClass] + public class DeclarationFinderTests + { + [TestCategory("Resolver")] + [TestMethod] + public void DeclarationFinderCanCopeWithMultipleModulesImplementingTheSameInterface() + { + var project = GetTestProject("testProject"); + var interf = GetTestClassModule(project, "interface"); + var member = GetTestFunction(interf, "testMember", Accessibility.Public); + var implementingClass1 = GetTestClassModule(project, "implementingClass1"); + var implementingClass2 = GetTestClassModule(project, "implementingClass2"); + var implementsContext1 = new VBAParser.ImplementsStmtContext(null, 0); + var implementsContext2 = new VBAParser.ImplementsStmtContext(null, 0); + AddReference(interf, implementingClass1, implementsContext1); + AddReference(interf, implementingClass1, implementsContext2); + var declarations = new List {interf, member, implementingClass1, implementingClass2}; + + DeclarationFinder finder = new DeclarationFinder(declarations, new List(), new List()); + var interfaceDeclarations = finder.FindAllInterfaceMembers().ToList(); + + Assert.AreEqual(1, interfaceDeclarations.Count()); + } + + private static ClassModuleDeclaration GetTestClassModule(Declaration projectDeclatation, string name, bool isExposed = false) + { + var qualifiedClassModuleMemberName = new QualifiedMemberName(StubQualifiedModuleName(name), name); + var classModuleAttributes = new Rubberduck.Parsing.VBA.Attributes(); + if (isExposed) + { + classModuleAttributes.AddExposedClassAttribute(); + } + return new ClassModuleDeclaration(qualifiedClassModuleMemberName, projectDeclatation, name, false, null, classModuleAttributes); + } + + private static ProjectDeclaration GetTestProject(string name) + { + var qualifiedProjectName = new QualifiedMemberName(StubQualifiedModuleName("proj"), name); + return new ProjectDeclaration(qualifiedProjectName, name, false, null); + } + + private static QualifiedModuleName StubQualifiedModuleName(string name) + { + return new QualifiedModuleName("dummy", "dummy", name); + } + + private static FunctionDeclaration GetTestFunction(Declaration moduleDeclatation, string name, Accessibility functionAccessibility) + { + var qualifiedFunctionMemberName = new QualifiedMemberName(moduleDeclatation.QualifiedName.QualifiedModuleName, name); + return new FunctionDeclaration(qualifiedFunctionMemberName, moduleDeclatation, moduleDeclatation, "test", null, "test", functionAccessibility, null, Selection.Home, false, false, null, null); + } + + private static void AddReference(Declaration toDeclaration, Declaration fromModuleDeclaration, ParserRuleContext context = null) + { + toDeclaration.AddReference(toDeclaration.QualifiedName.QualifiedModuleName, fromModuleDeclaration, fromModuleDeclaration, context, toDeclaration.IdentifierName, toDeclaration, Selection.Home, new List()); + } + + } +}