# Testing the Trained Wav2Vec2 Model

We will be testing the model using 2 methods: <br>
- live audio from your microphone
- upload a .wav file

## Import Libaries

In [1]:
# Loading the model
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification, Wav2Vec2Config
import torch

# For recording using your microphone
import sounddevice as sd

# For managing uploaded audio
import torchaudio

# Display classification output
from tabulate import tabulate
import time
from IPython.display import clear_output

## Setup Model and Processor

In [2]:
# Class names
class_names = ["Cello", "Piano", "Violin"]

# Path to your model
model_path = "model/2024-03-31_23-47-51_0.9984.bin"

# Load model and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Wav2Vec2Config.from_pretrained("facebook/wav2vec2-base", num_labels=len(class_names))
model = Wav2Vec2ForSequenceClassification(config)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")



## Preprocess input audio

The function below preprocess the audio before it is input into the model.

The function ensure that the audio is mono channel (1 channel) and it is sampled in 16000Hz.

In [3]:
def preprocess_audio(recording, processor, sampling_rate=16000):
  # Ensure recording is 1D for single channel
  if recording.ndim > 1:
      recording = recording.squeeze()
  
  # Check length of recording, pad if necessary
  min_length = 16000  # Minimum length required by the model
  if len(recording) < min_length:
    # Pad the recording if it's too short
    pad_amount = min_length - len(recording)
    recording = torch.nn.functional.pad(recording, (0, pad_amount), "constant", 0)
  
  input_values = processor(recording, sampling_rate=sampling_rate, return_tensors="pt", padding=True).input_values
  return input_values

## Classify/Predict the Audio Input

This function uses the model to predict the instrument present in the input audio.

It returns the inference time along with the classification probabilities of the input.

In [4]:
def predict(model, processor, input_values, device):
    input_values = input_values.to(device)

    start_time = time.time()  # Step 2: Record the start time
    with torch.no_grad():
        logits = model(input_values).logits
    end_time = time.time()  # Step 3: Record the end time
    inference_time = end_time - start_time  # Step 4: Calculate the inference time

    # Compute softmax probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1)
    return probabilities, inference_time  # Step 5: Return the inference time along with the probabilities

## Predicting Live Audio

Use this model to classify the instrument used to produce the audio.

The code below records a 1-second audio from your microphone and produces a prediction after every recording.

#### `record_audio` function

Helper function that records an audio based on the input duration and sampling rate

In [5]:
def record_audio(duration, sampling_rate=16000):
  print("Recording...")
  recording = sd.rec(int(duration * sampling_rate), samplerate=sampling_rate, channels=1, dtype='float32')
  sd.wait()  # Wait until recording is finished
  return recording

#### `Live Prediction` function

Helper function that classifies the recorded audio into one of the 3 instruments

In [6]:
def live_prediction():
    duration=1
    sampling_rate=16000

    try:
        while True:
            # Record and preprocess audio
            start_time = time.time()
            recording = record_audio(duration, sampling_rate)
            end_time = time.time()
            record_time = end_time - start_time

            start_time = time.time()
            input_values = preprocess_audio(recording, processor, sampling_rate)
            end_time = time.time()
            preprocess_time = end_time - start_time
            
            # Predict and get probabilities
            probabilities, inference_time = predict(model, processor, input_values, device)
            predicted_prob, predicted_index = torch.max(probabilities, dim=1)
            predicted_class = class_names[predicted_index.item()]

            # Convert probabilities to percentages and prepare table data
            percentages = [prob.item() * 100 for prob in probabilities[0]]
            table_data = [[class_name, f"{percentage:.2f}%"] for class_name, percentage in zip(class_names, percentages)]

            # Print the table using tabulate
            clear_output(wait=True)
            print(tabulate(table_data, headers=['Class', 'Probability'], tablefmt='grid'))
            print(f"Predicted class: {predicted_class}\n")
            print(f"Record Time: {record_time:.4f} seconds")  # Display the inference time
            print(f"Preprocess Time: {preprocess_time:.4f} seconds")
            print(f"Inference Time: {inference_time:.4f} seconds")
            print("\nPress the Interrupt button to stop")
    except KeyboardInterrupt:
        print("Live prediction session ended.")

## Implement the Live Prediction!

Press the interrupt button to stop predicting.

In [7]:
live_prediction()

