In [2]:
import pandas as pd
import librosa
import torch


In [3]:
!cp "TasNet/src/tasnet.py" .
import TasNet

In [4]:
!cp "TasNet/src/pit_criterion.py" .
import pit_criterion

In [5]:
import json
import os

import numpy as np
import torch
import torch.utils.data as data

import librosa


class AudioDataset(data.Dataset):

    def __init__(self, json_dir, batch_size,
                 sample_rate=8000, L=int(8000*0.005)):
        """
        Args:
            json_dir: directory including mix.json, s1.json and s2.json
        xxx_infos is a list and each item is a tuple (wav_file, #samples)
        """
        super(AudioDataset, self).__init__()
        mix_json = os.path.join(json_dir, 'mix.json')
        s1_json = os.path.join(json_dir, 's1.json')
        s2_json = os.path.join(json_dir, 's2.json')
        with open(mix_json, 'r') as f:
            mix_infos = json.load(f)
            #mix_infos = mix_infos[:10000]
        with open(s1_json, 'r') as f:
            s1_infos = json.load(f)
            #s1_infos = s1_infos[:10000]
        with open(s2_json, 'r') as f:
            s2_infos = json.load(f)
            #s2_infos = s2_infos[:10000]
        # sort it by #samples (impl bucket)
        def sort(infos): return sorted(
            infos, key=lambda info: int(info[1]), reverse=True)
        sorted_mix_infos = sort(mix_infos)
        sorted_mix_infos = sorted_mix_infos[:10000]
        sorted_s1_infos = sort(s1_infos)
        sorted_s1_infos = sorted_s1_infos[:10000]
        sorted_s2_infos = sort(s2_infos)
        sorted_s2_infos = sorted_s2_infos[:10000]
        # generate minibach infomations
        minibatch = []
        start = 0
        while True:
            end = min(len(sorted_mix_infos), start + batch_size)
            minibatch.append([sorted_mix_infos[start:end],
                              sorted_s1_infos[start:end],
                              sorted_s2_infos[start:end],
                              sample_rate, L])
            if end == len(sorted_mix_infos):
                break
            start = end
        self.minibatch = minibatch

    def __getitem__(self, index):
        return self.minibatch[index]

    def __len__(self):
        return len(self.minibatch)


