In [1]:
# Mount Google Drive
from google.colab import drive # import drive from google colab
 
ROOT = "/content/drive"     # default location for the drive
print(ROOT)                 # print content of ROOT (Optional)
 
drive.mount(ROOT)           # we mount the google drive at /content/drive


/content/drive
Mounted at /content/drive


## This Notebook Gave in SMR=0 SDR=2.57 and 2.45

In [2]:
!apt install octave
!apt install liboctave-dev  # development files
!pip3 install oct2py
from oct2py import Oct2Py
oc = Oct2Py()
script = '''
  function [SDR,SIR,SAR,perm]=bss_eval_sources(se,s)
%%% Errors %%%
if nargin<2, error('Not enough input arguments.'); end
[nsrc,nsampl]=size(se);
[nsrc2,nsampl2]=size(s);
if nsrc2~=nsrc, error('The number of estimated sources and reference sources must be equal.'); end
if nsampl2~=nsampl, error('The estimated sources and reference sources must have the same duration.'); end

%%% Performance criteria %%%
% Computation of the criteria for all possible pair matches
SDR=zeros(nsrc,nsrc);
SIR=zeros(nsrc,nsrc);
SAR=zeros(nsrc,nsrc);
for jest=1:nsrc,
    for jtrue=1:nsrc,
        [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se(jest,:),s,jtrue,512);
        [SDR(jest,jtrue),SIR(jest,jtrue),SAR(jest,jtrue)]=bss_source_crit(s_true,e_spat,e_interf,e_artif);
    end
end
% Selection of the best ordering
perm=perms(1:nsrc);
nperm=size(perm,1);
meanSIR=zeros(nperm,1);
for p=1:nperm,
    meanSIR(p)=mean(SIR((0:nsrc-1)*nsrc+perm(p,:)));
end
[meanSIR,popt]=max(meanSIR);
perm=perm(popt,:).';
SDR=SDR((0:nsrc-1).'*nsrc+perm);
SIR=SIR((0:nsrc-1).'*nsrc+perm);
SAR=SAR((0:nsrc-1).'*nsrc+perm);

return;


function [s_true,e_spat,e_interf,e_artif]=bss_decomp_mtifilt(se,s,j,flen)

if nargin<4, error('Not enough input arguments.'); end
[nchan2,nsampl2]=size(se);
[nsrc,nsampl,nchan]=size(s);
if nchan2~=nchan, error('The number of channels of the true source images and the estimated source image must be equal.'); end
if nsampl2~=nsampl, error('The duration of the true source images and the estimated source image must be equal.'); end

%%% Decomposition %%%
% True source image
s_true=[reshape(s(j,:,:),nsampl,nchan).',zeros(nchan,flen-1)];
% Spatial (or filtering) distortion
e_spat=project(se,s(j,:,:),flen)-s_true;
% Interference
e_interf=project(se,s,flen)-s_true-e_spat;
% Artifacts
e_artif=[se,zeros(nchan,flen-1)]-s_true-e_spat-e_interf;

return;


function sproj=project(se,s,flen)

% SPROJ Least-squares projection of each channel of se on the subspace
% spanned by delayed versions of the channels of s, with delays between 0
% and flen-1

[nsrc,nsampl,nchan]=size(s);
s=reshape(permute(s,[3 1 2]),nchan*nsrc,nsampl);

%%% Computing coefficients of least squares problem via FFT %%%
% Zero padding and FFT of input data
s=[s,zeros(nchan*nsrc,flen-1)];
se=[se,zeros(nchan,flen-1)];
fftlen=2^nextpow2(nsampl+flen-1);
sf=fft(s,fftlen,2);
sef=fft(se,fftlen,2);
% Inner products between delayed versions of s
G=zeros(nchan*nsrc*flen);
for k1=0:nchan*nsrc-1,
    for k2=0:k1,
        ssf=sf(k1+1,:).*conj(sf(k2+1,:));
        ssf=real(ifft(ssf));
        ss=toeplitz(ssf([1 fftlen:-1:fftlen-flen+2]),ssf(1:flen));
        G(k1*flen+1:k1*flen+flen,k2*flen+1:k2*flen+flen)=ss;
        G(k2*flen+1:k2*flen+flen,k1*flen+1:k1*flen+flen)=ss.';
    end
end
% Inner products between se and delayed versions of s
D=zeros(nchan*nsrc*flen,nchan);
for k=0:nchan*nsrc-1,
    for i=1:nchan,
        ssef=sf(k+1,:).*conj(sef(i,:));
        ssef=real(ifft(ssef,[],2));
        D(k*flen+1:k*flen+flen,i)=ssef(:,[1 fftlen:-1:fftlen-flen+2]).';
    end
end

%%% Computing projection %%%
% Distortion filters
C=G\D;
C=reshape(C,flen,nchan*nsrc,nchan);
% Filtering
sproj=zeros(nchan,nsampl+flen-1);
for k=1:nchan*nsrc,
    for i=1:nchan,
        sproj(i,:)=sproj(i,:)+fftfilt(C(:,k,i).',s(k,:));
    end
end

return;



function [SDR,SIR,SAR]=bss_source_crit(s_true,e_spat,e_interf,e_artif)


if nargin<4, error('Not enough input arguments.'); end
[nchant,nsamplt]=size(s_true);
[nchans,nsampls]=size(e_spat);
[nchani,nsampli]=size(e_interf);
[nchana,nsampla]=size(e_artif);
if ~((nchant==nchans)&&(nchant==nchani)&&(nchant==nchana)), error('All the components must have the same number of channels.'); end
if ~((nsamplt==nsampls)&&(nsamplt==nsampli)&&(nsamplt==nsampla)), error('All the components must have the same duration.'); end

%%% Energy ratios %%%
s_filt=s_true+e_spat;
% SDR
SDR=10*log10(sum(sum(s_filt.^2))/sum(sum((e_interf+e_artif).^2)))
% SIR
SIR=10*log10(sum(sum(s_filt.^2))/sum(sum(e_interf.^2)));
% SA
SAR=10*log10(sum(sum((s_filt+e_interf).^2))/sum(sum(e_artif.^2)));
return;

         '''