+---------+---------------+
| Class   | Probability   |
| Cello   | 18.32%        |
+---------+---------------+
| Piano   | 69.84%        |
+---------+---------------+
| Violin  | 11.84%        |
+---------+---------------+
Predicted class: Piano

Record Time: 1.1055 seconds
Preprocess Time: 0.0000 seconds
Inference Time: 0.1465 seconds

Press the Interrupt button to stop
Recording...
Live prediction session ended.


## Predicting a .wav Audio File

Predict the instrument used in an audio file uploaded by you.

The code below will produce a prediction after analysing the whole audio file you uploaded.

#### `format_time` function

A helper function that converts a time in seconds to minutes and seconds format and rounds seconds to the nearest integer.

In [None]:
def format_time(seconds):
  minutes = int(seconds // 60)  # Get the full minutes
  remaining_seconds = round(seconds % 60)  # Get the remainder seconds, rounded to the nearest integer
  return f"{minutes} minutes and {remaining_seconds} seconds"

#### `analyse_wav_file` function

A helper function that processes the uploaded audio and produces a culmulative prediction.

The first and last 5% of the audio is not analysed as they are usually no instrument playing during these periods.

In [None]:
def analyse_wav_file(wav_path, model, processor, device, segment_duration=1, sampling_rate=16000):
  waveform, sr = torchaudio.load(wav_path)
  if sr != sampling_rate:
    resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sampling_rate)
    waveform = resampler(waveform)

  if waveform.size(0) > 1:
    waveform = torch.mean(waveform, dim=0, keepdim=True)

  # Determine the start and end indices to exclude the first and last 5%
  total_samples = waveform.size(1)
  start_index = int(total_samples * 0.05)  # Start at 5% of the total samples
  end_index = int(total_samples * 0.95)    # End at 95% of the total samples

  # Adjusted number of samples to process
  adjusted_total_samples = end_index - start_index
  num_samples = int(sampling_rate * segment_duration)
  num_segments = adjusted_total_samples // num_samples

  cumulative_probabilities = None

  print("Processing and predicting segments...")
  for i in range(num_segments):
    start_sample = start_index + i * num_samples
    end_sample = start_sample + num_samples
    segment = waveform[:, start_sample:end_sample]

    input_values = preprocess_audio(segment, processor, sampling_rate)
    probabilities, _ = predict(model, processor, input_values, device)

    if cumulative_probabilities is None:
      cumulative_probabilities = probabilities
    else:
      cumulative_probabilities += probabilities

  average_probabilities = cumulative_probabilities / num_segments
  average_probabilities = average_probabilities.squeeze()  # Remove unnecessary dimensions

  # Identify the final predicted class
  max_prob, max_index = torch.max(average_probabilities, dim=0)
  final_predicted_class = class_names[max_index.item()]
  
  # Time information
  start_time_seconds = start_index / sampling_rate
  end_time_seconds = end_index / sampling_rate
  start_time_formatted = format_time(start_time_seconds)
  end_time_formatted = format_time(end_time_seconds)

  # Prepare data for tabulation
  data = [(class_name, f"{prob * 100:.2f}%") for class_name, prob in zip(class_names, average_probabilities)]
  print(tabulate(data, headers=["Class", "Probability"], tablefmt="grid"))
  print(f"Final predicted class: {final_predicted_class} ({max_prob.item() * 100:.2f}%)")
  print(f"Analysis from {start_time_formatted} to {end_time_formatted}.")

#### Downloading an Audio File

Since we do not currently have any audio files, lets download some from YouTube.

##### Import Libaries

In [None]:
from pytube import YouTube
from moviepy.editor import AudioFileClip
import os

##### Helper Functions

In [None]:
# Downloads a YouTube video and returns the download path.
def download_video(youtube_url, save_path):
  try:
    yt = YouTube(youtube_url)
    video_stream = yt.streams.filter(only_audio=True).first()
    downloaded_file = video_stream.download(output_path=save_path)
    return downloaded_file
  except Exception as e:
    print(f"Error downloading video: {e}")
    return None

In [None]:
# Converts the downloaded video to WAV format.
def convert_to_wav(video_path, output_path):
  try:
    video_clip = AudioFileClip(video_path)
    video_clip.write_audiofile(output_path, codec='pcm_s16le')
    print(f"Conversion successful. File saved to: {output_path}")
  finally:
    video_clip.close()  # Ensure resources are cleaned up