class AudioDataLoader(data.DataLoader):
    """
    NOTE: just use batchsize=1 here, so drop_last=True makes no sense here.
    """

    def __init__(self, *args, **kwargs):
        super(AudioDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = _collate_fn


def _collate_fn(batch):
    """
    Args:
        batch: list, len(batch) = 1. See AudioDataset.__getitem__()
    Returns:
        mixtures_pad: B x K x L, torch.Tensor
        ilens : B, torch.Tentor
        sources_pad: B x C x K x L, torch.Tensor
    """
    # batch should be located in list
    assert len(batch) == 1
    batch = load_mixtures_and_sources(batch[0])
    mixtures, sources = batch

    # get batch of lengths of input sequences
    ilens = np.array([mix.shape[0] for mix in mixtures])

    # perform padding and convert to tensor
    pad_value = 0
    mixtures_pad = pad_list([torch.from_numpy(mix).float()
                             for mix in mixtures], pad_value)
    ilens = torch.from_numpy(ilens)
    sources_pad = pad_list([torch.from_numpy(s).float()
                            for s in sources], pad_value)
    # N x K x L x C -> N x C x K x L
    sources_pad = sources_pad.permute((0, 3, 1, 2)).contiguous()
    return mixtures_pad, ilens, sources_pad

In [6]:
def load_mixtures_and_sources(batch):
    """
    Returns:
        mixtures: a list containing B items, each item is K x L np.ndarray
        sources: a list containing B items, each item is K x L x C np.ndarray
        K varies from item to item.
    """
    mixtures, sources = [], []
    mix_infos, s1_infos, s2_infos, sample_rate, L = batch
    # for each utterance
    for mix_info, s1_info, s2_info in zip(mix_infos, s1_infos, s2_infos):
        mix_path = mix_info[0]
        s1_path = s1_info[0]
        s2_path = s2_info[0]
        assert mix_info[1] == s1_info[1] and s1_info[1] == s2_info[1]
        # read wav file
        mix, _ = librosa.load(mix_path, sr=sample_rate)
        s1, _ = librosa.load(s1_path, sr=sample_rate)
        s2, _ = librosa.load(s2_path, sr=sample_rate)
        # Generate inputs and targets
        K = int(np.ceil(len(mix) / L))
        # padding a little. mix_len + K > pad_len >= mix_len
        pad_len = K * L
        pad_mix = np.concatenate([mix, np.zeros([pad_len - len(mix)])])
        pad_s1 = np.concatenate([s1, np.zeros([pad_len - len(s1)])])
        pad_s2 = np.concatenate([s2, np.zeros([pad_len - len(s2)])])
        # reshape
        mix = np.reshape(pad_mix, [K, L])
        s1 = np.reshape(pad_s1, [K, L])
        s2 = np.reshape(pad_s2, [K, L])
        # merge s1 and s2
        s = np.dstack((s1, s2))  # K x L x C, C = 2
        # s = np.transpose(s, (2, 0, 1))  # C x K x L

        mixtures.append(mix)
        sources.append(s)
    return mixtures, sources


def load_mixtures(batch):
    """
    Returns:
        mixtures: a list containing B items, each item is K x L np.ndarray
        filenames: a list containing B strings
        K varies from item to item.
    """
    mixtures, filenames = [], []
    mix_infos, sample_rate, L = batch
    # for each utterance
    for mix_info in mix_infos:
        mix_path = mix_info[0]
        # read wav file
        mix, _ = librosa.load(mix_path, sr=sample_rate)
        # Generate inputs and targets
        K = int(np.ceil(len(mix) / L))
        # padding a little. mix_len + K > pad_len >= mix_len
        pad_len = K * L
        pad_mix = np.concatenate([mix, np.zeros([pad_len - len(mix)])])
        # reshape
        mix = np.reshape(pad_mix, [K, L])
        mixtures.append(mix)
        filenames.append(mix_path)
    return mixtures, filenames


def pad_list(xs, pad_value):
    n_batch = len(xs)
    max_len = max(x.size(0) for x in xs)
    pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value)
    for i in range(n_batch):
        pad[i, :xs[i].size(0)] = xs[i]
    return pad

In [7]:
labeled_dataset = AudioDataset('wav8k/train', batch_size = 1)
labeled_loader = AudioDataLoader(labeled_dataset)

In [8]:
test_dataset = AudioDataset('wav8k/test', batch_size=1)
test_loader = AudioDataLoader(test_dataset)

In [9]:
import torch.utils.data as data
class EvalDataset(data.Dataset):

    def __init__(self, mix_dir, mix_json, batch_size,
                 sample_rate=8000, L=int(8000*0.005)):
        """
        Args:
            mix_dir: directory including mixture wav files
            mix_json: json file including mixture wav files
        """
        super(EvalDataset, self).__init__()
        assert mix_dir != None or mix_json != None
        if mix_dir is not None:
            # Generate mix.json given mix_dir
            #preprocess_one_dir(mix_dir, mix_dir, 'mix',
                             #  sample_rate=sample_rate)
            mix_json = os.path.join(mix_dir, 'mix.json')
        with open(mix_json, 'r') as f:
            mix_infos = json.load(f)
           # mix_infos = mix_infos[:10000]
        # sort it by #samples (impl bucket)
        def sort(infos): return sorted(
            infos, key=lambda info: int(info[1]), reverse=True)
        sorted_mix_infos = sort(mix_infos)
        sorted_mix_infos = sorted_mix_infos[:10000]
        # generate minibach infomations
        minibatch = []
        start = 0
        while True:
            end = min(len(sorted_mix_infos), start + batch_size)
            minibatch.append([sorted_mix_infos[start:end],
                              sample_rate, L])
            if end == len(sorted_mix_infos):
                break
            start = end
        self.minibatch = minibatch

    def __getitem__(self, index):
        return self.minibatch[index]

    def __len__(self):
        return len(self.minibatch)


