In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [48]:
class Encoder(nn.Module):
    def __init__(self,L,N):
        super(Encoder,self).__init__()
        """
        L: Number of input channels
        N: Number of output channels
        """
        self.L = L
        self.N = N
        self.EPS = 1e-8
        self.conv1d_U=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)
        self.conv1d_V=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)
    
    def forward(self,mixture):
        """
        mixture: Tensor of shape (batch_size, channels, timestasmps) or [B,K,L]
        """
        B,K,L=mixture.size()
        norm_coef=torch.norm(mixture,p=2,dim=2,keepdim=True) # [B,K,1]
        norm_mixture=mixture/(norm_coef+self.EPS) # [B,K,L]
        norm_mixture=torch.unsqueeze(norm_mixture.view(-1,L),2) # [B,1,K,L]
        conv=F.relu(self.conv1d_U(norm_mixture)) # [B,N,K,L]
        gate=F.sigmoid(self.conv1d_V(norm_mixture))
        mixture_w=conv*gate
        mixture_w=mixture_w.view(B,K,self.N)
        return mixture_w,norm_coef
        

In [53]:
class Seperator(nn.Module):
    def __init__(self,N,hidden_size,num_layers,bidirectional=False,nspk=2):
        super(Seperator,self).__init__()
        """
        N: Number of input channels
        hidden_size: Number of hidden units
        num_layers: Number of layers
        bidirectional: Whether the RNN is bidirectional
        nspk: Number of speakers
        """
        self.N = N
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.nspk = nspk
        self.layer_norm=nn.LayerNorm(N)
        self.LSTM=nn.LSTM(input_size=N,hidden_size=hidden_size,num_layers=num_layers,bidirectional=bidirectional,batch_first=True)
        fc_in_dim=hidden_size*2 if bidirectional else hidden_size
        self.fc=nn.Linear(fc_in_dim,nspk*N)

    def forward(self,mixture_w,mixture_lengths):
        """
        mixture_w: Tensor of shape (batch_size, N, timestasmps) or [B,N,L]
        """
        B,K,N=mixture_w.size()
        norm_mixture_w=self.layer_norm(mixture_w)
        total_length=norm_mixture_w.size(1)
        packed_input=pack_padded_sequence(norm_mixture_w,mixture_lengths,batch_first=True)
        packed_output,hidden=self.LSTM(packed_input)
        output,_=pad_packed_sequence(packed_output,total_length=total_length,batch_first=True)
        score=self.fc(output)
        score=score.view(B,K,self.nspk,N)
        est_mask=F.softmax(score,dim=2)
        return est_mask


In [54]:
class Decoder(nn.Module):
    def __init__(self,N,L):
        super(Decoder,self).__init__()
        """
        N: Number of input channels
        L: Number of output channels
        """
        self.N = N
        self.L = L
        self.basis_signals=nn.Linear(N,L,bias=False)
    
    def forward(self,mixture_w,est_mask,norm_coef):
        """
        est_mask: Tensor of shape (batch_size, N, nspk, timestasmps) or [B,N,K,L]
        mixture: Tensor of shape (batch_size, channels, timestasmps) or [B,K,L]
        norm_coef: Tensor of shape (batch_size, channels, 1) or [B,K,1]
        """
        source_w=torch.unsqueeze(mixture_w,2)*est_mask # [B,N,K,L]
        est_source=self.basis_signals(source_w)
        norm_coef=torch.unsqueeze(norm_coef,2)
        est_source=est_source*norm_coef
        est_source=est_source.permute(0,2,1,3).contiguous()
        return est_source

In [55]:
class TasNet(nn.Module):
    def __init__(self,L,N,hidden_size,num_layers,bidirectional=False,nspk=2):
        super(TasNet,self).__init__()
        self.L=L
        self.N=N
        self.hidden_size=hidden_size
        self.num_layers=num_layers
        self.bidirectional=bidirectional
        self.nspk=nspk
        self.encoder=Encoder(L,N)
        self.seperator=Seperator(N,hidden_size,num_layers,bidirectional,nspk)
        self.decoder=Decoder(N,L)
    
    def forward(self,mixture,mixture_lengths):
        mixture_w,norm_coef=self.encoder(mixture)
        est_mask=self.seperator(mixture_w,mixture_lengths)
        est_source=self.decoder(mixture_w,est_mask,norm_coef)
        return est_source

