-
Notifications
You must be signed in to change notification settings - Fork 47
/
Copy pathClassificationService.swift
70 lines (62 loc) · 2.35 KB
/
ClassificationService.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import CoreML
import Vision
import VisionLab
/// Delegate protocol used for `ClassificationService`
protocol ClassificationServiceDelegate: class {
func classificationService(_ service: ClassificationService, didDetectGender gender: String)
func classificationService(_ service: ClassificationService, didDetectAge age: String)
func classificationService(_ service: ClassificationService, didDetectEmotion emotion: String)
}
/// Service used to perform gender, age and emotion classification
final class ClassificationService: ClassificationServiceProtocol {
/// The service's delegate
weak var delegate: ClassificationServiceDelegate?
/// Array of vision requests
private var requests = [VNRequest]()
/// Create CoreML model and classification requests
func setup() {
do {
// Gender request
requests.append(VNCoreMLRequest(
model: try VNCoreMLModel(for: GenderNet().model),
completionHandler: handleGenderClassification
))
// Age request
requests.append(VNCoreMLRequest(
model: try VNCoreMLModel(for: AgeNet().model),
completionHandler: handleAgeClassification
))
// Emotions request
requests.append(VNCoreMLRequest(
model: try VNCoreMLModel(for: CNNEmotions().model),
completionHandler: handleEmotionClassification
))
} catch {
assertionFailure("Can't load Vision ML model: \(error)")
}
}
/// Run individual requests one by one.
func classify(image: CIImage) {
do {
for request in self.requests {
let handler = VNImageRequestHandler(ciImage: image)
try handler.perform([request])
}
} catch {
print(error)
}
}
// MARK: - Handling
@objc private func handleGenderClassification(request: VNRequest, error: Error?) {
let result = extractClassificationResult(from: request, count: 1)
delegate?.classificationService(self, didDetectGender: result)
}
@objc private func handleAgeClassification(request: VNRequest, error: Error?) {
let result = extractClassificationResult(from: request, count: 1)
delegate?.classificationService(self, didDetectAge: result)
}
@objc private func handleEmotionClassification(request: VNRequest, error: Error?) {
let result = extractClassificationResult(from: request, count: 1)
delegate?.classificationService(self, didDetectEmotion: result)
}
}