In [1]:
import datetime
import sys

from pathlib import Path
from collections import defaultdict
import fire
from typing import Dict, List, Any, Union, Sequence, Tuple
import logging

import torch
import numpy as np

import os
from math import ceil

from dataset_noav_8k_2 import MyDatasets, collate_LibriMix
from dual_path import SepformerWrapper
#from losses import SI_SNR
from torch.utils.data import Dataset, DataLoader
from torchaudio import datasets
from scipy.io.wavfile import read, write
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
#from mir_eval.separation import bss_eval_sources
import mir_eval
from itertools import permutations
from dataclasses import dataclass, field, fields
from torchaudio.transforms import MelScale
from typing import List, Type, Any, Callable, Optional, Union

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
#device = torch.device( 'cpu' )
num_workers = 0 if device == torch.device('cpu') else 8

print( num_workers )

8


In [2]:
train_ds = MyDatasets( "train" )
train_loader = DataLoader( train_ds, batch_size=1,shuffle=True, num_workers=num_workers, collate_fn = collate_LibriMix)
print( len(train_loader))
val_ds = MyDatasets( "valid" )
val_loader = DataLoader( val_ds, batch_size=1,shuffle=False, num_workers=num_workers, collate_fn = collate_LibriMix)
print( len(val_loader))
test_ds = MyDatasets( "test" )
test_loader = DataLoader( test_ds, batch_size=1,shuffle=False, num_workers=num_workers, collate_fn = collate_LibriMix)
print( len(test_loader))

13900
3000
3000


In [3]:
def l2norm(mat, keepdim=False):
    return torch.norm(mat, dim=-1, keepdim=keepdim)

In [4]:
@dataclass(slots=True)
class STFTBase(torch.nn.Module):
    """
    Base layer for (i)STFT
    NOTE:
        1) Recommend sqrt_hann window with 2**N frame length, because it 
           could achieve perfect reconstruction after overlap-add
        2) Now haven't consider padding problems yet
    """
    device: torch.device
    frame_length: int
    frame_shift: int
    window: str
    K: torch.nn.Parameter = field(init=False)
    num_bins: int = field(init=False)

    def __post_init__(self):
        super(STFTBase, self).__init__()  # Initialize the torch.nn.Module base class
        K = self._init_kernel(self.frame_length, self.frame_shift)
        self.K = torch.nn.Parameter(K, requires_grad=False).to(self.device)
        self.num_bins = self.K.shape[0] // 2
    
    def _init_kernel(self, frame_len, frame_hop):
        # FFT points
        N = frame_len
        # window
        if self.window == 'hann':
            W = torch.hann_window(frame_len)
        if N//4 == frame_hop:
            const = (2/3)**0.5       
            W = const*W
        elif N//2 == frame_hop:
            W = W**0.5
        S = 0.5 * (N * N / frame_hop)**0.5
        
        # Updated FFT calculation for efficiency
        K = torch.fft.rfft(torch.eye(N) / S, dim=1)[:frame_len]
        K = torch.stack((torch.real(K), torch.imag(K)), dim=2)
        K = torch.transpose(K, 0, 2) * W # 2 x N/2+1 x F
        K = torch.reshape(K, (N + 2, 1, frame_len)) # N+2 x 1 x F
        return K

    def extra_repr(self):
        return (f"window={self.window}, stride={self.frame_shift}, " +
                f"kernel_size={self.K.shape[0]}x{self.K.shape[2]}")


