In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
import os,glob,random
import librosa
import soundfile as sf  
import numpy as np
from itertools import permutations

In [2]:
class Encoder(nn.Module):
    def __init__(self,L,N):
        super(Encoder,self).__init__()
        """
        L: Number of input channels(number of samples per segment)
        N: Number of output channels(number of basis signals)
        """
        self.L = L
        self.N = N
        self.EPS = 1e-8
        self.conv1d_U=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)
        self.conv1d_V=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)
    
    def forward(self,mixture):
        """
        mixture:Tensor of shape [B,K,L] where K are the number of segment being processed at once
        output: Tensor of shape [B,K,N] where N are the number of basis signals
        """
        B,K,L=mixture.size()
        norm_coef=torch.norm(mixture,p=2,dim=2,keepdim=True)
        normed_mixture=mixture/(norm_coef+self.EPS)
        normed_mixture=torch.unsqueeze(normed_mixture.view(-1,L),2)
        conv=F.relu(self.conv1d_U(normed_mixture))
        gate=F.sigmoid(self.conv1d_V(normed_mixture))
        mixture_w=conv*gate
        mixture_w=mixture_w.view(B,K,self.N)
        return mixture_w,norm_coef

In [3]:
class Separator(nn.Module):
    def __init__(self,N:int,hidden_size,num_layers,bidirectional=False,nspk=2) -> None:
        super(Separator,self).__init__()
        self.N=N
        self.hidden_size=hidden_size
        self.bidirectional=bidirectional
        self.num_layers=num_layers
        self.nspk=nspk
        self.layer_norm=nn.LayerNorm(N)
        self.LSTM=nn.LSTM(input_size=N,hidden_size=hidden_size,num_layers=num_layers,bidirectional=bidirectional,batch_first=True)
        fc_in_dim=hidden_size*2 if bidirectional else hidden_size
        self.fc=nn.Linear(fc_in_dim,nspk*N)
    
    def forward(self,mixture_w):
        """
        mixture_w: Tensor of shape [B,K,N]
        output: Tensor of shape [B,K,nspk,N]
        """
        B,K,N=mixture_w.size()
        normed_mixture_w=self.layer_norm(mixture_w)
        output,_=self.LSTM(normed_mixture_w)
        score=self.fc(output)
        score=score.view(B,K,self.nspk,N)
        est_mask=F.softmax(score,dim=2)
        return est_mask

In [4]:
class Decoder(nn.Module):
    def __init__(self,N,L):
        super(Decoder,self).__init__()
        self.N=N
        self.L=L
        self.basis_signals=nn.Linear(N,L,bias=False)
    
    def forward(self,mixture_w,est_mask,norm_coef):
        """
        mixture_w: Tensor of shape [B,K,N]
        est_mask: Tensor of shape [B,K,nspk,N]
        norm_coef: Tensor of shape [B,K,1]
        output: Tensor of shape [B,nspk,K,L]
        """
        source_w=torch.unsqueeze(mixture_w,2)*est_mask
        est_source=self.basis_signals(source_w)
        norm_coef=torch.unsqueeze(norm_coef,2)
        est_source=est_source*norm_coef
        est_source=est_source.permute(0,2,1,3).contiguous()
        return est_source

In [5]:
class TasNet(nn.Module):
    def __init__(self,L,N,hidden_size,num_layers,bidirectional=False,nspk=2):
        super(TasNet,self).__init__()
        self.L=L
        self.N=N
        self.hidden_size=hidden_size
        self.num_layers=num_layers
        self.bidirectional=bidirectional
        self.nspk=nspk
        self.encoder=Encoder(L,N)
        self.separator=Separator(N,hidden_size,num_layers,bidirectional,nspk)
        self.decoder=Decoder(N,L)
    
    def forward(self,mixture):
        mixture_w,norm_coef=self.encoder(mixture)
        est_mask=self.separator(mixture_w)
        est_source=self.decoder(mixture_w,est_mask,norm_coef)
        return est_source

In [6]:
class AudioDataset(Dataset):
    def __init__(self,L:int,K:int,folder_path:str,sample_rate=8000) -> None:
        self.L=L
        self.K=K
        self.folder_path=folder_path
        self.sample_rate=sample_rate
        self.files=glob.glob(os.path.join(folder_path,'*.wav'))
        self.audio_info=self.load_audio_info()
    
    def __len__(self):
        return len(self.audio_info['path'])
    
    def __getitem__(self,idx):
        audio_path=self.audio_info['path'][idx]
        start=self.audio_info['start'][idx]/self.sample_rate
        end=self.audio_info['end'][idx]/self.sample_rate
        audio1,_=librosa.load(audio_path,sr=self.sample_rate,mono=True,offset=start,duration=end-start)
        # load a random audio from the data
        i=random.randint(a=0,b=len(self.audio_info['path'])-2)
        while i==idx:
            i=random.randint(a=0,b=len(self.audio_info['path'])-2)
        
        audio_path=self.audio_info['path'][i]
        start=self.audio_info['start'][i]/self.sample_rate
        end=self.audio_info['end'][i]/self.sample_rate
        audio2,_=librosa.load(audio_path,sr=self.sample_rate,mono=True,offset=start,duration=end-start)
        mixture=audio1+audio2
        mixture=librosa.util.normalize(mixture)
        audio1=librosa.util.normalize(audio1)
        audio2=librosa.util.normalize(audio2)
        mixture=torch.from_numpy(mixture.reshape(self.K,self.L))
        sources=torch.from_numpy(np.array([audio1.reshape(self.K,self.L),audio2.reshape(self.K,self.L)]))
        return mixture,sources

    def load_audio_info(self):
        audio_info=dict(path=list(),start=list(),end=list())
        for file in self.files:
            info=sf.info(os.path.join(self.folder_path,file))
            duration=int(info.duration*self.sample_rate)
            chunk_length=self.L*self.K
            start=0
            for i in range(chunk_length,duration,chunk_length):
                if(i-start)==chunk_length:
                    audio_info['path'].append(info.name)
                    audio_info['start'].append(start)
                    audio_info['end'].append(i)
                    start=i
        return audio_info
        

In [7]:
L=500
N=250
hidden_size=250
num_layers=2
bidirectional=False
nspk=2

In [8]:
model=TasNet(L,N,hidden_size,num_layers,bidirectional,nspk)

In [9]:
x=torch.randn(2,3,500)

In [10]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

1505000

In [11]:
out=model(x)

In [12]:
out.shape

torch.Size([2, 2, 3, 500])

In [13]:
audio_folder="/mnt/d/Programs/Python/PW/projects/asteroid/zip-hindi-2k"

In [16]:
dataset=AudioDataset(L=500,K=16,folder_path=audio_folder)

In [17]:
dataset[0][0].shape,dataset[0][1].shape

(torch.Size([16, 500]), torch.Size([2, 16, 500]))

In [18]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [19]:
model=model.to(device)

In [20]:
dtatloader=DataLoader(dataset,batch_size=125,shuffle=True)

In [21]:
model.train()
for mixture,sources in dtatloader:
    mixture=mixture.to(device)
    sources=sources.to(device)
    print(mixture.shape,sources.shape)
    break

torch.Size([125, 16, 500]) torch.Size([125, 2, 16, 500])


In [24]:
list(permutations([0,1]))

[(0, 1), (1, 0)]