# Speech Separation by Pretrained Models

In [None]:
%%bash
git clone https://github.com/tky823/DNN-based_source_separation.git

cd "./DNN-based_source_separation/egs/tutorials/"

# To install torch & torchaudio
pip install -r requirements.txt

In [None]:
%%bash
# Download speech dataset
for spk in aew axb bdl ; do
    wget "http://festvox.org/cmu_arctic/packed/cmu_us_${spk}_arctic.tar.bz2"
    tar -xjvf "./cmu_us_${spk}_arctic.tar.bz2" 
done

In [None]:
import sys

In [None]:
sys.path.append("/content/DNN-based_source_separation/src")

In [None]:
import IPython.display as ipd
import torch
import torchaudio

In [None]:
torch.manual_seed(111)

In [None]:
from models.deep_clustering import DeepClustering
from models.danet import DANet, FixedAttractorDANet
from models.adanet import ADANet
from models.lstm_tasnet import LSTMTasNet
from models.conv_tasnet import ConvTasNet
from models.dprnn_tasnet import DPRNNTasNet
from models.dptnet import DPTNet
from models.sepformer import SepFormer

In [None]:
waveform_aew, sample_rate = torchaudio.load("/content/cmu_us_aew_arctic/wav/arctic_a0001.wav")
waveform_axb, sample_rate = torchaudio.load("/content/cmu_us_axb_arctic/wav/arctic_a0002.wav")
waveform_bdl, sample_rate = torchaudio.load("/content/cmu_us_bdl_arctic/wav/arctic_a0003.wav")
SAMPLE_RATE_WSJ0 = 8000

In [None]:
resampler = torchaudio.transforms.Resample(sample_rate, SAMPLE_RATE_WSJ0)
waveform_aew = resampler(waveform_aew)
waveform_axb = resampler(waveform_axb)
waveform_bdl = resampler(waveform_bdl)

In [None]:
T_min = min(waveform_aew.size(-1), waveform_axb.size(-1), waveform_bdl.size(-1))
waveform_aew, waveform_axb, waveform_bdl = waveform_aew[:, :T_min], waveform_axb[:, :T_min], waveform_bdl[:, :T_min]
display(ipd.Audio(waveform_aew, rate=SAMPLE_RATE_WSJ0))
display(ipd.Audio(waveform_axb, rate=SAMPLE_RATE_WSJ0))
display(ipd.Audio(waveform_bdl, rate=SAMPLE_RATE_WSJ0))

## 2 speakers

In [None]:
n_sources = 2
mixture = waveform_aew + waveform_axb
display(ipd.Audio(mixture, rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DeepClustering.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = DeepClustering.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=model.threshold, n_sources=n_sources, iter_clustering=None)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = DANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=model.threshold, n_sources=n_sources, iter_clustering=None)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = FixedAttractorDANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = FixedAttractorDANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = ADANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=40, n_sources=n_sources)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = LSTMTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = ConvTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DPRNNTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DPTNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = SepFormer.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

## 3 speakers

In [None]:
n_sources = 3
mixture = waveform_aew + waveform_axb + waveform_bdl
display(ipd.Audio(mixture, rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DeepClustering.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = DeepClustering.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=model.threshold, n_sources=n_sources, iter_clustering=None)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = DANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=model.threshold, n_sources=n_sources, iter_clustering=None)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = FixedAttractorDANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = FixedAttractorDANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
wrapper_model = ADANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = wrapper_model(input, threshold=40, n_sources=n_sources)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = LSTMTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = ConvTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DPRNNTasNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = DPTNet.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))

In [None]:
model = SepFormer.build_from_pretrained(task="wsj0-mix", sample_rate=SAMPLE_RATE_WSJ0, n_sources=n_sources)
model.eval()

input = mixture.unsqueeze(dim=0)
with torch.no_grad():
    output = model(input)

output = output.squeeze(dim=0)
estimated = torch.split(output, [1]*n_sources, dim=0)

for idx in range(n_sources):
    display(ipd.Audio(estimated[idx], rate=SAMPLE_RATE_WSJ0))