class EvalDataLoader(data.DataLoader):
    """
    NOTE: just use batchsize=1 here, so drop_last=True makes no sense here.
    """

    def __init__(self, *args, **kwargs):
        super(EvalDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = _collate_fn_eval


def _collate_fn_eval(batch):
    """
    Args:
        batch: list, len(batch) = 1. See AudioDataset.__getitem__()
    Returns:
        mixtures_pad: B x K x L, torch.Tensor
        ilens : B, torch.Tentor
        filenames: a list contain B strings
    """
    # batch should be located in list
    assert len(batch) == 1
    mixtures, filenames = load_mixtures(batch[0])

    # get batch of lengths of input sequences
    ilens = np.array([mix.shape[0] for mix in mixtures])

    # perform padding and convert to tensor
    pad_value = 0
    mixtures_pad = pad_list([torch.from_numpy(mix).float()
                             for mix in mixtures], pad_value)
    ilens = torch.from_numpy(ilens)
    return mixtures_pad, ilens, filenames

In [10]:
unlabeled_dataset = EvalDataset('wav8k/test', 'mix', batch_size = 1)
unlabeled_loader = EvalDataLoader(unlabeled_dataset[:50])

In [11]:
import pit_criterion

def _run_one_epoch( model,epoch,optimizer, cross_valid=False):
    start = time.time()
    total_loss = 0
    max_norm = 5
    print_freq = 100

    data_loader = train_loader if not cross_valid else test_loader

    # visualizing loss using visdom
    for i, (data) in enumerate(data_loader):
        padded_mixture, mixture_lengths, padded_source = data
        padded_mixture = padded_mixture.cuda()
        
        #print(padded_mixture, mixture_lengths)
        padded_source = padded_source.cuda()
        estimate_source = model(padded_mixture, mixture_lengths)
        mixture_lengths = mixture_lengths.cuda()
        loss, max_snr, estimate_source, reorder_estimate_source = pit_criterion.cal_loss(padded_source, estimate_source, mixture_lengths)
        if not cross_valid:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm)
            optimizer.step()

        total_loss += loss.item()

        if i % print_freq == 0:
            #wandb.log({'epoch':epoch+1, 'iter': i+1, 'Average Loss':total_loss / (i + 1),'Current Loss ':loss.item()})
            print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
                  'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                      epoch + 1, i + 1, total_loss / (i + 1),
                      loss.item(), 1000 * (time.time() - start) / (i + 1)),
                  flush=True)
    return total_loss / (i + 1)


In [12]:
def Overall_Cost(classification_cost, consistency_cost, ratio=0.5):
    return (ratio * classification_cost) + ((1 - ratio)*consistency_cost)

In [90]:
import IPython
IPython.display.Audio('RESULTS/1089-134686-0000_121-127105-0031.wav')

In [91]:
import IPython
IPython.display.Audio('RESULTS/1089-134686-0000_121-127105-0031_s1.wav')

In [89]:
import IPython
IPython.display.Audio('RESULTS/1089-134686-0000_121-127105-0031_s2.wav')

In [13]:

def train(student_model,teacher_model):
        optimizier = torch.optim.Adam(student_model.parameters(),
                          lr=1e-3,
                          weight_decay=0.0)
        # Train model multi-epoches
        teacher_model_prev = teacher_model
        for epoch in range(10):
            # Train one epoch
            print("Training...")
            student_model.train()  # Turn on BatchNorm & Dropout
            start = time.time()
            tr_avg_loss, teacher_model = _run_one_epoch_with_ema(student_model,teacher_model_prev, epoch, optimizier)
            teacher_model_prev = teacher_model
            #wandb.log({'Train Loss':float(tr_avg_loss)})
            print('-' * 85)
            print('Train Summary | End of Epoch {0} | Time {1:.2f}s | '
                  'Train Loss {2:.3f}'.format(
                      epoch + 1, time.time() - start, float(tr_avg_loss)))
            print('-' * 85)
            
            print('Cross validation...')
            model.eval()  # Turn off Batchnorm & Dropout
            val_loss = _run_one_epoch(model,epoch,optimizier, cross_valid=True)
            print('-' * 85)
            
            print('Valid Summary | End of Epoch {0} | Time {1:.2f}s | '
                  'Valid Loss {2:.3f}'.format(
                      epoch + 1, time.time() - start, val_loss))
            #wandb.log({'Valid Loss': val_loss})
            print('-' * 85)


