<a href="https://colab.research.google.com/github/shiehn/dawnet-remotes/blob/main/DAWNet_Remote_BeatNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INSTALL DEPENDENCIES

In [None]:
#Check if librosa is installed, if not then install it
try:
    import librosa
except Exception:
    !pip install librosa
    import librosa

#Check if madmom is installed, if not then install it
try:
    import madmom
except Exception:
    !pip install --upgrade --no-deps --force-reinstall --quiet 'git+https://github.com/CPJKU/madmom.git'
    import madmom

#Check if pyaudio is installed, if not then install it
try:
    import pyaudio
except Exception:
    !apt-get install portaudio19-dev
    !pip install pyaudio
    import pyaudio

#Check if BeatNet is installed, if not then install it
try:
    from BeatNet.BeatNet import BeatNet
except Exception:
    !pip install BeatNet
    from BeatNet.BeatNet import BeatNet


In [None]:
# Check if ffmpeg is installed, if not then install it
ffmpeg_installed = !command -v ffmpeg
if not ffmpeg_installed:
    !apt-get install ffmpeg

#Check if dawnet-client is installed, if not then install it
try:
    import dawnet_client
except ImportError:
    !pip install dawnet-client
    import dawnet_client

import dawnet_client.core as dawnet
from dawnet_client import DAWNetFilePath
import os
import shutil
import uuid
import numpy as np

placeholder_txt = "Enter the token generated by the DAWNet plugin"
DAWNET_TOKEN = "44d3d06b-f3dc-4f39-a532-c9bf486ae62a" #@param {type:"string"}
dawnet_token = DAWNET_TOKEN

if dawnet_token is None or dawnet_token == "" or dawnet_token == placeholder_txt:
  print("ERROR: The token provided is not valid.")
  exit()

def calculate_average_bpm(output):
    """
    Calculate a refined average BPM from BeatNet output by filtering outliers.

    :param output: numpy array of shape (n, 2) where each row is [time, beat_position]
    :return: Refined average BPM as a float
    """
    # Calculate time differences (beat durations)
    beat_durations = np.diff(output[:, 0])

    # Filter outliers
    Q1 = np.percentile(beat_durations, 25)
    Q3 = np.percentile(beat_durations, 75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    filtered_durations = beat_durations[(beat_durations >= lower_bound) & (beat_durations <= upper_bound)]

    # Calculate average beat duration from filtered data
    if len(filtered_durations) == 0:
        return 0  # Avoid division by zero if there are no beats left after filtering
    average_beat_duration = np.mean(filtered_durations)

    # Convert to BPM
    bpm = 60 / average_beat_duration

    # Round to nearest whole number
    rounded_bpm = round(bpm)

    return rounded_bpm

def extract_bpm(input_file):
    # Generate a unique directory name
    unique_dir_name = f"/tmp/{uuid.uuid4()}"
    os.makedirs(unique_dir_name, exist_ok=True)

    # Copy the input file to the new directory
    new_file_path = os.path.join(unique_dir_name, os.path.basename(input_file))
    shutil.copy(input_file, new_file_path)

    # Initialize the estimator
    estimator = BeatNet(1, mode='offline', inference_model='DBN', plot=[], thread=False)

    # Use the new directory in estimator.process
    Output = estimator.process(new_file_path)

    bpm = calculate_average_bpm(Output)

    return bpm


async def dawnet_func(input_file: DAWNetFilePath):
    try:
      bpm = extract_bpm(input_file)

      # after executing your custom code you send data back to the plugin like so ..
      await dawnet.output().add_file(input_file)
      await dawnet.output().add_log(f"BPM is: {bpm}")
      await dawnet.output().add_message(f"BPM is: {bpm}")
      await dawnet.output().send()

      return True
    except Exception as e:
      await dawnet.output().add_error(f"Error in arbitrary_method: {e}")

      return False


dawnet.set_token(token=dawnet_token)
dawnet.set_name("DAWNet Template")
dawnet.set_description("This is a template for creating a DAWNet Remote")
dawnet.register_method(dawnet_func)

dawnet.connect_to_server()