In [None]:
import os
import sys
import torch

import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

device = 'cpu'

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

HPARAM_FILE = 'hparams/inference/sepformer_llama2_lora.yaml' 
CKPT_PATH = 'save/pretrain_sepformer_llama2_lora' # or 'save/cpu_pretrain_sepformer_llama2_lora'
LLM_PATH = '/engram/naplab/shared/LLaMA2/huggingface/Llama-2-7b-chat-hf' # meta-llama/Llama-2-7b-chat-hf

assert LLM_PATH != None

  from .autonotebook import tqdm as notebook_tqdm
torchvision is not available - cannot save figures


## Load Everything

In [2]:
argv = [HPARAM_FILE]

# Test set not ready yet
argv += ['--test_set', 'null'] # Test set not ready yet
argv += ['--test_pattern', 'null']
argv += ['--test_files', 'null']
argv += ['--analyze', 'false']

argv += ['--llm_path', LLM_PATH]
argv += ['--save_folder', CKPT_PATH] # Ckpt folder

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

# Init model
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)
hparams['tokenizer'].pad_token = '[PAD]'
    
# Init data
# test_loader = torch.utils.data.DataLoader(
#     hparams['test_set'],
#     **hparams['test_loader_opts']
# )

# Load model weights
loaded = hparams['checkpointer'].recover_if_possible()
assert loaded != None

# Put model on GPU
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    print(f'Load {name} to {device}.')

# Save memory
if hparams['llm_mix_prec']:
    print('Cast LLM to bf16')
    hparams['llm'] = hparams['llm'].to(hparams['mix_dtype'])

Initialized ShortTemplate: 
shuffle: True random: True
Initialized a FiLM before IntraInterBlock 0.
Initialized a FiLM before IntraInterBlock 1.


Loading checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.58s/it]
  from pkg_resources import packaging  # type: ignore[attr-defined]


LoRA loaded from  save/pretrain_sepformer_llama2_lora/CKPT+2023-12-14+23-27-09+00/lora_llm.ckpt
Load encoder to cpu.
Load decoder to cpu.
Load masknet to cpu.
Load lora_llm to cpu.


## Inference functions

In [3]:
@torch.no_grad()
def edit_sound(mix, text_embed, features):
        
    # Encoding speech
    mix_h = hparams['modules']['encoder'](mix)
    
    features['Encoder Output'] = mix_h.detach().clone().squeeze()

    # Extraction
    est_mask = hparams['modules']['masknet'](mix_h, text_embed, features).squeeze(0)
    
    features['Estimated Mask'] = est_mask.detach().clone().squeeze()
    
    est_tar_h = mix_h * est_mask # (B, F, T)
    
    features['Decoder Input'] = est_tar_h.detach().clone().squeeze()

    # Decoding
    est_tar = hparams['modules']['decoder'](est_tar_h)

    # T changed after conv1d in encoder, fix it here
    T_origin = mix.size(1)
    T_ext = est_tar.size(1)

    if T_origin > T_ext:
        est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
    else:
        est_tar = est_tar[:, :T_origin]
        
    features['Output Waveform'] = est_tar.detach().clone().squeeze()

    return est_tar

@torch.no_grad()
def read_prompt(prompt):
    
    # Tokenize
    tokens = hparams['tokenizer'](
        prompt, padding=True, return_tensors='pt'
    )['input_ids'].to(device)
    
    # Encode
    words_embed = hparams['lora_llm'](
        tokens, output_hidden_states=True
    ).hidden_states[-1] # last layer

    return words_embed[:, -1, :] # last or EOS token

@torch.no_grad()
def infer(mix, prompt):
    
    features = {'Input Waveform': mix.detach().clone().squeeze()}
    
    text_embed = read_prompt(prompt)
    
    features['Semantic Filter'] = text_embed.detach().clone().squeeze()
    
    est_tar = edit_sound(mix, text_embed, features)
    
    return est_tar, features


weights = {
    'Encoder': hparams['modules']['encoder'].conv1d.weight.detach().clone(),
    'Decoder': hparams['modules']['decoder'].weight.detach().clone()
}

## Test one sample

In [4]:
import librosa
from IPython.display import Audio

mix, sr = librosa.load('samples/mix1.wav', sr=16000)
mix = torch.tensor(mix).to(device).unsqueeze(0) # (1, T)
prompt1 = 'Extract the piano.'
prompt2 = 'Extract the machine sound.'

est1, features1 = infer(mix, prompt1)
est2, features2 = infer(mix, prompt2)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [5]:
Audio(mix.cpu(), rate=16000)

In [6]:
Audio(est1.cpu(), rate=16000)

In [8]:
Audio(est2.cpu(), rate=16000)

## Inference in dataset

In [60]:
import json

test_manifest = 'test/TextrolSpeech2VGGSound2Mix/prompts/test_{}k.json'
tests = []
for i in range(1, 11):
    tests += json.load(open(test_manifest.format(i), 'r'))

print('Found in total '+str(len(tests))+' test samples.')

tasks = [
    'TSE', 'TSR', 'TAE', 'TAR',
    'SE', 'SR', 'S↑', 'S↓',
    'TS↑', 'TS↓', 'TA↑', 'TA↓',
    'HE', 'HVC', 'OVC', 'RHVC', 
]

tests = {task: [x for x in tests if x['task']==task] for task in tasks}


print(f"Found in total {str(len(tests['TSE']))} TSE samples.")
print(f"Found in total {str(len(tests['TAE']))} TAE samples.")
print(f"Found in total {str(len(tests['HE']))} HE samples.")

Found in total 10000 test samples.
Found in total 625 TSE samples.
Found in total 625 TAE samples.
Found in total 625 HE samples.


In [None]:
# Load and test the i-th sample of the 'TSE' task
i = 1
test = tests['TSE'][i]
prompt = test['parsed_prompts'][0]
mix_path = test['path']
mix, sr = librosa.load(mix_path, sr=16000)
mix = torch.tensor(mix).to(device).unsqueeze(0) # (1, T)

est, features = infer(mix, prompt)

In [38]:
print(prompt)
Audio(mix.cpu(), rate=16000)

Could you help me isolate the man with a deep voice and fast pace in the audio?


In [39]:
Audio(est.cpu(), rate=16000)