with open("myScript.m","w+") as f:
    f.write(script)

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  aglfn epstool fonts-droid-fallback fonts-noto-mono ghostscript gnuplot-data
  gnuplot-qt gsfonts imagemagick-6-common info install-info libamd2
  libauthen-sasl-perl libcamd2 libccolamd2 libcholmod3 libcolamd2
  libcupsfilters1 libcupsimage2 libcxsparse3 libdata-dump-perl libemf1
  libencode-locale-perl libfftw3-single3 libfile-listing-perl libfltk-gl1.3
  libfltk1.3 libfont-afm-perl libgail-common libgail18 libglpk40
  libgraphicsmagick++-q16-12 libgraphicsmagick-q16-3 libgs9 libgs9-common
  libgtk2.0-0 libgtk2.0-bin libgtk2.0-common libhtml-form-perl
  libhtml-format-perl libhtml-parser-perl libhtml-tagset-perl
  libhtml-tree-perl libhttp-cookies-perl libhttp-daemon-perl libhttp-date-perl
  libhttp-message-perl libhttp-negotiate-perl libijs-0.35 libio-html-perl
  libio-socket-ssl-perl libjbig2dec0 liblqr-1-0 liblua5.3-0
  liblwp-mediat

In [6]:
import numpy as np
import os
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, freqz

from scipy.io.wavfile import read, write
from scipy import signal
from sklearn.decomposition import NMF
from sklearn.preprocessing import MinMaxScaler

from numpy import linalg as LA
from numpy.linalg import inv
from helpers2 import Reconstruct, Viz_Y,SMR,get_mixed_signal,SDR,ReconstructSoft,butter_lowpass_filter
import seaborn as sns
import warnings
import math
from tqdm import tqdm
warnings.simplefilter('ignore')

In [7]:
def Reconstruct(B,G,Ns,Nm,Yabs,p):
    
    numerators=[]
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerators.append(np.power(np.matmul(B1,G1),p))
    numerators.append(np.power(np.matmul(B2,G2),p))

    denominator = np.power(np.matmul(B1,G1),p)+np.power(np.matmul(B2,G2),p)
  
    

    Sources=[]
    Masks=[]
    for i in range(2):

        Sources.append(np.multiply(numerators[i]/denominator,Yabs))
        Masks.append(numerators[i]/denominator)

    #print('Source shape = {}'.format(Sources[0].shape))
    
    return Sources,Masks




def SMR(speech, music):
    
    """
    Function that takes music and speech signals.
    returns SMR in db
    """
    speech_power = LA.norm(speech,2)
    music_power = LA.norm(music,2)
    SMR_db=10*np.log10(speech_power/music_power)
    print('SMR = {:.2f}'.format(SMR_db))
    
    return SMR_db

