-
Notifications
You must be signed in to change notification settings - Fork 126
/
TFLImageClassifierCoreMLDelegateTest.mm
130 lines (106 loc) · 5.86 KB
/
TFLImageClassifierCoreMLDelegateTest.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#import <XCTest/XCTest.h>
#import "tensorflow_lite_support/ios/task/core/sources/TFLBaseOptions.h"
#import "tensorflow_lite_support/ios/task/vision/sources/TFLImageClassifier.h"
#import "tensorflow_lite_support/ios/task/vision/utils/sources/GMLImage+Utils.h"
#include "tensorflow_lite_support/c/task/vision/utils/frame_buffer_cpp_c_utils.h"
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
using ImageClassifier = ::tflite::task::vision::ImageClassifier;
using ImageClassifierOptions = ::tflite::task::vision::ImageClassifierOptions;
using ClassificationResult = ::tflite::task::vision::ClassificationResult;
@interface TFLImageClassifierCoreMLDelegateTest : XCTestCase {
NSString* _modelPath;
}
@end
@implementation TFLImageClassifierCoreMLDelegateTest
- (void)setUp {
[super setUp];
// This image classifier can mostly be deoplyed through CoreML. Below is from the delegate logs:
// "INFO: CoreML delegate: 64 nodes delegated out of 66 nodes, with 2 partitions."
_modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"mobilenet_v2_1.0_224"
ofType:@"tflite"];
XCTAssertNotNil(_modelPath);
}
- (void)testCoreMLDelegateCreationSucceedsWithDevicesAllUsingCppImageClassifier {
// Configures the options.
ImageClassifierOptions options;
options.mutable_base_options()->mutable_model_file()->set_file_name(_modelPath.UTF8String);
options.mutable_base_options()
->mutable_compute_settings()
->mutable_tflite_settings()
->set_delegate(::tflite::proto::Delegate::CORE_ML);
options.mutable_base_options()
->mutable_compute_settings()
->mutable_tflite_settings()
->mutable_coreml_settings()
->set_enabled_devices(::tflite::proto::CoreMLSettings::DEVICES_ALL);
// Creates the classifier.
tflite::support::StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_status =
ImageClassifier::CreateFromOptions(options);
XCTAssertTrue(image_classifier_status.ok());
const std::unique_ptr<ImageClassifier>& image_classifier = image_classifier_status.value();
XCTAssertNotEqual(image_classifier.get(), nullptr);
// Loads the test image.
GMLImage* gmlImage = [GMLImage imageFromBundleWithClass:[self class]
fileName:@"burger"
ofType:@"jpg"];
XCTAssertNotNil(gmlImage);
// Converts the test image to a frame buffer.
NSError* error = nullptr;
TfLiteFrameBuffer* cFrameBuffer = [gmlImage cFrameBufferWithError:&error];
XCTAssertNotEqual(cFrameBuffer, nullptr);
tflite::support::StatusOr<std::unique_ptr<::tflite::task::vision::FrameBuffer>>
frame_buffer_status = ::tflite::task::vision::CreateCppFrameBuffer(cFrameBuffer);
XCTAssertTrue(frame_buffer_status.ok());
const ::tflite::task::vision::FrameBuffer& frame_buffer = *frame_buffer_status.value();
// Classifies the image.
tflite::support::StatusOr<ClassificationResult> classification_result_status =
image_classifier->Classify(frame_buffer);
XCTAssertTrue(classification_result_status.ok());
const ClassificationResult& classification_result = classification_result_status.value();
// Retrieves the top class.
XCTAssertGreaterThan(classification_result.classifications_size(), 0);
const ::tflite::task::vision::Classifications& classification =
classification_result.classifications(0);
XCTAssertGreaterThan(classification.classes_size(), 0);
const ::tflite::task::vision::Class& topClass = classification.classes(0);
// Verifies the class name & score.
NSString* className = [NSString stringWithCString:topClass.class_name().c_str()];
XCTAssertEqualObjects(className, @"cheeseburger");
XCTAssertEqualWithAccuracy(topClass.score(), 0.748976, 0.001);
}
- (void)testCoreMLDelegateCreationSucceedsWithDevicesAllUsingObjcImageClassifier {
TFLCoreMLDelegateSettings* coreMLDelegateSettings = [[TFLCoreMLDelegateSettings alloc]
initWithCoreMLVersion:3
enableddevices:TFLCoreMLDelegateSettings_DevicesAll];
TFLImageClassifierOptions* imageClassifierOptions =
[[TFLImageClassifierOptions alloc] initWithModelPath:_modelPath];
// Implicitly enables Core ML Delegate.
imageClassifierOptions.baseOptions.coreMLDelegateSettings = coreMLDelegateSettings;
TFLImageClassifier* imageClassifier =
[TFLImageClassifier imageClassifierWithOptions:imageClassifierOptions error:nil];
XCTAssertNotNil(imageClassifier);
GMLImage* gmlImage =
[GMLImage imageFromBundleWithClass:self.class fileName:@"burger" ofType:@"jpg"];
XCTAssertNotNil(gmlImage);
TFLClassificationResult* classificationResults = [imageClassifier classifyWithGMLImage:gmlImage
error:nil];
XCTAssertTrue(classificationResults.classifications.count > 0);
XCTAssertTrue(classificationResults.classifications[0].categories.count > 0);
TFLCategory* category = classificationResults.classifications[0].categories[0];
XCTAssertTrue([category.label isEqual:@"cheeseburger"]);
// Comment from TFLImageClassifierTests.m:
// "TODO: match the score as image_classifier_test.cc"
XCTAssertEqualWithAccuracy(category.score, 0.748976, 0.001);
}
@end