Skip to content

Commit

Permalink
Use COM smart pointers in WASAPI driver
Browse files Browse the repository at this point in the history
  • Loading branch information
lalitshankarchowdhury committed May 16, 2024
1 parent f92c4bc commit 028484c
Showing 1 changed file with 42 additions and 85 deletions.
127 changes: 42 additions & 85 deletions RtAudio.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#include <locale>

#if defined(_WIN32)
#include <wrl/client.h>
using Microsoft::WRL::ComPtr;
#include <windows.h>
#endif

Expand Down Expand Up @@ -273,7 +275,7 @@ class RtApiWasapi : public RtApi

private:
bool coInitialized_;
IMMDeviceEnumerator* deviceEnumerator_;
ComPtr<IMMDeviceEnumerator> deviceEnumerator_;
std::vector< std::pair< std::string, bool> > deviceIds_;

void probeDevices( void ) override;
Expand Down Expand Up @@ -4637,17 +4639,17 @@ class WasapiResampler
_mediaType->SetUINT32( MF_MT_ALL_SAMPLES_INDEPENDENT, TRUE );

MFCreateMediaType( &_inputMediaType );
_mediaType->CopyAllItems( _inputMediaType );
_mediaType->CopyAllItems( _inputMediaType.Get() );

_transform->SetInputType( 0, _inputMediaType, 0 );
_transform->SetInputType( 0, _inputMediaType.Get(), 0 );

MFCreateMediaType( &_outputMediaType );
_mediaType->CopyAllItems( _outputMediaType );
_mediaType->CopyAllItems( _outputMediaType.Get() );

_outputMediaType->SetUINT32( MF_MT_AUDIO_SAMPLES_PER_SECOND, outSampleRate );
_outputMediaType->SetUINT32( MF_MT_AUDIO_AVG_BYTES_PER_SECOND, _bytesPerSample * channelCount * outSampleRate );

_transform->SetOutputType( 0, _outputMediaType, 0 );
_transform->SetOutputType( 0, _outputMediaType.Get(), 0 );

// 4. Send stream start messages to Resampler

Expand All @@ -4666,16 +4668,6 @@ class WasapiResampler
// 9. Cleanup

MFShutdown();

SAFE_RELEASE( _transformUnk );
SAFE_RELEASE( _transform );
SAFE_RELEASE( _mediaType );
SAFE_RELEASE( _inputMediaType );
SAFE_RELEASE( _outputMediaType );

#ifdef __IWMResamplerProps_FWD_DEFINED__
SAFE_RELEASE( _resamplerProps );
#endif
}

void Convert( char* outBuffer, const char* inBuffer, unsigned int inSampleCount, unsigned int& outSampleCount, int maxOutSampleCount = -1 )
Expand All @@ -4699,8 +4691,8 @@ class WasapiResampler
outputBufferSize = ( unsigned int ) ceilf( inputBufferSize * _sampleRatio ) + ( _bytesPerSample * _channelCount );
}

IMFMediaBuffer* rInBuffer;
IMFSample* rInSample;
ComPtr<IMFMediaBuffer> rInBuffer;
ComPtr<IMFSample> rInSample;
BYTE* rInByteBuffer = NULL;

// 5. Create Sample object from input data
Expand All @@ -4715,18 +4707,15 @@ class WasapiResampler
rInBuffer->SetCurrentLength( inputBufferSize );

MFCreateSample( &rInSample );
rInSample->AddBuffer( rInBuffer );
rInSample->AddBuffer( rInBuffer.Get() );

// 6. Pass input data to Resampler

_transform->ProcessInput( 0, rInSample, 0 );

SAFE_RELEASE( rInBuffer );
SAFE_RELEASE( rInSample );
_transform->ProcessInput( 0, rInSample.Get(), 0 );

// 7. Perform sample rate conversion

IMFMediaBuffer* rOutBuffer = NULL;
ComPtr<IMFMediaBuffer> rOutBuffer = NULL;
BYTE* rOutByteBuffer = NULL;