In [5]:
#@logger_wraps()
@dataclass(slots=True)
class STFT(STFTBase):
    """
    Short-time Fourier Transform as a Layer
    """

    def forward(self, x, cplx=False):
        """
        Accept (single or multiple channel) raw waveform and output magnitude and phase
        args
            x: input signal, N x C x S or N x S
        return
            m: magnitude, N x C x F x T or N x F x T
            p: phase, N x C x F x T or N x F x T
        """
        if x.dim() not in [2, 3]:
            #raise RuntimeError(
            #    "{} expect 2D/3D tensor, but got {:d}D signal".format(
            #        self.__name__, x.dim()))
            raise RuntimeError(
                "expect 2D/3D tensor, but got {:d}D signal".format(
                     x.dim()))
        # if N x S, reshape N x 1 x S
        N_frame = ceil(x.shape[-1] / self.frame_shift)
        len_padded = N_frame * self.frame_shift
        if x.dim() == 2:
            
            x = torch.cat((x, torch.zeros(x.shape[0], len_padded-x.shape[-1], device=x.device)), dim=-1)
            x = torch.unsqueeze(x, 1)
            # N x 2F x T
            c = torch.nn.functional.conv1d(x, self.K, stride=self.frame_shift, padding=0)
            # N x F x T
            r, i = torch.chunk(c, 2, dim=1)
        else:        
            x = torch.cat((x, torch.zeros(x.shape[0], x.shape[1], len_padded-x.shape[-1])), dim=-1)
            N, C, S = x.shape
            x = x.reshape(N * C, 1, S)
            # NC x 2F x T
            c = torch.nn.functional.conv1d(x, self.K, stride=self.frame_shift, padding=0)
            # N x C x 2F x T
            c = c.reshape(N, C, -1, c.shape[-1])
            # N x C x F x T
            r, i = torch.chunk(c, 2, dim=2)

        if cplx:
            return r, i
        m = (r**2 + i**2 + 1.0e-10)**0.5
        p = torch.atan2(i, r)
        return m, p


In [6]:
#@logger_wraps()
@dataclass(slots=True)
class PIT_SISNR_mag:
    device: torch.device
    frame_length: int
    frame_shift: int
    window: str
    num_stages: int
    num_spks: int
    scale_inv: bool
    mel_opt: bool
    
    
    stft: List[Any] = field(init=False)
    mel_fb: Callable[[torch.Tensor], torch.Tensor] = field(init=False)
    
    def __post_init__(self):
        self.stft = [STFT(self.device, self.frame_length, self.frame_shift, self.window) for _ in range(self.num_stages)]
        self.mel_fb = MelScale(n_mels=80, sample_rate=16000, n_stft=int(self.frame_length / 2) + 1).to(self.device) if self.mel_opt else lambda x: x

    def __repr__(self):
        # __init__
        class_name = self.__class__.__name__
        init_fields = [f for f in fields(self) if f.init]
        field_strs = [f"{field.name}={getattr(self, field.name)!r}" for field in init_fields]

        # __post_init__
        stft_repr = f"stft = [STFT instance for {len(self.stft)} layers]"
        mel_fb_repr = "mel_fb = MelScale" if self.mel_opt else "mel_fb=Identity"
        post_init_reprs = [stft_repr, mel_fb_repr]

        return f"<{class_name}({', '.join(field_strs + post_init_reprs)})>"
    
    #def __call__(self, **kwargs):
    def __call__(self,  targets, estims, idx ):
        #estims = kwargs['estims']
        #idx = kwargs['idx']
        #num_utts = estims.size(0) * estims.size(1)
        num_utts = estims.size(1)
        #idx = 0
        #targets = targets.view( targets.size(0) * targets.size(1), -1 ).cpu()
        #estims = estims.view( estims.size(0) * estims.size(1), -1 ) 
        #input_sizes = kwargs["input_sizes"].to(self.device)
        #targets = [t.to(self.device) for t in kwargs["target_attr"]]
        
        def _STFT_Mag_SDR_loss(permute, eps=1.0e-12):
            loss_for_permute = []
            for s, t in enumerate(permute):
                mix = estims[s]
                src = targets[t]
                mix_zm = mix - torch.mean(mix, dim=-1, keepdim=True)
                src_zm = src - torch.mean(src, dim=-1, keepdim=True)
                if self.scale_inv:
                    scale = torch.sum(mix_zm * src_zm, dim=-1, keepdim=True) / (l2norm(src_zm, keepdim=True)**2 + eps)
                    src_zm = torch.clamp(scale, min=1e-2) * src_zm
                mix_zm = self.stft[idx](mix_zm.to(self.device))[0]
                src_zm = self.stft[idx](src_zm.to(self.device))[0]
                if self.mel_opt:
                    mix_zm = self.mel_fb(mix_zm)
                    src_zm = self.mel_fb(src_zm)
                utt_loss = -20 * torch.log10(eps + l2norm(l2norm((src_zm))) / (l2norm(l2norm(mix_zm - src_zm)) + eps))                
                loss_for_permute.append(utt_loss)
            return sum(loss_for_permute)
        
        pscore = torch.stack([_STFT_Mag_SDR_loss(p) for p in permutations(range(self.num_spks))])
        min_perutt, _ = torch.min(pscore, dim=0)
        #num_utts = input_sizes.shape[0]
        return torch.sum(min_perutt) / num_utts


