Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lite/examples/sound_classification/android/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Sound Classifier Android sample.

This Android application demonstrates how to classify sound on-device. It uses:
* [TFLite Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/overview)
* [YAMNet](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1), an audio event classification model.

## Requirements

Expand Down Expand Up @@ -36,4 +39,4 @@ Re-installing the app may require you to uninstall the previous installations.
## Resources used:

* [TensorFlow Lite](https://www.tensorflow.org/lite)
* [Teachable Machine Audio Project](https://teachablemachine.withgoogle.com/train/audio)
* [YAMNet audio classification model](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1)
17 changes: 7 additions & 10 deletions lite/examples/sound_classification/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,16 @@ apply from: 'download_model.gradle'
dependencies {
implementation fileTree(dir: "libs", include: ["*.jar"])
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
implementation "androidx.core:core-ktx:1.3.1"
implementation "androidx.core:core-ktx:1.3.2"
implementation "androidx.appcompat:appcompat:1.2.0"
implementation "androidx.lifecycle:lifecycle-common-java8:2.2.0"
implementation "androidx.constraintlayout:constraintlayout:2.0.1"
implementation "androidx.recyclerview:recyclerview:1.1.0"
implementation "com.google.android.material:material:1.2.1"
implementation "androidx.lifecycle:lifecycle-common-java8:2.3.1"
implementation "androidx.constraintlayout:constraintlayout:2.0.4"
implementation "androidx.recyclerview:recyclerview:1.2.0"
implementation "com.google.android.material:material:1.3.0"

implementation "org.tensorflow:tensorflow-lite:2.3.0"
implementation "org.tensorflow:tensorflow-lite-select-tf-ops:2.3.0"
implementation "org.tensorflow:tensorflow-lite-support:0.1.0"
implementation "org.tensorflow:tensorflow-lite-metadata:0.1.0"
implementation 'org.tensorflow:tensorflow-lite-task-audio:0.2.0-rc2'

testImplementation "junit:junit:4.13"
testImplementation "junit:junit:4.13.2"
androidTestImplementation "androidx.test.ext:junit:1.1.2"
androidTestImplementation "androidx.test.espresso:espresso-core:3.3.0"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
task downloadSoundClassificationModelFile(type: Download) {
src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/sound_classification/snap_clap.tflite'
dest project.ext.ASSET_DIR + '/sound_classifier.tflite'
src 'https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1?lite-format=tflite'
dest project.ext.ASSET_DIR + '/yamnet.tflite'
overwrite false
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,84 +18,150 @@ package org.tensorflow.lite.examples.soundclassifier

import android.Manifest
import android.content.pm.PackageManager
import android.media.AudioRecord
import android.os.Build
import android.os.Bundle
import android.os.Handler
import android.os.HandlerThread
import android.util.Log
import android.view.WindowManager
import androidx.annotation.RequiresApi
import androidx.appcompat.app.AppCompatActivity
import androidx.core.content.ContextCompat
import androidx.core.os.HandlerCompat
import org.tensorflow.lite.examples.soundclassifier.databinding.ActivityMainBinding
import org.tensorflow.lite.task.audio.classifier.AudioClassifier


class MainActivity : AppCompatActivity() {
private val probabilitiesAdapter by lazy { ProbabilitiesAdapter() }

private lateinit var soundClassifier: SoundClassifier
private var audioClassifier: AudioClassifier? = null
private var audioRecord: AudioRecord? = null
private var classificationInterval = 500L // how often should classification run in milli-secs
private lateinit var handler: Handler // background thread handler to run classification

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)

val binding = ActivityMainBinding.inflate(layoutInflater)
setContentView(binding.root)

soundClassifier = SoundClassifier(this, SoundClassifier.Options()).also {
it.lifecycleOwner = this
}

with(binding) {
recyclerView.apply {
setHasFixedSize(true)
adapter = probabilitiesAdapter.apply {
labelList = soundClassifier.labelList
}
setHasFixedSize(false)
adapter = probabilitiesAdapter
}

// Input switch to turn on/off classification
keepScreenOn(inputSwitch.isChecked)
inputSwitch.setOnCheckedChangeListener { _, isChecked ->
soundClassifier.isPaused = !isChecked
if (isChecked) startAudioClassification() else stopAudioClassification()
keepScreenOn(isChecked)
}

overlapFactorSlider.value = soundClassifier.overlapFactor
overlapFactorSlider.addOnChangeListener { _, value, _ ->
soundClassifier.overlapFactor = value
// Slider which control how often the classification task should run
classificationIntervalSlider.value = classificationInterval.toFloat()
classificationIntervalSlider.setLabelFormatter { value: Float ->
"${value.toInt()} ms"
}
}

soundClassifier.probabilities.observe(this) { resultMap ->
if (resultMap.isEmpty() || resultMap.size > soundClassifier.labelList.size) {
Log.w(TAG, "Invalid size of probability output! (size: ${resultMap.size})")
return@observe
classificationIntervalSlider.addOnChangeListener { _, value, _ ->
classificationInterval = value.toLong()
stopAudioClassification()
startAudioClassification()
}
probabilitiesAdapter.probabilityMap = resultMap
probabilitiesAdapter.notifyDataSetChanged()
}

// Create a handler to run classification in a background thread
val handlerThread = HandlerThread("backgroundThread")
handlerThread.start()
handler = HandlerCompat.createAsync(handlerThread.looper)

// Request microphone permission and start running classification
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
requestMicrophonePermission()
} else {
soundClassifier.start()
startAudioClassification()
}

}

private fun startAudioClassification() {
// If the audio classifier is initialized and running, do nothing.
if (audioClassifier != null) return;

// Initialize the audio classifier
val classifier = AudioClassifier.createFromFile(this, MODEL_FILE)
val audioTensor = classifier.createInputTensorAudio()

// Initialize the audio recorder
val record = classifier.createAudioRecord()
record.startRecording()

// Define the classification runnable
val run = object : Runnable {
override fun run() {
val startTime = System.currentTimeMillis()

// Load the latest audio sample
audioTensor.load(record)
val output = classifier.classify(audioTensor)

// Filter out results above a certain threshold, and sort them descendingly
val filteredModelOutput = output[0].categories.filter {
it.score > MINIMUM_DISPLAY_THRESHOLD
}.sortedBy {
-it.score
}

val finishTime = System.currentTimeMillis()

Log.d(TAG, "Latency = ${finishTime - startTime}ms")

// Updating the UI
runOnUiThread {
probabilitiesAdapter.categoryList = filteredModelOutput
probabilitiesAdapter.notifyDataSetChanged()
}

// Rerun the classification after a certain interval
handler.postDelayed(this, classificationInterval)
}
}

// Start the classification process
handler.post(run)

// Save the instances we just created for use later
audioClassifier = classifier
audioRecord = record
}

private fun stopAudioClassification() {
handler.removeCallbacksAndMessages(null)
audioRecord?.stop()
audioRecord = null
audioClassifier = null
}

override fun onTopResumedActivityChanged(isTopResumedActivity: Boolean) {
// Handles "top" resumed event on multi-window environment
if (isTopResumedActivity) {
soundClassifier.start()
startAudioClassification()
} else {
soundClassifier.stop()
stopAudioClassification()
}
}

override fun onRequestPermissionsResult(
requestCode: Int,
permissions: Array<out String>,
grantResults: IntArray
requestCode: Int,
permissions: Array<out String>,
grantResults: IntArray
) {
if (requestCode == REQUEST_RECORD_AUDIO) {
if (grantResults.isNotEmpty() && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
Log.i(TAG, "Audio permission granted :)")
soundClassifier.start()
startAudioClassification()
} else {
Log.e(TAG, "Audio permission not granted :(")
}
Expand All @@ -105,11 +171,11 @@ class MainActivity : AppCompatActivity() {
@RequiresApi(Build.VERSION_CODES.M)
private fun requestMicrophonePermission() {
if (ContextCompat.checkSelfPermission(
this,
Manifest.permission.RECORD_AUDIO
) == PackageManager.PERMISSION_GRANTED
this,
Manifest.permission.RECORD_AUDIO
) == PackageManager.PERMISSION_GRANTED
) {
soundClassifier.start()
startAudioClassification()
} else {
requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)
}
Expand All @@ -125,5 +191,7 @@ class MainActivity : AppCompatActivity() {
companion object {
const val REQUEST_RECORD_AUDIO = 1337
private const val TAG = "AudioDemo"
private const val MODEL_FILE = "yamnet.tflite"
private const val MINIMUM_DISPLAY_THRESHOLD: Float = 0.3f
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import android.view.ViewGroup
import android.view.animation.AccelerateDecelerateInterpolator
import androidx.recyclerview.widget.RecyclerView
import org.tensorflow.lite.examples.soundclassifier.databinding.ItemProbabilityBinding
import org.tensorflow.lite.support.label.Category

internal class ProbabilitiesAdapter : RecyclerView.Adapter<ProbabilitiesAdapter.ViewHolder>() {
var labelList = emptyList<String>()
var probabilityMap = mapOf<String, Float>()
var categoryList: List<Category> = emptyList()

override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder {
val binding =
Expand All @@ -35,22 +35,21 @@ internal class ProbabilitiesAdapter : RecyclerView.Adapter<ProbabilitiesAdapter.
}

override fun onBindViewHolder(holder: ViewHolder, position: Int) {
val label = labelList[position]
val probability = probabilityMap[label] ?: 0f
holder.bind(position, label, probability)
val category = categoryList[position]
holder.bind(position, category.label, category.score)
}

override fun getItemCount() = labelList.size
override fun getItemCount() = categoryList.size

class ViewHolder(private val binding: ItemProbabilityBinding) :
RecyclerView.ViewHolder(binding.root) {
fun bind(position: Int, label: String, probability: Float) {
fun bind(position: Int, label: String, score: Float) {
with(binding) {
labelTextView.text = label
progressBar.progressBackgroundTintList = progressColorPairList[position % 3].first
progressBar.progressTintList = progressColorPairList[position % 3].second

val newValue = (probability * 100).toInt()
val newValue = (score * 100).toInt()
// If you don't want to animate, you can write like `progressBar.progress = newValue`.
val animation =
ObjectAnimator.ofInt(progressBar, "progress", progressBar.progress, newValue)
Expand Down
Loading