In [14]:
def update_ema_variables(model, ema_model, alpha, global_step):
    # Use the true average until the exponential average is more correct
    alpha = min(1 - 1 / (global_step + 1), alpha)
    for ema_param, param in zip(ema_model.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(1 - alpha, param.data)

In [19]:
import time
EPS = 1e-8
import tasnet
student_model = tasnet.TasNet(40, 100, 100, 4,bidirectional=True, nspk=2)
teacher_model = tasnet.TasNet(40, 100, 100, 4,bidirectional=True, nspk=2)
#student_model.cuda()

#wandb.watch(model)
#train(student_model)

In [17]:
#train teacher model one epoch
teacher_model.cuda()
optimizier = torch.optim.Adam(teacher_model.parameters(),
                          lr=1e-3,
                          weight_decay=0.0)
teacher_model.train()
_run_one_epoch(teacher_model,1, optimizier)

Epoch 2 | Iter 1 | Average Loss 33.270 | Current Loss 33.269531 | 1971.6 ms/batch
Epoch 2 | Iter 101 | Average Loss 35.727 | Current Loss 46.718857 | 1842.0 ms/batch
Epoch 2 | Iter 201 | Average Loss 38.207 | Current Loss 47.198475 | 1849.2 ms/batch
Epoch 2 | Iter 301 | Average Loss 39.389 | Current Loss 47.324963 | 1846.3 ms/batch
Epoch 2 | Iter 401 | Average Loss 39.801 | Current Loss 38.064636 | 1843.1 ms/batch
Epoch 2 | Iter 501 | Average Loss 40.106 | Current Loss 44.886456 | 1838.1 ms/batch
Epoch 2 | Iter 601 | Average Loss 40.558 | Current Loss 43.472759 | 1831.8 ms/batch
Epoch 2 | Iter 701 | Average Loss 40.927 | Current Loss 36.352833 | 1825.3 ms/batch
Epoch 2 | Iter 801 | Average Loss 40.918 | Current Loss 44.384579 | 1818.6 ms/batch
Epoch 2 | Iter 901 | Average Loss 40.925 | Current Loss 41.788342 | 1811.7 ms/batch


41.04556074428558

In [23]:
teacher_model.cuda()
padded_mixture, mixture_lengths, padded_source = next(iter(train_loader))
padded_mixture = padded_mixture.cuda()
res = teacher_model(padded_mixture, mixture_lengths)


In [29]:
mixture_lengths = mixture_lengths.cuda()
padded_source = padded_source.cuda()

In [30]:
loss, max_snr, estimate_source, reorder_estimate_source = pit_criterion.cal_loss(padded_source,res, mixture_lengths)

In [38]:
with torch.no_grad():
    results = remove_pad_and_flat(res,mixture_lengths)

In [53]:
break_s1,break_s2 = Break(0.5, results[0][0], results[0][1])

In [99]:
new_res.append(results[0][1].reshape([1,3373,40]))

In [100]:
new_res.append(results[0][0].reshape([1,3373,40]))

In [25]:
lmbd = np.random.beta(1,1)

In [120]:
after_mix = Mix(lmbd,results[0][0], results[0][1])

In [65]:
a = next(iter(train_loader))

In [118]:
padded_source.shape

torch.Size([1, 2, 3373, 40])

In [71]:
after_mix.reshape([1,3373,40]).shape

(1, 3373, 40)

In [126]:
after_mix.reshape(after_mix,(sources_pad.shape[0],sources_pad.shape[2], sources_pad.shape[3]))

TypeError: only integer scalar arrays can be converted to a scalar index

In [110]:
pad_value = 0
sources_pad = pad_list([torch.from_numpy(s).float()
                            for s in new_res], pad_value)
    # N x K x L x C -> N x C x K x L

sources_pad = sources_pad.permute((1, 0, 2, 3)).contiguous()


In [111]:
sources_pad.shape #готов к полету в сетку

torch.Size([1, 2, 3373, 40])

In [86]:
    import soundfile as sf
    
    def write(inputs, filename, sr=8000):
        sf.write(filename, inputs, sr)# norm=True)

    with torch.no_grad():
        for (i, data) in enumerate(eval_loader):
            # Get batch data
            mixture, mix_lengths, filenames = data
            #if args.use_cuda:
            mixture = mixture.cuda()
            # Forward
            estimate_source = teacher_model(mixture, mix_lengths) 
            mix_lengths = mix_lengths.cuda()# [B, C, K, L]
            # Remove padding and flat
            flat_estimate = remove_pad_and_flat(estimate_source, mix_lengths)
            mixture = remove_pad_and_flat(mixture, mix_lengths)
            # Write result
            for i, filename in enumerate(filenames):
                filename = os.path.join('RESULTS',
                                        os.path.basename(filename).strip('.wav'))
                write(mixture[i], filename + '.wav')
                C = flat_estimate[i].shape[0]
                for c in range(C):
                    write(flat_estimate[i][c], filename + '_s{}.wav'.format(c+1))

KeyboardInterrupt: 

In [85]:

IPython.display.Audio('amazing_sound3.wav')

0.8.0


In [19]:
EPOCHS = 5

In [20]:
def remove_pad_and_flat(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, K, L] or [B, K, L]
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 4:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 4: # [B, C, K, L]
            results.append(input[:,:length].view(C, -1).cpu().numpy())
        elif dim == 3:  # [B, K, L]
            results.append(input[:length].view(-1).cpu().numpy())
    return results

In [None]:
#разделение работает только для размера батча =1


In [16]:
def Break(l, s1, s2):
    s1 = l*s1
    s2 = (1-l)*s2
    return s1, s2

def Mix(l, s1, s2):
    return l*s1+(1-l)*s2



 

In [20]:
EPOCHS = 3

In [21]:
from sklearn.metrics import mean_squared_error
import pit_criterion
student_model.cuda()
optimizer = torch.optim.Adam(student_model.parameters(),
                             lr=1e-3,
                             weight_decay=0.0)


lmbd = np.random.beta(1,1)

total_loss = 0
for epoch in range(EPOCHS):

    start = time.time()
    total_correctness_loss = 0
    max_norm = 5  # уточнить что это
    print_freq = 100


    #SUPERVISED LEARNING
    student_model.train()
    for i, (data) in enumerate(labeled_loader):
        padded_mixture, mixture_lengths, padded_source = data
        padded_mixture = padded_mixture.cuda()
       
        student_pred = student_model(padded_mixture, mixture_lengths)
        mixture_lengths = mixture_lengths.cuda()
        padded_source = padded_source.cuda() 
        correctness_loss, max_snr, estimate_source, reorder_estimate_source = pit_criterion.cal_loss(padded_source, student_pred, mixture_lengths)
        total_correctness_loss+=correctness_loss
        
    
    total_consistency_loss = 0
    teacher_model.train()
    #UNSUPERVISED LEARNING
    for i, (data) in enumerate(labeled_loader):# labeled +unlabeled совместить в один датасет
        padded_mixture, mixture_lengths, padded_source = data
        padded_mixture = padded_mixture.cuda() 
        teacher_pred = teacher_model(padded_mixture, mixture_lengths)
        #BREAKDOWN
        flat_estimate = remove_pad_and_flat(teacher_pred.data, mixture_lengths) #каким то образом разделяем выход сети
        s1 = flat_estimate[0][0]
        s2 = flat_estimate[0][1]
        break_s1, break_s2 = Break(lmbd, s1, s2)

        sources_arr = []
        sources_arr.append(np.array(torch.reshape(torch.Tensor(break_s1),(1, int(break_s1.shape[0]/40),40)))) #собираем 2 сигнала в 1 для pad source
        sources_arr.append(np.array(torch.reshape(torch.Tensor(break_s2),(1, int(break_s2.shape[0]/40),40))))
        pad_value = 0
        sources_pad = pad_list([torch.from_numpy(s).float() for s in sources_arr], pad_value)

        new_padded_source = sources_pad.permute((1, 0, 2, 3)).contiguous() # 2 сигнала собраны в 1 для сравнения потом с PIT

        #MIXUP
        Mix_s1_s2 = Mix(lmbd, s1, s2)
        Mix_s1_s2 = torch.reshape(torch.Tensor(Mix_s1_s2), (1,mixture_lengths, int(Mix_s1_s2.shape[0]/mixture_lengths)))
        Mix_s1_s2 = Mix_s1_s2.cuda()
        #mixture_lengths = mixture_lengths.cpu()

        student_estimate_source = student_model(Mix_s1_s2, mixture_lengths)
        #mixture_lengths = mixture_lengths.cuda()
        new_padded_source = new_padded_source.cuda()

        consist_loss, max_snr, estimate_source, reorder_estimate_source = pit_criterion.cal_loss(new_padded_source, student_estimate_source, mixture_lengths)

        total_consist_loss+=consist_loss
    #------------UNLABELED LOADER------------------#
    for i, (data) in enumerate(unlabeled_loader):
        padded_mixture, mixture_lengths, padded_source = data
        padded_mixture = padded_mixture.cuda() 
        teacher_pred = teacher_model(padded_mixture, mixture_lengths)
        #BREAKDOWN
        flat_estimate = remove_pad_and_flat(teacher_pred.data, mixture_lengths) #каким то образом разделяем выход сети
        s1 = flat_estimate[0][0]
        s2 = flat_estimate[0][1]
        break_s1, break_s2 = Break(lmbd, s1, s2)

        sources_arr = []
        sources_arr.append(np.array(torch.reshape(torch.Tensor(break_s1),(1, int(break_s1.shape[0]/40),40)))) #собираем 2 сигнала в 1 для pad source
        sources_arr.append(np.array(torch.reshape(torch.Tensor(break_s2),(1, int(break_s2.shape[0]/40),40))))
        pad_value = 0
        sources_pad = pad_list([torch.from_numpy(s).float() for s in sources_arr], pad_value)

        new_padded_source = sources_pad.permute((1, 0, 2, 3)).contiguous() # 2 сигнала собраны в 1 для сравнения потом с PIT

        #MIXUP
        Mix_s1_s2 = Mix(lmbd, s1, s2)
        Mix_s1_s2 = torch.reshape(torch.Tensor(Mix_s1_s2), (1,mixture_lengths, int(Mix_s1_s2.shape[0]/mixture_lengths)))
        Mix_s1_s2 = Mix_s1_s2.cuda()
        mixture_lengths = mixture_lengths.cpu()

        student_estimate_source = student_model(Mix_s1_s2, mixture_lengths)
        #mixture_lengths = mixture_lengths.cuda()
        new_padded_source = new_padded_source.cuda()

        consist_loss, max_snr, estimate_source, reorder_estimate_source = pit_criterion.cal_loss(new_padded_source, student_estimate_source, mixture_lengths)  
        total_consist_loss+=consist_loss
    

    LOSS = 1/len(labeled_loader) * total_correctness_loss + np.exp(epoch/EPOCHS-1)/(len(labeled_loader)+len(unlabeled_loader))*total_consist_loss
    #if not cross_valid:
    optimizer.zero_grad()
    LOSS.backward()
    torch.nn.utils.clip_grad_norm_(student_model.parameters(),
                                   max_norm)
    optimizer.step()

    total_loss += LOSS.item()
    update_ema_variables(student_model, teacher_model, 0.999, epoch+1)
    if i % print_freq == 0:
        print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | '
              'Current Loss {3:.6f} | {4:.1f} ms/batch'.format(
                  epoch + 1, i + 1, total_loss / (i + 1),
                  LOSS.item(), 1000 * (time.time() - start) / (i + 1)),
              flush=True)
            

    


RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

In [19]:
student_model.cuda()
padded_mixture, mixture_lengths, padded_source = next(iter(train_loader))
with torch.no_grad():
    res = student_model(padded_mixture.cuda(),mixture_lengths)

In [31]:
res = res.reshape([res.shape[0], res.shape[1]*res.shape[2]*res.shape[3]]).cpu()

In [33]:
from sklearn.metrics import mean_squared_error


In [34]:
mean_squared_error(res.detach().numpy(),res.detach().numpy())

0.0

In [57]:
import argparse
import os

import librosa
from mir_eval.separation import bss_eval_sources
import numpy as np
import torch

#from data import AudioDataLoader, AudioDataset
#from pit_criterion import cal_loss



#parser = argparse.ArgumentParser('Evaluate separation performance using TasNet')
#parser.add_argument('--model_path', type=str, required=True,
 #                   help='Path to model file created by training')
#parser.add_argument('--data_dir', type=str, required=True,
 #                   help='directory including mix.json, s1.json and s2.json')
#parser.add_argument('--cal_sdr', type=int, default=0,
 #                   help='Whether calculate SDR, add this option because calculation of SDR is very slow')
#parser.add_argument('--use_cuda', type=int, default=0,
 #                   help='Whether use GPU')
#parser.add_argument('--sample_rate', default=8000, type=int,
  #                  help='Sample rate')