In [7]:
num_spks = 2
pit_sisnr_mag = PIT_SISNR_mag(     
    device = device,
    frame_length = 512,
    frame_shift = 128,
    window = 'hann',
    num_stages = 4,
    num_spks = 2,
    scale_inv = True,
    mel_opt = False
    )

In [8]:
x = torch.randn( (2,1, 32000))
y = torch.randn( (2,1, 32000) )

loss11 = pit_sisnr_mag( x, y, 0 )
loss12 = pit_sisnr_mag( x, y, 1 )
loss13 = pit_sisnr_mag( x, y, 2 )
loss14 = pit_sisnr_mag( x, y, 3 )
#loss13 = pit_sisnr_mag( x, y, 2 )
loss1 = loss11 + loss12 + loss13 + loss14

print( loss11 )
print( loss12 )
print( loss13 )
print( loss14 )
print( loss1 )

tensor(74.7565, device='cuda:0')
tensor(74.7565, device='cuda:0')
tensor(74.7565, device='cuda:0')
tensor(74.7565, device='cuda:0')
tensor(299.0260, device='cuda:0')


In [9]:
#@logger_wraps()
@dataclass(slots=True)
class PIT_SISNR_time:
    device: torch.device
    num_spks: int
    scale_inv: bool

    def __repr__(self):
        class_name = self.__class__.__name__
        init_fields = [f for f in fields(self) if f.init]
        field_strs = [f"{field.name}={getattr(self, field.name)!r}" for field in init_fields]
        return f"<{class_name}({', '.join(field_strs)})>"
    
    #def __call__(self, **kwargs):
    def __call__(self, targets, estims):
        #estims = kwargs['estims']
        #input_sizes = kwargs["input_sizes"].to(self.device)
        #num_utts = estims.size(0) * estims.size(1)
        num_utts =  estims.size(1)
        #targets = [target.to(self.device) for target in kwargs["target_attr"]]
        
        def _SDR_loss(permute, eps=1.0e-8):
            loss_for_permute = []
            for s, t in enumerate(permute):
                mix = estims[s]
                src = targets[t]
                
                mix_zm = mix - torch.mean(input=mix, dim=-1, keepdim=True)
                src_zm = src - torch.mean(input=src, dim=-1, keepdim=True)
                if self.scale_inv:
                    scale_factor = torch.sum(mix_zm * src_zm, dim=-1, keepdim=True) / (l2norm(src_zm, keepdim=True)**2 + eps)
                    src_zm_scale = scale_factor * src_zm
                
                utt_loss = - 20 * torch.log10(eps + l2norm(src_zm_scale) / (l2norm(mix_zm - src_zm_scale) + eps))
                utt_loss = torch.clamp(utt_loss, min=-30)
                
                loss_for_permute.append(utt_loss)
            return sum(loss_for_permute)
        
        pscore = torch.stack([_SDR_loss(p) for p in permutations(range(self.num_spks))])
        min_perutt, _ = torch.min(pscore, dim=0)
        #num_utts = input_sizes.shape[0]
        return torch.sum(min_perutt) / num_utts


In [10]:
pit_sisnr_time = PIT_SISNR_time(
    device = device,
    num_spks = 2,
    scale_inv = True
)

