# Time Domain Wrappers for Time-Frequency Domain Models
- DeepClustering
- DANet
- ADANet
- MMDenseLSTM
- Open-Unmix
- CrossNet-Open-Unmix
- D3Net

In [None]:
%%bash
git clone https://github.com/tky823/DNN-based_source_separation.git

In [None]:
import sys
sys.path.append("/content/DNN-based_source_separation/src")

In [None]:
import torch

In [None]:
from models.deep_clustering import DeepEmbedding
from models.danet import DANet, FixedAttractorDANet
from models.adanet import ADANet
from models.mm_dense_lstm import MMDenseLSTM, ParallelMMDenseLSTM
from models.umx import OpenUnmix, ParallelOpenUnmix
from models.xumx import CrossNetOpenUnmix
from models.d3net import D3Net, ParallelD3Net

In [None]:
torch.manual_seed(111)

In [None]:
batch_size = 4

In [None]:
T = 6400

In [None]:
model = DeepEmbedding.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)
wrapper_model = DeepEmbedding.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, T)
with torch.no_grad():
    output = wrapper_model(input, n_sources=2)
print(input.size(), output.size())

with torch.no_grad():
    output = wrapper_model(input, n_sources=3)
print(input.size(), output.size())

In [None]:
model = DANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)
wrapper_model = DANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, T)
with torch.no_grad():
    output = wrapper_model(input, n_sources=2)
print(input.size(), output.size())

with torch.no_grad():
    output = wrapper_model(input, n_sources=3)
print(input.size(), output.size())

In [None]:
model = FixedAttractorDANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)
wrapper_model = FixedAttractorDANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)

input = torch.randn(batch_size, 1, T)
with torch.no_grad():
    output = wrapper_model(input, threshold=model.threshold)
print(input.size(), output.size())

In [None]:
model = ADANet.build_from_pretrained(task="wsj0-mix", sample_rate=8000, n_sources=2)
wrapper_model = ADANet.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, T)
with torch.no_grad():
    output = wrapper_model(input, n_sources=2)
print(input.size(), output.size())

with torch.no_grad():
    output = wrapper_model(input, n_sources=3)
print(input.size(), output.size())

In [None]:
mono_channels, stereo_channels = 1, 2
T = 5 * 44100

In [None]:
model = MMDenseLSTM.build_from_pretrained(task="musdb18", sample_rate=44100, target="vocals")
wrapper_model = MMDenseLSTM.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = ParallelMMDenseLSTM.build_from_pretrained(task="musdb18", sample_rate=44100)
wrapper_model = ParallelMMDenseLSTM.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = OpenUnmix.build_from_pretrained(task="musdb18", sample_rate=44100, target="vocals")
wrapper_model = OpenUnmix.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = ParallelOpenUnmix.build_from_pretrained(task="musdb18", sample_rate=44100)
wrapper_model = ParallelOpenUnmix.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = CrossNetOpenUnmix.build_from_pretrained(task="musdb18", sample_rate=44100)
wrapper_model = CrossNetOpenUnmix.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = D3Net.build_from_pretrained(task="musdb18", sample_rate=44100, target="vocals")
wrapper_model = D3Net.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())

In [None]:
model = ParallelD3Net.build_from_pretrained(task="musdb18", sample_rate=44100)
wrapper_model = ParallelD3Net.TimeDomainWrapper(model, n_fft=model.n_fft, hop_length=model.hop_length, window_fn=model.window_fn)
wrapper_model.eval()

input = torch.randn(batch_size, 1, stereo_channels, T)
with torch.no_grad():
    output = wrapper_model(input)
print(input.size(), output.size())