In [None]:
import os
import sys
import torch

import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

device = 'cuda'

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

HPARAM_FILE = 'hparams/sepformer_llama2_lora.yaml'
CKPT_PATH = 'save/pretrain_sepformer_llama2_lora'
LLM_PATH = None
assert LLM_PATH != None

  from .autonotebook import tqdm as notebook_tqdm
torchvision is not available - cannot save figures


## Load Everything

In [2]:
argv = [HPARAM_FILE]
argv += ['--test_set', 'null'] # Test set not ready yet
argv += ['--test_pattern', 'null']
argv += ['--test_files', 'null']

argv += ['--analyze', 'true']

argv += ['--llm_path', LLM_PATH]
argv += ['--save_folder', CKPT_PATH] # Ckpt folder

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

# Init model
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)
hparams['tokenizer'].pad_token = '[PAD]'
    
# Init data
test_loader = torch.utils.data.DataLoader(
    hparams['test_set'],
    **hparams['test_loader_opts']
)

# Load model weights
loaded = hparams['checkpointer'].recover_if_possible()
assert loaded != None

# Put model on GPU
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    print(f'Load {name} to {device}.')

# Save memory
if hparams['llm_mix_prec']:
    print('Cast LLM to bf16')
    hparams['llm'] = hparams['llm'].to(hparams['mix_dtype'])

Initialized ShortTemplate: 
shuffle: True random: True
Initialized a FiLM before IntraInterBlock 0.
Initialized a FiLM before IntraInterBlock 1.
### Analyze mode: MaskNet returns output and features ###


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.53s/it]


LoRA loaded from  save/sepformer_llama2_lora/CKPT+2023-12-14+23-27-09+00/lora_llm.ckpt
Load encoder to cuda.
Load decoder to cuda.
Load masknet to cuda.
Load lora_llm to cuda.


## Inference functions

In [25]:
@torch.no_grad()
def edit_sound(mix, text_embed, features):
        
    # Encoding speech
    mix_h = hparams['modules']['encoder'](mix)
    
    features['Encoder Output'] = mix_h.detach().clone().squeeze()

    # Extraction
    est_mask = hparams['modules']['masknet'](mix_h, text_embed, features).squeeze(0)
    
    features['Estimated Mask'] = est_mask.detach().clone().squeeze()
    
    est_tar_h = mix_h * est_mask # (B, F, T)
    
    features['Decoder Input'] = est_tar_h.detach().clone().squeeze()

    # Decoding
    est_tar = hparams['modules']['decoder'](est_tar_h)

    # T changed after conv1d in encoder, fix it here
    T_origin = mix.size(1)
    T_ext = est_tar.size(1)

    if T_origin > T_ext:
        est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
    else:
        est_tar = est_tar[:, :T_origin]
        
    features['Output Waveform'] = est_tar.detach().clone().squeeze()

    return est_tar

@torch.no_grad()
def read_prompt(prompt):
    
    # Tokenize
    tokens = hparams['tokenizer'](
        prompt, padding=True, return_tensors='pt'
    )['input_ids'].to(device)
    
    # Encode
    words_embed = hparams['lora_llm'](
        tokens, output_hidden_states=True
    ).hidden_states[-1] # last layer

    return words_embed[:, -1, :] # last or EOS token

@torch.no_grad()
def infer(mix, prompt):
    
    features = {'Input Waveform': mix.detach().clone().squeeze()}
    
    text_embed = read_prompt(prompt)
    
    features['Semantic Filter'] = text_embed.detach().clone().squeeze()
    
    est_tar = edit_sound(mix, text_embed, features)
    
    return est_tar, features


weights = {
    'Encoder': hparams['modules']['encoder'].conv1d.weight.detach().clone(),
    'Decoder': hparams['modules']['decoder'].weight.detach().clone()
}

## Test one sample

In [4]:
import librosa
from IPython.display import Audio

mix, sr = librosa.load('samples/mix1.wav', sr=16000)
mix = torch.tensor(mix).to(device).unsqueeze(0) # (1, T)
prompt1 = 'Extract the piano.'
prompt2 = 'Extract the machine sound.'

est1, features1 = infer(mix, prompt1)
est2, features2 = infer(mix, prompt2)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [32]:
print('### Features/Weights: Shape ###\n')

print('Encoder weights (kernel=16, stride=8, basis=256):', tuple(weights['Encoder'].shape))
print('Decoder weights (kernel=16, stride=8, basis=256):', tuple(weights['Decoder'].shape))
print('\n')

for k, v in features1.items():
    print(k+':', tuple(v.shape))

### Features/Weights: Shape ###

Encoder weights (kernel=16, stride=8, basis=256): (256, 1, 16)
Decoder weights (kernel=16, stride=8, basis=256): (256, 1, 16)


Input Waveform: (80000,)
Semantic Filter: (4096,)
Encoder Output: (256, 9999)
Block0 Input Before FiLM (2D): (256, 9999)
Block0 Input Before FiLM (3D): (256, 250, 82)
FiLM0 Gamma: (256,)
FiLM0 Beta: (256,)
Block0 Input After FiLM (2D): (256, 9999)
Block0 Input After FiLM (3D): (256, 250, 82)
Block1 Input Before FiLM (2D): (256, 9999)
Block1 Input Before FiLM (3D): (256, 250, 82)
FiLM1 Gamma: (256,)
FiLM1 Beta: (256,)
Block1 Input After FiLM (2D): (256, 9999)
Block1 Input After FiLM (3D): (256, 250, 82)
Block1 Output (2D): (256, 9999)
Block1 Output (3D): (256, 250, 82)
Estimated Mask: (256, 9999)
Decoder Input: (256, 9999)
Output Waveform: (80000,)


In [6]:
Audio(mix.cpu(), rate=16000)

In [7]:
Audio(est1.cpu(), rate=16000)

In [8]:
Audio(est2.cpu(), rate=16000)

## Inference in dataset (not ready until you have the data...)

In [8]:
for data in test_loader:
    mix, tar, prompt, acts = data[0:4]
    mix = mix.to(device)
    est_tar = infer(mix, prompt)
    break
    assert est_tar.shape == mix.shape

In [9]:
print(prompt)
Audio(mix.cpu(), rate=16000)

('Please ensure that the piano sound and the female speaker with the high pitch are removed from the audio file.',)


In [10]:
Audio(est_tar.cpu(), rate=16000)