In [1]:
from pathlib import Path
import time, uuid, hashlib
import ray

from data_classes import InstrumentDetectJob

def jobs_from_audio_dir(audio_dir: str) -> list[InstrumentDetectJob]:
    audio_dir = Path(audio_dir)
    created_at = int(time.time())
    job_id = f"job_{uuid.uuid4().hex[:12]}"

    jobs: list[InstrumentDetectJob] = []
    for p in sorted(audio_dir.glob("*")):
        if not p.is_file():
            continue

        audio_bytes = p.read_bytes()
        song_id = f"trk_{uuid.uuid4().hex[:12]}"
        song_hash = hashlib.sha256(audio_bytes).hexdigest()
        audio_ref = ray.put(audio_bytes)

        jobs.append(
            InstrumentDetectJob(
                job_id=job_id,
                created_at=created_at,
                song_id=song_id,
                song_hash=song_hash,
                audio_ref=audio_ref,
                filename=p.name,
            )
        )
    return jobs

In [2]:
jobs = jobs_from_audio_dir("audio_files")

2025-12-21 14:14:15,875	INFO worker.py:2007 -- Started a local Ray instance.


In [3]:
from abc import ABC, abstractmethod
from typing import List
from load_qwen import load_model_and_processor
import numpy as np
import soundfile as sf
import io
import tempfile
import torchaudio

class InstrumentDetector(ABC):
    """
    Abstract base class for instrument detection models.

    All instrument detector implementations must provide both `process` and `predict` methods.
    """

    @abstractmethod
    def process(self, audio_bytes_list) -> List[str]:
        pass

    @abstractmethod
    def predict(self, audio_bytes_list) -> List[str]:
        """ """
        pass


class QwenInstrumentDetector(InstrumentDetector):
    def __init__(self):
        self.model_name = "Qwen/Qwen3-Omni-30B-A3B-Thinking"
        self.model, self.processor = load_model_and_processor(self.model_name)

    def process(self, jobs: list[InstrumentDetectJob]) -> List[str]:
        waveform_audios = []
        for job in jobs:
            suffix = job.filename.split(".")[-1]
            audio_bytes = ray.get(job.audio_ref)
            waveform_audios.append(self.decode_audio_bytes_to_waveform(audio_bytes, suffix, target_sr=16000))

        return waveform_audios

    def predict(self, audio_ref_list) -> List[str]:
        pass

    def decode_audio_bytes_to_waveform(self, audio_bytes: bytes, suffix: str, target_sr: int = 16000) -> np.ndarray:
        # Write bytes to a temp file so ffmpeg backend can decode it
        with tempfile.NamedTemporaryFile(suffix=suffix, delete=True) as f:
            f.write(audio_bytes)
            f.flush()
    
            wav, sr = torchaudio.load(f.name)  # wav: (channels, time)
    
        # Convert to mono (optional, but typical)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
    
        # Resample to target_sr
        if sr != target_sr:
            wav = torchaudio.functional.resample(wav, orig_freq=sr, new_freq=target_sr)
    
        return wav.squeeze(0).numpy().astype("float32")


In [4]:
detector = QwenInstrumentDetector()

Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section', 'mrope_interleaved', 'interleaved'}
You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour


Loading checkpoint shards:   0%|          | 0/16 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


In [5]:
wave_form_list = detector.process(jobs)

In [6]:
wave_form_list

[array([-1.4878102e-05, -2.9345811e-05, -2.9599703e-05, ...,
         0.0000000e+00,  0.0000000e+00,  0.0000000e+00],
       shape=(4277974,), dtype=float32),
 array([2.4055022e-05, 4.3867341e-05, 3.3767825e-05, ..., 3.7365222e-05,
        3.2300261e-05, 3.1887150e-05], shape=(4278416,), dtype=float32),
 array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
        -1.4919133e-05, -1.6012546e-05, -1.2392791e-05],
       shape=(4203416,), dtype=float32)]

[36m(pid=38956)[0m [2025-12-21 14:14:48,073 E 38956 39157] core_worker_process.cc:842: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
[2025-12-21 14:14:48,660 E 38668 38954] core_worker_process.cc:842: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14