#parser.add_argument('--batch_size', default=1, type=int,
   #                 help='Batch size')


def evaluate(dataset):
    total_SISNRi = 0
    total_SDRi = 0
    total_cnt = 0

    # Load model
    #model = TasNet.load_model('model_v2.pkl')
    print(model)
    model.eval()
    #if args.use_cuda:
    model.cuda()

    # Load data
    #dataset = AudioDataset('dev', 1,
     #                      sample_rate=16000, L=model.L)
    data_loader = AudioDataLoader(dataset, batch_size=1, num_workers=2)

    with torch.no_grad():
        for i, (data) in enumerate(data_loader):
            # Get batch data
            padded_mixture, mixture_lengths, padded_source = data
            #if args.use_cuda:
            padded_mixture = padded_mixture.cuda()
            #mixture_lengths = mixture_lengths.cuda()
            padded_source = padded_source.cuda()
            # Forward
            estimate_source = model(padded_mixture, mixture_lengths)# [B, C, K, L]
            mixture_lengths = mixture_lengths.cuda()
            loss, max_snr, estimate_source, reorder_estimate_source = \
                pit_criterion.cal_loss(padded_source, estimate_source, mixture_lengths)
            # Remove padding and flat
            mixture = remove_pad_and_flat(padded_mixture, mixture_lengths)
            source = remove_pad_and_flat(padded_source, mixture_lengths)
            # NOTE: use reorder estimate source
            estimate_source = remove_pad_and_flat(reorder_estimate_source,
                                                  mixture_lengths)
            # for each utterance
            for mix, src_ref, src_est in zip(mixture, source, estimate_source):
                print("Utt", total_cnt + 1)
                # Compute SDRi
                #if args.cal_sdr:
                 #   avg_SDRi = cal_SDRi(src_ref, src_est, mix)
                  #  total_SDRi += avg_SDRi
                   # print("\tSDRi={0:.2f}".format(avg_SDRi))
                # Compute SI-SNRi
                avg_SISNRi = cal_SISNRi(src_ref, src_est, mix)
                print("\tSI-SNRi={0:.2f}".format(avg_SISNRi))
                total_SISNRi += avg_SISNRi
                total_cnt += 1
   # if args.cal_sdr:
    #    print("Average SDR improvement: {0:.2f}".format(total_SDRi / total_cnt))
    print("Average SISNR improvement: {0:.2f}".format(total_SISNRi / total_cnt))


