In [11]:
import os
import sys
import torch

import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

device = 'cuda'

os.environ['TOKENIZERS_PARALLELISM'] = 'true'

HPARAM_FILE = 'hparams/inference/sepformer_llama2_lora_animals.yaml'
CKPT_PATH = 'save/sepformer_llama2_lora'

# 'meta-llama/Llama-2-7b-chat-hf' if llama2 is not in your system
# huggingface will download and save llama2 for you for the first time
# (and directly load from cache after the first time)
# or pass where llama2 is stored directly
# e.g. 'meta-llama/Llama-2-7b-chat-hf'
# '/engram/naplab/shared/LLaMA2/huggingface/Llama-2-7b-chat-hf'
LLM_PATH = '/engram/naplab/shared/LLaMA2/huggingface/Llama-2-7b-chat-hf'
assert LLM_PATH != None

# DATA_ROOT is where "animal_mixtures" are stored and and should end with "animal_mixtures"
DATA_ROOT = '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix'
assert DATA_ROOT != None

test  train  valid


## Load Everything

In [2]:
argv = [HPARAM_FILE]
argv += ['--filt_labels_mode', 'both'] 
argv += ['--keep_spks', 'true']

argv += ['--analyze', 'true']
argv += ['--llm_path', LLM_PATH]
argv += ['--save_folder', CKPT_PATH] # Ckpt folder

hparam_file, run_opts, overrides = sb.parse_arguments(argv)

# Init model
with open(hparam_file) as f:
    hparams = load_hyperpyyaml(f, overrides)
hparams['tokenizer'].pad_token = '[PAD]'

# Load model weights
loaded = hparams['checkpointer'].recover_if_possible()
assert loaded != None

# Put model on GPU
for name, mod in hparams['modules'].items():
    mod.to(device)
    mod.eval()
    print(f'Load {name} to {device}.')

# Save memory
if hparams['llm_mix_prec']:
    print('Cast LLM to bf16')
    hparams['llm'] = hparams['llm'].to(hparams['mix_dtype'])

Fetched 1 manifest files.
Filtering test set by labels... both audio(s) should belong to labels...
Found 386 mixtures from 57 labels.
Fetched 5 manifest files.
Filtering test set by labels... both audio(s) should belong to labels...
Found 187 mixtures from 57 labels.
Initialized a FiLM before IntraInterBlock 0.
Initialized a FiLM before IntraInterBlock 1.
### Analyze mode: MaskNet returns output and features ###


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



LoRA loaded from  save/sepformer_llama2_lora/CKPT+2023-12-14+23-27-09+00/lora_llm.ckpt
Load encoder to cuda.
Load decoder to cuda.
Load masknet to cuda.
Load lora_llm to cuda.


In [3]:
!ls data/manifests/animals_test.json

ls: cannot access data/mainfests/animals_test.json: No such file or directory


## Inference functions

In [3]:
def extract_prompt(label):
    if isinstance(label, tuple) or isinstance(label, list):
        label = str(label[0])
    return f'Please extract the sound of {label}.'

def remove_prompt(label):
    if isinstance(label, tuple) or isinstance(label, list):
        label = str(label[0])
    return f'Please remove the sound of {label}.'

@torch.no_grad()
def edit_sound(mix, text_embed, features):
        
    # Encoding speech
    mix_h = hparams['modules']['encoder'](mix)
    
    features['Encoder Output'] = mix_h.detach().clone().squeeze()

    # Extraction
    est_mask = hparams['modules']['masknet'](mix_h, text_embed, features).squeeze(0)
    
    features['Estimated Mask'] = est_mask.detach().clone().squeeze()
    
    est_tar_h = mix_h * est_mask # (B, F, T)
    
    features['Decoder Input'] = est_tar_h.detach().clone().squeeze()

    # Decoding
    est_tar = hparams['modules']['decoder'](est_tar_h)

    # T changed after conv1d in encoder, fix it here
    T_origin = mix.size(1)
    T_ext = est_tar.size(1)

    if T_origin > T_ext:
        est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
    else:
        est_tar = est_tar[:, :T_origin]
        
    features['Output Waveform'] = est_tar.detach().clone().squeeze()

    return est_tar

@torch.no_grad()
def read_prompt(prompt):
    
    # Tokenize
    tokens = hparams['tokenizer'](
        prompt, padding=True, return_tensors='pt'
    )['input_ids'].to(device)
    
    # Encode
    words_embed = hparams['lora_llm'](
        tokens, output_hidden_states=True
    ).hidden_states[-1] # last layer

    return words_embed[:, -1, :] # last or EOS token

@torch.no_grad()
def infer(mix, prompt):
    
    features = {'Input Waveform': mix.detach().clone().squeeze()}
    
    text_embed = read_prompt(prompt)
    
    features['Semantic Filter'] = text_embed.detach().clone().squeeze()
    
    est_tar = edit_sound(mix, text_embed, features)
    
    return est_tar, features


weights = {
    'Encoder': hparams['modules']['encoder'].conv1d.weight.detach().clone(),
    'Decoder': hparams['modules']['decoder'].weight.detach().clone()
}

