In [46]:
# Mount Google Drive
from google.colab import drive # import drive from google colab
 
ROOT = "/content/drive"     # default location for the drive
print(ROOT)                 # print content of ROOT (Optional)
 
drive.mount(ROOT)           # we mount the google drive at /content/drive


/content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## This Notebook Gave in SMR=0 SDR=2.57 and 2.45

In [47]:
!apt install octave
!apt install liboctave-dev  # development files
!pip3 install oct2py
from oct2py import Oct2Py
oc = Oct2Py()
script = '''
  function [SDR,SIR,SAR,perm]=bss_eval_sources(se,s)
%%% Errors %%%
if nargin<2, error('Not enough input arguments.'); end
[nsrc,nsampl]=size(se);
[nsrc2,nsampl2]=size(s);
if nsrc2~=nsrc, error('The number of estimated sources and reference sources must be equal.'); end
if nsampl2~=nsampl, error('The estimated sources and reference sources must have the same duration.'); end

%%% Performance criteria %%%
% Computation of the criteria for all possible pair matches
SDR=zeros(nsrc,nsrc);
SIR=zeros(nsrc,nsrc);
SAR=zeros(nsrc,nsrc);
for jest=1:nsrc,
    for jtrue=1:nsrc,
        [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se(jest,:),s,jtrue,512);
        [SDR(jest,jtrue),SIR(jest,jtrue),SAR(jest,jtrue)]=bss_source_crit(s_true,e_spat,e_interf,e_artif);
    end
end
% Selection of the best ordering
perm=perms(1:nsrc);
nperm=size(perm,1);
meanSIR=zeros(nperm,1);
for p=1:nperm,
    meanSIR(p)=mean(SIR((0:nsrc-1)*nsrc+perm(p,:)));
end
[meanSIR,popt]=max(meanSIR);
perm=perm(popt,:).';
SDR=SDR((0:nsrc-1).'*nsrc+perm);
SIR=SIR((0:nsrc-1).'*nsrc+perm);
SAR=SAR((0:nsrc-1).'*nsrc+perm);

return;


function [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se,s,j,flen)

if nargin<4, error('Not enough input arguments.'); end
[nchan2,nsampl2]=size(se);
[nsrc,nsampl,nchan]=size(s);
if nchan2~=nchan, error('The number of channels of the true source images and the estimated source image must be equal.'); end
if nsampl2~=nsampl, error('The duration of the true source images and the estimated source image must be equal.'); end

%%% Decomposition %%%
% True source image
s_true=[reshape(s(j,:,:),nsampl,nchan).',zeros(nchan,flen-1)];
% Spatial (or filtering) distortion
e_spat=project(se,s(j,:,:),flen)-s_true;
% Interference
e_interf=project(se,s,flen)-s_true-e_spat;
% Artifacts
e_artif=[se,zeros(nchan,flen-1)]-s_true-e_spat-e_interf;

return;


function sproj=project(se,s,flen)

% SPROJ Least-squares projection of each channel of se on the subspace
% spanned by delayed versions of the channels of s, with delays between 0
% and flen-1

[nsrc,nsampl,nchan]=size(s);
s=reshape(permute(s,[3 1 2]),nchan*nsrc,nsampl);

%%% Computing coefficients of least squares problem via FFT %%%
% Zero padding and FFT of input data
s=[s,zeros(nchan*nsrc,flen-1)];
se=[se,zeros(nchan,flen-1)];
fftlen=2^nextpow2(nsampl+flen-1);
sf=fft(s,fftlen,2);
sef=fft(se,fftlen,2);
% Inner products between delayed versions of s
G=zeros(nchan*nsrc*flen);
for k1=0:nchan*nsrc-1,
    for k2=0:k1,
        ssf=sf(k1+1,:).*conj(sf(k2+1,:));
        ssf=real(ifft(ssf));
        ss=toeplitz(ssf([1 fftlen:-1:fftlen-flen+2]),ssf(1:flen));
        G(k1*flen+1:k1*flen+flen,k2*flen+1:k2*flen+flen)=ss;
        G(k2*flen+1:k2*flen+flen,k1*flen+1:k1*flen+flen)=ss.';
    end
end
% Inner products between se and delayed versions of s
D=zeros(nchan*nsrc*flen,nchan);
for k=0:nchan*nsrc-1,
    for i=1:nchan,
        ssef=sf(k+1,:).*conj(sef(i,:));
        ssef=real(ifft(ssef,[],2));
        D(k*flen+1:k*flen+flen,i)=ssef(:,[1 fftlen:-1:fftlen-flen+2]).';
    end
end

%%% Computing projection %%%
% Distortion filters
C=G\D;
C=reshape(C,flen,nchan*nsrc,nchan);
% Filtering
sproj=zeros(nchan,nsampl+flen-1);
for k=1:nchan*nsrc,
    for i=1:nchan,
        sproj(i,:)=sproj(i,:)+fftfilt(C(:,k,i).',s(k,:));
    end
end

return;



function [SDR,SIR,SAR]=bss_source_crit(s_true,e_spat,e_interf,e_artif)


if nargin<4, error('Not enough input arguments.'); end
[nchant,nsamplt]=size(s_true);
[nchans,nsampls]=size(e_spat);
[nchani,nsampli]=size(e_interf);
[nchana,nsampla]=size(e_artif);
if ~((nchant==nchans)&&(nchant==nchani)&&(nchant==nchana)), error('All the components must have the same number of channels.'); end
if ~((nsamplt==nsampls)&&(nsamplt==nsampli)&&(nsamplt==nsampla)), error('All the components must have the same duration.'); end

%%% Energy ratios %%%
s_filt=s_true+e_spat;
% SDR
SDR=10*log10(sum(sum(s_filt.^2))/sum(sum((e_interf+e_artif).^2)))
% SIR
SIR=10*log10(sum(sum(s_filt.^2))/sum(sum(e_interf.^2)));
% SA
SAR=10*log10(sum(sum((s_filt+e_interf).^2))/sum(sum(e_artif.^2)));
return;

         '''

