In [1]:
import tensorflow as tf
import numpy as np  

2024-08-26 16:13:17.441196: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-26 16:13:20.625297: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-26 16:13:21.613237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-08-26 16:13:28.093965: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [16]:
from keras.layers import Dense,Conv1D,Conv1DTranspose,PReLU,LayerNormalization
from keras import Model,Sequential

In [3]:
class Encoder(Model):
    def __init__(self,L:int,N:int)->None:
        super(Encoder,self).__init__()
        """
        N: number of basis signals
        L: Length of audio segment
        """
        self.L=L
        self.N=N
        self.conv1d_U=Conv1D(filters=self.N,kernel_size=self.L,strides=self.L//2,use_bias=False,activation='relu',data_format='channels_last')
    
    def call(self,x):
        return self.conv1d_U(x)


In [4]:
encoder=Encoder(10,50)

In [5]:
x=np.random.rand(1,10,1)

In [6]:
out=encoder(x)

2024-08-26 16:14:27.638459: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: UNKNOWN ERROR (100)


In [7]:
out.shape

TensorShape([1, 1, 50])

In [8]:
class Decoder(Model):
    def __init__(self,L:int,N:int)->None:
        super(Decoder,self).__init__()
        self.L=L
        self.N=N
        self.conv1d_trans=Conv1DTranspose(filters=1,kernel_size=L,strides=L//2,use_bias=False,activation='relu',data_format='channels_last')
    
    def call(self,x):
        return self.conv1d_trans(x)


In [9]:
decoder=Decoder(10,50)

In [10]:
decoder_out=decoder(out)

In [11]:
decoder_out.shape

TensorShape([1, 10, 1])

In [19]:
class Conv1D_block(Model):
    def __init__(self,out_channels=512,kernel_size=3,dilation=1):
        super(Conv1D_block,self).__init__()
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.dilation=dilation
        self.conv1x1=Conv1D(filters=self.out_channels,kernel_size=1,data_format='channels_last')
        self.PReLu1=PReLU()
        self.norm_1=LayerNormalization()
        self.pad=self.dilation*(self.kernel_size-1)
        self.dwconv=Conv1D(filters=self.out_channels,kernel_size=self.kernel_size,groups=self.out_channels,padding='same',dilation_rate=self.dilation)
        self.Sc_conv=Conv1D(filters=self.out_channels,kernel_size=1,use_bias=True)
    
    def call(self,x):
        c=self.conv1x1(x)
        c=self.PReLu1(c)
        c=self.norm_1(c)
        c=self.dwconv(c)
        c=self.Sc_conv(c)
        return x+c
    

In [22]:
class ConvTasnet(Model):
    def __init__(self,
                 N=512,
                 L=16,
                 B=128,
                 H=512,
                 P=3,
                 X=8,
                 R=3,
                 nspk=2):
        super(ConvTasnet,self).__init__()
        self.encoder=Encoder(L=L,N=N)
        self.layer_norm=LayerNormalization()
        self.bottle_neck=Conv1D(filters=B,kernel_size=1)
        self.seperation=self._sequential_repeat(R,X,out_channels=H,kernel_size=P)
        self.gen_mask=Conv1D(filters=N*nspk,kernel_size=1)
        self.decoder=Decoder(L,N)
        self.nspk=nspk

    def _sequential_block(self,num_blocks,**block_kwargs):
        conv1d_block_list=[
            Conv1D_block(**block_kwargs,dilation=2**i) for i in range(num_blocks)
        ]
        return Sequential(conv1d_block_list)
    
    def _sequential_repeat(self,num_repeat,num_block,**block_kwargs):
        repeat_list=[self._sequential_block(num_blocks=num_block,**block_kwargs) for _ in range(num_repeat)]
        return Sequential(repeat_list)

    def call(self,x):
        w=self.encoder(x)
        e=self.layer_norm(w)
        e=self.bottle_neck(e)
        e=self.seperation(e)
        m=self.gen_mask(e)
        d = [w*m[i] for i in range(self.nspk)]
        s = [self.decoder(d[i]) for i in range(self.nspk)]
        return s
    