In [10]:
import pickle, os, logging
import numpy as np

import torch as th
import torchaudio

In [5]:
ROOT_DIR = os.getcwd()
print("Root Dir: ", ROOT_DIR)

Root Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation


In [9]:
# choose input, output and target
INPUT_PATH = "Sample.wav"
OUTPUT_PATH = "sample-extract.wav"
TARGET = "Computer keyboard"

# ImageBind embeddings

In [6]:
# change the current path to the imagebind root
os.chdir(os.path.join(os.getcwd(), "imagebind"))
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation/imagebind


In [6]:
import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType

  from .autonotebook import tqdm as notebook_tqdm


In [19]:
text_list=[TARGET]
audio_paths=[INPUT_PATH]

In [1]:
device = "cuda:0" if th.cuda.is_available() else "cpu"

# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

# Load data
inputs = {
    ModalityType.TEXT: data.load_and_transform_text(text_list, device),
    # ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
}

with torch.no_grad():
    embeddings = model(inputs)

# Waveformer output

In [7]:
# change the current path to the imagebind root
os.chdir(os.path.join(ROOT_DIR, "multimod-waveformer"))
print("Dir: ", os.getcwd())

Dir:  /scratch/IOSZ/waveformer/multimod-sound-separation/multimod-waveformer


In [8]:
from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
os.path.join(os.getcwd(), "experiments", "dcc_tf_ckpt_E256_10_D128_1", "140.pt")

'/scratch/IOSZ/waveformer/multimod-sound-separation/multimod-waveformer/experiments/dcc_tf_ckpt_E256_10_D128_1/config.json'

In [None]:
# load config file
params = utils.Params(os.path.join(os.getcwd(), "experiments", "dcc_tf_ckpt_E256_10_D128_1", "config.json"))
model = Waveformer(**params.model_params)

# Instantiate waveformer and load pretrained weights
model.load_state_dict(
    th.load(os.path.join(os.getcwd(), "experiments", "dcc_tf_ckpt_E256_10_D128_1", "141.pt"), 
            map_location=device)["model_state_dict"]
    )

model.to(device).eval()

# Read input audio
mixture, fs = torchaudio.load(args.input)

if fs != 44100:
    mixture = torchaudio.functional.resample(mixture, orig_freq=fs, new_freq=44100)
mixture = mixture.unsqueeze(0).to(device)
print("Loaded input audio from %s" % args.input)

# get the query from imagebind
query = embeddings[ModalityType.TEXT]

# run inference
with th.inference_mode():
        output = model(mixture.to(device), query.to(device)).squeeze(0).cpu()
    if fs != 44100:
        output = torchaudio.functional.resample(output, orig_freq=44100, new_freq=fs)
    print("Inference done. Saving output audio to %s" % args.output)

    assert not os.path.exists(args.output), "Output file already exists."
    torchaudio.save(args.output, output, fs)

In [None]:
# TARGETS = [
#     "Acoustic_guitar",
#     "Applause",
#     "Bark",
#     "Bass_drum",
#     "Burping_or_eructation",
#     "Bus",
#     "Cello",
#     "Chime",
#     "Clarinet",
#     "Computer_keyboard",
#     "Cough",
#     "Cowbell",
#     "Double_bass",
#     "Drawer_open_or_close",
#     "Electric_piano",
#     "Fart",
#     "Finger_snapping",
#     "Fireworks",
#     "Flute",
#     "Glockenspiel",
#     "Gong",
#     "Gunshot_or_gunfire",
#     "Harmonica",
#     "Hi-hat",
#     "Keys_jangling",
#     "Knock",
#     "Laughter",
#     "Meow",
#     "Microwave_oven",
#     "Oboe",
#     "Saxophone",
#     "Scissors",
#     "Shatter",
#     "Snare_drum",
#     "Squeak",
#     "Tambourine",
#     "Tearing",
#     "Telephone",
#     "Trumpet",
#     "Violin_or_fiddle",
#     "Writing",
# ]