In [31]:
model=TasNet(1,250,250,4,bidirectional=True,nspk=2)

In [32]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

5767750

In [37]:
x=torch.randn(2,1,40)

In [38]:
print(x)
print(x.shape)

tensor([[[ 1.2086,  2.0988,  0.8783, -1.0730, -2.4978,  0.3917, -0.2256,
          -0.3452, -1.1686, -0.8310,  0.9254, -0.8503,  0.1203,  1.1072,
          -0.6407,  1.1583, -0.5139, -0.5353,  0.3703,  0.6892,  0.1710,
           1.1443,  0.8203,  0.7230,  1.1424,  0.7780, -1.6464,  1.6767,
          -1.3182,  0.3406, -0.2319, -1.9739,  0.4833, -0.3824, -1.1773,
          -1.5991,  0.7753, -1.2457,  0.2406, -0.7732]],

        [[-1.9884,  1.4801, -0.7135,  0.9549, -1.2234, -0.5150,  0.9733,
          -0.5323,  0.8743, -1.2915,  0.3379, -0.0094,  0.6453,  0.2503,
          -0.6325,  0.0552, -2.0403, -0.6826,  0.2837, -0.6834,  0.0314,
           1.4162, -0.6548, -0.8204,  0.7546,  1.0601,  0.3665, -0.7210,
          -0.1676, -0.0032,  0.8184,  0.3159,  0.3440, -0.8520,  1.0428,
          -0.0203, -0.7260,  0.2557,  0.7778, -0.1656]]])
torch.Size([2, 1, 40])


# Encoder Dataflow

In [44]:
L=1 # input signal length
N=250 # number of basis signals
hidden_size=250 # number of hidden units
num_layers=2 # number of layers
bidirectional=True # whether the RNN is bidirectional
nspk=2 # number of speakers

In [27]:
conv1d_v=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)
conv1d_u=nn.Conv1d(in_channels=L,out_channels=N,kernel_size=1,stride=1,bias=False)

In [39]:
conv1d_u_out=conv1d_u(x)
conv1d_v_out=conv1d_v(x)

In [40]:
conv1d_u_out.shape,conv1d_v_out.shape

(torch.Size([2, 250, 40]), torch.Size([2, 250, 40]))

In [41]:
norm_coef=torch.norm(x,p=2,dim=2,keepdim=True)
norm_coef.shape

torch.Size([2, 1, 1])

In [42]:
norm_mixture=x/(norm_coef+1e-8)
norm_mixture.shape

torch.Size([2, 1, 40])

In [43]:
norm_mixture=torch.unsqueeze(norm_mixture.view(-1,L),2)
norm_mixture.shape

torch.Size([80, 1, 1])

In [46]:
# Created on 2018/12/10
# Author: Kaituo XU

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

EPS = 1e-8