MFT_OUTPUT_DATA_BUFFER rOutDataBuffer;
Expand All @@ -4738,7 +4727,7 @@ class WasapiResampler
memset( &rOutDataBuffer, 0, sizeof rOutDataBuffer );
MFCreateSample( &( rOutDataBuffer.pSample ) );
MFCreateMemoryBuffer( rBytes, &rOutBuffer );
rOutDataBuffer.pSample->AddBuffer( rOutBuffer );
rOutDataBuffer.pSample->AddBuffer( rOutBuffer.Get() );
rOutDataBuffer.dwStreamID = 0;
rOutDataBuffer.dwStatus = 0;
rOutDataBuffer.pEvents = NULL;
Expand All @@ -4748,14 +4737,12 @@ class WasapiResampler
if ( _transform->ProcessOutput( 0, 1, &rOutDataBuffer, &rStatus ) == MF_E_TRANSFORM_NEED_MORE_INPUT )
{
outSampleCount = 0;
SAFE_RELEASE( rOutBuffer );
SAFE_RELEASE( rOutDataBuffer.pSample );
return;
}

// 7.3 Write output data to outBuffer

SAFE_RELEASE( rOutBuffer );
rOutDataBuffer.pSample->ConvertToContiguousBuffer( &rOutBuffer );
rOutBuffer->GetCurrentLength( &rBytes );

Expand All @@ -4765,7 +4752,6 @@ class WasapiResampler
rOutByteBuffer = NULL;

outSampleCount = rBytes / _bytesPerSample / _channelCount;
SAFE_RELEASE( rOutBuffer );
SAFE_RELEASE( rOutDataBuffer.pSample );
}

Expand All @@ -4774,14 +4760,14 @@ class WasapiResampler
unsigned int _channelCount;
float _sampleRatio;

IUnknown* _transformUnk;
IMFTransform* _transform;
IMFMediaType* _mediaType;
IMFMediaType* _inputMediaType;
IMFMediaType* _outputMediaType;
ComPtr<IUnknown> _transformUnk;
ComPtr<IMFTransform> _transform;
ComPtr<IMFMediaType> _mediaType;
ComPtr<IMFMediaType> _inputMediaType;
ComPtr<IMFMediaType> _outputMediaType;

#ifdef __IWMResamplerProps_FWD_DEFINED__
IWMResamplerProps* _resamplerProps;
ComPtr<IWMResamplerProps> _resamplerProps;
#endif
};

Expand All @@ -4790,10 +4776,10 @@ class WasapiResampler
// A structure to hold various information related to the WASAPI implementation.
struct WasapiHandle
{
IAudioClient* captureAudioClient;
IAudioClient* renderAudioClient;
IAudioCaptureClient* captureClient;
IAudioRenderClient* renderClient;
ComPtr<IAudioClient> captureAudioClient;
ComPtr<IAudioClient> renderAudioClient;
ComPtr<IAudioCaptureClient> captureClient;
ComPtr<IAudioRenderClient> renderClient;
HANDLE captureEvent;
HANDLE renderEvent;

Expand All @@ -4820,10 +4806,6 @@ RtApiWasapi::RtApiWasapi()
hr = CoCreateInstance( __uuidof( MMDeviceEnumerator ), NULL,
CLSCTX_ALL, __uuidof( IMMDeviceEnumerator ),
( void** ) &deviceEnumerator_ );

// If this runs on an old Windows, it will fail. Ignore and proceed.
if ( FAILED( hr ) )
deviceEnumerator_ = NULL;
}

//-----------------------------------------------------------------------------
Expand All @@ -4838,8 +4820,6 @@ RtApiWasapi::~RtApiWasapi()
MUTEX_LOCK( &stream_.mutex );
}

SAFE_RELEASE( deviceEnumerator_ );

// If this object previously called CoInitialize()
if ( coInitialized_ )
CoUninitialize();
Expand All @@ -4850,7 +4830,7 @@ RtApiWasapi::~RtApiWasapi()