with open("myScript.m","w+") as f:
    f.write(script)

Reading package lists... Done
Building dependency tree       
Reading state information... Done
octave is already the newest version (4.2.2-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.
Reading package lists... Done
Building dependency tree       
Reading state information... Done
liboctave-dev is already the newest version (4.2.2-1ubuntu1).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


In [48]:
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, freqz

from scipy.io.wavfile import read, write
from scipy import signal
from sklearn.decomposition import NMF
from sklearn.preprocessing import MinMaxScaler

from numpy import linalg as LA
from numpy.linalg import inv
from helpers2 import Reconstruct, Viz_Y,SMR,get_mixed_signal,SDR,ReconstructSoft,butter_lowpass_filter
import seaborn as sns
import warnings
import math
from tqdm import tqdm
warnings.simplefilter('ignore')

In [49]:
def Reconstruct(B,G,Ns,Nm,Yabs,p):
    
    numerators=[]
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerators.append(np.power(np.matmul(B1,G1),p))
    numerators.append(np.power(np.matmul(B2,G2),p))

    denominator = np.power(np.matmul(B1,G1),p)+np.power(np.matmul(B2,G2),p)
  
    

    Sources=[]
    Masks=[]
    for i in range(2):

        Sources.append(np.multiply(numerators[i]/denominator,Yabs))
        Masks.append(numerators[i]/denominator)

    #print('Source shape = {}'.format(Sources[0].shape))
    
    return Sources,Masks




def SMR(speech, music):
    
    """
    Function that takes music and speech signals.
    returns SMR in db
    """
    speech_power = LA.norm(speech,2)
    music_power = LA.norm(music,2)
    SMR_db=10*np.log10(speech_power/music_power)
    print('SMR = {:.2f}'.format(SMR_db))
    
    return SMR_db

def SDR(s_est, s):
    """
    Function that takes original and estimated spectrogram
    returns SDR in DB
    """
    
    signal_power = LA.norm(s,2)
    distorsion_power = LA.norm(s_est - s,2) 
    SDR_db=10*np.log10(signal_power/distorsion_power)
    
    return SDR_db

In [50]:
from helpers2 import *

In [51]:
# Best 1-20 min

start = 1 * 60 * 44100
end = 20 * 60 * 44100 

samplerate_s, data_speech = read("/content/drive/MyDrive/Conversation.wav")
speech=data_speech[start:end,0]
length=speech.shape[0]/samplerate_s
print('Shape of the speech {} ... Length : {:.2f}s ... Sample rate : {}'.format(speech.shape[0],length,samplerate_s))

start = 1 * 60 * 44100
end = 20 * 60 * 44100 
samplerate_m, data_music = read("/content/drive/MyDrive/music.wav")
music=data_music[start:end,0]
length=music.shape[0]/samplerate_m
print('Shape of the music {} ... Length : {:.2f}s ... Sample rate : {}'.format(music.shape[0],length,samplerate_m))


Shape of the speech 50274000 ... Length : 1140.00s ... Sample rate : 44100
Shape of the music 50274000 ... Length : 1140.00s ... Sample rate : 44100


In [52]:
fs = 16000

rate = samplerate_s / fs


start = 1 * 60 * 44100
end = 20 * 60 * 44100


speech_t=data_speech[start : end, 0]
music_t = data_music[start : end, 0]


speech_t = signal.resample(speech_t,int(speech_t.shape[0]/rate))
music_t = signal.resample(music_t,int(music_t.shape[0]/rate))
samplerate=int(samplerate_m/rate)
length=music_t.shape[0]/samplerate

print('Shape of the test {} ... Length : {:.2f}s ... Sample rate : {}'.format(music_t.shape[0],length,samplerate))

speech = signal.resample(speech,int(speech.shape[0]/rate))
music = signal.resample(music,int(music.shape[0]/rate))


print('Downsampled rate = {}'.format(samplerate))

speech = butter_lowpass_filter(speech,5000,fs)
music = butter_lowpass_filter(music,5000,fs)

music_t = butter_lowpass_filter(music_t,5000,fs)
speech_t = butter_lowpass_filter(speech_t,5000,fs)

Shape of the test 18240000 ... Length : 1140.00s ... Sample rate : 16000
Downsampled rate = 16000


## Training STFT :


In [53]:
WINDOW = 'hamming'
WINDOW_SIZE=480
OVERLAP = 0.6 * WINDOW_SIZE
NFFT=512

f,t,Y= signal.stft(speech,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_s=np.abs(Y)
f,t,Y= signal.stft(music,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_m=np.abs(Y)



SMR_db = -5
mix,speech_mix,music_mix=get_mixed_signal(speech_t,music_t,SMR_db)


f,t,Ymix= signal.stft(mix,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_mix=np.abs(Ymix)

Yabs_mix[Yabs_mix==0]=0.00001
write("/MixX.wav", samplerate, mix.astype(np.int16))



SMR = -5.00


## Test STFT :

In [363]:
fs = 16000

rate = samplerate_s / fs


start = 24 * 60 * 44100
step = int(0.5 * 60 * 44100)

test_s = np.array([])
test_m = np.array([])

for i in range(1):

  test_s = np.hstack([test_s,data_speech[start+i*step:start+(i+1)*step,0]])
  test_m = np.hstack([test_m,data_music[start+i*step:start+(i+1)*step,0]])


test_s = signal.resample(test_s,int(test_s.shape[0]/rate))
test_m = signal.resample(test_m,int(test_m.shape[0]/rate))
samplerate=int(samplerate_m/rate)
length=music_t.shape[0]/samplerate


test_s = butter_lowpass_filter(test_s,5000,fs)
test_m = butter_lowpass_filter(test_m,5000,fs)


################################################################################
SMR_db = 5
test,speech_test,music_test=get_mixed_signal(test_s,test_m,SMR_db)


f,t,Ytest= signal.stft(test,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_test=np.abs(Ytest)

Yabs_test[Yabs_test==0]=0.00001


SMR = 5.00


# Train First NMF on Clean Speech :

In [284]:
def softmax(x):

  e_x = np.exp(x)
  return e_x / e_x.sum(axis=0)

In [56]:
Nc = 16
Nm = 48

model = NMF(n_components=Nc, init='nndsvd',alpha=0.0,beta_loss='frobenius',solver="mu",max_iter=200, random_state=7)
model.fit(np.transpose(Yabs_s))
Dc = np.transpose(model.components_)
scaler = MinMaxScaler()
Dc = scaler.fit_transform(Dc)


# Train NMF on Noisy Speech :

In [57]:
def nmf(X, Dc, Nn, lamb=0.1, maxit=100):

    Nc = Dc.shape[1]
    H = np.random.rand(Nc+Nn, X.shape[1])

    Dn = np.random.rand(X. shape[0], Nn)
    print(f"Shape of Dc {Dc.shape} Shape of Dn {Dn.shape}")
    D = np.hstack([Dc,Dn])
    Dnorm = D / np.sum(D**2, axis=0)**(.5)

    print(f'Dnorm shape {Dnorm.shape} and X shape {X.shape} and H shape {H.shape}')
    hist=[]
    for i in range(maxit):
        H = H * (np.matmul(Dnorm.T, X)) / (np.matmul(np.matmul(Dnorm.T, Dnorm), H) + lamb)
        D[:,Dc.shape[1]:] = (Dnorm * (np.matmul(X, H.T) + Dnorm * (np.matmul(np.ones((X.shape[0], X.shape[0])), np.matmul(Dnorm, np.matmul(H, H.T)) * Dnorm))) / (np.matmul(Dnorm, np.matmul(H, H.T)) + Dnorm * (np.matmul(np.ones((X.shape[0], X.shape[0])), np.matmul(X, H.T) * Dnorm))))[:,Dc.shape[1]:]
        Dnorm = D / np.sum(D**2, axis=0)**(.5)
        hist.append(LA.norm(X-np.matmul(Dnorm,H)))
    #Dnorm[:,Dc.shape[1]:] = softmax(Dnorm[:,Dc.shape[1]:])
    return Dnorm, H,hist

In [58]:
D,H,hist = nmf(Yabs_mix,Dc,Nm)

scaler = MinMaxScaler()
D =  scaler.fit_transform(D)


Shape of Dc (257, 16) Shape of Dn (257, 48)
Dnorm shape (257, 64) and X shape (257, 95001) and H shape (64, 95001)


# Test NMF :

In [59]:
model_test = NMF(n_components=Nc+Nm, init='nndsvd',alpha=100,beta_loss='frobenius',solver="mu",max_iter=200, random_state=7)
model_test.fit(np.transpose(Yabs_test))
    
model_test.components_= np.transpose(D)
G_test=np.transpose(model_test.transform(np.transpose(Yabs_test)))

In [286]:
from tqdm import tqdm



def soft(z,a,l=0.02):
  h = np.maximum(np.abs(z)-l/a,np.zeros(z.shape[0]))
  return h
  
def warm_start_ISTA(x,W,n_components,a,K,l=0.02):

  np.random.seed(seed=7)
  h = np.random.rand(W.shape[1] , x.shape[1])
  T = x.shape[1]

  for t_ in tqdm(range(1,T)):
    h[:,t_] = h[:,t_-1]

    for _ in range(1,K):
      z = (np.identity(n_components) - (1/a)*(np.transpose(W)@W))@h[:,t_] + \
          (1/a)*np.transpose(W)@x[:,t_]

      h[:,t_] = soft(z,a,l)
  return h 

In [287]:
def eval(D,G_test,Ytest):
  Sources,Masks=Reconstruct(B=D,G=G_test,Yabs=Ytest,p=mask_p,Ns=ns,Nm=nm)

  print('Reconstruction Step .... Done')
  speech_est = Sources[0]
  music_est = Sources[1]

  _, speech_est =  signal.istft(speech_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)

  _, music_est =  signal.istft(music_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)
 

  sdr_speech = SDR(s_est=speech_est,s=test_s)
  sdr_music = SDR(s_est=music_est, s=test_m)
  print("smr equal = {}".format(SMR_db))

  with open("myScript.m","w+") as f:
    f.write(script)
  print("Speech SDR \n")
  oc.myScript(speech_est ,test_s)
  print("MUSIC SDR \n")
  oc.myScript(music_est ,test_m)

In [324]:
S =np.transpose(scaler.fit_transform(np.transpose(D)))


In [379]:
h_ws = warm_start_ISTA(Yabs_test,S,48+16,
                    10000000,
                    5,l=0)

mask_p = 0.5
ns = 16
nm = 48
eval(S,h_ws,Ytest)

100%|██████████| 2500/2500 [00:01<00:00, 1383.96it/s]


Reconstruction Step .... Done
smr equal = 5
Speech SDR 

SDR =  10.436
MUSIC SDR 

SDR = -9.3511


In [None]:
mask_p = 0.5
eval(D,G_test,Ytest)

In [None]:
eval(D,G_test,Ytest) ,eval(D,h_ws,Ytest)

# Unfolded ISTA

In [65]:
print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda:0


In [66]:
import torch

In [81]:
!unrar x -Y "/content/drive/MyDrive/NMF.rar" 


UNRAR 5.50 freeware      Copyright (c) 1993-2017 Alexander Roshal


Extracting from /content/drive/MyDrive/NMF.rar

Creating    NMF                                                       OK
Extracting  NMF/B1616.npy                                                  2%  OK 
Extracting  NMF/B1632.npy                                                  6%  OK 
Extracting  NMF/B1648.npy                                                 12%  OK 
Extracting  NMF/B2424.npy                                                 16%  OK 
Extracting  NMF/B2432.npy                                                 20%  OK 
Extracting  NMF/B3216.npy                                                 24%  OK 
Extracting  NMF/B3232.npy                                                 29%  OK 
Extracting  NMF/B3232comp.npy                                             34%  OK 
Extracting  NMF/B3248.npy                                                

In [175]:
D = np.concatenate([np.load("NMF/Bspeech16.npy" ),np.load("NMF/Bmusic48.npy" )],axis=1)

In [272]:
fs = 16000

rate = samplerate_s / fs


start = 24 * 60 * 44100
step = int(0.5 * 60 * 44100)

test_s = np.array([])
test_m = np.array([])

for i in range(1):

  test_s = np.hstack([test_s,data_speech[start+i*step:start+(i+1)*step,0]])
  test_m = np.hstack([test_m,data_music[start+i*step:start+(i+1)*step,0]])


test_s = signal.resample(test_s,int(test_s.shape[0]/rate))
test_m = signal.resample(test_m,int(test_m.shape[0]/rate))
samplerate=int(samplerate_m/rate)
length=music_t.shape[0]/samplerate


test_s = butter_lowpass_filter(test_s,5000,fs)
test_m = butter_lowpass_filter(test_m,5000,fs)


################################################################################
SMR_db = -5
test,speech_test,music_test=get_mixed_signal(test_s,test_m,SMR_db)


f,t,Ytest= signal.stft(test,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_test=np.abs(Ytest)

Yabs_test[Yabs_test==0]=0.00001


SMR = -5.00


In [273]:
Sources,Masks=Reconstruct(B=D,G=G_test,Ns=Dc.shape[1],Nm=Nm,Yabs=Ytest,p=2)
M = Masks[0]
f,t,Y= signal.stft(test_s,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Y_clean=np.abs(Y)
Y_clean[Y_clean==0]=0.00001

X = Yabs_test
X_cmplx = Ytest
W =scaler.fit_transform(D)
alpha = 100
K = 3
lambd = 0.01
torch.autograd.set_detect_anomaly(False)


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7fd001246a10>

In [274]:
def eval_dnn_2(D,G_test,Ytest,p):
  d = p
  Sources,Masks=Reconstruct(B=D,G=G_test,Ns=16,Nm=48,Yabs=Ytest,p=d)

  print('Reconstruction Step .... Done')
  speech_est = Sources[0]
  music_est = Sources[1]

  _, speech_est =  signal.istft(speech_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)

  _, music_est =  signal.istft(music_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)

  with open("myScript.m","w+") as f:
    f.write(script)
  print("DNN Results \n")
  speech_sdr = oc.myScript(speech_est ,test_s);
  music_sdr = oc.myScript(music_est ,test_m);

  return speech_sdr,music_sdr

In [278]:
def speech_mask(B,G,Ns,p):
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerator = torch.pow(torch.matmul(B1,G1),p)

    denominator = torch.pow(torch.matmul(B1,G1),p)+torch.pow(torch.matmul(B2,G2),p)
  

    mask_speech = numerator/(denominator+0.00001)

    
    return mask_speech


def sigmoid(x):
  return torch.nn.Sigmoid()(x)

def soft(z,a,l=0.02):
  h = torch.maximum(torch.abs(z)-l/a,torch.zeros(z.shape[0]).to(device))
  return h


def error(X,M,Y):

  return torch.nn.L1Loss()(Y,M*X)

def Unfolded_ISTA(X, W, alpha, K, lambd,epochs,learning_rate):
  loss_e = 0
  train_hist = np.zeros(epochs)
  epoch_loss = []
  sdr_speech= []
  sdr_music= []
  torch.manual_seed(7)
  X = torch.from_numpy(X).float().to(device)
  W = torch.from_numpy(W).float().to(device)
  W = torch.tile(W, (K,1,1)).to(device)
  alpha_list = list(torch.tile(torch.tensor([alpha]).float(),(K,)))
  H = torch.rand(W.shape[2],X.shape[1]).to(device)
  H = torch.tile(H, (K,1,1)).to(device)
  H_clone = H.clone().to(device)
  Y = torch.from_numpy(Y_clean).float().to(device)
  W1 = W[0]
  W2 = W[1]
  W3 = W[2]
  W4 = W[3]
  W5 = W[4]
  W1.requires_grad = True
  W2.requires_grad = True
  W3.requires_grad = True
  W4.requires_grad = True
  W5.requires_grad = True

  params =  [W1] + [W2]  + [W3] + [W4] + [W5]
  sgd = torch.optim.SGD(params,lr=learning_rate)
  for e in range(epochs):
    for t in tqdm(np.arange(1,X.shape[1]),position=0, leave=True):
      
      for k in range(1,K):
        if k ==1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[K-1,:,t-1] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = soft(z,alpha,lambd)
        elif k > 1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[k-1,:,t] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = soft(z,alpha,lambd)

        H[k] = torch.maximum(H[k],torch.tensor([0]).to(device))

      sgd.zero_grad()      


      
      mask = speech_mask(torch.maximum(W5,torch.tensor([0]).to(device)),H[K-1],16,2)
      #print(mask)
      loss = error(X[:,t],mask[:,t],Y[:,t])
      #print("\n",loss)
      train_hist[e] = loss.item()
      loss.backward()

      sgd.step()
    print(f'Epoch {e} ... Loss = {loss.item()}')
    d = torch.maximum(W5,torch.tensor([0]).to(device)).cpu().detach().numpy()#torch.nn.Sigmoid()(W2).cpu().detach().numpy()
    d[d == 0] = 0.000001
    h = H.cpu().detach().numpy()
    h[h == 0] = 0.000001
    speech,music = eval_dnn_2(d,h[K-1],Ytest,mask_p+1)
    sdr_speech.append(speech)
    sdr_music.append(music)
  return d, h[k-1],train_hist,sdr_speech,sdr_music


In [279]:
mask_p  = 2
for a in [50]:
  d,h,train_hist,sdr_speech_,sdr_music_ = Unfolded_ISTA(X, W, a, 5, 0,epochs=20,learning_rate= 0.005)
  print("finish")

100%|██████████| 2500/2500 [00:06<00:00, 360.83it/s]


Epoch 0 ... Loss = 54.53267288208008
Reconstruction Step .... Done
DNN Results 

SDR =  2.2969
SDR =  12.302


100%|██████████| 2500/2500 [00:06<00:00, 363.75it/s]


Epoch 1 ... Loss = 53.75046157836914
Reconstruction Step .... Done
DNN Results 

SDR =  2.9717
SDR =  12.635


100%|██████████| 2500/2500 [00:06<00:00, 359.73it/s]


Epoch 2 ... Loss = 53.153263092041016
Reconstruction Step .... Done
DNN Results 

SDR =  3.2767
SDR =  12.895


100%|██████████| 2500/2500 [00:06<00:00, 364.87it/s]


Epoch 3 ... Loss = 52.783226013183594
Reconstruction Step .... Done
DNN Results 

SDR =  3.5787
SDR =  13.138


100%|██████████| 2500/2500 [00:06<00:00, 362.97it/s]


Epoch 4 ... Loss = 52.559661865234375
Reconstruction Step .... Done
DNN Results 

SDR =  3.8691
SDR =  13.340


100%|██████████| 2500/2500 [00:06<00:00, 365.57it/s]


Epoch 5 ... Loss = 52.4412841796875
Reconstruction Step .... Done
DNN Results 

SDR =  4.0648
SDR =  13.572


100%|██████████| 2500/2500 [00:06<00:00, 366.48it/s]


Epoch 6 ... Loss = 52.36884689331055
Reconstruction Step .... Done
DNN Results 

SDR =  4.1551
SDR =  13.690


100%|██████████| 2500/2500 [00:06<00:00, 366.10it/s]


Epoch 7 ... Loss = 52.306854248046875
Reconstruction Step .... Done
DNN Results 

SDR =  4.1003
SDR =  13.742


100%|██████████| 2500/2500 [00:06<00:00, 360.54it/s]


Epoch 8 ... Loss = 52.24629211425781
Reconstruction Step .... Done
DNN Results 

SDR =  4.0572
SDR =  13.795


100%|██████████| 2500/2500 [00:06<00:00, 364.19it/s]


Epoch 9 ... Loss = 52.2668342590332
Reconstruction Step .... Done
DNN Results 

SDR =  3.9693
SDR =  13.811


100%|██████████| 2500/2500 [00:06<00:00, 358.80it/s]


Epoch 10 ... Loss = 52.27647018432617
Reconstruction Step .... Done
DNN Results 

SDR =  3.9105
SDR =  13.843


 39%|███▉      | 973/2500 [00:02<00:04, 352.26it/s]


KeyboardInterrupt: ignored

In [267]:
eval_dnn_2(d,h,Ytest,3)

Reconstruction Step .... Done
DNN Results 

SDR =  7.9950
SDR =  7.0636


(7.995037172516319, 7.063577138710981)

In [268]:
sdr_speech_

[6.342497236348114,
 7.408783588809818,
 7.911760926417247,
 7.993962394492061,
 8.047948671298094,
 7.944005395623446,
 8.066586997401101,
 8.016634532765003,
 8.100674969727715,
 8.107569859716676,
 8.11473964462759,
 8.104407390518148,
 8.105885401146102,
 8.079206882397965,
 8.072005333951031,
 8.084756619604605,
 8.084160610531136,
 8.08760307134802,
 8.089579888242287,
 8.079307120595608]

In [None]:
hist1 = train_hist
hist2 = sdr_speech_
hist3 = sdr_music_



plt.style.use('seaborn-paper')

fig,ax=plt.subplots(1,1,sharex=True,figsize=(15,8),dpi=400)
ax2=ax.twinx()

x1=ax.plot(hist2,'b',linewidth=2)
x2=ax.plot(hist3,'g',linewidth=2)

x3=ax2.plot(hist1,'r',linewidth=2)

ax.set_ylabel('SDR ',fontsize=20)
ax2.set_ylabel('SDR ',fontsize=20)
ax2.set_ylabel('Loss',fontsize=20)


ax.set_xlabel('Epochs',fontsize=20)


ax.tick_params(axis='both', which='major', labelsize=12)
ax2.tick_params(axis='both', which='major', labelsize=12)


ax.grid(linestyle='--',linewidth=1)
"""
ax2.set_ylim(6.3, 7.1)
ax.set_ylim(0, 0.4)
"""

ax.legend(x1+x2+x3,['SDR Speech','SDR Music','Loss'],loc=4, prop={'size': 16},bbox_to_anchor=(1.25, 0.88),fancybox=True, shadow=True,borderaxespad=0.1,title=' SMR = 0',title_fontsize=15)
plt.savefig('/content/drive/MyDrive/ISTA_results/SMR_0',bbox_inches='tight')

In [None]:
!mkdir "/content/drive/MyDrive/ISTA_results"

In [None]:
def speech_mask(B,G,Ns,p):
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerator = torch.pow(torch.matmul(B1,G1),p)

    denominator = torch.pow(torch.matmul(B1,G1),p)+torch.pow(torch.matmul(B2,G2),p)
  

    mask_speech = numerator/(denominator+0.00001)

    
    return mask_speech


def sigmoid(x):
  return torch.nn.Sigmoid()(x)

def soft(z,a,l=0.02):
  h = torch.maximum(torch.abs(z)-l/a,torch.zeros(z.shape[0]).to(device))
  return h


def error(X,M,Y):

  return torch.nn.L1Loss()(Y,M*X)

def Unfolded_ISTA(X, W, alpha, K, lambd,epochs,learning_rate):
  loss_e = 0
  train_hist = np.zeros(epochs)
  epoch_loss = []
  sdr_speech= []
  sdr_music= []
  torch.manual_seed(7)
  X = torch.from_numpy(X).float().to(device)
  W = torch.from_numpy(W).float().to(device)
  W = torch.tile(W, (K,1,1)).to(device)
  alpha_list = list(torch.tile(torch.tensor([alpha]).float(),(K,)))
  H = torch.rand(W.shape[2],X.shape[1]).to(device)
  H = torch.tile(H, (K,1,1)).to(device)
  H_clone = H.clone().to(device)
  Y = torch.from_numpy(Y_clean).float().to(device)
  W1 = W[0]
  W2 = W[1]
  W3 = W[2]
 # W4 = W[3]
 # W5 = W[4]

  W1.requires_grad = True
  W2.requires_grad = True
  W3.requires_grad = True
 # W4.requires_grad = True
 # W5.requires_grad = True

  params =  [W1] + [W2] + [W3]  
  sgd = torch.optim.SGD(params,lr=learning_rate)
  for e in range(epochs):
    for t in tqdm(np.arange(1,X.shape[1]),position=0, leave=True):
      
      for k in range(1,K):
        if k ==1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[K-1,:,t-1] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = soft(z,alpha,lambd)
        elif k > 1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[k-1,:,t] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = soft(z,alpha,lambd)

        H[k] = torch.maximum(H[k],torch.tensor([0]).to(device))

      sgd.zero_grad()      


      
      mask = speech_mask(torch.maximum(W3,torch.tensor([0]).to(device)),H[K-1],16,2)
      #print(mask)
      loss = error(X[:,t],mask[:,t],Y[:,t])
      #print("\n",loss)
      train_hist[e] = loss.item()
      loss.backward()

      sgd.step()
    print(f'Epoch {e} ... Loss = {loss.item()}')
    d = torch.maximum(W3,torch.tensor([0]).to(device)).cpu().detach().numpy()#torch.nn.Sigmoid()(W2).cpu().detach().numpy()
    d[d == 0] = 0.000001
    h = H.cpu().detach().numpy()
    h[h == 0] = 0.000001
    speech,music = eval_dnn_2(d,h[K-1],Ytest,2);
    sdr_speech.append(speech)
    sdr_music.append(music)
  return d, h[k-1],train_hist,sdr_speech,sdr_music


In [None]:
for a in [50]:
  d,h,train_hist,sdr_speech_,sdr_music_ = Unfolded_ISTA(X, W, a, 3, 0.02,epochs=20,learning_rate= 0.005)
  print("finish")

100%|██████████| 2500/2500 [00:04<00:00, 503.36it/s]


Epoch 0 ... Loss = 11.522324562072754
Reconstruction Step .... Done
DNN Results 

SDR =  7.1179
SDR =  4.1478


100%|██████████| 2500/2500 [00:05<00:00, 492.21it/s]


Epoch 1 ... Loss = 8.572219848632812
Reconstruction Step .... Done
DNN Results 

SDR =  7.4079
SDR =  4.4143


100%|██████████| 2500/2500 [00:05<00:00, 495.31it/s]


Epoch 2 ... Loss = 7.408426284790039
Reconstruction Step .... Done
DNN Results 

SDR =  7.5370
SDR =  5.1596


100%|██████████| 2500/2500 [00:05<00:00, 487.45it/s]


Epoch 3 ... Loss = 6.969976902008057
Reconstruction Step .... Done
DNN Results 

SDR =  7.6345
SDR =  5.2467


100%|██████████| 2500/2500 [00:05<00:00, 496.34it/s]


Epoch 4 ... Loss = 6.769251823425293
Reconstruction Step .... Done
DNN Results 

SDR =  7.7035
SDR =  5.3739


100%|██████████| 2500/2500 [00:04<00:00, 501.28it/s]


Epoch 5 ... Loss = 6.554028034210205
Reconstruction Step .... Done
DNN Results 

SDR =  7.7254
SDR =  5.4228


100%|██████████| 2500/2500 [00:05<00:00, 495.91it/s]


Epoch 6 ... Loss = 6.425950527191162
Reconstruction Step .... Done
DNN Results 

SDR =  7.7357
SDR =  5.4348


100%|██████████| 2500/2500 [00:05<00:00, 491.90it/s]


Epoch 7 ... Loss = 6.312808513641357
Reconstruction Step .... Done
DNN Results 

SDR =  7.7566
SDR =  5.5380


100%|██████████| 2500/2500 [00:05<00:00, 494.96it/s]


Epoch 8 ... Loss = 6.234522342681885
Reconstruction Step .... Done
DNN Results 

SDR =  7.7866
SDR =  5.5787


100%|██████████| 2500/2500 [00:05<00:00, 479.82it/s]


Epoch 9 ... Loss = 6.209493160247803
Reconstruction Step .... Done
DNN Results 

SDR =  7.8045
SDR =  5.6300


100%|██████████| 2500/2500 [00:05<00:00, 496.22it/s]


Epoch 10 ... Loss = 6.165719032287598
Reconstruction Step .... Done
DNN Results 

SDR =  7.8051
SDR =  5.6429


100%|██████████| 2500/2500 [00:05<00:00, 492.39it/s]


Epoch 11 ... Loss = 6.1311492919921875
Reconstruction Step .... Done
DNN Results 

SDR =  7.8250
SDR =  5.6764


100%|██████████| 2500/2500 [00:05<00:00, 494.74it/s]


Epoch 12 ... Loss = 6.115301609039307
Reconstruction Step .... Done
DNN Results 

SDR =  7.8301
SDR =  5.6760


100%|██████████| 2500/2500 [00:05<00:00, 493.42it/s]


Epoch 13 ... Loss = 6.133864402770996
Reconstruction Step .... Done
DNN Results 

SDR =  7.8421
SDR =  5.7006


100%|██████████| 2500/2500 [00:05<00:00, 493.19it/s]


Epoch 14 ... Loss = 6.128572940826416
Reconstruction Step .... Done
DNN Results 

SDR =  7.8487
SDR =  5.7132


100%|██████████| 2500/2500 [00:05<00:00, 496.09it/s]


Epoch 15 ... Loss = 6.117512226104736
Reconstruction Step .... Done
DNN Results 

SDR =  7.8424



 32%|███▏      | 795/2500 [00:01<00:03, 468.42it/s]


KeyboardInterrupt: ignored