class TasNet(nn.Module):
    def __init__(self, L, N, hidden_size, num_layers,
                 bidirectional=True, nspk=2):
        super(TasNet, self).__init__()
        # hyper-parameter
        self.L, self.N = L, N
        self.hidden_size, self.num_layers = hidden_size, num_layers
        self.bidirectional = bidirectional
        self.nspk = nspk
        # Components
        self.encoder = Encoder(L, N)
        self.separator = Separator(N, hidden_size, num_layers,
                                   bidirectional=bidirectional, nspk=nspk)
        self.decoder = Decoder(N, L)

    def forward(self, mixture, mixture_lengths):
        """
        Args:
            mixture: [B, K, L]
            mixture_lengths: [B]
        Returns:
            est_source: [B, nspk, K, L]
        """
        mixture_w, norm_coef = self.encoder(mixture)
        est_mask = self.separator(mixture_w, mixture_lengths)
        est_source = self.decoder(mixture_w, est_mask, norm_coef)
        return est_source

    @classmethod
    def load_model(cls, path):
        # Load to CPU
        package = torch.load(path, map_location=lambda storage, loc: storage)
        model = cls.load_model_from_package(package)
        return model

    @classmethod
    def load_model_from_package(cls, package):
        model = cls(package['L'], package['N'],
                    package['hidden_size'], package['num_layers'],
                    bidirectional=package['bidirectional'],
                    nspk=package['nspk'])
        model.load_state_dict(package['state_dict'])
        return model

    @staticmethod
    def serialize(model, optimizer, epoch, tr_loss=None, cv_loss=None):
        package = {
            # hyper-parameter
            'L': model.L,
            'N': model.N,
            'hidden_size': model.hidden_size,
            'num_layers': model.num_layers,
            'bidirectional': model.bidirectional,
            'nspk': model.nspk,
            # state
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        if tr_loss is not None:
            package['tr_loss'] = tr_loss
            package['cv_loss'] = cv_loss
        return package


class Encoder(nn.Module):
    """Estimation of the nonnegative mixture weight by a 1-D gated conv layer.
    """
    def __init__(self, L, N):
        super(Encoder, self).__init__()
        # hyper-parameter
        self.L = L
        self.N = N
        # Components
        # Maybe we can impl 1-D conv by nn.Linear()?
        self.conv1d_U = nn.Conv1d(L, N, kernel_size=1, stride=1, bias=False)
        self.conv1d_V = nn.Conv1d(L, N, kernel_size=1, stride=1, bias=False)

    def forward(self, mixture):
        """
        Args:
            mixture: [B, K, L]
        Returns:
            mixture_w: [B, K, N]
            norm_coef: [B, K, 1]
        """
        B, K, L = mixture.size()
        # L2 Norm along L axis
        norm_coef = torch.norm(mixture, p=2, dim=2, keepdim=True)  # B x K x 1
        norm_mixture = mixture / (norm_coef + EPS) # B x K x L
        # 1-D gated conv
        norm_mixture = torch.unsqueeze(norm_mixture.view(-1, L), 2)  # B*K x L x 1
        conv = F.relu(self.conv1d_U(norm_mixture))         # B*K x N x 1
        gate = torch.sigmoid(self.conv1d_V(norm_mixture))  # B*K x N x 1
        mixture_w = conv * gate  # B*K x N x 1
        mixture_w = mixture_w.view(B, K, self.N) # B x K x N
        return mixture_w, norm_coef


class Separator(nn.Module):
    """Estimation of source masks
    TODO: 1. normlization described in paper
          2. LSTM with skip connection
    """
    def __init__(self, N, hidden_size, num_layers, bidirectional=True, nspk=2):
        super(Separator, self).__init__()
        # hyper-parameter
        self.N = N
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.nspk = nspk
        # Components
        self.layer_norm = nn.LayerNorm(N)
        self.rnn = nn.LSTM(N, hidden_size, num_layers,
                           batch_first=True,
                           bidirectional=bidirectional)
        fc_in_dim = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(fc_in_dim, nspk * N)
        ### To impl LSTM with skip connection
        # self.rnn = nn.ModuleList()
        # self.rnn += [nn.LSTM(N, hidden_size, num_layers=1,
        #                      batch_first=True,
        #                      bidirectional=bidirectional)]
        # for l in range(1, num_layers):
        #     self.rnn += [nn.LSTM(hidden_size, hidden_size, num_layers=1,
        #                          batch_first=True,
        #                          bidirectional=bidirectional)]

    def forward(self, mixture_w, mixture_lengths):
        """
        Args:
            mixture_w: [B, K, N], padded
        Returns:
            est_mask: [B, K, nspk, N]
        """
        B, K, N = mixture_w.size()
        # layer norm
        norm_mixture_w = self.layer_norm(mixture_w)
        # LSTM
        total_length = norm_mixture_w.size(1)  # get the max sequence length
        packed_input = pack_padded_sequence(norm_mixture_w, mixture_lengths,
                                            batch_first=True)
        packed_output, hidden = self.rnn(packed_input)
        output, _ = pad_packed_sequence(packed_output,
                                        batch_first=True,
                                        total_length=total_length)
        # fc
        score = self.fc(output)  # B x K x nspk*N
        score = score.view(B, K, self.nspk, N)
        # softmax
        est_mask = F.softmax(score, dim=2)
        return est_mask


class Decoder(nn.Module):
    def __init__(self, N, L):
        super(Decoder, self).__init__()
        # hyper-parameter
        self.N, self.L = N, L
        # Components
        self.basis_signals = nn.Linear(N, L, bias=False)

    def forward(self, mixture_w, est_mask, norm_coef):
        """
        Args:
            mixture_w: [B, K, N]
            est_mask: [B, K, nspk, N]
            norm_coef: [B, K, 1]
        Returns:
            est_source: [B, nspk, K, L]
        """
        # D = W * M
        source_w = torch.unsqueeze(mixture_w, 2) * est_mask  # B x K x nspk x N
        # S = DB
        est_source = self.basis_signals(source_w)  # B x K x nspk x L
        # reverse L2 norm
        norm_coef = torch.unsqueeze(norm_coef, 2)  # B x K x 1 x1
        est_source = est_source * norm_coef  # B x K x nspk x L
        est_source = est_source.permute((0, 2, 1, 3)).contiguous() # B x nspk x K x L
        return est_source



# if __name__ == "__main__":
#     torch.manual_seed(123)
#     B, K, L, N, C = 2, 3, 4, 3, 2
#     hidden_size, num_layers = 4, 2
#     mixture = torch.randint(3, (B, K, L))
#     lengths = torch.LongTensor([K for i in range(B)])
#     # test Encoder
#     encoder = Encoder(L, N)
#     encoder.conv1d_U.weight.data = torch.randint(2, encoder.conv1d_U.weight.size())
#     encoder.conv1d_V.weight.data = torch.randint(2, encoder.conv1d_V.weight.size())
#     mixture_w, norm_coef = encoder(mixture)
#     print('mixture', mixture)
#     print('U', encoder.conv1d_U.weight)
#     print('V', encoder.conv1d_V.weight)
#     print('mixture_w', mixture_w)
#     print('norm_coef', norm_coef)

#     # test Separator
#     separator = Separator(N, hidden_size, num_layers)
#     est_mask = separator(mixture_w, lengths)
#     print('est_mask', est_mask)

#     # test Decoder
#     decoder = Decoder(N, L)
#     est_mask = torch.randint(2, (B, K, C, N))
#     est_source = decoder(mixture_w, est_mask, norm_coef)
#     print('est_source', est_source)

#     # test TasNet
#     tasnet = TasNet(L, N, hidden_size, num_layers)
#     est_source = tasnet(mixture, lengths)
#     print('est_source', est_source)

In [56]:
import torch
import librosa

# 1. Load the audio file
file_path = "/mnt/d/Programs/Python/PW/projects/asteroid/zip-hindi-2k/cv/19273.wav"
audio, sr = librosa.load(file_path, sr=None)  # Load with original sampling rate

# 2. Convert to a PyTorch tensor
audio_tensor = torch.tensor(audio)

# 3. Segment the audio (Example: 256 samples per segment)
#    This depends on your model's expected input length `L`
L = 256  # Length of each time step (number of samples per segment)
K = audio_tensor.size(0) // L  # Calculate number of time steps

# Truncate or pad the audio if necessary to make it fit exactly into K * L
audio_tensor = audio_tensor[:K * L].view(1, K, L)  # Reshape to [B, K, L]

# 4. Prepare the batch
mixture_lengths = torch.tensor([K])  # Length of the input

# 5. Pass the tensor to the TasNet model
model = TasNet(L=L, N=250, hidden_size=250, num_layers=2, bidirectional=True, nspk=2)
output = model(audio_tensor, mixture_lengths)

print(output)


tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
          [ 9.6799e-04, -6.8922e-05, -1.2105e-03,  ..., -7.2665e-03,
            4.3924e-04, -8.7612e-04],
          [-4.0198e-03, -3.1121e-03,  9.8232e-03,  ..., -5.1843e-03,
           -1.0728e-03,  9.0640e-03],
          [-1.8041e-04,  1.9591e-03,  1.4766e-03,  ..., -6.5888e-03,
           -2.5243e-03,  2.4783e-03]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          ...,
     