def cal_SDRi(src_ref, src_est, mix):
    """Calculate Source-to-Distortion Ratio improvement (SDRi).
    NOTE: bss_eval_sources is very very slow.
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SDRi
    """
    src_anchor = np.stack([mix, mix], axis=0)
    sdr, sir, sar, popt = bss_eval_sources(src_ref, src_est)
    sdr0, sir0, sar0, popt0 = bss_eval_sources(src_ref, src_anchor)
    avg_SDRi = ((sdr[0]-sdr0[0]) + (sdr[1]-sdr0[1])) / 2
    # print("SDRi1: {0:.2f}, SDRi2: {1:.2f}".format(sdr[0]-sdr0[0], sdr[1]-sdr0[1]))
    return avg_SDRi


def cal_SISNRi(src_ref, src_est, mix):
    """Calculate Scale-Invariant Source-to-Noise Ratio improvement (SI-SNRi)
    Args:
        src_ref: numpy.ndarray, [C, T]
        src_est: numpy.ndarray, [C, T], reordered by best PIT permutation
        mix: numpy.ndarray, [T]
    Returns:
        average_SISNRi
    """
    sisnr1 = cal_SISNR(src_ref[0], src_est[0])
    sisnr2 = cal_SISNR(src_ref[1], src_est[1])
    sisnr1b = cal_SISNR(src_ref[0], mix)
    sisnr2b = cal_SISNR(src_ref[1], mix)
    # print("SISNR base1 {0:.2f} SISNR base2 {1:.2f}, avg {2:.2f}".format(
    #     sisnr1b, sisnr2b, (sisnr1b+sisnr2b)/2))
    # print("SISNRi1: {0:.2f}, SISNRi2: {1:.2f}".format(sisnr1, sisnr2))
    avg_SISNRi = ((sisnr1 - sisnr1b) + (sisnr2 - sisnr2b)) / 2
    return avg_SISNRi


