In [1]:
import pickle, os, logging
import numpy as np

import torch as th
import torchaudio

In [2]:
ROOT_DIR = os.getcwd()
print("Root Dir: ", ROOT_DIR)

Root Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation


# Embedding model

In [3]:
# change the current path to the imagebind root
os.chdir(os.path.join(ROOT_DIR, "imagebind"))
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation/imagebind


In [4]:
import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
device = "cuda:0" if th.cuda.is_available() else "cpu"

# Instantiate model
embedding_model = imagebind_model.imagebind_huge(pretrained=True)
embedding_model.eval()
embedding_model.to(device)

print("Embedding model loaded!")

Embedding model loaded!


# Waveformer model

In [6]:
waveformer_path = os.path.join(ROOT_DIR, "multimod-waveformer", "experiments", "cv_dcc_tf_ckpt_E256_10_D128_3")
model_ckp =  "149.pt"

waveformer_path

'/scratch/IOSZ/waveformer/multimod-sound-separation/multimod-waveformer/experiments/cv_dcc_tf_ckpt_E256_10_D128_3'

In [7]:
# change the current path to the waveformer root
os.chdir(os.path.join(ROOT_DIR, "multimod-waveformer"))
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation/multimod-waveformer


In [8]:
from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer

In [9]:
# load config file
params = utils.Params(os.path.join(waveformer_path, "config.json"))
audio_model = Waveformer(**params.model_params)

# Instantiate waveformer and load pretrained weights
audio_model.load_state_dict(
    th.load(os.path.join(waveformer_path, model_ckp), 
            map_location=device)["model_state_dict"]
    )

audio_model.to(device).eval()

print("Audio model loaded!")

LABEL LEN:  2
ENCODING DIM:  256
Audio model loaded!


# Process audio

## Define file paths

In [19]:
# choose input, output and target
INPUT_PATH = "common_voice_en_36703188-extract.wav"
OUTPUT_PATH = "common_voice_en_36703188-extract-extract.wav"
TARGET = "Female voice"

## Get embeddings

In [20]:
# change the current path to the imagebind root
os.chdir(os.path.join(ROOT_DIR, "imagebind"))
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation/imagebind


In [21]:
text_list=[TARGET]
audio_paths=[os.path.join("..", INPUT_PATH)]

In [22]:
# Load data
inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    # ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with th.no_grad():
    embeddings = embedding_model(inputs)

RuntimeError: Failed to open the input "../common_voice_en_36703188-extract.wav" (No such file or directory).

## Run through waveformer

In [14]:
# change the current path to the root
os.chdir(ROOT_DIR)
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation


In [15]:
# Read input audio
mixture, fs = torchaudio.load(INPUT_PATH)

if fs != 44100:
    mixture = torchaudio.functional.resample(mixture, orig_freq=fs, new_freq=44100)
mixture = mixture.unsqueeze(0).to(device)
print("Loaded input audio from %s" % INPUT_PATH)

# get the query from imagebind
query = embeddings[ModalityType.TEXT]
print("Query shape: ", query.shape)

# run inference
with th.inference_mode():
    
    # the fg_audio_paths arg here is not used, assuming the query has the right imagebind embeddings (1024)
    output = audio_model(mixture.to(device), query.to(device), fg_audio_paths=query.to(device), mode="test").squeeze(0).cpu()
    
    if fs != 44100:
        output = torchaudio.functional.resample(output, orig_freq=44100, new_freq=fs)
        print("Resample input to 44.1 kHz!")
    print("Inference done. Saving output audio to %s" % OUTPUT_PATH)

    assert not os.path.exists(OUTPUT_PATH), "Output file already exists."
    torchaudio.save(OUTPUT_PATH, output, fs)

Loaded input audio from common_voice_en_36613135.wav
Query shape:  torch.Size([1, 1024])
Resample input to 44.1 kHz!
Inference done. Saving output audio to common_voice_en_36613135-extract.wav


## Inspect output file (optional)

In [28]:
import librosa

# Load audio file
y, sr = librosa.load('common_voice_en_36613135-extract.wav')

print(sr)

# # Normalize audio data
# y = librosa.util.normalize(y)

# # Save normalized audio
# librosa.output.write_wav('normalized_audio.wav', y, sr)

22050