def SDR(s_est, s):
    """
    Function that takes original and estimated spectrogram
    returns SDR in DB
    """
    
    signal_power = LA.norm(s,2)
    distorsion_power = LA.norm(s_est - s,2) 
    SDR_db=10*np.log10(signal_power/distorsion_power)
    
    return SDR_db

In [8]:
from helpers2 import *

In [9]:
# Best 1-20 min

start = 1 * 60 * 44100
end = 20 * 60 * 44100 

samplerate_s, data_speech = read("/content/drive/MyDrive/Conversation.wav")
speech=data_speech[start:end,0]
length=speech.shape[0]/samplerate_s
print('Shape of the speech {} ... Length : {:.2f}s ... Sample rate : {}'.format(speech.shape[0],length,samplerate_s))

start = 1 * 60 * 44100
end = 20 * 60 * 44100 
samplerate_m, data_music = read("/content/drive/MyDrive/music.wav")
music=data_music[start:end,0]
length=music.shape[0]/samplerate_m
print('Shape of the music {} ... Length : {:.2f}s ... Sample rate : {}'.format(music.shape[0],length,samplerate_m))


Shape of the speech 50274000 ... Length : 1140.00s ... Sample rate : 44100
Shape of the music 50274000 ... Length : 1140.00s ... Sample rate : 44100


In [10]:
fs = 16000

rate = samplerate_s / fs


start = 1 * 60 * 44100
end = 20 * 60 * 44100


speech_t=data_speech[start : end, 0]
music_t = data_music[start : end, 0]


speech_t = signal.resample(speech_t,int(speech_t.shape[0]/rate))
music_t = signal.resample(music_t,int(music_t.shape[0]/rate))
samplerate=int(samplerate_m/rate)
length=music_t.shape[0]/samplerate

print('Shape of the test {} ... Length : {:.2f}s ... Sample rate : {}'.format(music_t.shape[0],length,samplerate))

speech = signal.resample(speech,int(speech.shape[0]/rate))
music = signal.resample(music,int(music.shape[0]/rate))


print('Downsampled rate = {}'.format(samplerate))

speech = butter_lowpass_filter(speech,5000,fs)
music = butter_lowpass_filter(music,5000,fs)

music_t = butter_lowpass_filter(music_t,5000,fs)
speech_t = butter_lowpass_filter(speech_t,5000,fs)

Shape of the test 18240000 ... Length : 1140.00s ... Sample rate : 16000
Downsampled rate = 16000


## Training STFT :


In [180]:
WINDOW = 'hamming'
WINDOW_SIZE=480
OVERLAP = 0.6 * WINDOW_SIZE
NFFT=512

