In [1]:
import tensorflow as tf
from tensorflow.keras import Model,Sequential # type:ignore
from tensorflow.keras.layers import Conv1D,PReLU,BatchNormalization,Conv1DTranspose,LayerNormalization,ReLU # type:ignore
from tensorflow.data import Dataset # type:ignore

2024-09-17 08:04:33.988403: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-17 08:04:35.611260: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-17 08:04:36.136470: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-17 08:04:40.110897: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from glob import glob
import soundfile as sf
import librosa
import numpy as np  

In [4]:
class ConvResBlock(Model):
    def __init__(self,
                 in_channels=256,
                 out_channels=512,
                 kernel_size=3,
                 dilation=1,
                 causal=False):
        super(ConvResBlock,self).__init__()
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.causal=causal
        self.in_channels=in_channels
        self.conv1x1=Conv1D(
            filters=self.out_channels,
            kernel_size=1,
            data_format='channels_first',
            use_bias=False
        )
        self.PReLU_1=PReLU(shared_axes=[2])
        self.norm1=BatchNormalization(axis=1)
        self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
            dilation * (kernel_size - 1))
        self.dwconv=Conv1D(
            filters=self.out_channels,
            kernel_size=self.kernel_size,
            groups=self.out_channels,
            padding='same',
            dilation_rate=dilation,
            data_format='channels_first',
            use_bias=False
        )
        self.PReLU_2=PReLU(shared_axes=[2])
        self.norm2=BatchNormalization(axis=1)
        self.sc_conv=Conv1D(
            filters=self.in_channels,
            kernel_size=1,
            data_format='channels_first',
            use_bias=True
        )
    def call(self,x):
        c=self.conv1x1(x)
        c=self.PReLU_1(c)
        c=self.norm1(c)
        c=self.dwconv(c)
        if self.causal:
            c=c[:,:,:-self.pad]
        c=self.sc_conv(c)
        return c+x
    def build(self,input_shape):
        super(ConvResBlock,self).build(input_shape)


In [5]:
class Encoder(Model):
    def __init__(self,out_channels=512,
                 kernel_size=16):
        super(Encoder,self).__init__()
        self.out_channels=out_channels
        self.kernel_size=kernel_size
        self.conv1=Conv1D(filters=self.out_channels,
                        kernel_size=self.kernel_size,
                        strides=self.kernel_size//2,
                        padding='valid',
                        data_format='channels_first')

    def call(self,x):
        y=self.conv1(x)
        return y

    def build(self,input_shape):
        super(Encoder,self).build(input_shape)


In [6]:
class Decoder(Model):
    def __init__(self,
                 out_channels=1,
                 kernel_size=16,
                 strides=8):
        super(Decoder,self).__init__()
        self.conv_trans1d=Conv1DTranspose(
            filters=out_channels,
            kernel_size=kernel_size,
            strides=strides,
            data_format='channels_first'
        )

    def call(self,x):
        x=self.conv_trans1d(x)
        return x

    def build(self,input_shape):
        super(Decoder,self).build(input_shape)

In [7]:
class SequentialBlock(Model):
    def __init__(self,num_blocks,**block_args):
        super(SequentialBlock,self).__init__()
        self.block_list=[ConvResBlock(**block_args,dilation=2**i) for i in range(num_blocks)]
        self.seq_block=Sequential(self.block_list)

    def call(self,x):
        x=self.seq_block(x)
        return x

    def build(self,input_shape):
        super(SequentialBlock,self).build(input_shape)

In [8]:
class SeparationBlock(Model):
    def __init__(self,num_repeat,**block_args):
        super(SeparationBlock,self).__init__()
        self.seq_repeat=Sequential([SequentialBlock(**block_args) for _ in range(num_repeat)])

    def call(self,x):
        x=self.seq_repeat(x)
        return x

    def build(self,input_shape):
        super(SeparationBlock,self).build(input_shape)


