In [None]:
!pip install onnxruntime torchcodec streamlit pyngrok

Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting torchcodec
  Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting streamlit
  Downloading streamlit-1.52.2-py3-none-any.whl.metadata (9.8 kB)
Collecting pyngrok
  Downloading pyngrok-7.5.0-py3-none-any.whl.metadata (8.1 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m83.0 MB/s[0m eta [36m0:00:00[0m
[?25hDow

In [None]:
import onnxruntime as ort
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB
import torchaudio
import torch
import numpy as np

In [None]:
# Define constants for audio processing and model interpretation
SAMPLE_RATE = 16000
N_FFT = 1024
HOP_LENGTH = 256
N_MELS = 128
SPEC_LEN = 128
INSTR_FAMILY_MAP = {
    0: 'Accordion', 1: 'Acoustic_Guitar', 2: 'Banjo', 3: 'Bass_Guitar', 4: 'Clarinet',
    5: 'Cymbals', 6: 'Dobro', 7: 'Drum_set', 8: 'Electro_Guitar', 9: 'Floor_Tom',
    10: 'Harmonica', 11: 'Harmonium', 12: 'Hi_Hats', 13: 'Horn', 14: 'Keyboard',
    15: 'Mandolin', 16: 'Organ', 17: 'Piano', 18: 'Saxophone', 19: 'Shakers',
    20: 'Tambourine', 21: 'Trombone', 22: 'Trumpet', 23: 'Ukulele', 24: 'Violin',
    25: 'cowbell', 26: 'flute', 27: 'vibraphone'
}

class NsynthOnnxWrapper:
    def __init__(self, onnx_model_path):
        self.session = ort.InferenceSession(
            onnx_model_path,
            providers=["CPUExecutionProvider"]
        )
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

        # Reuse same transforms as Dataset
        self.mel = MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=N_FFT,
            hop_length=HOP_LENGTH,
            n_mels=N_MELS,
            center=True,
            power=2.0,
        )
        self.to_db = AmplitudeToDB(stype="power")

    def preprocess_wav(self, wav_path):
        waveform, sr = torchaudio.load(wav_path)
        if sr != SAMPLE_RATE:
            waveform = torchaudio.functional.resample(
                waveform, orig_freq=sr, new_freq=SAMPLE_RATE
            )
        if waveform.size(0) > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)

        spec = self.mel(waveform)
        spec = self.to_db(spec).squeeze(0)
        # same length handling as Dataset
        if spec.size(1) < SPEC_LEN:
            pad = SPEC_LEN - spec.size(1)
            spec = torch.nn.functional.pad(spec, (0, pad))
        elif spec.size(1) > SPEC_LEN:
            spec = spec[:, :SPEC_LEN]

        mean = spec.mean()
        std = spec.std() + 1e-6
        spec = (spec - mean) / std
        spec = spec.unsqueeze(0).unsqueeze(0)  # (1,1,128,128)

        return spec.numpy().astype("float32")

    def predict(self, wav_path):
        x = self.preprocess_wav(wav_path)
        logits = self.session.run([self.output_name], {self.input_name: x})[0]
        probs = torch.softmax(torch.from_numpy(logits), dim=1).numpy()[0]
        pred_idx = int(np.argmax(probs))
        pred_name = INSTR_FAMILY_MAP[pred_idx]
        return pred_idx, pred_name, probs

#Example usage (outside Streamlit):
wrapper = NsynthOnnxWrapper("/content/nsynth_instrument_family_cnn.onnx")
idx, name, probs = wrapper.predict("/content/100.wav")
print(idx, name)

2 Banjo


In [None]:
from pyngrok import ngrok
import threading, time, subprocess

# Set your auth token from ngrok dashboard
ngrok.set_auth_token("35XhVSBD5bvBsCR9Sz7uLsRB3ns_6rgzDnoCqRtGs6oSxPkX1")

def run_streamlit():
    subprocess.Popen(
        ["streamlit", "run", "/content/app.py",
         "--server.port", "8501",
         "--server.headless", "true"]
    )

threading.Thread(target=run_streamlit, daemon=True).start()
time.sleep(5)

public_url = ngrok.connect(8501)
print("Public URL:", public_url)


Public URL: NgrokTunnel: "https://unobjectional-overreactive-noma.ngrok-free.dev" -> "http://localhost:8501"