unsigned int RtApiWasapi::getDefaultInputDevice( void )
{
IMMDevice* devicePtr = NULL;
ComPtr<IMMDevice> devicePtr = NULL;
LPWSTR defaultId = NULL;
std::string id;

Expand All @@ -4872,7 +4852,6 @@ unsigned int RtApiWasapi::getDefaultInputDevice( void )
id = convertCharPointerToStdString( defaultId );

Release:
SAFE_RELEASE( devicePtr );
CoTaskMemFree( defaultId );

if ( !errorText_.empty() ) {
Expand Down Expand Up @@ -4907,7 +4886,7 @@ unsigned int RtApiWasapi::getDefaultInputDevice( void )

unsigned int RtApiWasapi::getDefaultOutputDevice( void )
{
IMMDevice* devicePtr = NULL;
ComPtr<IMMDevice> devicePtr = NULL;
LPWSTR defaultId = NULL;
std::string id;

Expand All @@ -4929,7 +4908,6 @@ unsigned int RtApiWasapi::getDefaultOutputDevice( void )
id = convertCharPointerToStdString( defaultId );

Release:
SAFE_RELEASE( devicePtr );
CoTaskMemFree( defaultId );

if ( !errorText_.empty() ) {
Expand Down Expand Up @@ -4967,9 +4945,9 @@ void RtApiWasapi::probeDevices( void )
unsigned int captureDeviceCount = 0;
unsigned int renderDeviceCount = 0;

IMMDeviceCollection* captureDevices = NULL;
IMMDeviceCollection* renderDevices = NULL;
IMMDevice* devicePtr = NULL;
ComPtr<IMMDeviceCollection> captureDevices = NULL;
ComPtr<IMMDeviceCollection> renderDevices = NULL;
ComPtr<IMMDevice> devicePtr = NULL;

LPWSTR defaultCaptureId = NULL;
LPWSTR defaultRenderId = NULL;
Expand Down Expand Up @@ -5028,7 +5006,6 @@ void RtApiWasapi::probeDevices( void )
}

// Get the default render device Id.
SAFE_RELEASE( devicePtr );
hr = deviceEnumerator_->GetDefaultAudioEndpoint( eRender, eConsole, &devicePtr );
if ( SUCCEEDED( hr) ) {
hr = devicePtr->GetId( &defaultRenderId );
Expand All @@ -5041,7 +5018,6 @@ void RtApiWasapi::probeDevices( void )

// Collect device IDs with mode.
for ( unsigned int n=0; n<nDevices; n++ ) {
SAFE_RELEASE( devicePtr );
if ( n < renderDeviceCount ) {
hr = renderDevices->Item( n, &devicePtr );
if ( FAILED( hr ) ) {
Expand Down Expand Up @@ -5114,11 +5090,6 @@ void RtApiWasapi::probeDevices( void )
}

Exit:
// Release all references
SAFE_RELEASE( captureDevices );
SAFE_RELEASE( renderDevices );
SAFE_RELEASE( devicePtr );

CoTaskMemFree( defaultCaptureId );
CoTaskMemFree( defaultRenderId );

Expand All @@ -5135,9 +5106,9 @@ void RtApiWasapi::probeDevices( void )
bool RtApiWasapi::probeDeviceInfo( RtAudio::DeviceInfo &info, LPWSTR deviceId, bool isCaptureDevice )
{
PROPVARIANT deviceNameProp;
IMMDevice* devicePtr = NULL;
IAudioClient* audioClient = NULL;
IPropertyStore* devicePropStore = NULL;
ComPtr<IMMDevice> devicePtr = NULL;
ComPtr<IAudioClient> audioClient = NULL;
ComPtr<IPropertyStore> devicePropStore = NULL;

WAVEFORMATEX* deviceFormat = NULL;
WAVEFORMATEX* closestMatchFormat = NULL;
Expand Down Expand Up @@ -5239,10 +5210,6 @@ bool RtApiWasapi::probeDeviceInfo( RtAudio::DeviceInfo &info, LPWSTR deviceId, b
// Release all references
PropVariantClear( &deviceNameProp );

SAFE_RELEASE( devicePtr );
SAFE_RELEASE( audioClient );
SAFE_RELEASE( devicePropStore );

CoTaskMemFree( deviceFormat );
CoTaskMemFree( closestMatchFormat );

Expand Down Expand Up @@ -5270,13 +5237,6 @@ void RtApiWasapi::closeStream( void )
MUTEX_LOCK( &stream_.mutex );
}

// clean up stream memory
SAFE_RELEASE(((WasapiHandle*)stream_.apiHandle)->captureClient)
SAFE_RELEASE(((WasapiHandle*)stream_.apiHandle)->renderClient)

SAFE_RELEASE( ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient )
SAFE_RELEASE( ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient )

if ( ( ( WasapiHandle* ) stream_.apiHandle )->captureEvent )
CloseHandle( ( ( WasapiHandle* ) stream_.apiHandle )->captureEvent );

Expand Down Expand Up @@ -5412,7 +5372,7 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int deviceId, StreamMode mode, unsig
{
MUTEX_LOCK( &stream_.mutex );
bool methodResult = FAILURE;
IMMDevice* devicePtr = NULL;
ComPtr<IMMDevice> devicePtr = NULL;
WAVEFORMATEX* deviceFormat = NULL;
unsigned int bufferBytes;
stream_.state = STREAM_STOPPED;
Expand Down Expand Up @@ -5457,7 +5417,7 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int deviceId, StreamMode mode, unsig
stream_.apiHandle = ( void* ) new WasapiHandle();

if ( isInput ) {
IAudioClient*& captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;
ComPtr<IAudioClient> captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;

hr = devicePtr->Activate( __uuidof( IAudioClient ), CLSCTX_ALL,
NULL, ( void** ) &captureAudioClient );
Expand All @@ -5479,15 +5439,15 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int deviceId, StreamMode mode, unsig
// If an output device and is configured for loopback (input mode)
if ( isInput == false && mode == INPUT ) {
// If renderAudioClient is not initialised, initialise it now
IAudioClient*& renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
ComPtr<IAudioClient> renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
if ( !renderAudioClient ) {
MUTEX_UNLOCK( &stream_.mutex );
probeDeviceOpen( deviceId, OUTPUT, channels, firstChannel, sampleRate, format, bufferSize, options );
MUTEX_LOCK( &stream_.mutex );
}

// Retrieve captureAudioClient from our stream handle.
IAudioClient*& captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;
ComPtr<IAudioClient> captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;

hr = devicePtr->Activate( __uuidof( IAudioClient ), CLSCTX_ALL,
NULL, ( void** ) &captureAudioClient );
Expand All @@ -5509,7 +5469,7 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int deviceId, StreamMode mode, unsig
// If output device and is configured for output.
if ( isInput == false && mode == OUTPUT ) {
// If renderAudioClient is already initialised, don't initialise it again
IAudioClient*& renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
ComPtr<IAudioClient> renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
if ( renderAudioClient ) {
methodResult = SUCCESS;
goto Exit;
Expand Down Expand Up @@ -5592,7 +5552,6 @@ bool RtApiWasapi::probeDeviceOpen( unsigned int deviceId, StreamMode mode, unsig

Exit:
//clean up
SAFE_RELEASE( devicePtr );
CoTaskMemFree( deviceFormat );

// if method failed, close the stream
Expand Down Expand Up @@ -5645,10 +5604,10 @@ void RtApiWasapi::wasapiThread()

HRESULT hr;

IAudioClient* captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;
IAudioClient* renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
IAudioCaptureClient* captureClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureClient;
IAudioRenderClient* renderClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderClient;
ComPtr<IAudioClient> captureAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureAudioClient;
ComPtr<IAudioClient> renderAudioClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderAudioClient;
ComPtr<IAudioCaptureClient> captureClient = ( ( WasapiHandle* ) stream_.apiHandle )->captureClient;
ComPtr<IAudioRenderClient> renderClient = ( ( WasapiHandle* ) stream_.apiHandle )->renderClient;
HANDLE captureEvent = ( ( WasapiHandle* ) stream_.apiHandle )->captureEvent;
HANDLE renderEvent = ( ( WasapiHandle* ) stream_.apiHandle )->renderEvent;

Expand Down Expand Up @@ -5708,7 +5667,7 @@ void RtApiWasapi::wasapiThread()
captureSrRatio = ( ( float ) captureFormat->nSamplesPerSec / stream_.sampleRate );

if ( !captureClient ) {
IAudioClient3* captureAudioClient3 = nullptr;
ComPtr<IAudioClient3> captureAudioClient3 = nullptr;
captureAudioClient->QueryInterface( __uuidof( IAudioClient3 ), ( void** ) &captureAudioClient3 );
if ( captureAudioClient3 && !loopbackEnabled )
{
Expand All @@ -5728,7 +5687,6 @@ void RtApiWasapi::wasapiThread()
MinPeriodInFrames,
captureFormat,
NULL );
SAFE_RELEASE(captureAudioClient3);
}
else
{
Expand Down Expand Up @@ -5820,7 +5778,7 @@ void RtApiWasapi::wasapiThread()
renderSrRatio = ( ( float ) renderFormat->nSamplesPerSec / stream_.sampleRate );

if ( !renderClient ) {
IAudioClient3* renderAudioClient3 = nullptr;
ComPtr<IAudioClient3> renderAudioClient3 = nullptr;
renderAudioClient->QueryInterface( __uuidof( IAudioClient3 ), ( void** ) &renderAudioClient3 );
if ( renderAudioClient3 )
{
Expand All @@ -5840,7 +5798,6 @@ void RtApiWasapi::wasapiThread()
MinPeriodInFrames,
renderFormat,
NULL );
SAFE_RELEASE(renderAudioClient3);
}
else
{
Expand Down

0 comments on commit 028484c

Please sign in to comment.