test_set = hparams['test_set']

In [6]:
len(hparams['train_set'].records)

3654

In [3]:
import json
import shutil

animal_records = []

for record in hparams['valid_set'].records:
    animal_record = {}
    for key in ['id', 'name', 'duration', 'snr1', 'snr2', 'snr3', 'snr4', 
    'path', 'path1', 'path2', 'path3', 'path4', 'spk1', 'spk2', 'label3', 'label4',
    'corpus1', 'corpus2', 'gender1', 'gender2', 'pitch1', 'pitch2', 'tempo1', 'tempo2',
    'energy1', 'energy2', 'emotion1', 'emotion2'
    ]:
        animal_record[key] = record[key]
    animal_record['path'] = record['path'].replace(
        '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix', '<DATA_ROOT>'
    )
    animal_record['path1'] = record['path1'].replace(
        '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix', '<DATA_ROOT>'
    )
    animal_record['path2'] = record['path2'].replace(
        '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix', '<DATA_ROOT>'
    )
    animal_record['path3'] = record['path3'].replace(
        '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix', '<DATA_ROOT>'
    )
    animal_record['path4'] = record['path4'].replace(
        '/engram/naplab/shared/TextSpeechVGGMix/TextrolSpeech2VGGSound2Mix', '<DATA_ROOT>'
    )
    
#     shutil.copy(record['path'], animal_record['path'].replace('<DATA_ROOT>', 'animal_mixtures'))
#     shutil.copy(record['path1'], animal_record['path1'].replace('<DATA_ROOT>', 'animal_mixtures'))
#     shutil.copy(record['path2'], animal_record['path2'].replace('<DATA_ROOT>', 'animal_mixtures'))
#     shutil.copy(record['path3'], animal_record['path3'].replace('<DATA_ROOT>', 'animal_mixtures'))
#     shutil.copy(record['path4'], animal_record['path4'].replace('<DATA_ROOT>', 'animal_mixtures'))
    
    animal_records.append(animal_record)
    


with open('data/manifests/animals_valid.json', 'w') as f:
    json.dump(animal_records, f, indent=4)

In [4]:
len(animal_records)

187

In [25]:




import json

with open('data/manifests/animal_mixtures_test.json', 'w') as f:
    json.dump(hparams['test_set'].records, f, indent=4)

## Inference in dataset (not ready until you have the data...)

In [22]:
list(all_labels)

['owl hooting',
 'baltimore oriole calling',
 'chicken crowing',
 'chipmunk chirping',
 'dinosaurs bellowing',
 'cat meowing',
 'chinchilla barking',
 'dog barking',
 'cat purring',
 'lions roaring',
 'whale calling',
 'coyote howling',
 'chimpanzee pant-hooting',
 'dog howling',
 'gibbon howling',
 'dog growling',
 'lions growling',
 'pheasant crowing',
 'penguins braying',
 'duck quacking',
 'cat growling',
 'magpie calling',
 'frog croaking',
 'cat caterwauling',
 'goose honking',
 'mynah bird singing',
 'chicken clucking',
 'bird squawking',
 'cheetah chirrup',
 'pig oinking',
 'woodpecker pecking tree',
 'francolin calling',
 'mosquito buzzing',
 'cattle mooing',
 'ferret dooking',
 'elk bugling',
 'cat hissing',
 'dog bow-wow',
 'turkey gobbling',
 'crow cawing',
 'black capped chickadee calling',
 'sheep bleating',
 'otter growling',
 'snake hissing',
 'cricket chirping',
 'eagle screaming',
 'elephant trumpeting',
 'dog whimpering',
 'canary calling',
 'goat bleating',
 'snake 

In [18]:
from tqdm import tqdm

for data in tqdm(test_set, total=len(test_set)):
    mix = data['mix']
    labels = data['audio_labels']
    audios = data['audios']
    
    ext_prompt1 = extract_prompt(labels[0])
    ext_prompt2 = extract_prompt(labels[1])
    rem_prompt1 = remove_prompt(labels[0])
    rem_prompt2 = remove_prompt(labels[1])
    
    if labels[0] not in all_labels:
        all_labels[labels[0]] = 1
    if labels[1] not in all_labels:
        all_labels[labels[1]] = 1
    continue
    mix = torch.tensor(mix).unsqueeze(0).to(device)
    ext1, feat_ext1 = infer(mix, ext_prompt1)
    ext2, feat_ext2 = infer(mix, ext_prompt2)
    rem1, feat_rem1 = infer(mix, rem_prompt1)
    rem2, feat_rem2 = infer(mix, rem_prompt2)

100%|██████████| 386/386 [01:56<00:00,  3.30it/s]


In [50]:
from IPython.display import Audio

In [51]:
Audio(mix.cpu(), rate=16000)

In [52]:
Audio(ext1.cpu(), rate=16000)

In [55]:
Audio(rem1.cpu(), rate=16000)

In [56]:
Audio(rem2.cpu(), rate=16000)