f,t,Y= signal.stft(speech,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_s=np.abs(Y)
f,t,Y= signal.stft(music,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_m=np.abs(Y)



SMR_db = 0
mix,speech_mix,music_mix=get_mixed_signal(speech_t,music_t,SMR_db)


f,t,Ymix= signal.stft(mix,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_mix=np.abs(Ymix)

Yabs_mix[Yabs_mix==0]=0.00001
write("/MixX.wav", samplerate, mix.astype(np.int16))



SMR = -0.00


## Test STFT :

In [181]:
fs = 16000

rate = samplerate_s / fs


start = 24 * 60 * 44100
step = int(0.5 * 60 * 44100)

test_s = np.array([])
test_m = np.array([])

for i in range(1):

  test_s = np.hstack([test_s,data_speech[start+i*step:start+(i+1)*step,0]])
  test_m = np.hstack([test_m,data_music[start+i*step:start+(i+1)*step,0]])


test_s = signal.resample(test_s,int(test_s.shape[0]/rate))
test_m = signal.resample(test_m,int(test_m.shape[0]/rate))
samplerate=int(samplerate_m/rate)
length=music_t.shape[0]/samplerate


test_s = butter_lowpass_filter(test_s,5000,fs)
test_m = butter_lowpass_filter(test_m,5000,fs)


################################################################################
SMR_db = 0
test,speech_test,music_test=get_mixed_signal(test_s,test_m,SMR_db)


f,t,Ytest= signal.stft(test,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Yabs_test=np.abs(Ytest)

Yabs_test[Yabs_test==0]=0.00001


SMR = 0.00


# Train First NMF on Clean Speech :

In [182]:
def softmax(x):

  e_x = np.exp(x)
  return e_x / e_x.sum(axis=0)

In [183]:
Nc = 16
Nm = 48

model = NMF(n_components=Nc, init='random',alpha=0.0,beta_loss='frobenius',solver="mu",max_iter=50, random_state=7)
model.fit(np.transpose(Yabs_s))
Dc = np.transpose(model.components_)
scaler = MinMaxScaler()
Dc = scaler.fit_transform(Dc)


# Train NMF on Noisy Speech :

In [184]:
def nmf(X, Dc, Nn, lamb=0.1, maxit=100):

    Nc = Dc.shape[1]
    H = np.random.rand(Nc+Nn, X.shape[1])

    Dn = np.random.rand(X. shape[0], Nn)
    print(f"Shape of Dc {Dc.shape} Shape of Dn {Dn.shape}")
    D = np.hstack([Dc,Dn])
    Dnorm = D / np.sum(D**2, axis=0)**(.5)

    print(f'Dnorm shape {Dnorm.shape} and X shape {X.shape} and H shape {H.shape}')
    hist=[]
    for i in range(maxit):
        H = H * (np.matmul(Dnorm.T, X)) / (np.matmul(np.matmul(Dnorm.T, Dnorm), H) + lamb)
        D[:,Dc.shape[1]:] = (Dnorm * (np.matmul(X, H.T) + Dnorm * (np.matmul(np.ones((X.shape[0], X.shape[0])), np.matmul(Dnorm, np.matmul(H, H.T)) * Dnorm))) / (np.matmul(Dnorm, np.matmul(H, H.T)) + Dnorm * (np.matmul(np.ones((X.shape[0], X.shape[0])), np.matmul(X, H.T) * Dnorm))))[:,Dc.shape[1]:]
        Dnorm = D / np.sum(D**2, axis=0)**(.5)
        hist.append(LA.norm(X-np.matmul(Dnorm,H)))
    #Dnorm[:,Dc.shape[1]:] = softmax(Dnorm[:,Dc.shape[1]:])
    return Dnorm, H,hist

In [185]:
D,H,hist = nmf(Yabs_mix,Dc,Nm)

scaler = MinMaxScaler()
D =  scaler.fit_transform(D)


Shape of Dc (257, 16) Shape of Dn (257, 48)
Dnorm shape (257, 64) and X shape (257, 95001) and H shape (64, 95001)


# Test NMF :

In [186]:
model_test = NMF(n_components=Nc+Nm, init='nndsvd',alpha=100,beta_loss='frobenius',solver="mu",max_iter=200, random_state=7)
model_test.fit(np.transpose(Yabs_test))
    
model_test.components_= np.transpose(D)
G_test=np.transpose(model_test.transform(np.transpose(Yabs_test)))

In [187]:
from tqdm import tqdm



def soft(z,a,l=0.02):
  h = np.maximum(np.abs(z)-l/a,np.zeros(z.shape[0]))
  return h
  
def warm_start_ISTA(x,W,n_components,a,K,l=0.02):

  np.random.seed(seed=7)
  h = np.random.rand(W.shape[1] , x.shape[1])
  T = x.shape[1]

  for t_ in tqdm(range(1,T)):
    h[:,t_] = h[:,t_-1]

    for _ in range(1,K):
      z = (np.identity(n_components) - (1/a)*(np.transpose(W)@W))@h[:,t_] + \
          (1/a)*np.transpose(W)@x[:,t_]

      h[:,t_] = soft(z,a,l)
  return h 

In [188]:
def eval(D,G_test,Ytest):
  Sources,Masks=Reconstruct(B=D,G=G_test,Yabs=Ytest,p=mask_p,Ns=ns,Nm=nm)

  print('Reconstruction Step .... Done')
  speech_est = Sources[0]
  music_est = Sources[1]

  _, speech_est =  signal.istft(speech_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)

  _, music_est =  signal.istft(music_est,
                      samplerate,
                      window = WINDOW,
                      nperseg=WINDOW_SIZE,
                      noverlap=OVERLAP,
                      nfft = NFFT)
 

  sdr_speech = SDR(s_est=speech_est,s=test_s)
  sdr_music = SDR(s_est=music_est, s=test_m)
  print("smr equal = {}".format(SMR_db))

  with open("myScript.m","w+") as f:
    f.write(script)
  print("Speech SDR \n")
  oc.myScript(speech_est ,test_s)
  print("MUSIC SDR \n")
  oc.myScript(music_est ,test_m)

In [201]:
h_ws = warm_start_ISTA(Yabs_test,D,48+16,
                    100000000,
                    3,l=0)

mask_p = 2
ns = 16
nm = 48
eval(D,h_ws,Ytest)

100%|██████████| 2500/2500 [00:00<00:00, 2647.53it/s]


Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  5.6939
MUSIC SDR 

SDR =  1.7326


In [195]:
eval(D,G_test,Ytest) ,eval(D,h_ws,Ytest)

Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  3.3117
MUSIC SDR 

SDR =  2.8080
Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  4.8519
MUSIC SDR 

SDR =  1.6010


(None, None)

# Unfolded ISTA

In [202]:
print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

True
cuda:0


In [203]:
import torch

In [204]:
Sources,Masks=Reconstruct(B=D,G=G_test,Ns=Dc.shape[1],Nm=Nm,Yabs=Ytest,p=2)
M = Masks[0]
f,t,Y= signal.stft(test_s,samplerate,window=WINDOW,nperseg=WINDOW_SIZE,noverlap=OVERLAP,nfft=NFFT)
Y_clean=np.abs(Y)
Y_clean[Y_clean==0]=0.00001

X = Yabs_test
X_cmplx = Ytest
W = D
alpha = 100
K = 3
lambd = 0.01
torch.autograd.set_detect_anomaly(False)


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f835a490110>

In [216]:
def speech_mask(B,G,Ns,p):
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerator = torch.pow(torch.matmul(B1,G1),p)

    denominator = torch.pow(torch.matmul(B1,G1),p)+torch.pow(torch.matmul(B2,G2),p)
  

    mask_speech = numerator/(denominator+0.00001)

    
    return mask_speech


def sigmoid(x):
  return torch.nn.Sigmoid()(x)

def soft(z,a,l=0.02):
  h = np.maximum(np.abs(z),np.zeros(z.shape[0]))
  return h


def error(X,M,Y):

  return torch.nn.MSELoss()(Y,M*X)

def Unfolded_ISTA(X, W, alpha, K, lambd,epochs,learning_rate):
  loss_e = 0
  epoch_loss = []
  torch.manual_seed(7)
  X = torch.from_numpy(X).float().to(device)
  W = torch.from_numpy(W).float().to(device)
  W = torch.tile(W, (K,1,1)).to(device)
  alpha_list = list(torch.tile(torch.tensor([alpha]).float(),(K,)))
  H = torch.rand(W.shape[2],X.shape[1]).to(device)
  H = torch.tile(H, (K,1,1)).to(device)
  H_clone = H.clone().to(device)
  Y = torch.from_numpy(Y_clean).float().to(device)
  W1 = W[0]
  W2 = W[1]
  W3 = W[2]
  W4 = W[3]
  W1.requires_grad = True
  W2.requires_grad = True
  W3.requires_grad = True
  W4.requires_grad = True

  params =  [W1] + [W2]  + [W3] + [W4]
  sgd = torch.optim.AdamW(params,lr=learning_rate)
  for e in range(epochs):
    for t in tqdm(np.arange(1,X.shape[1]),position=0, leave=True):
      
      for k in range(1,K):
        if k ==1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[K-1,:,t-1] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = z#soft(z,0,0.001)
        elif k > 1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[k-1,:,t] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = z#soft(z,0,0.001)

        H[k] = torch.maximum(H[k],torch.tensor([0]).to(device))

      sgd.zero_grad()      


      
      mask = speech_mask(torch.maximum(W4,torch.tensor([0]).to(device)),H[K-1],16,2)
      #print(mask)
      loss = error(X[:,t],mask[:,t],Y[:,t])
      #print("\n",loss)

      loss.backward()

      sgd.step()
    print(f'Epoch {e} ... Loss = {loss.item()}')
    d = torch.maximum(W4,torch.tensor([0]).to(device)).cpu().detach().numpy()#torch.nn.Sigmoid()(W2).cpu().detach().numpy()
    d[d == 0] = 0.000001
    h = H.cpu().detach().numpy()
    h[h == 0] = 0.000001

    eval(d,h[K-1],Ytest)
  return d, h


In [None]:
for a in [100]:
  d_new ,s = Unfolded_ISTA(X, W, a, 4, lambd,epochs=100,learning_rate= 0.001)


In [124]:
`
d_new[d_new == 0] = 0.0001

In [None]:
print("SDR Evaluation using MU \n")
eval(D,G_test,Ytest) ## mu 
print("\n")
print("SDR Evaluation using Warm Start ISTA \n")

eval(D,h_ws,Ytest)  ## warm start ISTA
print("\n")



In [125]:
print("SDR Evaluation using Unfolded ISTA \n")

eval(d_new,s[1],Ytest) ## DU ISTA

SDR Evaluation using Unfolded ISTA 

Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  5.9964
MUSIC SDR 

SDR =  6.4330


In [139]:
def speech_mask(B,G,Ns,p):
    B1=B[:,:Ns]
    B2=B[:,Ns:]
    G1=G[:Ns,:]
    G2=G[Ns:,:]
    
    
    numerator = torch.pow(torch.matmul(B1,G1),p)

    denominator = torch.pow(torch.matmul(B1,G1),p)+torch.pow(torch.matmul(B2,G2),p)
  

    mask_speech = numerator/(denominator+0.00001)

    
    return mask_speech


def sigmoid(x):
  return torch.nn.Sigmoid()(x)

def soft(z,a,l=0.02):
  h = np.maximum(np.abs(z),np.zeros(z.shape[0]))
  return h


def error(X,M,Y):

  return torch.nn.MSELoss()(Y,M*X)

def Unfolded_ISTA(X, W, alpha, K, lambd,epochs,learning_rate):
  loss_e = 0
  epoch_loss = []
  torch.manual_seed(7)
  X = torch.from_numpy(X).float().to(device)
  W = torch.from_numpy(W).float().to(device)
  W = torch.tile(W, (K,1,1)).to(device)
  alpha_list = list(torch.tile(torch.tensor([alpha]).float(),(K,)))
  H = torch.rand(W.shape[2],X.shape[1]).to(device)
  H = torch.tile(H, (K,1,1)).to(device)
  H_clone = H.clone().to(device)
  Y = torch.from_numpy(Y_clean).float().to(device)
  W1 = W[0]
  W2 = W[1]
  W3 = W[2]

  W1.requires_grad = True
  W2.requires_grad = True
  W3.requires_grad = True
  params =  [W1] + [W2] + [W3]
  sgd = torch.optim.SGD(params,lr=learning_rate)
  for e in range(epochs):
    for t in tqdm(np.arange(1,X.shape[1]),position=0, leave=True):

      for k in range(1,K):
        if k ==1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[k,:,t-1] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = z#soft(z,0,0.001)
        elif k > 1:
          z = (torch.eye(64).to(device) - (1/alpha_list[k])*(W[k-1].t()@W[k-1]))@H[k-1,:,t] + (1/alpha_list[k])*W[k-1].t()@X[:,t]

          H[k,:,t] = z#soft(z,0,0.001)

        H[k] = torch.maximum(H[k],torch.tensor([0]).to(device))

      sgd.zero_grad()      


          #h0_ = unfolded_cell(W1,H,alpha_list,0,t)
        # h1_k = unfolded_cell(W2,h0_,alpha_list,1,t)
        # h2_k  = unfolded_cell(W3,h1_k.float(),alpha_list,2,t)

      mask = speech_mask(torch.maximum(W3,torch.tensor([0]).to(device)),H[2],16,2)
      #print(mask)
      loss = error(X[:,t],mask[:,t],Y[:,t])
      #print("\n",loss)

      loss.backward()

      sgd.step()
    print(f'Epoch {e} ... Loss = {loss.item()}')
    d = torch.maximum(W3,torch.tensor([0]).to(device)).cpu().detach().numpy()#torch.nn.Sigmoid()(W2).cpu().detach().numpy()
    d[d == 0] = 0.000001
    h = H.cpu().detach().numpy()
    h[h == 0] = 0.000001

    eval(d,h[2],Ytest)
  return d, h


In [142]:
for a in [100]:
  d_new ,s = Unfolded_ISTA(X, W, a, 3, lambd,epochs=50,learning_rate= 0.01)

100%|██████████| 2500/2500 [00:04<00:00, 565.43it/s]


Epoch 0 ... Loss = 878.5275268554688
Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  3.5784
MUSIC SDR 

SDR =  4.3788


100%|██████████| 2500/2500 [00:04<00:00, 563.57it/s]


Epoch 1 ... Loss = 863.2802734375
Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  3.6293
MUSIC SDR 

SDR =  4.5157


100%|██████████| 2500/2500 [00:04<00:00, 563.94it/s]


Epoch 2 ... Loss = 862.808837890625
Reconstruction Step .... Done
smr equal = 0
Speech SDR 

SDR =  3.6320
MUSIC SDR 

SDR =  4.5161


 55%|█████▌    | 1385/2500 [00:02<00:02, 552.88it/s]


KeyboardInterrupt: ignored