In [9]:
class ConvTasNet(Model):
    def __init__(self,
                 N=512,
                 L=16,
                 B=128,
                 H=512,
                 P=3,
                 X=8,
                 R=3,
                 num_spks=2,
                 causal=False):
        super(ConvTasNet,self).__init__()
        self.encoder=Encoder(
            out_channels=N,
            kernel_size=L
        )
        self.layer_norm=LayerNormalization(axis=1)
        self.bottle_neck=Conv1D(
            filters=B,
            kernel_size=1,
            data_format='channels_first'
        )
        self.separation=SeparationBlock(
            num_repeat=R,
            num_blocks=X,
            in_channels=B,
            out_channels=H,
            kernel_size=P,
            causal=causal
        )
        self.gen_masks=Conv1D(
            filters=N*num_spks,
            kernel_size=1,
            data_format='channels_first'
        )
        self.decoder=Decoder(out_channels=1,
                             kernel_size=L,
                             strides=L//2)
        self.activation=ReLU()
        self.num_spks=num_spks

    def build(self,input_shape):
        super(ConvTasNet,self).build(input_shape)
    def call(self,x):
        # print("shape x",x.shape)
        w=self.encoder(x)
        # print("shape encoder",w.shape)
        e=self.layer_norm(w)
        # print("shape layer norm",e.shape)
        e=self.bottle_neck(e)
        # print("shape bottle neck",e.shape)
        e=self.separation(e)
        # print("shape separation",e.shape)
        m=self.gen_masks(e)
        # print("shape gen mask",m.shape)
        m=tf.split(m,num_or_size_splits=self.num_spks,axis=1)
        m=self.activation(tf.stack(m,axis=0))
        d=[w*m[i] for i in range(self.num_spks)]
        s=[self.decoder(d[i]) for i in range(self.num_spks)]
        return tf.stack(s,axis=1)

In [14]:
class ConvTasNetDataGenerator:
    """
    audio sample rate should be same as passed in the class there is no auto conversion done in this class
    """
    def __init__(self,folder_path:str,
                 audio_chunk_len=2,
                 num_spks=2,
                sample_rate=8_000,
                data_len=1000,
                file_ext='wav'):
        self.all_files=glob(f"{folder_path}/*.{file_ext}")
        self.data_len=len(self.all_files)
        self.sample_rate=sample_rate
        self.num_spks=num_spks
        self.audio_chunk_len=audio_chunk_len
        self.audio_with_start_end=self.load_files()
        self.data_len=data_len

    def load_files(self):
        files=[]
        for file in self.all_files:
            info=sf.info(file)
            duration=info.duration
            start=0
            for end in range(int(self.sample_rate*self.audio_chunk_len),int(duration*self.sample_rate),int(self.sample_rate*self.audio_chunk_len)):
                if not  end-start<self.audio_chunk_len*self.sample_rate:
                    files.append({"path":file,"start":start,'end':end})
                start=end
        print(f"{len(files)} files loaded")
        return files

    def __len__(self):
        return len(self.audio_with_start_end)

    def __getitem__(self,idx):
        files=np.random.choice(self.audio_with_start_end,self.num_spks,replace=False)
        x=0
        y=[]
        for file in files:
            path=file['path']
            start=file['start']/self.sample_rate
            end=file['end']/self.sample_rate
            data,_=librosa.load(path,mono=True,offset=start,duration=end-start,sr=self.sample_rate,dtype='float32')
            x=x+data
            y.append(data.tolist())
        x=tf.constant(np.expand_dims(x,axis=0))
        y=tf.constant(np.expand_dims(y,axis=1))
        return x,y

    def generator(self):
        for idx in range(self.data_len):
            yield self[idx]


In [31]:
@tf.function
def si_sdr_loss(original, predicted,eps=1e-8,loss_type='sisdr'):

        original = original - tf.reduce_mean(original, axis=-1, keepdims=True)
        predicted = predicted - tf.reduce_mean(predicted, axis=-1, keepdims=True)

        dot_product = tf.reduce_sum(original * predicted, axis=-1)

        original_norm_sq = tf.reduce_sum(tf.square(original), axis=-1)
        scale = dot_product / (original_norm_sq +eps)
        s_target = scale[..., tf.newaxis] * original if loss_type == 'sisdr' else original
        e_noise = predicted - s_target
        s_target_norm_sq = tf.reduce_sum(tf.square(s_target), axis=-1)
        e_noise_norm_sq = tf.reduce_sum(tf.square(e_noise), axis=-1)
        si_sdr = 10 * tf.math.log(s_target_norm_sq / (e_noise_norm_sq +eps)) / tf.math.log(10.0)

        return -si_sdr
@tf.function
def cdist_si_sdr(A, B,loss_type='sisdr'):
        A=tf.squeeze(A,axis=2)
        B=tf.squeeze(B,axis=2)
        A_expanded = tf.expand_dims(A, axis=-2)
        B_expanded = tf.expand_dims(B, axis=-3)
        loss = si_sdr_loss(A_expanded, B_expanded,loss_type=loss_type)
        max_loss=tf.reduce_max(loss,axis=[1,2])
        return tf.reduce_mean(max_loss)
        # return max_loss

In [12]:
audio_folder="/mnt/d/Programs/Python/PW/projects/asteroid/zip-hindi-2k"

In [15]:
train_data_generator=ConvTasNetDataGenerator(folder_path=f"{audio_folder}/train",audio_chunk_len=1)
test_data_generator=ConvTasNetDataGenerator(folder_path=f"{audio_folder}/test",audio_chunk_len=1)
val_data_generator=ConvTasNetDataGenerator(folder_path=f"{audio_folder}/val",audio_chunk_len=1)

8225 files loaded
434 files loaded
88 files loaded


In [18]:
train_dataset=Dataset.from_generator(train_data_generator.generator,
                                     output_signature=(
                                         tf.TensorSpec(shape=(1,8000),dtype='float32'),
                                         tf.TensorSpec(shape=(2,1,8000),dtype='float32')
                                     ))
test_dataset=Dataset.from_generator(test_data_generator.generator,
                                     output_signature=(
                                         tf.TensorSpec(shape=(1,8000),dtype='float32'),
                                         tf.TensorSpec(shape=(2,1,8000),dtype='float32')
                                     ))
val_dataset=Dataset.from_generator(val_data_generator.generator,
                                     output_signature=(
                                         tf.TensorSpec(shape=(1,8000),dtype='float32'),
                                         tf.TensorSpec(shape=(2,1,8000),dtype='float32')
                                     ))

In [18]:
train_data_loader=train_dataset.batch(2)
test_data_loader=test_dataset.batch(2)
val_data_loader=val_dataset.batch(2)

In [19]:
model=ConvTasNet()

In [20]:
model.build(input_shape=(None,1,8000))

In [21]:
model.compile(optimizer='adam',loss=cdist_si_sdr)

In [22]:
model.summary()

In [23]:
optimizer=tf.keras.optimizers.Adam()

In [24]:
model.fit(train_data_loader,epochs=2,steps_per_epoch=100,validation_data=val_data_loader)


KeyboardInterrupt



In [26]:
x=tf.random.uniform(shape=(2,2,1,400))
y=tf.random.uniform(shape=(2,2,1,400))

In [30]:
cdist_si_sdr(x,y)

<tf.Tensor: shape=(2,), dtype=float32, numpy=array([39.607803, 37.689766], dtype=float32)>

In [33]:
x.ndim

4