Skip to content

Commit

Permalink
Merge pull request #3708 from WaynePhillipsEA/rewite-comwrapperenumer…
Browse files Browse the repository at this point in the history
…ation

Rewite of ComWrapperEnumeration, to take control of when IEnumVariant RCW's get released.
  • Loading branch information
retailcoder committed Jan 17, 2018
2 parents 88c6e7c + 4043e68 commit 6c27601
Show file tree
Hide file tree
Showing 14 changed files with 129 additions and 51 deletions.
128 changes: 116 additions & 12 deletions Rubberduck.VBEEditor/SafeComWrappers/ComWrapperEnumerator.cs
@@ -1,39 +1,143 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using ComTypes = System.Runtime.InteropServices.ComTypes;

namespace Rubberduck.VBEditor.SafeComWrappers
{
// The CLR automagically handles COM enumeration by custom marshalling IEnumVARIANT to a managed class (see EnumeratorToEnumVariantMarshaler)
// But as we need explicit control over all COM objects in RD, this is unacceptable. We need to obtain the explicit IEnumVARIANT interface
// and ensure this RCW is destroyed in a timely fashion, using Marshal.ReleaseComObject.
// The automatic custom marshalling of the enumeration getter method (DISPID_ENUM) prohibits access to the underlying IEnumVARIANT interface.
// To work around it, we must call the IDispatch:::Invoke method directly (instead of using CLRs normal late-bound method calling ability).

[ComImport(), Guid("00020400-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
interface IDispatch
{
[PreserveSig] int GetTypeInfoCount([Out] out uint pctinfo);
[PreserveSig] int GetTypeInfo([In] uint iTInfo, [In] uint lcid, [Out] out ComTypes.ITypeInfo pTypeInfo);
[PreserveSig] int GetIDsOfNames([In] ref Guid riid, [In] string[] rgszNames, [In] uint cNames, [In] uint lcid, [Out] out int[] rgDispId);

[PreserveSig]
int Invoke([In] int dispIdMember,
[In] ref Guid riid,
[In] uint lcid,
[In] uint dwFlags,
[In, Out] ref ComTypes.DISPPARAMS pDispParams,
[Out] out Object pVarResult,
[In, Out] ref ComTypes.EXCEPINFO pExcepInfo,
[Out] out uint pArgErr);
}

class IDispatchHelper
{
public enum StandardDispIds : int
{
DISPID_ENUM = -4
}

public enum InvokeKind : int
{
DISPATCH_METHOD = 1,
DISPATCH_PROPERTYGET = 2,
DISPATCH_PROPERTYPUT = 4,
DISPATCH_PROPERTYPUTREF = 8,
}

public static object PropertyGet_NoArgs(IDispatch obj, int memberId)
{
var pDispParams = new ComTypes.DISPPARAMS();
object pVarResult;
var pExcepInfo = new ComTypes.EXCEPINFO();
uint ErrArg;
Guid guid = new Guid();

int hr = obj.Invoke(memberId, ref guid, 0, (uint)(InvokeKind.DISPATCH_METHOD | InvokeKind.DISPATCH_PROPERTYGET),
ref pDispParams, out pVarResult, ref pExcepInfo, out ErrArg);

if (hr < 0)
{
// could expand this to better handle DISP_E_EXCEPTION
throw Marshal.GetExceptionForHR(hr);
}

return pVarResult;
}
}

[ComImport(), Guid("00020404-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
public interface IEnumVARIANT
{
// rgVar is technically an unmanaged array here, but we only ever call with celt=1, so this is compatible.
[PreserveSig] int Next([In] uint celt, [Out] out object rgVar, [Out] out uint pceltFetched);

[PreserveSig] int Skip([In] uint celt);
[PreserveSig] int Reset();
[PreserveSig] int Clone([Out] out IEnumVARIANT retval);
}

public class ComWrapperEnumerator<TWrapperItem> : IEnumerator<TWrapperItem>
where TWrapperItem : class
{
private readonly Func<object, TWrapperItem> _itemWrapper;
private readonly IEnumerator _internal;
private readonly IEnumVARIANT _enumeratorRCW;
private TWrapperItem _currentItem;

public ComWrapperEnumerator(IEnumerable source, Func<object, TWrapperItem> itemWrapper)
public ComWrapperEnumerator(object source, Func<object, TWrapperItem> itemWrapper)
{
_itemWrapper = itemWrapper;
_internal = source?.GetEnumerator() ?? Enumerable.Empty<TWrapperItem>().GetEnumerator();

if (source != null)
{
_enumeratorRCW = (IEnumVARIANT)IDispatchHelper.PropertyGet_NoArgs((IDispatch)source, (int)IDispatchHelper.StandardDispIds.DISPID_ENUM);
((IEnumerator)this).Reset(); // precaution
}
}

public void Dispose()
{
// nothing to dispose here
if (!IsWrappingNullReference) Marshal.ReleaseComObject(_enumeratorRCW);
}

public bool MoveNext()
void IEnumerator.Reset()
{
return _internal.MoveNext();
if (!IsWrappingNullReference)
{
int hr = _enumeratorRCW.Reset();
if (hr < 0)
{
throw Marshal.GetExceptionForHR(hr);
}
}
}

public bool IsWrappingNullReference => _enumeratorRCW == null;

public TWrapperItem Current => _currentItem;
object IEnumerator.Current => _currentItem;

public void Reset()
bool IEnumerator.MoveNext()
{
_internal.Reset();
}
if (IsWrappingNullReference) return false;

_currentItem = null;

public TWrapperItem Current => _itemWrapper.Invoke(_internal.Current);
object currentItemRCW;
uint celtFetched;
int hr = _enumeratorRCW.Next(1, out currentItemRCW, out celtFetched);
// hr == S_FALSE (1) or S_OK (0), or <0 means error

object IEnumerator.Current => Current;
_currentItem = _itemWrapper.Invoke(currentItemRCW); // creates a null wrapped reference even on end/error, just as a precaution

if (hr < 0)
{
throw Marshal.GetExceptionForHR(hr);
}

return (celtFetched == 1); // celtFetched will be 0 when we reach the end of the collection
}
}
}
Expand Up @@ -31,9 +31,7 @@ public ICommandBarControl Add(ControlType type, int before)

IEnumerator<ICommandBarControl> IEnumerable<ICommandBarControl>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<ICommandBarControl>(null, o => new CommandBarControl(null))
: new ComWrapperEnumerator<ICommandBarControl>(Target,
return new ComWrapperEnumerator<ICommandBarControl>(Target,
comObject => new CommandBarControl((Microsoft.Office.Core.CommandBarControl) comObject));
}

Expand Down
Expand Up @@ -57,9 +57,7 @@ public ICommandBarControl FindControl(ControlType type, int id)

IEnumerator<ICommandBar> IEnumerable<ICommandBar>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<ICommandBar>(null, o => new CommandBar(null))
: new ComWrapperEnumerator<ICommandBar>(Target,
return new ComWrapperEnumerator<ICommandBar>(Target,
comObject => new CommandBar((Microsoft.Office.Core.CommandBar) comObject));
}

Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VB6/VBComponents.cs
Expand Up @@ -69,9 +69,7 @@ public IVBComponent AddMTDesigner(int index = 0)

IEnumerator<IVBComponent> IEnumerable<IVBComponent>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IVBComponent>(null, o => new VBComponent(null))
: new ComWrapperEnumerator<IVBComponent>(Target, comObject => new VBComponent((VB.VBComponent)comObject));
return new ComWrapperEnumerator<IVBComponent>(Target, comObject => new VBComponent((VB.VBComponent)comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VB6/VBProjects.cs
Expand Up @@ -54,9 +54,7 @@ public IVBProject Open(string path)

IEnumerator<IVBProject> IEnumerable<IVBProject>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IVBProject>(null, o => new VBProject(null))
: new ComWrapperEnumerator<IVBProject>(Target, comObject => new VBProject((VB.VBProject)comObject));
return new ComWrapperEnumerator<IVBProject>(Target, comObject => new VBProject((VB.VBProject)comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/AddIns.cs
Expand Up @@ -48,9 +48,7 @@ IEnumerator IEnumerable.GetEnumerator()

IEnumerator<IAddIn> IEnumerable<IAddIn>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IAddIn>(null, o => new AddIn(null))
: new ComWrapperEnumerator<IAddIn>(Target, comObject => new AddIn((VB.AddIn) comObject));
return new ComWrapperEnumerator<IAddIn>(Target, comObject => new AddIn((VB.AddIn) comObject));
}
}
}
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/CodePanes.cs
Expand Up @@ -28,9 +28,7 @@ public ICodePane Current

IEnumerator<ICodePane> IEnumerable<ICodePane>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<ICodePane>(null, o => new CodePane(null))
: new ComWrapperEnumerator<ICodePane>(Target, comObject => new CodePane((VB.CodePane) comObject));
return new ComWrapperEnumerator<ICodePane>(Target, comObject => new CodePane((VB.CodePane) comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/Controls.cs
Expand Up @@ -20,9 +20,7 @@ public Controls(VB.Forms.Controls target, bool rewrapping = false)
IEnumerator<IControl> IEnumerable<IControl>.GetEnumerator()
{
// soft-casting because ImageClass doesn't implement IControl
return IsWrappingNullReference
? new ComWrapperEnumerator<IControl>(null, o => new Control(null))
: new ComWrapperEnumerator<IControl>(Target, comObject => new Control(comObject as VB.Forms.Control));
return new ComWrapperEnumerator<IControl>(Target, comObject => new Control(comObject as VB.Forms.Control));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/LinkedWindows.cs
Expand Up @@ -47,9 +47,7 @@ IEnumerator IEnumerable.GetEnumerator()

IEnumerator<IWindow> IEnumerable<IWindow>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IWindow>(null, o => new Window(null))
: new ComWrapperEnumerator<IWindow>(Target, comObject => new Window((VB.Window) comObject));
return new ComWrapperEnumerator<IWindow>(Target, comObject => new Window((VB.Window) comObject));
}

public override bool Equals(ISafeComWrapper<VB.LinkedWindows> other)
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/Properties.cs
Expand Up @@ -24,9 +24,7 @@ public Properties(VB.Properties target, bool rewrapping = false)

IEnumerator<IProperty> IEnumerable<IProperty>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IProperty>(null, o => new Property(null))
: new ComWrapperEnumerator<IProperty>(Target, comObject => new Property((VB.Property) comObject));
return new ComWrapperEnumerator<IProperty>(Target, comObject => new Property((VB.Property) comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/References.cs
Expand Up @@ -96,9 +96,7 @@ public void Remove(IReference reference)

IEnumerator<IReference> IEnumerable<IReference>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IReference>(null, o => new Reference(null))
: new ComWrapperEnumerator<IReference>(Target, comObject => new Reference((VB.Reference) comObject));
return new ComWrapperEnumerator<IReference>(Target, comObject => new Reference((VB.Reference) comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/VBComponents.cs
Expand Up @@ -72,9 +72,7 @@ public IVBComponent AddMTDesigner(int index = 0)

IEnumerator<IVBComponent> IEnumerable<IVBComponent>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IVBComponent>(null, o => new VBComponent(null))
: new ComWrapperEnumerator<IVBComponent>(Target, comObject => new VBComponent((VB.VBComponent) comObject));
return new ComWrapperEnumerator<IVBComponent>(Target, comObject => new VBComponent((VB.VBComponent) comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/VBProjects.cs
Expand Up @@ -60,9 +60,7 @@ public IVBProject Open(string path)

IEnumerator<IVBProject> IEnumerable<IVBProject>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IVBProject>(null, o => new VBProject(null))
: new ComWrapperEnumerator<IVBProject>(Target, comObject => new VBProject((VB.VBProject) comObject));
return new ComWrapperEnumerator<IVBProject>(Target, comObject => new VBProject((VB.VBProject) comObject));
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down
4 changes: 1 addition & 3 deletions Rubberduck.VBEEditor/SafeComWrappers/VBA/Windows.cs
Expand Up @@ -53,9 +53,7 @@ IEnumerator IEnumerable.GetEnumerator()

IEnumerator<IWindow> IEnumerable<IWindow>.GetEnumerator()
{
return IsWrappingNullReference
? new ComWrapperEnumerator<IWindow>(null, o => new Window(null))
: new ComWrapperEnumerator<IWindow>(Target, comObject => new Window((VB.Window) comObject));
return new ComWrapperEnumerator<IWindow>(Target, comObject => new Window((VB.Window) comObject));
}

public override bool Equals(ISafeComWrapper<VB.Windows> other)
Expand Down

0 comments on commit 6c27601

Please sign in to comment.