In [11]:
#@logger_wraps()
@dataclass(slots=True)
class PIT_SISNRi:
    device: torch.device
    num_spks: int
    scale_inv: bool

    def __repr__(self):
        class_name = self.__class__.__name__
        init_fields = [f for f in fields(self) if f.init]
        field_strs = [f"{field.name}={getattr(self, field.name)!r}" for field in init_fields]
        return f"<{class_name}({', '.join(field_strs)})>"
    
    #def __call__(self, **kwargs):
    def __call__(self, targets, estims, mix):
        #estims = kwargs['estims']
        #input_sizes = kwargs["input_sizes"].to(self.device)
        #num_utts = estims.size(0) * estims.size(1)
        num_utts = estims.size(1)
        #targets = [target.to(self.device) for target in kwargs["target_attr"]]
        #estims = kwargs['estims']
        #input_sizes = kwargs["input_sizes"].to(self.device)
        #targets = [t.to(self.device) for t in kwargs["target_attr"]]
        #input = kwargs['mixture'].to(self.device)
        #input_zm = input - torch.mean(input, dim=-1, keepdim=True)
        input = mix
        input_zm = input - torch.mean( input, dim=-1, keepdim=True)
        #eps = kwargs['eps']
        eps=1.0e-15 
        
        def _SDR_loss(permute):
            #print( "permute:", permute )
            #print( "permute device:", permute.device )
            #loss_for_permute = []
            loss_for_permute = torch.tensor( [], device=device )
            for s, t in enumerate(permute):
                est = estims[s]
                src = targets[t]
                est_zm = est - torch.mean(est, dim=-1, keepdim=True)
                src_zm = src - torch.mean(src, dim=-1, keepdim=True)
                #print( "est_zm device:", est_zm.device )
                #print( "src_zm device:", src_zm.device )
                if self.scale_inv:
                    src_zm_s = torch.sum(est_zm * src_zm, dim=-1, keepdim=True) / (l2norm(src_zm, keepdim=True)**2 + eps) * src_zm
                
                utt_loss_est = 20 * torch.log10(eps + l2norm(src_zm_s) / (l2norm(est_zm - src_zm_s) + eps))
                if self.scale_inv:
                    src_zm_x = torch.sum(input_zm * src_zm, dim=-1, keepdim=True) / (l2norm(src_zm, keepdim=True)**2 + eps) * src_zm
                utt_loss_in = 20 * torch.log10(eps + l2norm(src_zm_x) / (l2norm(input_zm - src_zm_x) + eps))
                #print( "utt_loss_est device:", utt_loss_est.device )
                #print( "utt_loss_in device:", utt_loss_in.device )
                #loss_for_permute.append(utt_loss_est - utt_loss_in)
                loss_for_permute = torch.cat( [loss_for_permute, utt_loss_est-utt_loss_in], dim = 0 )
            #loss_for_permute = torch.tensor( loss_for_permute ).to(device)
            #return torch.tensor(loss_for_permute) 
            return loss_for_permute 
            #return loss_for_permute.detach().clone().requires_grad_(True)
        
        pscore = torch.stack([_SDR_loss(torch.tensor(p).to(device)) for p in permutations(range(self.num_spks))],dim=0)
        #print( "pscore device:", pscore.device )
        min_perutt, min_idx = torch.max(pscore.sum(-1), dim=0)
        #print( "min_perutt device:", min_perutt.device )
        #num_utts = input_sizes.shape[0]
        return torch.sum(min_perutt) / num_utts, pscore[min_idx]


In [12]:
pit_sisnri = PIT_SISNRi(
    device = device,
    num_spks = 2,
    scale_inv = True
)

In [13]:
def _SDR( mix, targets, estims, num_spks ):
    targets = targets.reshape( targets.size(0) * targets.size(1), -1 ).cpu().data.numpy()
    estims = estims.reshape( estims.size(0) * estims.size(1), -1 ).cpu().data.numpy() 
    mix1 = mix[None,:,:].expand( num_spks, mix.size(0), mix.size(1) )
    mix2 = mix1.reshape( num_spks * mix.size(0), -1 ).cpu().data.numpy() 
    mix = mix[None,:,:].expand( num_spks, mix.size(0), mix.size(1) ).reshape( num_spks * mix.size(0), -1 ).cpu().data.numpy() 
    
    min_perutt_out, _, _, _ = mir_eval.separation.bss_eval_sources(targets, estims)
    #print( min_perutt_out )
    min_perutt_in, _, _, _ = mir_eval.separation.bss_eval_sources(targets, mix)
    #print( min_perutt_in )
    
    num_utts = len( mix )
    return np.sum(min_perutt_out - min_perutt_in) / num_utts, min_perutt_out - min_perutt_in

