##### Copyright 2023 The TensorFlow Hub Authors.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title Copyright 2023 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<table class="tfo-notebook-buttons" align="left">
  <!-- <td>
    <a target="_blank" href="https://tensorflow.google.cn/hub/tutorials/bird_vocalization_classifier"><img src="https://tensorflow.google.cn/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td> -->
  <td>     <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/bird_vocalization_classifier.ipynb"><img src="https://tensorflow.google.cn/images/colab_logo_32px.png">在 Google Colab 中运行</a>
</td>
  <td>     <a target="_blank" href="https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/hub/tutorials/bird_vocalization_classifier.ipynb"><img src="https://tensorflow.google.cn/images/GitHub-Mark-32px.png">在 GitHub 上查看源代码</a>
</td>
  <td>     <a href="https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/hub/tutorials/bird_vocalization_classifier.ipynb"><img src="https://tensorflow.google.cn/images/download_logo_32px.png">下载笔记本</a>
</td>
  <td>     <a href="https://tfhub.dev/google/bird-vocalization-classifier/1"><img src="https://tensorflow.google.cn/images/hub_logo_32px.png">查看 TF Hub 模型</a>
</td>
</table>

# 使用 Google Bird Vocalization 模型

Google Bird Vocalization 是一个全球鸟类嵌入和分类模型。

此模型需要以 32kHz 采样的 5 秒音频片段作为输入

此模型为音频的每个输入窗口输出逻辑和嵌入向量。

在此笔记本上，您会学习如何将音频正确提供给模型以及如何使用 logit 进行推断。


In [None]:
!pip install -q "tensorflow_io==0.28.*"
!pip install -q librosa

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_io as tfio

import numpy as np
import librosa

import csv
import io

from IPython.display import Audio

从 TFHub 加载模型

In [None]:
model_handle = "https://tfhub.dev/google/bird-vocalization-classifier/1"
model = hub.load(model_handle)

我们来加载训练模型使用的标签。

标签文件位于 assets 文件夹中的 label.csv 下。每行都是一个 ebird id。

In [None]:
# Find the name of the class with the top score when mean-aggregated across frames.
def class_names_from_csv(class_map_csv_text):
  """Returns list of class names corresponding to score vector."""
  with open(labels_path) as csv_file:
    csv_reader = csv.reader(csv_file, delimiter=',')
    class_names = [mid for mid, desc in csv_reader]
    return class_names[1:]

labels_path = hub.resolve(model_handle) + "/assets/label.csv"
classes = class_names_from_csv(labels_path)
print(classes)

`frame_audio` 函数基于 [Chirp 库](https://github.com/google-research/chirp/blob/10c5faa325a3c3468fa6f18a736fc1aeb9bf8129/chirp/inference/interface.py#L128)版本，但使用 tf.signal 而不是 librosa。

`ensure_sample_rate` 是一个用于确保模型使用的任何音频都具有 32kHz 预期采样率的函数

In [None]:
def frame_audio(
      audio_array: np.ndarray,
      window_size_s: float = 5.0,
      hop_size_s: float = 5.0,
      sample_rate = 32000,
  ) -> np.ndarray:
    """Helper function for framing audio for inference."""
    if window_size_s is None or window_size_s < 0:
      return audio_array[np.newaxis, :]
    frame_length = int(window_size_s * sample_rate)
    hop_length = int(hop_size_s * sample_rate)
    framed_audio = tf.signal.frame(audio_array, frame_length, hop_length, pad_end=True)
    return framed_audio

def ensure_sample_rate(waveform, original_sample_rate,
                       desired_sample_rate=32000):
  """Resample waveform if required."""
  if original_sample_rate != desired_sample_rate:
    waveform = tfio.audio.resample(waveform, original_sample_rate, desired_sample_rate)
  return desired_sample_rate, waveform

我们从 Wikipedia 加载一个文件。

更准确地说，是[常见黑鸟](https://es.wikipedia.org/wiki/Turdus_merula)的声音

<p><img src="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a9/Common_Blackbird.jpg/1200px-Common_Blackbird.jpg" alt="Common Blackbird.jpg" class=""></p>
:-:
*作者：<a rel="nofollow" class="external text" href="http://photo-natur.de">Andreas Trepte</a> - <span class="int-own-work" lang="en">自有作品</span>，<a href="https://creativecommons.org/licenses/by-sa/2.5" title="Creative Commons Attribution-Share Alike 2.5">CC BY-SA 2.5</a>，<a href="https://commons.wikimedia.org/w/index.php?curid=16110223">链接</a>*

此音频由 Oona Räisänen (Mysid) 根据公共领域许可提供。

In [None]:
!curl -O  "https://upload.wikimedia.org/wikipedia/commons/7/7c/Turdus_merula_2.ogg"

In [None]:
turdus_merula = "Turdus_merula_2.ogg"

audio, sample_rate = librosa.load(turdus_merula)

sample_rate, wav_data_turdus = ensure_sample_rate(audio, sample_rate)
Audio(wav_data_turdus, rate=sample_rate)

此音频有 24 秒，而模型需要 5 秒的块。

`frame_audio` 函数可以解决此问题并将音频拆分为适当的帧

In [None]:
fixed_tm = frame_audio(wav_data_turdus)
fixed_tm.shape

我们只在第一帧应用模型：

In [None]:
logits, embeddings = model.infer_tf(fixed_tm[:1])

label.csv 文件包含 ebird id。乌鸫的 ebird id 为 eurbla

In [None]:
probabilities = tf.nn.softmax(logits)
argmax = np.argmax(probabilities)
print(f"The audio is from the class {classes[argmax]} (element:{argmax} in the label.csv file), with probability of {probabilities[0][argmax]}")

现在我们在所有帧上应用模型：

*注*：此代码也基于 [Chirp 库](https://github.com/google-research/chirp/blob/d6ff5e7cee3865940f31697bf4b70176c1072572/chirp/inference/models.py#L174)

In [None]:
all_logits, all_embeddings = model.infer_tf(fixed_tm[:1])
for window in fixed_tm[1:]:
  logits, embeddings = model.infer_tf(window[np.newaxis, :])
  all_logits = np.concatenate([all_logits, logits], axis=0)

all_logits.shape

In [None]:
frame = 0
for frame_logits in all_logits:
  probabilities = tf.nn.softmax(frame_logits)
  argmax = np.argmax(probabilities)
  print(f"For frame {frame}, the audio is from the class {classes[argmax]} (element:{argmax} in the label.csv file), with probability of {probabilities[argmax]}")
  frame += 1