diff --git a/services/ml/compilation_impl_mac.mm b/services/ml/compilation_impl_mac.mm index cdaeacbc3896b3..86467e1563fad0 100644 --- a/services/ml/compilation_impl_mac.mm +++ b/services/ml/compilation_impl_mac.mm @@ -7,6 +7,54 @@ #import +API_AVAILABLE(macosx(10.13)) @interface ConvDataSource : NSObject {} +@property(nonatomic, assign) float* weights_; +@property(nonatomic, assign) float* bias_; +@property(nonatomic, assign) MPSCNNConvolutionDescriptor* desc_; +@end + +@implementation ConvDataSource +@synthesize weights_; +@synthesize bias_; +@synthesize desc_; +- (id)initWithWeight:(float*)weights + bias:(float*)bias + desc:(MPSCNNConvolutionDescriptor*)desc { + self = [super init]; + self.weights_ = weights; + self.bias_ = bias; + self.desc_ = desc; + return self; +} +- (float*)biasTerms { + return self.bias_; +} +- (MPSDataType)dataType { + return MPSDataTypeFloat32; +} +- (MPSCNNConvolutionDescriptor*)descriptor { + return self.desc_; +} +- (NSString*)label { + return nullptr; +} +- (BOOL)load { + return true; +} +- (float*)lookupTableForUInt8Kernel { + return nullptr; +} +- (void)purge { + return; +} +- (vector_float2*)rangesForUInt8Kernel { + return nullptr; +} +- (void*)weights { + return self.weights_; +} +@end + namespace ml { OperationMac::OperationMac() = default; @@ -56,13 +104,14 @@ desc.strideInPixelsY = stride_height; desc.groups = 1; + auto data_source = [[ConvDataSource alloc] + initWithWeight:(float*)weights + bias:(float*)bias + desc:(MPSCNNConvolutionDescriptor*)desc]; MPSCNNConvolution* conv = [[MPSCNNConvolution alloc] initWithDevice:GetMPSCNNContext().device - convolutionDescriptor:desc - kernelWeights:weights - biasTerms:bias - flags:MPSCNNConvolutionFlagsNone]; + weights:data_source]; return conv; }