Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ apply plugin: 'com.android.library'
android {
// Bumping the plugin compileSdkVersion requires all clients of this plugin
// to bump the version in their app.
compileSdkVersion 31
compileSdkVersion 35
namespace 'org.tensorflow.tflite_flutter'

// Bumping the plugin ndkVersion requires all clients of this plugin to bump
Expand Down Expand Up @@ -55,7 +55,8 @@ android {
}

defaultConfig {
minSdkVersion 19
minSdkVersion 23
targetSdkVersion 35
}
}

Expand All @@ -65,4 +66,4 @@ dependencies {

implementation("org.tensorflow:tensorflow-lite:${tflite_version}")
implementation("org.tensorflow:tensorflow-lite-gpu:${tflite_version}")
}
}
25 changes: 8 additions & 17 deletions example/audio_classification/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ class MyHomePage extends StatefulWidget {
}

class _MyHomePageState extends State<MyHomePage> {
static const platform =
MethodChannel('org.tensorflow.audio_classification/audio_record');
static const platform = MethodChannel('org.tensorflow.audio_classification/audio_record');

// The YAMNet/classifier model used in this code example accepts data that
// represent single-channel, or mono, audio clips recorded at 16kHz in 0.975
// second clips (15600 samples).
static const _sampleRate = 16000; // 16kHz
static const _expectAudioLength = 975; // milliseconds
final int _requiredInputBuffer =
(16000 * (_expectAudioLength / 1000)).toInt();
final int _requiredInputBuffer = (16000 * (_expectAudioLength / 1000)).toInt();
late AudioClassificationHelper _helper;
List<MapEntry<String, double>> _classification = List.empty();
final List<Color> _primaryProgressColorList = [
Expand Down Expand Up @@ -102,10 +100,7 @@ class _MyHomePageState extends State<MyHomePage> {

Future<bool> _requestPermission() async {
try {
return await platform.invokeMethod('requestPermissionAndCreateRecorder', {
"sampleRate": _sampleRate,
"requiredInputBuffer": _requiredInputBuffer
});
return await platform.invokeMethod('requestPermissionAndCreateRecorder', {"sampleRate": _sampleRate, "requiredInputBuffer": _requiredInputBuffer});
} on Exception catch (e) {
log("Failed to create recorder: '${e.toString()}'.");
return false;
Expand All @@ -115,8 +110,7 @@ class _MyHomePageState extends State<MyHomePage> {
Future<Float32List> _getAudioFloatArray() async {
var audioFloatArray = Float32List(0);
try {
final Float32List result =
await platform.invokeMethod('getAudioFloatArray');
final Float32List result = await platform.invokeMethod('getAudioFloatArray');
audioFloatArray = result;
} on PlatformException catch (e) {
log("Failed to get audio array: '${e.message}'.");
Expand Down Expand Up @@ -160,8 +154,7 @@ class _MyHomePageState extends State<MyHomePage> {

Future<void> _runInference() async {
Float32List inputArray = await _getAudioFloatArray();
final result =
await _helper.inference(inputArray.sublist(0, _requiredInputBuffer));
final result = await _helper.inference(inputArray.sublist(0, _requiredInputBuffer));
setState(() {
// take top 3 classification
_classification = (result.entries.toList()
Expand All @@ -186,7 +179,7 @@ class _MyHomePageState extends State<MyHomePage> {
backgroundColor: Colors.white,
appBar: AppBar(
title: Image.asset('assets/images/tfl_logo.png'),
backgroundColor: Colors.black.withOpacity(0.5),
backgroundColor: Colors.black.withValues(alpha: 0.5),
),
body: _buildBody(),
);
Expand Down Expand Up @@ -216,10 +209,8 @@ class _MyHomePageState extends State<MyHomePage> {
),
Flexible(
child: LinearProgressIndicator(
backgroundColor: _backgroundProgressColorList[
index % _backgroundProgressColorList.length],
color: _primaryProgressColorList[
index % _primaryProgressColorList.length],
backgroundColor: _backgroundProgressColorList[index % _backgroundProgressColorList.length],
color: _primaryProgressColorList[index % _primaryProgressColorList.length],
value: item.value,
minHeight: 20,
))
Expand Down
53 changes: 11 additions & 42 deletions example/bertqa/lib/ui/qa_detail.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ import 'package:bertqa/ml/qa_client.dart';
import 'package:flutter/material.dart';

class QaDetail extends StatefulWidget {
const QaDetail(
{super.key,
required this.title,
required this.content,
required this.questions});
const QaDetail({super.key, required this.title, required this.content, required this.questions});

final String title;
final String content;
Expand Down Expand Up @@ -66,8 +62,7 @@ class _QaDetailState extends State<QaDetail> {
if (!trimQuestion.endsWith("?")) {
trimQuestion += "?";
}
List<QaAnswer> answers =
await _qaClient.runInference(trimQuestion, widget.content);
List<QaAnswer> answers = await _qaClient.runInference(trimQuestion, widget.content);
// Highlight the answer here
_highlightAnswer(answers.first);
}
Expand Down Expand Up @@ -115,36 +110,15 @@ class _QaDetailState extends State<QaDetail> {
style: Theme.of(context).textTheme.bodyMedium,
)
: RichText(
text: TextSpan(
style: Theme.of(context).textTheme.bodyMedium,
children: [
if (_answerIndex > 0)
TextSpan(
text: widget.content
.substring(0, _answerIndex)),
TextSpan(
style: TextStyle(
background: Paint()
..color = Colors.yellow),
text: widget.content.substring(_answerIndex,
_answerIndex + _qaAnswer!.text.length)),
if ((_answerIndex + _qaAnswer!.text.length) <
widget.content.length)
TextSpan(
text: widget.content.substring(
_answerIndex + _qaAnswer!.text.length,
widget.content.length))
]),
text: TextSpan(style: Theme.of(context).textTheme.bodyMedium, children: [
if (_answerIndex > 0) TextSpan(text: widget.content.substring(0, _answerIndex)),
TextSpan(style: TextStyle(background: Paint()..color = Colors.yellow), text: widget.content.substring(_answerIndex, _answerIndex + _qaAnswer!.text.length)),
if ((_answerIndex + _qaAnswer!.text.length) < widget.content.length) TextSpan(text: widget.content.substring(_answerIndex + _qaAnswer!.text.length, widget.content.length))
]),
))),
Container(
padding: const EdgeInsets.all(16),
decoration: BoxDecoration(color: Colors.white, boxShadow: [
BoxShadow(
color: Colors.grey.withOpacity(0.5),
spreadRadius: 2,
blurRadius: 5,
offset: const Offset(0, 3))
]),
decoration: BoxDecoration(color: Colors.white, boxShadow: [BoxShadow(color: Colors.grey.withValues(alpha: 0.5), spreadRadius: 2, blurRadius: 5, offset: const Offset(0, 3))]),
// color: Colors.white,
child: Column(
children: [
Expand All @@ -157,8 +131,7 @@ class _QaDetailState extends State<QaDetail> {
child: ListView.separated(
shrinkWrap: true,
scrollDirection: Axis.horizontal,
separatorBuilder: (BuildContext context, int index) =>
const Divider(
separatorBuilder: (BuildContext context, int index) => const Divider(
indent: 16,
),
itemCount: widget.questions.length,
Expand All @@ -175,9 +148,7 @@ class _QaDetailState extends State<QaDetail> {
Expanded(
child: TextField(
controller: _controller,
decoration: const InputDecoration(
border: UnderlineInputBorder(),
labelText: "Text query"),
decoration: const InputDecoration(border: UnderlineInputBorder(), labelText: "Text query"),
onChanged: (text) {
setState(() {
_currentQuestion = text;
Expand All @@ -194,9 +165,7 @@ class _QaDetailState extends State<QaDetail> {
_answerQuestion();
}
: null,
style: ElevatedButton.styleFrom(
disabledBackgroundColor: Colors.grey,
backgroundColor: const Color(0xFFFFA800)),
style: ElevatedButton.styleFrom(disabledBackgroundColor: Colors.grey, backgroundColor: const Color(0xFFFFA800)),
child: const Icon(
Icons.east,
color: Colors.white,
Expand Down
15 changes: 5 additions & 10 deletions example/digit_classification/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ class _MyHomePageState extends State<MyHomePage> {

Future<void> _predictNumber() async {
// capture sketch area
RenderRepaintBoundary boundary =
_globalKey.currentContext!.findRenderObject() as RenderRepaintBoundary;
RenderRepaintBoundary boundary = _globalKey.currentContext!.findRenderObject() as RenderRepaintBoundary;
ui.Image image = await boundary.toImage();
final byteData = await image.toByteData(format: ui.ImageByteFormat.png);
final inputImageData = byteData?.buffer.asUint8List();

final stopwatch = Stopwatch()..start();
final (number, confidence) =
await _digitClassifierHelper.runInference(inputImageData!);
final (number, confidence) = await _digitClassifierHelper.runInference(inputImageData!);
stopwatch.stop();

setState(() {
Expand All @@ -91,7 +89,7 @@ class _MyHomePageState extends State<MyHomePage> {
title: Center(
child: Image.asset('assets/images/tfl_logo.png'),
),
backgroundColor: Colors.black.withOpacity(0.5),
backgroundColor: Colors.black.withValues(alpha: 0.5),
),
body: Center(
child: Column(
Expand All @@ -107,9 +105,7 @@ class _MyHomePageState extends State<MyHomePage> {
children: [
const Spacer(),
const Text("Predicted number:"),
if (_predictedNumber != null && _predictedConfidence != null)
Text(
"$_predictedNumber (${_predictedConfidence?.toStringAsFixed(3)})"),
if (_predictedNumber != null && _predictedConfidence != null) Text("$_predictedNumber (${_predictedConfidence?.toStringAsFixed(3)})"),
const Spacer(),
Text("Inference Time: $_inferenceTime (ms)"),
Padding(
Expand All @@ -133,8 +129,7 @@ class _MyHomePageState extends State<MyHomePage> {
}

Widget sketchArea() {
return LayoutBuilder(
builder: (BuildContext context, BoxConstraints constraints) {
return LayoutBuilder(builder: (BuildContext context, BoxConstraints constraints) {
return GestureDetector(
onPanUpdate: (DragUpdateDetails details) {
final width = constraints.maxWidth;
Expand Down
27 changes: 7 additions & 20 deletions example/gesture_classification/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ class MyApp extends StatelessWidget {
colorScheme: ColorScheme.fromSeed(seedColor: Colors.orange),
useMaterial3: true,
),
home: const MyHomePage(
title:
'An end-to-end example of gesture classification using Flutter and TensorFlow Lite'),
home: const MyHomePage(title: 'An end-to-end example of gesture classification using Flutter and TensorFlow Lite'),
);
}
}
Expand All @@ -64,13 +62,8 @@ class _MyHomePageState extends State<MyHomePage> with WidgetsBindingObserver {

// init camera
_initCamera() {
_cameraDescription = _cameras.firstWhere(
(element) => element.lensDirection == CameraLensDirection.front);
_cameraController = CameraController(
_cameraDescription, ResolutionPreset.high,
imageFormatGroup: Platform.isIOS
? ImageFormatGroup.bgra8888
: ImageFormatGroup.yuv420);
_cameraDescription = _cameras.firstWhere((element) => element.lensDirection == CameraLensDirection.front);
_cameraController = CameraController(_cameraDescription, ResolutionPreset.high, imageFormatGroup: Platform.isIOS ? ImageFormatGroup.bgra8888 : ImageFormatGroup.yuv420);
_cameraController!.initialize().then((value) {
_cameraController!.startImageStream(_imageAnalysis);
if (mounted) {
Expand All @@ -85,8 +78,7 @@ class _MyHomePageState extends State<MyHomePage> with WidgetsBindingObserver {
return;
}
_isProcessing = true;
_classification =
await _gestureClassificationHelper.inferenceCameraFrame(cameraImage);
_classification = await _gestureClassificationHelper.inferenceCameraFrame(cameraImage);
_isProcessing = false;
if (mounted) {
setState(() {});
Expand Down Expand Up @@ -116,8 +108,7 @@ class _MyHomePageState extends State<MyHomePage> with WidgetsBindingObserver {
_cameraController?.stopImageStream();
break;
case AppLifecycleState.resumed:
if (_cameraController != null &&
!_cameraController!.value.isStreamingImages) {
if (_cameraController != null && !_cameraController!.value.isStreamingImages) {
await _cameraController!.startImageStream(_imageAnalysis);
}
break;
Expand Down Expand Up @@ -166,7 +157,7 @@ class _MyHomePageState extends State<MyHomePage> with WidgetsBindingObserver {
title: Center(
child: Image.asset('assets/images/tfl_logo.png'),
),
backgroundColor: Colors.black.withOpacity(0.5),
backgroundColor: Colors.black.withValues(alpha: 0.5),
),
body: Center(
// Center is a layout widget. It takes a single child and positions it
Expand All @@ -193,11 +184,7 @@ class _MyHomePageState extends State<MyHomePage> with WidgetsBindingObserver {
padding: const EdgeInsets.all(8),
color: Colors.white,
child: Row(
children: [
Text(e.key),
const Spacer(),
Text(e.value.toStringAsFixed(2))
],
children: [Text(e.key), const Spacer(), Text(e.value.toStringAsFixed(2))],
),
),
),
Expand Down
8 changes: 3 additions & 5 deletions example/image_classification_mobilenet/lib/main.dart
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,10 @@ class BottomNavigationBarExample extends StatefulWidget {
const BottomNavigationBarExample({super.key});

@override
State<BottomNavigationBarExample> createState() =>
_BottomNavigationBarExampleState();
State<BottomNavigationBarExample> createState() => _BottomNavigationBarExampleState();
}

class _BottomNavigationBarExampleState
extends State<BottomNavigationBarExample> {
class _BottomNavigationBarExampleState extends State<BottomNavigationBarExample> {
late CameraDescription cameraDescription;
int _selectedIndex = 0;
List<Widget>? _widgetOptions;
Expand Down Expand Up @@ -88,7 +86,7 @@ class _BottomNavigationBarExampleState
return Scaffold(
appBar: AppBar(
title: Image.asset('assets/images/tfl_logo.png'),
backgroundColor: Colors.black.withOpacity(0.5),
backgroundColor: Colors.black.withValues(alpha: 0.5),
),
body: Center(
child: _widgetOptions?.elementAt(_selectedIndex),
Expand Down
Loading
Loading