def cal_SISNR(ref_sig, out_sig, eps=1e-8):
    """Calcuate Scale-Invariant Source-to-Noise Ratio (SI-SNR)
    Args:
        ref_sig: numpy.ndarray, [T]
        out_sig: numpy.ndarray, [T]
    Returns:
        SISNR
    """
    assert len(ref_sig) == len(out_sig)
    ref_sig = ref_sig - np.mean(ref_sig)
    out_sig = out_sig - np.mean(out_sig)
    ref_energy = np.sum(ref_sig ** 2) + eps
    proj = np.sum(ref_sig * out_sig) * ref_sig / ref_energy
    noise = out_sig - proj
    ratio = np.sum(proj ** 2) / (np.sum(noise ** 2) + eps)
    sisnr = 10 * np.log(ratio + eps) / np.log(10.0)
    return sisnr

            
def remove_pad_and_flat(inputs, inputs_lengths):
    """
    Args:
        inputs: torch.Tensor, [B, C, K, L] or [B, K, L]
        inputs_lengths: torch.Tensor, [B]
    Returns:
        results: a list containing B items, each item is [C, T], T varies
    """
    results = []
    dim = inputs.dim()
    if dim == 4:
        C = inputs.size(1)
    for input, length in zip(inputs, inputs_lengths):
        if dim == 4: # [B, C, K, L]
            results.append(input[:,:length].view(C, -1).cpu().numpy())
        elif dim == 3:  # [B, K, L]
            results.append(input[:length].view(-1).cpu().numpy())
    return results


In [24]:
!nvidia-smi


Fri Apr  9 12:47:49 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.06    Driver Version: 450.51.06    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce GTX 108...  Off  | 00000000:65:00.0 Off |                  N/A |
| 36%   43C    P8     9W / 250W |  11147MiB / 11175MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces