In [None]:
%%shell
git clone https://github.com/tky823/DNN-based_source_separation.git

# Build environment
cd "/content/DNN-based_source_separation/egs/tutorials"

pip install -r requirements.txt -q

In [None]:
import sys
sys.path.append("/content/DNN-based_source_separation/src")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

In [None]:
plt.rcParams['font.size'] = 20

In [None]:
import torch

In [None]:
from models.conv_tasnet import ConvTasNet

In [None]:
def show_basis(basis, normalized=True):
    spectrogram = torch.fft.rfft(basis, dim=1)
    idx = torch.argmax(torch.abs(spectrogram), dim=1)
    order = torch.argsort(idx)
    sorted_basis = basis[order]

    sorted_spectrogram = torch.fft.rfft(sorted_basis, dim=1)
    sorted_spectrogram = torch.abs(sorted_spectrogram)
    
    if normalized:
        norm, _ = torch.max(torch.abs(sorted_basis), dim=1, keepdim=True)
        sorted_basis = sorted_basis / norm

        norm, _ = torch.max(sorted_spectrogram, dim=1, keepdim=True)
        sorted_spectrogram = sorted_spectrogram / norm
    
    vmax_basis = torch.max(torch.abs(sorted_basis)).item()
    vmax_spectrogram = torch.max(sorted_spectrogram).item()

    fig, axes = plt.subplots(1, 2, figsize=(8, 16))
    
    mappable = axes[0].pcolormesh(sorted_basis, cmap='bwr', norm=Normalize(vmin=-vmax_basis, vmax=vmax_basis))
    fig.colorbar(mappable, ax=axes[0])
    mappable = axes[1].pcolormesh(sorted_spectrogram, cmap='bwr', norm=Normalize(vmin=-vmax_spectrogram, vmax=vmax_spectrogram))
    fig.colorbar(mappable, ax=axes[1])
    fig.tight_layout()

    plt.show()
    plt.close()

In [None]:
model = ConvTasNet.build_from_pretrained(task="librispeech", sample_rate=16000, n_sources=2)
enc_basis = model.encoder.get_basis().detach().squeeze(dim=1)
dec_basis = model.decoder.get_basis().detach().squeeze(dim=1)

In [None]:
show_basis(enc_basis)

In [None]:
show_basis(dec_basis)