In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn import metrics
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

In [2]:
NUM_CLASSES = 10
X, y = make_classification(6300, 16000, n_informative=1000, n_classes=NUM_CLASSES, random_state=914)
X = X.astype(np.float32)
y = y.astype(np.int64)

In [3]:
X

array([[ 0.75105166, -0.15409301, -0.49907622, ..., -0.5336411 ,
        -0.62046456,  0.13240224],
       [-0.9926991 , -0.01308838, -0.7677303 , ...,  1.9629287 ,
        -1.9373821 , -1.2822019 ],
       [-0.90256   ,  0.36578926,  0.13826247, ...,  1.5470697 ,
         0.17384359, -0.80612546],
       ...,
       [ 1.7498939 , -0.7487113 ,  0.14626776, ..., -1.1909691 ,
        -1.5861615 , -0.5214099 ],
       [ 0.06106601, -0.44239494,  2.2134423 , ..., -1.2444557 ,
         0.09982967,  0.296488  ],
       [ 0.5877269 , -0.30185723, -1.0686728 , ..., -0.5939525 ,
         2.1231668 , -0.02211769]], dtype=float32)

In [4]:
y

array([8, 7, 3, ..., 1, 2, 9], dtype=int64)

In [5]:
def joint_different_speakers(waveforms, speakers, num_mix):
    """
    Parameters
    ----------
    audio_files: list
    speakers: list
    num_mix: int
    """
    mixed = []
    for i, wav in enumerate(waveforms):
        current_speaker = speakers[i]
        is_different_speakers = list(map(lambda x: x!=current_speaker, speakers))
        different_speakers_idx = [k for k, boolean in enumerate(is_different_speakers) if boolean]
        select_idx = list(np.random.choice(different_speakers_idx, num_mix, replace=False))
        for j in select_idx:
            trg_wav = wav
            itf_wav = waveforms[j]
            trg_spk = current_speaker
            itf_spk = speakers[j]
            yield trg_wav, itf_wav, trg_spk, itf_spk

def mix_speakers_by_snr(waveforms, speakers, num_mix, snr):
    generator = joint_different_speakers(waveforms, speakers, num_mix)
    mixed_data = []
    for trg_wav, itf_wav, trg_spk, itf_spk in tqdm(generator, total=len(speakers)*num_mix):
        # Calculate the scale to mix two speakers based on fixed SNR
        itf_spk_power = np.mean(np.square(trg_wav)) / (10**(snr/10))
        scale = np.sqrt(itf_spk_power / np.mean(np.square(itf_wav)))

        # Mix two speakers based on given snr
        mix_wav = trg_wav + scale * itf_wav

        mixed_data.append([mix_wav, itf_spk, trg_spk])
    return mixed_data

In [6]:
X_mixed_data = mix_speakers_by_snr(X, y, num_mix=10, snr=5)
mix_wavs = np.stack([data[0] for data in X_mixed_data])
itf_spks = np.stack([data[1] for data in X_mixed_data])
trg_spks = np.stack([data[2] for data in X_mixed_data])
X_mix_wavs = torch.tensor(mix_wavs, dtype=torch.float)

y_onehot = np.zeros((len(mix_wavs), NUM_CLASSES), dtype=float)
for idx, (itf_spk, trg_spk) in enumerate(zip(itf_spks, trg_spks)):
    y_onehot[idx, itf_spk] = 1
    y_onehot[idx, trg_spk] = 1
y_onehot = torch.from_numpy(y_onehot)

  0%|          | 0/63000 [00:00<?, ?it/s]

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X_mix_wavs, y_onehot, test_size=0.2, random_state=914)

In [8]:
class SVD(nn.Module):
    """
    Singular value decomposition layer
    
    Examples
    --------
    >>> A = torch.rand(126, 100, 20).to('cuda')
    >>> U, S, V = SVD(compute_uv=True)(A)
    >>> A_ = torch.matmul(U, torch.matmul(S, V.transpose(-1, -2)))
    >>> print(torch.dist(A_, A))
    """
    def __init__(self):
        super(SVD, self).__init__()
    
    def forward(self, A):
        """
        Inputs
        ------
        A: [b, m, n]
        
        Outputs
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]
        """
        return self.svd_(A)
        
    @staticmethod
    def svd_(A):
        """
        Parameters
        ----------
        A: torch.FloatTensor
            A tensor of shape [b, m, n].

        Returns
        -------
        U: [b, m, n]
        S: [b, n, n]
        V: [b, n, n]

        References
        ----------
        1. https://www.youtube.com/watch?v=pSbafxDHdgE&t=205s
        2. https://www2.math.ethz.ch/education/bachelor/lectures/hs2014/other/linalg_INFK/svdneu.pdf
        """
        ATA = torch.matmul(A.transpose(-1, -2), A)
        lv, vv = torch.linalg.eig(ATA)
        lv = lv.real
        vv = vv.real
        V = F.normalize(vv, dim=1)
        S = torch.diag_embed(torch.sqrt(lv))
        U = torch.matmul(torch.matmul(A, V), torch.inverse(S))
        return U, S, V

In [9]:
class DemixingNet(nn.Module):
    
    def __init__(self):
        super(DemixingNet, self).__init__()
        # self.conv = SincConv(in_channels=1, out_channels=20, kernel_size=11)
        self.ff1 = nn.Linear(16000, 100)
        self.ff2 = nn.Linear(100, 10)
        
    def forward(self, wavs):
        x = self.ff1(wavs)
        x = self.ff2(F.relu(x, inplace=True))
        return x

In [10]:
X_train.shape, X_test.shape

(torch.Size([50400, 16000]), torch.Size([12600, 16000]))

In [11]:
import torch.optim as optim

model = DemixingNet()
criterion = nn.BCEWithLogitsLoss()
optimiser = optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(20):
    optimiser.zero_grad()
    output = model(X_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimiser.step()
    print(f'Loss: {loss.item():.4f}')

Loss: 1.0450
Loss: 0.9906
Loss: 0.9517
Loss: 0.9229
Loss: 0.9005
Loss: 0.8822
Loss: 0.8666
Loss: 0.8528
Loss: 0.8404
Loss: 0.8291
Loss: 0.8185
Loss: 0.8088
Loss: 0.7996
Loss: 0.7910
Loss: 0.7830
Loss: 0.7754
Loss: 0.7683
Loss: 0.7616
Loss: 0.7553
Loss: 0.7494


In [14]:
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    '''
    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case
    https://stackoverflow.com/q/32239577/395857
    '''
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
            tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/\
                    float( len(set_true.union(set_pred)) )
        acc_list.append(tmp_a)
    return np.mean(acc_list)

In [17]:
y_hat = model(X_train)
y_hat = torch.sigmoid(y_hat)
y_hat = y_hat.detach().numpy()
y_pred = (y_hat >= 0.5).astype(int)
print(hamming_score(y_train.numpy(), y_pred))
print(metrics.accuracy_score(y_train.numpy(), y_pred, normalize=True, sample_weight=None))

0.16210112748803224
0.004563492063492064


In [18]:
y_hat = model(X_test)
y_hat = torch.sigmoid(y_hat)
y_hat = y_hat.detach().numpy()
y_pred = (y_hat >= 0.5).astype(int)
print(hamming_score(y_test.numpy(), y_pred))
print(metrics.accuracy_score(y_test.numpy(), y_pred, normalize=True, sample_weight=None))

0.16174864575459813
0.004126984126984127
