# LD-PSDTF (Log-Determinant Positive Semidefinite Tensor Factorization)

## 因子分解

In [None]:
%%shell
git clone https://github.com/tky823/audio_source_separation.git

In [None]:
%cd "/content/audio_source_separation/egs/psdtf-example/ld-psdtf"

In [None]:
%%shell
wget "http://sap.ist.i.kyoto-u.ac.jp/members/yoshii/codes/LD-PSDTF.zip"
unzip LD-PSDTF.zip -d /tmp
mkdir -p ../../../dataset/sample-signal/
cp /tmp/audio/*.wav ../../../dataset/sample-signal/

In [None]:
import sys
sys.path.append("../../../src")

In [None]:
import numpy as np
import scipy.signal as ss
import soundfile as sf
import IPython.display as ipd
import matplotlib.pyplot as plt

In [None]:
from algorithm.psdtf import LDPSDTF

In [None]:
plt.rcParams['figure.dpi'] = 200

In [None]:
fft_size, hop_size = 512, 320
n_bins = fft_size // 2 + 1

In [None]:
x, sr = sf.read("../../../dataset/sample-signal/mixture.wav")

### 音源信号

In [None]:
ipd.Audio(x, rate=sr)

In [None]:
_, _, X = ss.stft(x, nperseg=fft_size, noverlap=fft_size-hop_size)
XX = X[:, np.newaxis, :] * X[np.newaxis, :, :].conj()

In [None]:
n_basis = 3

### LD-PSDTFの実行

In [None]:
np.random.seed(111)
psdtf = LDPSDTF(n_basis=n_basis)

In [None]:
basis, activation = psdtf(XX, iteration=100)

In [None]:
plt.figure()
plt.plot(psdtf.loss, color='black')
plt.show()

### PSDTF後の各信号

In [None]:
inv_XX = np.linalg.inv(XX.transpose(2, 0, 1) + 1e-12 * np.eye(n_bins))

In [None]:
for idx in range(n_basis):
    YY = basis[:, :, idx, np.newaxis] * activation[np.newaxis, np.newaxis, idx, :]
    Y = YY.transpose(2, 0, 1) @ inv_XX @ X.transpose(1, 0)[:, :, np.newaxis]
    Y = Y.squeeze(axis=2).transpose(1, 0)
    _, estimated_signal = ss.istft(Y, nperseg=fft_size, noverlap=fft_size-hop_size)
    estimated_signal = estimated_signal / np.abs(estimated_signal).max()
    display(ipd.Audio(estimated_signal, rate=sr))

In [None]:
for idx in range(n_basis):
    YY = basis[:, :, idx, np.newaxis] * activation[np.newaxis, np.newaxis, idx, :]
    estimated_power = np.abs(YY)**2
    estimated_power[estimated_power < 1e-24] = 1e-24
    log_spectrogram = 10 * np.log10(estimated_power)
    n_bins = log_spectrogram.shape[-2]

    plt.figure(figsize=(4, 4))
    plt.pcolormesh(log_spectrogram[..., 100], cmap='jet')
    plt.xlim(0, n_bins)
    plt.ylim(0, n_bins)
    plt.show()

## 更新アルゴリズム
- `'em'`: Expectation-maximization アルゴリズム
- `'mm'`: Majorization-minimization アルゴリズム

### EMアルゴリズム

In [None]:
np.random.seed(111)
psdtf_em = LDPSDTF(n_basis=n_basis, algorithm='em')
basis, activation = psdtf_em(X, iteration=50)

### MMアルゴリズム

In [None]:
np.random.seed(111)
psdtf_mm = LDPSDTF(n_basis=n_basis, algorithm='mm')
basis, activation = psdtf_mm(X, iteration=50)

In [None]:
plt.figure()
plt.plot(psdtf_em.loss, color='mediumvioletred', label='EM algorithm')
plt.plot(psdtf_mm.loss, color='mediumblue', label='MM algorithm')
plt.legend()
plt.show()