In [14]:
#device = torch.device('cpu')
model = SepformerWrapper()
model = model.to(device)
optimizer = Adam( model.parameters(), lr = 1.5e-4 )
#criterion = SI_SNR()
scheduler = ReduceLROnPlateau(optimizer, 'min', factor = 0.5, patience = 2 )
num_spks = 2
frame_length = 512
frame_shift = 128
window = 'hann'
num_stages = 4
scale_inv = True
mel_opt = False
pit_sisnr_mag = PIT_SISNR_mag(     
    device = device,
    frame_length = frame_length,
    frame_shift = frame_shift,
    window = window,
    num_stages = num_stages,
    num_spks = num_spks,
    scale_inv = scale_inv,
    mel_opt = mel_opt
    )
pit_sisnr_time = PIT_SISNR_time(
    device = device,
    num_spks = num_spks,
    scale_inv = scale_inv
)
pit_sisnri = PIT_SISNRi(
    device = device,
    num_spks = num_spks,
    scale_inv = scale_inv
)
# これは関数。SDR, _ = _SDR( mix, targets, estims, num_spks )

In [15]:
mix = torch.rand( ( 1, 32000 ), device = device )
targets = torch.randn( (2, 1, 32000 ), device = device )
estims = torch.randn( (2, 1, 32000 ), device = device )
#mix = torch.rand( ( 2, 32000 ), device = device )
#targets = torch.randn( (2, 2, 32000 ), device = device )
#estims = torch.randn( (2, 2, 32000 ), device = device )

#SDR, _ = _SDR( mix, targets, estims, num_spks )
#loss1 = criterion(targets, estims)
loss1 = pit_sisnr_time( targets, estims )
loss21 = pit_sisnr_mag( targets, estims, 0 )
loss22 = pit_sisnr_mag( targets, estims, 1 )
loss23 = pit_sisnr_mag( targets, estims, 2 )
loss24 = pit_sisnr_mag( targets, estims, 3 )
loss3, plus = pit_sisnri( targets, estims, mix )
sdr, _ = _SDR( mix, targets, estims, num_spks )

print( loss1 )
print( loss21 )
print( loss22 )
print( loss23 )
print( loss24 )
print( loss3 )
print( device )
print( loss3.device )
print( sdr )

	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  min_perutt_out, _, _, _ = mir_eval.separation.bss_eval_sources(targets, estims)
	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  min_perutt_in, _, _, _ = mir_eval.separation.bss_eval_sources(targets, mix)


tensor(83.5468, device='cuda:0')
tensor(79.9074, device='cuda:0')
tensor(79.9074, device='cuda:0')
tensor(79.9074, device='cuda:0')
tensor(79.9074, device='cuda:0')
tensor(31.1788, device='cuda:0')
cuda:0
cuda:0
0.36761789217849383


In [16]:
pred_wav_layers = model( mix )

print( pred_wav_layers.size() )

torch.Size([4, 1, 32000, 2])


In [18]:
num_epochs = 10

best_val_loss = 1e9