In [None]:
# Deletes the video file specified by video_path.
def delete_video_file(video_path):
  try:
    os.remove(video_path)
    print(f"Deleted video file: {video_path}")
  except Exception as e:
    print(f"Error deleting video file: {e}")

##### Download the Audio of YouTube Videos
3 videos are pre-selected to demonstrate the results of the model.

In [None]:
output_folder = "audio"
if not os.path.exists(output_folder):
  os.makedirs(output_folder)

youtube_urls = [
  "https://www.youtube.com/watch?v=hykHZOXV-S0",  #cello
  "https://www.youtube.com/watch?v=q7mgrcULIbs",  #piano
  "https://www.youtube.com/watch?v=xZZFU0KVpKE",  #violin
]

for url in youtube_urls:
  video_path = download_video(url,output_folder)
  if video_path:
    output_path = os.path.join(output_folder, video_path.replace('.mp4', '.wav'))  # Change extension to .wav
    convert_to_wav(video_path, output_path)
    delete_video_file(video_path)  # Delete the original video file after conversion

MoviePy - Writing audio in d:\_NTU\Y2S2\SC1015\Project\audio\Down by the Salley Gardens - Irish Cello.wav


                                                                      

MoviePy - Done.
Conversion successful. File saved to: d:\_NTU\Y2S2\SC1015\Project\audio\Down by the Salley Gardens - Irish Cello.wav
Deleted video file: d:\_NTU\Y2S2\SC1015\Project\audio\Down by the Salley Gardens - Irish Cello.mp4
MoviePy - Writing audio in d:\_NTU\Y2S2\SC1015\Project\audio\David Lanz performs Cristoforis Dream live solo piano concert at Piano Haven.wav


                                                                      

MoviePy - Done.
Conversion successful. File saved to: d:\_NTU\Y2S2\SC1015\Project\audio\David Lanz performs Cristoforis Dream live solo piano concert at Piano Haven.wav
Deleted video file: d:\_NTU\Y2S2\SC1015\Project\audio\David Lanz performs Cristoforis Dream live solo piano concert at Piano Haven.mp4
MoviePy - Writing audio in d:\_NTU\Y2S2\SC1015\Project\audio\Homelanders Theme I can do anything finale violin performance.wav


                                                                      

MoviePy - Done.
Conversion successful. File saved to: d:\_NTU\Y2S2\SC1015\Project\audio\Homelanders Theme I can do anything finale violin performance.wav
Deleted video file: d:\_NTU\Y2S2\SC1015\Project\audio\Homelanders Theme I can do anything finale violin performance.mp4




## Implement the Audio File Prediction!

##### Cello

In [None]:
# Select the audio file
wav_path = 'audio/Down by the Salley Gardens - Irish Cello.wav'

# Analyse the audio file
analyse_wav_file(wav_path, model, processor, device)

Processing and predicting segments...
+---------+---------------+
| Class   | Probability   |
| Cello   | 53.25%        |
+---------+---------------+
| Piano   | 31.26%        |
+---------+---------------+
| Violin  | 15.49%        |
+---------+---------------+
Final predicted class: Cello (53.25%)
Analysis from 0 minutes and 11 seconds to 3 minutes and 33 seconds.


##### Piano

In [None]:
# Select the audio file
wav_path = 'audio/David Lanz performs Cristoforis Dream live solo piano concert at Piano Haven.wav'

# Analyse the audio file
analyse_wav_file(wav_path, model, processor, device)

Processing and predicting segments...
+---------+---------------+
| Class   | Probability   |
| Cello   | 11.85%        |
+---------+---------------+
| Piano   | 78.62%        |
+---------+---------------+
| Violin  | 9.53%         |
+---------+---------------+
Final predicted class: Piano (78.62%)
Analysis from 0 minutes and 19 seconds to 6 minutes and 10 seconds.


##### Violin

In [None]:
# Select the audio file
wav_path = 'audio/Homelanders Theme I can do anything finale violin performance.wav'

# Analyse the audio file
analyse_wav_file(wav_path, model, processor, device)

Processing and predicting segments...
+---------+---------------+
| Class   | Probability   |
| Cello   | 9.83%         |
+---------+---------------+
| Piano   | 9.06%         |
+---------+---------------+
| Violin  | 81.11%        |
+---------+---------------+
Final predicted class: Violin (81.11%)
Analysis from 0 minutes and 3 seconds to 1 minutes and 1 seconds.