for epoch in range( num_epochs ):
    total_sisnri = 0
    total_sdr = 0
    alpha = 0.4 * 0.8**(1+(epoch-101)//5) if epoch > 100 else 0.4
    model.train()
    with tqdm(enumerate(train_loader),
              total=len(train_loader)) as pbar:
        for step, (mixture, source) in pbar:
            mixture, source = mixture.to(device), source.to(device)
            #print( "size of mixture:", mixture.size() )
            #print( "size of source:", source.size() )
            pred_wav_layers = model(mixture)
            #pred_wav = pred_wav.permute( 2, 0, 1 )
            pred_wav_layers = pred_wav_layers.permute( 0, 3, 1, 2 )
            #print( "pred_wav size(): ", pred_wav.size() )
            #print( "size of pred_wav:", pred_wav.size() )
            loss1 = pit_sisnr_time( source, pred_wav_layers[-1] )
            #loss21 = pit_sisnr_mag( source, pred_wav_layers[0], 0 )
            #loss22 = pit_sisnr_mag( source, pred_wav_layers[1], 1 )
            #loss23 = pit_sisnr_mag( source, pred_wav_layers[2], 2 )
            #loss24 = pit_sisnr_mag( source, pred_wav, 3 )
            loss2i = []
            for idx, pred_wav_layer in enumerate( pred_wav_layers ):
                loss2i.append( pit_sisnr_mag( source, pred_wav_layer, idx ) )
            loss2 = torch.mean( torch.stack(loss2i) / num_spks )
            loss3, _ = pit_sisnri( source, pred_wav_layers[-1], mixture )
            sdr, _ = _SDR( mixture, source, pred_wav_layers[-1], num_spks )
            #loss = loss1 + loss2 + loss3
            #loss2 =  ( loss21 / num_spks  + loss22 / num_spks + loss23 / num_spks + loss24 / num_spks ) / num_stages  
            #loss = ( loss1 + loss2 ) / 2
            loss = (1-alpha) * loss1 + alpha * loss2
            total_sisnri += loss3.item() / num_spks
            total_sdr += sdr
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            pbar.set_postfix({"train loss": loss.item() })
            pbar.update()
    print( "learning rate:", optimizer.param_groups[0]['lr'] )
    print( "epoch:", epoch, "train sisnri loss:", total_sisnri / len( train_loader ), " sdr:", total_sdr / len( train_loader ) )
    
    total_sisnri = 0
    total_sdr = 0
    total_loss = 0
    model.eval()
    with tqdm(enumerate(val_loader),
              total=len(val_loader)) as pbar:
        for step, (mixture, source) in pbar:
            #source = source.permute( 1, 2, 0 )
            mixture, source = mixture.to(device), source.to(device)
            with torch.no_grad():
                pred_wav_layers = model(mixture)
                #pred_wav = pred_wav.permute( 2, 0, 1 )
                pred_wav_layers = pred_wav_layers.permute( 0, 3, 1, 2 )
                loss1 = pit_sisnr_time( source, pred_wav_layers[-1] )
                #loss21 = pit_sisnr_mag( source, pred_wav_layers[0], 0 )
                #loss22 = pit_sisnr_mag( source, pred_wav_layers[1], 1 )
                #loss23 = pit_sisnr_mag( source, pred_wav_layers[2], 2 )
                #loss24 = pit_sisnr_mag( source, pred_wav, 3 )
                loss2i = []
                for idx, pred_wav_layer in enumerate( pred_wav_layers ):
                    loss2i.append( pit_sisnr_mag( source, pred_wav_layer, idx ) )
                loss2 = torch.mean( torch.stack(loss2i) / num_spks )
                loss3, _ = pit_sisnri( source, pred_wav_layers[-1], mixture )
                sdr, _ = _SDR( mixture, source, pred_wav_layers[-1], num_spks )
                #loss2 =  ( loss21 / num_spks  + loss22 / num_spks + loss23 / num_spks + loss24 / num_spks ) / num_stages  
                loss = (1-alpha) * loss1 + alpha * loss2
                total_sisnri += loss3.item() /num_spks
                total_sdr += sdr
                total_loss += loss.item()
            pbar.set_postfix({"val loss": loss.item() })
            pbar.update()
    val_loss = total_loss / len( val_loader )
    if best_val_loss > val_loss:
        best_val_loss = val_loss
        save_path = "noav_my_model_training_state.pt"
        torch.save({'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
             'loss': loss,            
            },
           save_path)
    scheduler.step(val_loss)
    print( "epoch:", epoch, "val sisnri loss:", total_sisnri / len( val_loader ), " sdr:", total_sdr / len( val_loader ) )
    
    total_sisnri = 0
    total_sdr = 0
    with tqdm(enumerate(test_loader),
              total=len(test_loader)) as pbar:
        for step, (mixture, source) in pbar:
            #source = source.permute( 1, 2, 0 )
            mixture, source = mixture.to(device), source.to(device)
            with torch.no_grad():
                pred_wav_layers = model(mixture)
                pred_wav_layers = pred_wav_layers.permute( 0, 3, 1, 2 )
                loss3, _ = pit_sisnri( source, pred_wav_layers[-1], mixture )
                sdr, _ = _SDR( mixture, source, pred_wav_layers[-1], num_spks )
                total_sisnri += loss3.item() / num_spks
                total_sdr += sdr
            pbar.set_postfix({"test loss": loss.item() })
            pbar.update()
            if step == 0:
                #print( "size of pred_wav:", pred_wav.size() )
                write("./mix.wav", rate=8000, data=mixture[0, :].cpu().detach().numpy())
                write("./spk1.wav", rate=8000, data=pred_wav_layers[-1,0,0,:].cpu().detach().numpy())
                write("./spk2.wav", rate=8000, data=pred_wav_layers[-1,1,0,:].cpu().detach().numpy())  
                print( "wav file was wrote." )
    print( "epoch:", epoch, "test sisnri loss:", total_sisnri / len( test_loader ), " sdr:", total_sdr / len( test_loader ) )
    

  0%|          | 0/13900 [00:00<?, ?it/s]

	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  min_perutt_out, _, _, _ = mir_eval.separation.bss_eval_sources(targets, estims)
	Deprecated as of mir_eval version 0.8.
	It will be removed in mir_eval version 0.9.
  min_perutt_in, _, _, _ = mir_eval.separation.bss_eval_sources(targets, mix)


learning rate: 0.00015
epoch: 0 train sisnri loss: 7.239302484248396  sdr: 7.5784643107449


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 0 val sisnri loss: 8.030588511427244  sdr: 8.479904782104674


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 0 test sisnri loss: 7.730380677722395  sdr: 8.178212767270002


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 1 train sisnri loss: 9.226973466374677  sdr: 9.510773866904923


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 1 val sisnri loss: 9.837368381207188  sdr: 10.231234453297999


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 1 test sisnri loss: 9.635624479730923  sdr: 10.033044394771833


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 2 train sisnri loss: 11.017513997710223  sdr: 11.267923372224969


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 2 val sisnri loss: 11.380515728277464  sdr: 11.78331854646249


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 2 test sisnri loss: 11.106311268731952  sdr: 11.528127476656646


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 3 train sisnri loss: 12.57184047207933  sdr: 12.805431790030973


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 3 val sisnri loss: 12.616835375924905  sdr: 13.020324239660265


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 3 test sisnri loss: 12.472764425923428  sdr: 12.879614995253412


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 4 train sisnri loss: 13.72124180725903  sdr: 13.931387780352141


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 4 val sisnri loss: 13.408919490774473  sdr: 13.850093914984983


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 4 test sisnri loss: 13.190306102544069  sdr: 13.621359821876075


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 5 train sisnri loss: 14.290421241250613  sdr: 14.489045469006706


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 5 val sisnri loss: 13.980664086312055  sdr: 14.393553693547217


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 5 test sisnri loss: 13.674984008078773  sdr: 14.0975008112316


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 6 train sisnri loss: 14.893316152489229  sdr: 15.076259802349757


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 6 val sisnri loss: 14.135446846654018  sdr: 14.589863259236187


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 6 test sisnri loss: 13.89964544300735  sdr: 14.336363410643084


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 7 train sisnri loss: 15.445762196259318  sdr: 15.609163032246958


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 7 val sisnri loss: 14.996027999177574  sdr: 15.387114511213365


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 7 test sisnri loss: 14.612842641423146  sdr: 15.01751935551529


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 8 train sisnri loss: 15.908737356283254  sdr: 16.060840772238013


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 8 val sisnri loss: 15.117382384429375  sdr: 15.535613605720338


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 8 test sisnri loss: 14.74218484172225  sdr: 15.155095781081734


  0%|          | 0/13900 [00:00<?, ?it/s]

learning rate: 0.00015
epoch: 9 train sisnri loss: 16.399549388063253  sdr: 16.53004396877334


  0%|          | 0/3000 [00:00<?, ?it/s]

epoch: 9 val sisnri loss: 15.506569648226103  sdr: 15.911527003930502


  0%|          | 0/3000 [00:00<?, ?it/s]

wav file was wrote.
epoch: 9 test sisnri loss: 15.172893245607614  sdr: 15.578649943662189
